1use std::sync::Arc;
51use rand::Rng;
52use rig::agent::Agent;
53use rig::client::builder::BoxAgent;
54use rig::client::completion::CompletionModelHandle;
55use rig::completion::{Message, Prompt, PromptError};
56use tokio::sync::Mutex;
57use crate::AgentInfo;
58
59pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
61
62pub struct ThreadSafeRandAgent {
66 agents: Arc<Mutex<Vec<ThreadSafeAgentState>>>,
67 on_agent_invalid: OnAgentInvalidCallback,
68}
69
70#[derive(Clone)]
72pub struct ThreadSafeAgentState {
73 pub id: i32,
74 pub agent: Arc<BoxAgent<'static>>,
75 pub info: AgentInfo,
76}
77
78impl Prompt for ThreadSafeRandAgent {
79 #[allow(refining_impl_trait)]
80 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
81 let agent_index = self.get_random_valid_agent_index().await
83 .ok_or(PromptError::MaxDepthError {
84 max_depth: 0,
85 chat_history: vec![],
86 prompt: "没有有效agent".into(),
87 })?;
88
89 let mut agents = self.agents.lock().await;
91 let agent_state = &mut agents[agent_index];
92
93 tracing::info!("Using provider: {}, model: {}", agent_state.info.provider, agent_state.info.model);
94 match agent_state.agent.prompt(prompt).await {
95 Ok(content) => {
96 agent_state.record_success();
97 Ok(content)
98 }
99 Err(e) => {
100 agent_state.record_failure();
101 if !agent_state.is_valid() {
102 if let Some(cb) = &self.on_agent_invalid {
103 cb(agent_state.id);
104 }
105 }
106 Err(e)
107 }
108 }
109 }
110}
111
112impl ThreadSafeAgentState {
113 fn new(agent: BoxAgent<'static>,id: i32, provider: String, model: String, max_failures: u32) -> Self {
114 Self {
115 id,
116 agent: Arc::new(agent),
117 info: AgentInfo{
118 id,
119 provider,
120 model,
121 failure_count: 0,
122 max_failures,
123 }
124 }
125 }
126
127 fn is_valid(&self) -> bool {
128 self.info.failure_count < self.info.max_failures
129 }
130
131 fn record_failure(&mut self) {
132 self.info.failure_count += 1;
133 }
134
135 fn record_success(&mut self) {
136 self.info.failure_count = 0;
137 }
138}
139
140impl ThreadSafeRandAgent {
141 pub fn new(agents: Vec<(BoxAgent<'static>, i32, String, String)>) -> Self {
143 Self::with_max_failures_and_callback(agents, 3, None)
144 }
145
146 pub fn with_max_failures_and_callback(
148 agents: Vec<(BoxAgent<'static>, i32, String, String)>,
149 max_failures: u32,
150 on_agent_invalid: OnAgentInvalidCallback,
151 ) -> Self {
152 let agent_states = agents
153 .into_iter()
154 .map(|(agent, id, provider, model)| ThreadSafeAgentState::new(agent, id, provider, model, max_failures))
155 .collect();
156 Self {
157 agents: Arc::new(Mutex::new(agent_states)),
158 on_agent_invalid,
159 }
160 }
161
162 pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, i32, String, String)>, max_failures: u32) -> Self {
164 Self::with_max_failures_and_callback(agents, max_failures, None)
165 }
166
167 pub fn set_on_agent_invalid<F>(&mut self, callback: F)
169 where
170 F: Fn(i32) + Send + Sync + 'static,
171 {
172 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
173 }
174
175 pub async fn add_agent(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String) {
177 let mut agents = self.agents.lock().await;
178 agents.push(ThreadSafeAgentState::new(agent, id, provider, model, 3));
179 }
180
181 pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String, max_failures: u32) {
183 let mut agents = self.agents.lock().await;
184 agents.push(ThreadSafeAgentState::new(agent, id, provider, model, max_failures));
185 }
186
187 pub async fn len(&self) -> usize {
189 let agents = self.agents.lock().await;
190 agents.iter().filter(|state| state.is_valid()).count()
191 }
192
193 pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
195 let agents = self.agents.lock().await;
196 let valid_indices: Vec<usize> = agents
197 .iter()
198 .enumerate()
199 .filter(|(_, state)| state.is_valid())
200 .map(|(i, _)| i)
201 .collect();
202
203 if valid_indices.is_empty() {
204 return None;
205 }
206
207 let mut rng = rand::rng();
208 let random_index = rng.random_range(0..valid_indices.len());
209 Some(valid_indices[random_index])
210 }
211
212 pub async fn get_random_valid_agent_state(&self) -> Option<ThreadSafeAgentState> {
215 let mut agents = self.agents.lock().await;
216
217 let valid_indices: Vec<usize> = agents
218 .iter()
219 .enumerate()
220 .filter(|(_, state)| state.is_valid())
221 .map(|(i, _)| i)
222 .collect();
223
224 if valid_indices.is_empty() {
225 return None;
226 }
227
228 let mut rng = rand::rng();
229 let random_index = rng.random_range(0..valid_indices.len());
230 let agent_index = valid_indices[random_index];
231 agents.get_mut(agent_index).cloned()
232 }
233
234
235 pub async fn total_len(&self) -> usize {
237 let agents = self.agents.lock().await;
238 agents.len()
239 }
240
241 pub async fn is_empty(&self) -> bool {
243 self.len().await == 0
244 }
245
246 #[deprecated(since = "0.6.1", note = "Renamed to `get_agent_info`")]
248 pub async fn agents(&self) -> Vec<(String, String, u32, u32)> {
249 let agents = self.agents.lock().await;
250 agents
251 .iter()
252 .map(|state| (
253 state.info.provider.clone(),
254 state.info.model.clone(),
255 state.info.failure_count,
256 state.info.max_failures
257 ))
258 .collect()
259 }
260
261 pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
263 let agents = self.agents.lock().await;
264 let agent_infos = agents.iter()
265 .map(|agent|{
266 agent.info.clone()
267 }).collect::<_>();
268 tracing::info!("agents info: {:?}", agent_infos);
269 agent_infos
270 }
271
272 pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
274 let agents = self.agents.lock().await;
275 agents
276 .iter()
277 .enumerate()
278 .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
279 .collect()
280 }
281
282 pub async fn reset_failures(&self) {
284 let mut agents = self.agents.lock().await;
285 for state in agents.iter_mut() {
286 state.info.failure_count = 0;
287 }
288 }
289
290 pub async fn get_agent_by_name(&self,provider_name: &str, model_name: &str) -> Option<ThreadSafeAgentState> {
292 let mut agents = self.agents.lock().await;
293
294 for agent in agents.iter_mut() {
295 if agent.info.provider == provider_name && agent.info.model == model_name {
296 return Some(agent.clone());
297 }
298 }
299
300 None
301 }
302
303 pub async fn get_agent_by_id(&self,id:i32) -> Option<ThreadSafeAgentState> {
305 let mut agents = self.agents.lock().await;
306
307 for agent in agents.iter_mut() {
308 if agent.info.id == id {
309 return Some(agent.clone());
310 }
311 }
312
313 None
314 }
315}
316
317pub struct ThreadSafeRandAgentBuilder {
319 pub(crate) agents: Vec<(BoxAgent<'static>, i32, String, String)>,
320 max_failures: u32,
321 on_agent_invalid: OnAgentInvalidCallback,
322}
323
324impl ThreadSafeRandAgentBuilder {
325 pub fn new() -> Self {
327 Self {
328 agents: Vec::new(),
329 max_failures: 3, on_agent_invalid: None,
331 }
332 }
333
334 pub fn max_failures(mut self, max_failures: u32) -> Self {
336 self.max_failures = max_failures;
337 self
338 }
339
340 pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
342 where
343 F: Fn(i32) + Send + Sync + 'static,
344 {
345 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
346 self
347 }
348
349 pub fn add_agent(mut self, agent: BoxAgent<'static>, id: i32, provider_name: String, model_name: String) -> Self {
356 self.agents.push((agent, id, provider_name, model_name));
357 self
358 }
359
360 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, id: i32, provider_name: &str, model_name: &str) -> Self {
367 self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
368 self
369 }
370
371 pub fn build(self) -> ThreadSafeRandAgent {
373 ThreadSafeRandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
374 }
375
376}
377
378impl Default for ThreadSafeRandAgentBuilder {
379 fn default() -> Self {
380 Self::new()
381 }
382}