1use rand::Rng;
59use rig::agent::{Agent};
60use rig::client::builder::BoxAgent;
61use rig::completion::{Message, Prompt, PromptError};
62use rig::client::completion::CompletionModelHandle;
63use tokio::sync::Mutex;
64
65
66pub struct AgentState<'a> {
68 id: i32,
69 agent: BoxAgent<'a>,
70 provider: String,
71 model: String,
72 failure_count: u32,
73 max_failures: u32,
74}
75
76impl<'a> AgentState<'a> {
77 fn new(agent: BoxAgent<'a>,id:i32, provider: String, model: String, max_failures: u32) -> Self {
78 Self {
79 id,
80 agent,
81 provider,
82 model,
83 failure_count: 0,
84 max_failures,
85 }
86 }
87
88 fn is_valid(&self) -> bool {
89 self.failure_count < self.max_failures
90 }
91
92 fn record_failure(&mut self) {
93 self.failure_count += 1;
94 }
95
96 fn record_success(&mut self) {
97 self.failure_count = 0;
98 }
99}
100
101pub type OnRandAgentInvalidCallback = Option<Box<dyn Fn(i32) + Send + Sync + 'static>>;
103
104pub struct RandAgent<'a> {
106 agents: Mutex<Vec<AgentState<'a>>>,
107 on_agent_invalid: OnRandAgentInvalidCallback,
108}
109
110impl Prompt for RandAgent<'_> {
111 #[allow(refining_impl_trait)]
112 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
113 let mut agents = self.agents.lock().await;
114 let agent_state = Self::get_random_valid_agent(&mut agents)
115 .await
116 .ok_or(PromptError::MaxDepthError {
117 max_depth: 0,
118 chat_history: vec![],
119 prompt: "没有有效agent".into(),
120 })?;
121
122 tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
123 match agent_state.agent.prompt(prompt).await {
124 Ok(content) => {
125 agent_state.record_success();
126 Ok(content)
127 }
128 Err(e) => {
129 agent_state.record_failure();
130 if !agent_state.is_valid() {
131 if let Some(cb) = &self.on_agent_invalid {
132 cb(agent_state.id);
133 }
134 }
135 Err(e)
136 }
137 }
138 }
139}
140
141impl<'a> RandAgent<'a> {
142 pub fn new(agents: Vec<(BoxAgent<'a>, i32, String, String)>) -> Self {
144 Self::with_max_failures_and_callback(agents, 3, None)
145 }
146
147 pub fn with_max_failures_and_callback(
149 agents: Vec<(BoxAgent<'a>, i32, String, String)>,
150 max_failures: u32,
151 on_agent_invalid: OnRandAgentInvalidCallback,
152 ) -> Self {
153 let agent_states = agents
154 .into_iter()
155 .map(|(agent, id, provider, model)| AgentState::new(agent, id, provider, model, max_failures))
156 .collect();
157 Self {
158 agents: Mutex::new(agent_states),
159 on_agent_invalid,
160 }
161 }
162
163 pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, i32,String, String)>, max_failures: u32) -> Self {
165 Self::with_max_failures_and_callback(agents, max_failures, None)
166 }
167
168 pub fn set_on_agent_invalid<F>(&mut self, callback: F)
170 where
171 F: Fn(i32) + Send + Sync + 'static,
172 {
173 self.on_agent_invalid = Some(Box::new(callback));
174 }
175
176
177 pub async fn add_agent(&self, agent: BoxAgent<'a>, id: i32, provider: String, model: String) {
179 self.agents.lock().await.push(AgentState::new(agent, id, provider, model, 3)); }
181
182 pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'a>, id: i32, provider: String, model: String, max_failures: u32) {
184 self.agents.lock().await.push(AgentState::new(agent, id, provider, model, max_failures));
185 }
186
187 pub async fn len(&self) -> usize {
189 self.agents.lock().await.iter().filter(|state| state.is_valid()).count()
190 }
191
192 pub async fn total_len(&self) -> usize {
194 self.agents.lock().await.len()
195 }
196
197 pub async fn is_empty(&self) -> bool {
199 self.len().await == 0
200 }
201
202 async fn get_random_valid_agent<'b>(agents: &'b mut Vec<AgentState<'a>>) -> Option<&'b mut AgentState<'a>> {
204 let valid_indices: Vec<usize> = agents
205 .iter()
206 .enumerate()
207 .filter(|(_, state)| state.is_valid())
208 .map(|(i, _)| i)
209 .collect();
210
211 if valid_indices.is_empty() {
212 return None;
213 }
214
215 let mut rng = rand::rng();
216 let random_index = rng.random_range(0..valid_indices.len());
217 let agent_index = valid_indices[random_index];
218 agents.get_mut(agent_index)
219 }
220
221
222 pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
224 self.agents
225 .lock()
226 .await
227 .iter()
228 .enumerate()
229 .map(|(i, state)| (i, state.failure_count, state.max_failures))
230 .collect()
231 }
232
233 pub async fn reset_failures(&self) {
235 for state in self.agents.lock().await.iter_mut() {
236 state.failure_count = 0;
237 }
238 }
239}
240
241
242
243pub struct RandAgentBuilder<'a> {
245 agents: Vec<(BoxAgent<'a>, i32, String, String)>,
246 max_failures: u32,
247 on_agent_invalid: Option<Box<dyn Fn(i32) + Send + Sync + 'static>>,
248}
249
250impl<'a> RandAgentBuilder<'a> {
251 pub fn new() -> Self {
253 Self {
254 agents: Vec::new(),
255 max_failures: 3, on_agent_invalid: None,
257 }
258 }
259
260 pub fn max_failures(mut self, max_failures: u32) -> Self {
262 self.max_failures = max_failures;
263 self
264 }
265
266 pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
268 where
269 F: Fn(i32) + Send + Sync + 'static,
270 {
271 self.on_agent_invalid = Some(Box::new(callback));
272 self
273 }
274
275 pub fn add_agent(mut self, agent: BoxAgent<'a>, id: i32, provider_name: String, model_name: String) -> Self {
282 self.agents.push((agent, id, provider_name, model_name));
283 self
284 }
285
286 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, id: i32, provider_name: &str, model_name: &str) -> Self {
295 self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
296 self
297 }
298
299 pub fn build(self) -> RandAgent<'a> {
301 RandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
302 }
303}
304
305impl<'a> Default for RandAgentBuilder<'a> {
306 fn default() -> Self {
307 Self::new()
308 }
309}