1use crate::AgentInfo;
51use crate::agent_variant::AgentVariant;
52use crate::error::RandAgentError;
53use backon::{ExponentialBuilder, Retryable};
54use rand::Rng;
55use rig::completion::{Message, Prompt, PromptError};
56use std::sync::Arc;
57use std::time::Duration;
58use tokio::sync::Mutex;
59
60pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
62
63#[derive(Clone)]
67pub struct RandAgent {
68 agents: Arc<Mutex<Vec<AgentState>>>,
69 on_agent_invalid: OnAgentInvalidCallback,
70}
71
72#[derive(Clone)]
74pub struct AgentState {
75 pub id: i32,
76 pub agent: Arc<AgentVariant>,
77 pub info: AgentInfo,
78}
79
80impl Prompt for RandAgent {
81 #[allow(refining_impl_trait)]
82 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
83 let agent_index =
85 self.get_random_valid_agent_index()
86 .await
87 .ok_or(PromptError::MaxDepthError {
88 max_depth: 0,
89 chat_history: Box::new(vec![]),
90 prompt: "没有有效agent".into(),
91 })?;
92
93 let mut agents = self.agents.lock().await;
95 let agent_state = &mut agents[agent_index];
96
97 tracing::info!(
98 "Using provider: {}, model: {},id: {}",
99 agent_state.info.provider,
100 agent_state.info.model,
101 agent_state.info.id
102 );
103 match agent_state.agent.prompt(prompt).await {
104 Ok(content) => {
105 agent_state.record_success();
106 Ok(content)
107 }
108 Err(e) => {
109 agent_state.record_failure();
110 if !agent_state.is_valid()
111 && let Some(cb) = &self.on_agent_invalid
112 {
113 cb(agent_state.id);
114 }
115 Err(e)
116 }
117 }
118 }
119}
120
121impl AgentState {
122 fn new(
123 agent: AgentVariant,
124 id: i32,
125 provider: String,
126 model: String,
127 max_failures: u32,
128 ) -> Self {
129 Self {
130 id,
131 agent: Arc::new(agent),
132 info: AgentInfo {
133 id,
134 provider,
135 model,
136 failure_count: 0,
137 max_failures,
138 },
139 }
140 }
141
142 fn is_valid(&self) -> bool {
143 self.info.failure_count < self.info.max_failures
144 }
145
146 fn record_failure(&mut self) {
147 self.info.failure_count += 1;
148 }
149
150 fn record_success(&mut self) {
151 self.info.failure_count = 0;
152 }
153}
154
155impl RandAgent {
156 pub fn new(agents: Vec<(AgentVariant, i32, String, String)>) -> Self {
158 Self::with_max_failures_and_callback(agents, 3, None)
159 }
160
161 pub fn with_max_failures_and_callback(
163 agents: Vec<(AgentVariant, i32, String, String)>,
164 max_failures: u32,
165 on_agent_invalid: OnAgentInvalidCallback,
166 ) -> Self {
167 let agent_states = agents
168 .into_iter()
169 .map(|(agent, id, provider, model)| {
170 AgentState::new(agent, id, provider, model, max_failures)
171 })
172 .collect();
173 Self {
174 agents: Arc::new(Mutex::new(agent_states)),
175 on_agent_invalid,
176 }
177 }
178
179 pub fn with_max_failures(
181 agents: Vec<(AgentVariant, i32, String, String)>,
182 max_failures: u32,
183 ) -> Self {
184 Self::with_max_failures_and_callback(agents, max_failures, None)
185 }
186
187 pub fn set_on_agent_invalid<F>(&mut self, callback: F)
189 where
190 F: Fn(i32) + Send + Sync + 'static,
191 {
192 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
193 }
194
195 pub async fn add_agent(&self, agent: AgentVariant, id: i32, provider: String, model: String) {
197 let mut agents = self.agents.lock().await;
198 agents.push(AgentState::new(agent, id, provider, model, 3));
199 }
200
201 pub async fn add_agent_with_max_failures(
203 &self,
204 agent: AgentVariant,
205 id: i32,
206 provider: String,
207 model: String,
208 max_failures: u32,
209 ) {
210 let mut agents = self.agents.lock().await;
211 agents.push(AgentState::new(agent, id, provider, model, max_failures));
212 }
213
214 pub async fn len(&self) -> usize {
216 let agents = self.agents.lock().await;
217 agents.iter().filter(|state| state.is_valid()).count()
218 }
219
220 pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
222 let agents = self.agents.lock().await;
223 let valid_indices: Vec<usize> = agents
224 .iter()
225 .enumerate()
226 .filter(|(_, state)| state.is_valid())
227 .map(|(i, _)| i)
228 .collect();
229
230 if valid_indices.is_empty() {
231 return None;
232 }
233
234 let mut rng = rand::rng();
235 let random_index = rng.random_range(0..valid_indices.len());
236 Some(valid_indices[random_index])
237 }
238
239 pub async fn get_random_valid_agent_state(&self) -> Option<AgentState> {
242 let mut agents = self.agents.lock().await;
243
244 let valid_indices: Vec<usize> = agents
245 .iter()
246 .enumerate()
247 .filter(|(_, state)| state.is_valid())
248 .map(|(i, _)| i)
249 .collect();
250
251 if valid_indices.is_empty() {
252 return None;
253 }
254
255 let mut rng = rand::rng();
256 let random_index = rng.random_range(0..valid_indices.len());
257 let agent_index = valid_indices[random_index];
258 agents.get_mut(agent_index).cloned()
259 }
260
261 pub async fn total_len(&self) -> usize {
263 let agents = self.agents.lock().await;
264 agents.len()
265 }
266
267 pub async fn is_empty(&self) -> bool {
269 self.len().await == 0
270 }
271
272 pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
274 let agents = self.agents.lock().await;
275 let agent_infos = agents.iter().map(|agent| agent.info.clone()).collect::<_>();
276 tracing::info!("agents info: {:?}", agent_infos);
277 agent_infos
278 }
279
280 pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
282 let agents = self.agents.lock().await;
283 agents
284 .iter()
285 .enumerate()
286 .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
287 .collect()
288 }
289
290 pub async fn reset_failures(&self) {
292 let mut agents = self.agents.lock().await;
293 for state in agents.iter_mut() {
294 state.info.failure_count = 0;
295 }
296 }
297
298 pub async fn get_agent_by_name(
300 &self,
301 provider_name: &str,
302 model_name: &str,
303 ) -> Option<AgentState> {
304 let mut agents = self.agents.lock().await;
305
306 for agent in agents.iter_mut() {
307 if agent.info.provider == provider_name && agent.info.model == model_name {
308 return Some(agent.clone());
309 }
310 }
311
312 None
313 }
314
315 pub async fn get_agent_by_id(&self, id: i32) -> Option<AgentState> {
317 let mut agents = self.agents.lock().await;
318
319 for agent in agents.iter_mut() {
320 if agent.info.id == id {
321 return Some(agent.clone());
322 }
323 }
324
325 None
326 }
327
328 pub async fn try_invoke_with_retry(
330 &self,
331 info: Message,
332 retry_num: Option<usize>,
333 ) -> Result<String, RandAgentError> {
334 let mut config = ExponentialBuilder::default();
335 if let Some(retry_num) = retry_num {
336 config = config.with_max_times(retry_num)
337 }
338
339 let info = Arc::new(info);
340
341 let content = (|| {
342 let agent = self.clone();
343 let prompt = info.clone();
344 async move { agent.prompt((*prompt).clone()).await }
345 })
346 .retry(config)
347 .sleep(tokio::time::sleep)
348 .notify(|err: &PromptError, dur: Duration| {
349 println!("retrying {err:?} after {dur:?}");
350 })
351 .await?;
352 Ok(content)
353 }
354
355 #[allow(refining_impl_trait)]
356 pub async fn prompt_with_info(
357 &self,
358 prompt: impl Into<Message> + Send,
359 ) -> Result<(String, AgentInfo), PromptError> {
360 let agent_index =
362 self.get_random_valid_agent_index()
363 .await
364 .ok_or(PromptError::MaxDepthError {
365 max_depth: 0,
366 chat_history: Box::new(vec![]),
367 prompt: "没有有效agent".into(),
368 })?;
369
370 let mut agents = self.agents.lock().await;
372 let agent_state = &mut agents[agent_index];
373
374 let agent_info = agent_state.info.clone();
375
376 tracing::info!(
377 "prompt_with_info Using provider: {}, model: {},id: {}",
378 agent_state.info.provider,
379 agent_state.info.model,
380 agent_state.info.id
381 );
382 match agent_state.agent.prompt(prompt).await {
383 Ok(content) => {
384 agent_state.record_success();
385 Ok((content, agent_info))
386 }
387 Err(e) => {
388 agent_state.record_failure();
389 if !agent_state.is_valid()
390 && let Some(cb) = &self.on_agent_invalid
391 {
392 cb(agent_state.id);
393 }
394 Err(e)
395 }
396 }
397 }
398
399 pub async fn try_invoke_with_info_retry(
401 &self,
402 info: Message,
403 retry_num: Option<usize>,
404 ) -> Result<(String, AgentInfo), RandAgentError> {
405 let mut config = ExponentialBuilder::default();
406 if let Some(retry_num) = retry_num {
407 config = config.with_max_times(retry_num)
408 }
409
410 let info = Arc::new(info);
411
412 let content = (|| {
413 let agent = self.clone();
414 let prompt = info.clone();
415 async move { agent.prompt_with_info((*prompt).clone()).await }
416 })
417 .retry(config)
418 .sleep(tokio::time::sleep)
419 .notify(|err: &PromptError, dur: Duration| {
420 println!("retrying {err:?} after {dur:?}");
421 })
422 .await?;
423 Ok(content)
424 }
425}
426
427pub struct RandAgentBuilder {
429 pub(crate) agents: Vec<(AgentVariant, i32, String, String)>,
430 max_failures: u32,
431 on_agent_invalid: OnAgentInvalidCallback,
432}
433
434impl RandAgentBuilder {
435 pub fn new() -> Self {
437 Self {
438 agents: Vec::new(),
439 max_failures: 3, on_agent_invalid: None,
441 }
442 }
443
444 pub fn max_failures(mut self, max_failures: u32) -> Self {
446 self.max_failures = max_failures;
447 self
448 }
449
450 pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
452 where
453 F: Fn(i32) + Send + Sync + 'static,
454 {
455 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
456 self
457 }
458
459 pub fn add_agent(
466 mut self,
467 agent: AgentVariant,
468 id: i32,
469 provider_name: String,
470 model_name: String,
471 ) -> Self {
472 self.agents.push((agent, id, provider_name, model_name));
473 self
474 }
475
476 pub fn build(self) -> RandAgent {
478 RandAgent::with_max_failures_and_callback(
479 self.agents,
480 self.max_failures,
481 self.on_agent_invalid,
482 )
483 }
484}
485
486impl Default for RandAgentBuilder {
487 fn default() -> Self {
488 Self::new()
489 }
490}