1use crate::AgentInfo;
51use crate::agent_variant::AgentVariant;
52use crate::error::RandAgentError;
53use backon::{ExponentialBuilder, Retryable};
54use rand::Rng;
55use rig::completion::{CompletionError, Message, Prompt, PromptError};
56use std::sync::Arc;
57use std::time::Duration;
58use tokio::sync::Mutex;
59
60pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
62
63#[derive(Clone)]
67pub struct RandAgent {
68 agents: Arc<Mutex<Vec<AgentState>>>,
69 on_agent_invalid: OnAgentInvalidCallback,
70}
71
72#[derive(Clone)]
74pub struct AgentState {
75 pub id: i32,
76 pub agent: Arc<AgentVariant>,
77 pub info: AgentInfo,
78}
79
80impl Prompt for RandAgent {
81 #[allow(refining_impl_trait)]
82 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
83 let agent_index = self
85 .get_random_valid_agent_index()
86 .await
87 .ok_or(CompletionError::ProviderError("没有有效agent".to_string()))?;
88
89 let mut agents = self.agents.lock().await;
91 let agent_state = &mut agents[agent_index];
92
93 tracing::info!(
94 "Using provider: {}, model: {},id: {}",
95 agent_state.info.provider,
96 agent_state.info.model,
97 agent_state.info.id
98 );
99 match agent_state.agent.prompt(prompt).await {
100 Ok(content) => {
101 agent_state.record_success();
102 Ok(content)
103 }
104 Err(e) => {
105 agent_state.record_failure();
106 if !agent_state.is_valid()
107 && let Some(cb) = &self.on_agent_invalid
108 {
109 cb(agent_state.id);
110 }
111 Err(e)
112 }
113 }
114 }
115}
116
117impl AgentState {
118 fn new(
119 agent: AgentVariant,
120 id: i32,
121 provider: String,
122 model: String,
123 max_failures: u32,
124 ) -> Self {
125 Self {
126 id,
127 agent: Arc::new(agent),
128 info: AgentInfo {
129 id,
130 provider,
131 model,
132 failure_count: 0,
133 max_failures,
134 },
135 }
136 }
137
138 fn is_valid(&self) -> bool {
139 self.info.failure_count < self.info.max_failures
140 }
141
142 fn record_failure(&mut self) {
143 self.info.failure_count += 1;
144 }
145
146 fn record_success(&mut self) {
147 self.info.failure_count = 0;
148 }
149}
150
151impl RandAgent {
152 pub fn new(agents: Vec<(AgentVariant, i32, String, String)>) -> Self {
154 Self::with_max_failures_and_callback(agents, 3, None)
155 }
156
157 pub fn with_max_failures_and_callback(
159 agents: Vec<(AgentVariant, i32, String, String)>,
160 max_failures: u32,
161 on_agent_invalid: OnAgentInvalidCallback,
162 ) -> Self {
163 let agent_states = agents
164 .into_iter()
165 .map(|(agent, id, provider, model)| {
166 AgentState::new(agent, id, provider, model, max_failures)
167 })
168 .collect();
169 Self {
170 agents: Arc::new(Mutex::new(agent_states)),
171 on_agent_invalid,
172 }
173 }
174
175 pub fn with_max_failures(
177 agents: Vec<(AgentVariant, i32, String, String)>,
178 max_failures: u32,
179 ) -> Self {
180 Self::with_max_failures_and_callback(agents, max_failures, None)
181 }
182
183 pub fn set_on_agent_invalid<F>(&mut self, callback: F)
185 where
186 F: Fn(i32) + Send + Sync + 'static,
187 {
188 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
189 }
190
191 pub async fn add_agent(&self, agent: AgentVariant, id: i32, provider: String, model: String) {
193 let mut agents = self.agents.lock().await;
194 agents.push(AgentState::new(agent, id, provider, model, 3));
195 }
196
197 pub async fn add_agent_with_max_failures(
199 &self,
200 agent: AgentVariant,
201 id: i32,
202 provider: String,
203 model: String,
204 max_failures: u32,
205 ) {
206 let mut agents = self.agents.lock().await;
207 agents.push(AgentState::new(agent, id, provider, model, max_failures));
208 }
209
210 pub async fn len(&self) -> usize {
212 let agents = self.agents.lock().await;
213 agents.iter().filter(|state| state.is_valid()).count()
214 }
215
216 pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
218 let agents = self.agents.lock().await;
219 let valid_indices: Vec<usize> = agents
220 .iter()
221 .enumerate()
222 .filter(|(_, state)| state.is_valid())
223 .map(|(i, _)| i)
224 .collect();
225
226 if valid_indices.is_empty() {
227 return None;
228 }
229
230 let mut rng = rand::rng();
231 let random_index = rng.random_range(0..valid_indices.len());
232 Some(valid_indices[random_index])
233 }
234
235 pub async fn get_random_valid_agent_state(&self) -> Option<AgentState> {
238 let mut agents = self.agents.lock().await;
239
240 let valid_indices: Vec<usize> = agents
241 .iter()
242 .enumerate()
243 .filter(|(_, state)| state.is_valid())
244 .map(|(i, _)| i)
245 .collect();
246
247 if valid_indices.is_empty() {
248 return None;
249 }
250
251 let mut rng = rand::rng();
252 let random_index = rng.random_range(0..valid_indices.len());
253 let agent_index = valid_indices[random_index];
254 agents.get_mut(agent_index).cloned()
255 }
256
257 pub async fn total_len(&self) -> usize {
259 let agents = self.agents.lock().await;
260 agents.len()
261 }
262
263 pub async fn is_empty(&self) -> bool {
265 self.len().await == 0
266 }
267
268 pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
270 let agents = self.agents.lock().await;
271 let agent_infos = agents.iter().map(|agent| agent.info.clone()).collect::<_>();
272 tracing::info!("agents info: {:?}", agent_infos);
273 agent_infos
274 }
275
276 pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
278 let agents = self.agents.lock().await;
279 agents
280 .iter()
281 .enumerate()
282 .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
283 .collect()
284 }
285
286 pub async fn reset_failures(&self) {
288 let mut agents = self.agents.lock().await;
289 for state in agents.iter_mut() {
290 state.info.failure_count = 0;
291 }
292 }
293
294 pub async fn get_agent_by_name(
296 &self,
297 provider_name: &str,
298 model_name: &str,
299 ) -> Option<AgentState> {
300 let mut agents = self.agents.lock().await;
301
302 for agent in agents.iter_mut() {
303 if agent.info.provider == provider_name && agent.info.model == model_name {
304 return Some(agent.clone());
305 }
306 }
307
308 None
309 }
310
311 pub async fn get_agent_by_id(&self, id: i32) -> Option<AgentState> {
313 let mut agents = self.agents.lock().await;
314
315 for agent in agents.iter_mut() {
316 if agent.info.id == id {
317 return Some(agent.clone());
318 }
319 }
320
321 None
322 }
323
324 pub async fn try_invoke_with_retry(
326 &self,
327 info: Message,
328 retry_num: Option<usize>,
329 ) -> Result<String, RandAgentError> {
330 let mut config = ExponentialBuilder::default();
331 if let Some(retry_num) = retry_num {
332 config = config.with_max_times(retry_num)
333 }
334
335 let info = Arc::new(info);
336
337 let content = (|| {
338 let agent = self.clone();
339 let prompt = info.clone();
340 async move { agent.prompt((*prompt).clone()).await }
341 })
342 .retry(config)
343 .sleep(tokio::time::sleep)
344 .notify(|err: &PromptError, dur: Duration| {
345 println!("retrying {err:?} after {dur:?}");
346 })
347 .await?;
348 Ok(content)
349 }
350
351 #[allow(refining_impl_trait)]
352 pub async fn prompt_with_info(
353 &self,
354 prompt: impl Into<Message> + Send,
355 ) -> Result<(String, AgentInfo), PromptError> {
356 let agent_index = self
358 .get_random_valid_agent_index()
359 .await
360 .ok_or(CompletionError::ProviderError("没有有效agent".to_string()))?;
361
362 let mut agents = self.agents.lock().await;
364 let agent_state = &mut agents[agent_index];
365
366 let agent_info = agent_state.info.clone();
367
368 tracing::info!(
369 "prompt_with_info Using provider: {}, model: {},id: {}",
370 agent_state.info.provider,
371 agent_state.info.model,
372 agent_state.info.id
373 );
374 match agent_state.agent.prompt(prompt).await {
375 Ok(content) => {
376 agent_state.record_success();
377 Ok((content, agent_info))
378 }
379 Err(e) => {
380 agent_state.record_failure();
381 if !agent_state.is_valid()
382 && let Some(cb) = &self.on_agent_invalid
383 {
384 cb(agent_state.id);
385 }
386 Err(e)
387 }
388 }
389 }
390
391 pub async fn try_invoke_with_info_retry(
393 &self,
394 info: Message,
395 retry_num: Option<usize>,
396 ) -> Result<(String, AgentInfo), RandAgentError> {
397 let mut config = ExponentialBuilder::default();
398 if let Some(retry_num) = retry_num {
399 config = config.with_max_times(retry_num)
400 }
401
402 let info = Arc::new(info);
403
404 let content = (|| {
405 let agent = self.clone();
406 let prompt = info.clone();
407 async move { agent.prompt_with_info((*prompt).clone()).await }
408 })
409 .retry(config)
410 .sleep(tokio::time::sleep)
411 .notify(|err: &PromptError, dur: Duration| {
412 println!("retrying {err:?} after {dur:?}");
413 })
414 .await?;
415 Ok(content)
416 }
417}
418
419pub struct RandAgentBuilder {
421 pub(crate) agents: Vec<(AgentVariant, i32, String, String)>,
422 max_failures: u32,
423 on_agent_invalid: OnAgentInvalidCallback,
424}
425
426impl RandAgentBuilder {
427 pub fn new() -> Self {
429 Self {
430 agents: Vec::new(),
431 max_failures: 3, on_agent_invalid: None,
433 }
434 }
435
436 pub fn max_failures(mut self, max_failures: u32) -> Self {
438 self.max_failures = max_failures;
439 self
440 }
441
442 pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
444 where
445 F: Fn(i32) + Send + Sync + 'static,
446 {
447 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
448 self
449 }
450
451 pub fn add_agent(
458 mut self,
459 agent: AgentVariant,
460 id: i32,
461 provider_name: String,
462 model_name: String,
463 ) -> Self {
464 self.agents.push((agent, id, provider_name, model_name));
465 self
466 }
467
468 pub fn build(self) -> RandAgent {
470 RandAgent::with_max_failures_and_callback(
471 self.agents,
472 self.max_failures,
473 self.on_agent_invalid,
474 )
475 }
476}
477
478impl Default for RandAgentBuilder {
479 fn default() -> Self {
480 Self::new()
481 }
482}