1use std::sync::Arc;
51use std::time::Duration;
52use backon::{Retryable, ExponentialBuilder};
53use rand::Rng;
54use rig::agent::Agent;
55use rig::client::builder::BoxAgent;
56use rig::client::completion::CompletionModelHandle;
57use rig::completion::{Message, Prompt, PromptError};
58use tokio::sync::Mutex;
59use crate::AgentInfo;
60use crate::error::RandAgentError;
61
62pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
64
65#[derive(Clone)]
69pub struct RandAgent {
70 agents: Arc<Mutex<Vec<AgentState>>>,
71 on_agent_invalid: OnAgentInvalidCallback,
72}
73
74#[derive(Clone)]
76pub struct AgentState {
77 pub id: i32,
78 pub agent: Arc<BoxAgent<'static>>,
79 pub info: AgentInfo,
80}
81
82impl Prompt for RandAgent {
83 #[allow(refining_impl_trait)]
84 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
85 let agent_index = self.get_random_valid_agent_index().await
87 .ok_or(PromptError::MaxDepthError {
88 max_depth: 0,
89 chat_history: vec![],
90 prompt: "没有有效agent".into(),
91 })?;
92
93 let mut agents = self.agents.lock().await;
95 let agent_state = &mut agents[agent_index];
96
97 tracing::info!("Using provider: {}, model: {},id: {}", agent_state.info.provider, agent_state.info.model,agent_state.info.id);
98 match agent_state.agent.prompt(prompt).await {
99 Ok(content) => {
100 agent_state.record_success();
101 Ok(content)
102 }
103 Err(e) => {
104 agent_state.record_failure();
105 if !agent_state.is_valid() {
106 if let Some(cb) = &self.on_agent_invalid {
107 cb(agent_state.id);
108 }
109 }
110 Err(e)
111 }
112 }
113 }
114}
115
116impl AgentState {
117 fn new(agent: BoxAgent<'static>,id: i32, provider: String, model: String, max_failures: u32) -> Self {
118 Self {
119 id,
120 agent: Arc::new(agent),
121 info: AgentInfo{
122 id,
123 provider,
124 model,
125 failure_count: 0,
126 max_failures,
127 }
128 }
129 }
130
131 fn is_valid(&self) -> bool {
132 self.info.failure_count < self.info.max_failures
133 }
134
135 fn record_failure(&mut self) {
136 self.info.failure_count += 1;
137 }
138
139 fn record_success(&mut self) {
140 self.info.failure_count = 0;
141 }
142}
143
144impl RandAgent {
145 pub fn new(agents: Vec<(BoxAgent<'static>, i32, String, String)>) -> Self {
147 Self::with_max_failures_and_callback(agents, 3, None)
148 }
149
150 pub fn with_max_failures_and_callback(
152 agents: Vec<(BoxAgent<'static>, i32, String, String)>,
153 max_failures: u32,
154 on_agent_invalid: OnAgentInvalidCallback,
155 ) -> Self {
156 let agent_states = agents
157 .into_iter()
158 .map(|(agent, id, provider, model)| AgentState::new(agent, id, provider, model, max_failures))
159 .collect();
160 Self {
161 agents: Arc::new(Mutex::new(agent_states)),
162 on_agent_invalid,
163 }
164 }
165
166 pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, i32, String, String)>, max_failures: u32) -> Self {
168 Self::with_max_failures_and_callback(agents, max_failures, None)
169 }
170
171 pub fn set_on_agent_invalid<F>(&mut self, callback: F)
173 where
174 F: Fn(i32) + Send + Sync + 'static,
175 {
176 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
177 }
178
179 pub async fn add_agent(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String) {
181 let mut agents = self.agents.lock().await;
182 agents.push(AgentState::new(agent, id, provider, model, 3));
183 }
184
185 pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String, max_failures: u32) {
187 let mut agents = self.agents.lock().await;
188 agents.push(AgentState::new(agent, id, provider, model, max_failures));
189 }
190
191 pub async fn len(&self) -> usize {
193 let agents = self.agents.lock().await;
194 agents.iter().filter(|state| state.is_valid()).count()
195 }
196
197 pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
199 let agents = self.agents.lock().await;
200 let valid_indices: Vec<usize> = agents
201 .iter()
202 .enumerate()
203 .filter(|(_, state)| state.is_valid())
204 .map(|(i, _)| i)
205 .collect();
206
207 if valid_indices.is_empty() {
208 return None;
209 }
210
211 let mut rng = rand::rng();
212 let random_index = rng.random_range(0..valid_indices.len());
213 Some(valid_indices[random_index])
214 }
215
216 pub async fn get_random_valid_agent_state(&self) -> Option<AgentState> {
219 let mut agents = self.agents.lock().await;
220
221 let valid_indices: Vec<usize> = agents
222 .iter()
223 .enumerate()
224 .filter(|(_, state)| state.is_valid())
225 .map(|(i, _)| i)
226 .collect();
227
228 if valid_indices.is_empty() {
229 return None;
230 }
231
232 let mut rng = rand::rng();
233 let random_index = rng.random_range(0..valid_indices.len());
234 let agent_index = valid_indices[random_index];
235 agents.get_mut(agent_index).cloned()
236 }
237
238
239 pub async fn total_len(&self) -> usize {
241 let agents = self.agents.lock().await;
242 agents.len()
243 }
244
245 pub async fn is_empty(&self) -> bool {
247 self.len().await == 0
248 }
249
250 pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
252 let agents = self.agents.lock().await;
253 let agent_infos = agents.iter()
254 .map(|agent|{
255 agent.info.clone()
256 }).collect::<_>();
257 tracing::info!("agents info: {:?}", agent_infos);
258 agent_infos
259 }
260
261 pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
263 let agents = self.agents.lock().await;
264 agents
265 .iter()
266 .enumerate()
267 .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
268 .collect()
269 }
270
271 pub async fn reset_failures(&self) {
273 let mut agents = self.agents.lock().await;
274 for state in agents.iter_mut() {
275 state.info.failure_count = 0;
276 }
277 }
278
279 pub async fn get_agent_by_name(&self,provider_name: &str, model_name: &str) -> Option<AgentState> {
281 let mut agents = self.agents.lock().await;
282
283 for agent in agents.iter_mut() {
284 if agent.info.provider == provider_name && agent.info.model == model_name {
285 return Some(agent.clone());
286 }
287 }
288
289 None
290 }
291
292 pub async fn get_agent_by_id(&self,id:i32) -> Option<AgentState> {
294 let mut agents = self.agents.lock().await;
295
296 for agent in agents.iter_mut() {
297 if agent.info.id == id {
298 return Some(agent.clone());
299 }
300 }
301
302 None
303 }
304
305 pub async fn try_invoke_with_retry(&self, info: Message, retry_num: Option<usize>) -> Result<String, RandAgentError> {
307 let mut config = ExponentialBuilder::default();
308 if let Some(retry_num) = retry_num {
309 config = config.with_max_times(retry_num)
310 }
311
312 let info = Arc::new(info);
313
314 let content = (|| {
315 let agent = self.clone();
316 let prompt = info.clone();
317 async move {
318 agent.prompt((*prompt).clone()).await
319 }
320 })
321 .retry(config)
322 .sleep(tokio::time::sleep)
323 .notify(|err: &PromptError, dur: Duration| {
324 println!("retrying {err:?} after {dur:?}");
325 })
326 .await?;
327 Ok(content)
328 }
329
330 #[allow(refining_impl_trait)]
331 pub async fn prompt_with_info(&self, prompt: impl Into<Message> + Send) -> Result<(String,AgentInfo), PromptError> {
332 let agent_index = self.get_random_valid_agent_index().await
334 .ok_or(PromptError::MaxDepthError {
335 max_depth: 0,
336 chat_history: vec![],
337 prompt: "没有有效agent".into(),
338 })?;
339
340 let mut agents = self.agents.lock().await;
342 let agent_state = &mut agents[agent_index];
343
344 let agent_info = agent_state.info.clone();
345
346 tracing::info!("prompt_with_info Using provider: {}, model: {},id: {}", agent_state.info.provider, agent_state.info.model,agent_state.info.id);
347 match agent_state.agent.prompt(prompt).await {
348 Ok(content) => {
349 agent_state.record_success();
350 Ok((content, agent_info))
351 }
352 Err(e) => {
353 agent_state.record_failure();
354 if !agent_state.is_valid() {
355 if let Some(cb) = &self.on_agent_invalid {
356 cb(agent_state.id);
357 }
358 }
359 Err(e)
360 }
361 }
362 }
363
364 pub async fn try_invoke_with_info_retry(&self, info: Message, retry_num: Option<usize>) -> Result<(String,AgentInfo), RandAgentError> {
366 let mut config = ExponentialBuilder::default();
367 if let Some(retry_num) = retry_num {
368 config = config.with_max_times(retry_num)
369 }
370
371 let info = Arc::new(info);
372
373 let content = (|| {
374 let agent = self.clone();
375 let prompt = info.clone();
376 async move {
377 agent.prompt_with_info((*prompt).clone()).await
378 }
379 })
380 .retry(config)
381 .sleep(tokio::time::sleep)
382 .notify(|err: &PromptError, dur: Duration| {
383 println!("retrying {err:?} after {dur:?}");
384 })
385 .await?;
386 Ok(content)
387 }
388}
389
390pub struct RandAgentBuilder {
392 pub(crate) agents: Vec<(BoxAgent<'static>, i32, String, String)>,
393 max_failures: u32,
394 on_agent_invalid: OnAgentInvalidCallback,
395}
396
397impl RandAgentBuilder {
398 pub fn new() -> Self {
400 Self {
401 agents: Vec::new(),
402 max_failures: 3, on_agent_invalid: None,
404 }
405 }
406
407 pub fn max_failures(mut self, max_failures: u32) -> Self {
409 self.max_failures = max_failures;
410 self
411 }
412
413 pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
415 where
416 F: Fn(i32) + Send + Sync + 'static,
417 {
418 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
419 self
420 }
421
422 pub fn add_agent(mut self, agent: BoxAgent<'static>, id: i32, provider_name: String, model_name: String) -> Self {
429 self.agents.push((agent, id, provider_name, model_name));
430 self
431 }
432
433 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, id: i32, provider_name: &str, model_name: &str) -> Self {
440 self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
441 self
442 }
443
444 pub fn build(self) -> RandAgent {
446 RandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
447 }
448
449}
450
451impl Default for RandAgentBuilder {
452 fn default() -> Self {
453 Self::new()
454 }
455}