1use crate::AgentInfo;
51use crate::error::RandAgentError;
52use backon::{ExponentialBuilder, Retryable};
53use rand::Rng;
54use rig::agent::Agent;
55use rig::client::builder::BoxAgent;
56use rig::client::completion::CompletionModelHandle;
57use rig::completion::{Message, Prompt, PromptError};
58use std::sync::Arc;
59use std::time::Duration;
60use tokio::sync::Mutex;
61
62pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
64
65#[derive(Clone)]
69pub struct RandAgent {
70 agents: Arc<Mutex<Vec<AgentState>>>,
71 on_agent_invalid: OnAgentInvalidCallback,
72}
73
74#[derive(Clone)]
76pub struct AgentState {
77 pub id: i32,
78 pub agent: Arc<BoxAgent<'static>>,
79 pub info: AgentInfo,
80}
81
82impl Prompt for RandAgent {
83 #[allow(refining_impl_trait)]
84 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
85 let agent_index =
87 self.get_random_valid_agent_index()
88 .await
89 .ok_or(PromptError::MaxDepthError {
90 max_depth: 0,
91 chat_history: Box::new(vec![]),
92 prompt: "没有有效agent".into(),
93 })?;
94
95 let mut agents = self.agents.lock().await;
97 let agent_state = &mut agents[agent_index];
98
99 tracing::info!(
100 "Using provider: {}, model: {},id: {}",
101 agent_state.info.provider,
102 agent_state.info.model,
103 agent_state.info.id
104 );
105 match agent_state.agent.prompt(prompt).await {
106 Ok(content) => {
107 agent_state.record_success();
108 Ok(content)
109 }
110 Err(e) => {
111 agent_state.record_failure();
112 if !agent_state.is_valid() {
113 if let Some(cb) = &self.on_agent_invalid {
114 cb(agent_state.id);
115 }
116 }
117 Err(e)
118 }
119 }
120 }
121}
122
123impl AgentState {
124 fn new(
125 agent: BoxAgent<'static>,
126 id: i32,
127 provider: String,
128 model: String,
129 max_failures: u32,
130 ) -> Self {
131 Self {
132 id,
133 agent: Arc::new(agent),
134 info: AgentInfo {
135 id,
136 provider,
137 model,
138 failure_count: 0,
139 max_failures,
140 },
141 }
142 }
143
144 fn is_valid(&self) -> bool {
145 self.info.failure_count < self.info.max_failures
146 }
147
148 fn record_failure(&mut self) {
149 self.info.failure_count += 1;
150 }
151
152 fn record_success(&mut self) {
153 self.info.failure_count = 0;
154 }
155}
156
157impl RandAgent {
158 pub fn new(agents: Vec<(BoxAgent<'static>, i32, String, String)>) -> Self {
160 Self::with_max_failures_and_callback(agents, 3, None)
161 }
162
163 pub fn with_max_failures_and_callback(
165 agents: Vec<(BoxAgent<'static>, i32, String, String)>,
166 max_failures: u32,
167 on_agent_invalid: OnAgentInvalidCallback,
168 ) -> Self {
169 let agent_states = agents
170 .into_iter()
171 .map(|(agent, id, provider, model)| {
172 AgentState::new(agent, id, provider, model, max_failures)
173 })
174 .collect();
175 Self {
176 agents: Arc::new(Mutex::new(agent_states)),
177 on_agent_invalid,
178 }
179 }
180
181 pub fn with_max_failures(
183 agents: Vec<(BoxAgent<'static>, i32, String, String)>,
184 max_failures: u32,
185 ) -> Self {
186 Self::with_max_failures_and_callback(agents, max_failures, None)
187 }
188
189 pub fn set_on_agent_invalid<F>(&mut self, callback: F)
191 where
192 F: Fn(i32) + Send + Sync + 'static,
193 {
194 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
195 }
196
197 pub async fn add_agent(
199 &self,
200 agent: BoxAgent<'static>,
201 id: i32,
202 provider: String,
203 model: String,
204 ) {
205 let mut agents = self.agents.lock().await;
206 agents.push(AgentState::new(agent, id, provider, model, 3));
207 }
208
209 pub async fn add_agent_with_max_failures(
211 &self,
212 agent: BoxAgent<'static>,
213 id: i32,
214 provider: String,
215 model: String,
216 max_failures: u32,
217 ) {
218 let mut agents = self.agents.lock().await;
219 agents.push(AgentState::new(agent, id, provider, model, max_failures));
220 }
221
222 pub async fn len(&self) -> usize {
224 let agents = self.agents.lock().await;
225 agents.iter().filter(|state| state.is_valid()).count()
226 }
227
228 pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
230 let agents = self.agents.lock().await;
231 let valid_indices: Vec<usize> = agents
232 .iter()
233 .enumerate()
234 .filter(|(_, state)| state.is_valid())
235 .map(|(i, _)| i)
236 .collect();
237
238 if valid_indices.is_empty() {
239 return None;
240 }
241
242 let mut rng = rand::rng();
243 let random_index = rng.random_range(0..valid_indices.len());
244 Some(valid_indices[random_index])
245 }
246
247 pub async fn get_random_valid_agent_state(&self) -> Option<AgentState> {
250 let mut agents = self.agents.lock().await;
251
252 let valid_indices: Vec<usize> = agents
253 .iter()
254 .enumerate()
255 .filter(|(_, state)| state.is_valid())
256 .map(|(i, _)| i)
257 .collect();
258
259 if valid_indices.is_empty() {
260 return None;
261 }
262
263 let mut rng = rand::rng();
264 let random_index = rng.random_range(0..valid_indices.len());
265 let agent_index = valid_indices[random_index];
266 agents.get_mut(agent_index).cloned()
267 }
268
269 pub async fn total_len(&self) -> usize {
271 let agents = self.agents.lock().await;
272 agents.len()
273 }
274
275 pub async fn is_empty(&self) -> bool {
277 self.len().await == 0
278 }
279
280 pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
282 let agents = self.agents.lock().await;
283 let agent_infos = agents.iter().map(|agent| agent.info.clone()).collect::<_>();
284 tracing::info!("agents info: {:?}", agent_infos);
285 agent_infos
286 }
287
288 pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
290 let agents = self.agents.lock().await;
291 agents
292 .iter()
293 .enumerate()
294 .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
295 .collect()
296 }
297
298 pub async fn reset_failures(&self) {
300 let mut agents = self.agents.lock().await;
301 for state in agents.iter_mut() {
302 state.info.failure_count = 0;
303 }
304 }
305
306 pub async fn get_agent_by_name(
308 &self,
309 provider_name: &str,
310 model_name: &str,
311 ) -> Option<AgentState> {
312 let mut agents = self.agents.lock().await;
313
314 for agent in agents.iter_mut() {
315 if agent.info.provider == provider_name && agent.info.model == model_name {
316 return Some(agent.clone());
317 }
318 }
319
320 None
321 }
322
323 pub async fn get_agent_by_id(&self, id: i32) -> Option<AgentState> {
325 let mut agents = self.agents.lock().await;
326
327 for agent in agents.iter_mut() {
328 if agent.info.id == id {
329 return Some(agent.clone());
330 }
331 }
332
333 None
334 }
335
336 pub async fn try_invoke_with_retry(
338 &self,
339 info: Message,
340 retry_num: Option<usize>,
341 ) -> Result<String, RandAgentError> {
342 let mut config = ExponentialBuilder::default();
343 if let Some(retry_num) = retry_num {
344 config = config.with_max_times(retry_num)
345 }
346
347 let info = Arc::new(info);
348
349 let content = (|| {
350 let agent = self.clone();
351 let prompt = info.clone();
352 async move { agent.prompt((*prompt).clone()).await }
353 })
354 .retry(config)
355 .sleep(tokio::time::sleep)
356 .notify(|err: &PromptError, dur: Duration| {
357 println!("retrying {err:?} after {dur:?}");
358 })
359 .await?;
360 Ok(content)
361 }
362
363 #[allow(refining_impl_trait)]
364 pub async fn prompt_with_info(
365 &self,
366 prompt: impl Into<Message> + Send,
367 ) -> Result<(String, AgentInfo), PromptError> {
368 let agent_index =
370 self.get_random_valid_agent_index()
371 .await
372 .ok_or(PromptError::MaxDepthError {
373 max_depth: 0,
374 chat_history: Box::new(vec![]),
375 prompt: "没有有效agent".into(),
376 })?;
377
378 let mut agents = self.agents.lock().await;
380 let agent_state = &mut agents[agent_index];
381
382 let agent_info = agent_state.info.clone();
383
384 tracing::info!(
385 "prompt_with_info Using provider: {}, model: {},id: {}",
386 agent_state.info.provider,
387 agent_state.info.model,
388 agent_state.info.id
389 );
390 match agent_state.agent.prompt(prompt).await {
391 Ok(content) => {
392 agent_state.record_success();
393 Ok((content, agent_info))
394 }
395 Err(e) => {
396 agent_state.record_failure();
397 if !agent_state.is_valid() {
398 if let Some(cb) = &self.on_agent_invalid {
399 cb(agent_state.id);
400 }
401 }
402 Err(e)
403 }
404 }
405 }
406
407 pub async fn try_invoke_with_info_retry(
409 &self,
410 info: Message,
411 retry_num: Option<usize>,
412 ) -> Result<(String, AgentInfo), RandAgentError> {
413 let mut config = ExponentialBuilder::default();
414 if let Some(retry_num) = retry_num {
415 config = config.with_max_times(retry_num)
416 }
417
418 let info = Arc::new(info);
419
420 let content = (|| {
421 let agent = self.clone();
422 let prompt = info.clone();
423 async move { agent.prompt_with_info((*prompt).clone()).await }
424 })
425 .retry(config)
426 .sleep(tokio::time::sleep)
427 .notify(|err: &PromptError, dur: Duration| {
428 println!("retrying {err:?} after {dur:?}");
429 })
430 .await?;
431 Ok(content)
432 }
433}
434
435pub struct RandAgentBuilder {
437 pub(crate) agents: Vec<(BoxAgent<'static>, i32, String, String)>,
438 max_failures: u32,
439 on_agent_invalid: OnAgentInvalidCallback,
440}
441
442impl RandAgentBuilder {
443 pub fn new() -> Self {
445 Self {
446 agents: Vec::new(),
447 max_failures: 3, on_agent_invalid: None,
449 }
450 }
451
452 pub fn max_failures(mut self, max_failures: u32) -> Self {
454 self.max_failures = max_failures;
455 self
456 }
457
458 pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
460 where
461 F: Fn(i32) + Send + Sync + 'static,
462 {
463 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
464 self
465 }
466
467 pub fn add_agent(
474 mut self,
475 agent: BoxAgent<'static>,
476 id: i32,
477 provider_name: String,
478 model_name: String,
479 ) -> Self {
480 self.agents.push((agent, id, provider_name, model_name));
481 self
482 }
483
484 pub fn add_builder(
491 mut self,
492 builder: Agent<CompletionModelHandle<'static>>,
493 id: i32,
494 provider_name: &str,
495 model_name: &str,
496 ) -> Self {
497 self.agents.push((
498 builder,
499 id,
500 provider_name.to_string(),
501 model_name.to_string(),
502 ));
503 self
504 }
505
506 pub fn build(self) -> RandAgent {
508 RandAgent::with_max_failures_and_callback(
509 self.agents,
510 self.max_failures,
511 self.on_agent_invalid,
512 )
513 }
514}
515
516impl Default for RandAgentBuilder {
517 fn default() -> Self {
518 Self::new()
519 }
520}