1use std::sync::Arc;
51use rand::Rng;
52use rig::agent::Agent;
53use rig::client::builder::BoxAgent;
54use rig::client::completion::CompletionModelHandle;
55use rig::completion::{Message, Prompt, PromptError};
56use tokio::sync::Mutex;
57
58pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
60
61pub struct ThreadSafeRandAgent {
65 agents: Arc<Mutex<Vec<ThreadSafeAgentState>>>,
66 on_agent_invalid: OnAgentInvalidCallback,
67}
68
69#[derive(Clone)]
71pub struct ThreadSafeAgentState {
72 pub id: i32,
73 pub agent: Arc<BoxAgent<'static>>,
74 pub provider: String,
75 pub model: String,
76 pub failure_count: u32,
77 pub max_failures: u32,
78}
79
80impl Prompt for ThreadSafeRandAgent {
81 #[allow(refining_impl_trait)]
82 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
83 let agent_index = self.get_random_valid_agent_index().await
85 .ok_or(PromptError::MaxDepthError {
86 max_depth: 0,
87 chat_history: vec![],
88 prompt: "没有有效agent".into(),
89 })?;
90
91 let mut agents = self.agents.lock().await;
93 let agent_state = &mut agents[agent_index];
94
95 tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
96 match agent_state.agent.prompt(prompt).await {
97 Ok(content) => {
98 agent_state.record_success();
99 Ok(content)
100 }
101 Err(e) => {
102 agent_state.record_failure();
103 if !agent_state.is_valid() {
104 if let Some(cb) = &self.on_agent_invalid {
105 cb(agent_state.id);
106 }
107 }
108 Err(e)
109 }
110 }
111 }
112}
113
114impl ThreadSafeAgentState {
115 fn new(agent: BoxAgent<'static>,id: i32, provider: String, model: String, max_failures: u32) -> Self {
116 Self {
117 id,
118 agent: Arc::new(agent),
119 provider,
120 model,
121 failure_count: 0,
122 max_failures,
123 }
124 }
125
126 fn is_valid(&self) -> bool {
127 self.failure_count < self.max_failures
128 }
129
130 fn record_failure(&mut self) {
131 self.failure_count += 1;
132 }
133
134 fn record_success(&mut self) {
135 self.failure_count = 0;
136 }
137}
138
139impl ThreadSafeRandAgent {
140 pub fn new(agents: Vec<(BoxAgent<'static>, i32, String, String)>) -> Self {
142 Self::with_max_failures_and_callback(agents, 3, None)
143 }
144
145 pub fn with_max_failures_and_callback(
147 agents: Vec<(BoxAgent<'static>, i32, String, String)>,
148 max_failures: u32,
149 on_agent_invalid: OnAgentInvalidCallback,
150 ) -> Self {
151 let agent_states = agents
152 .into_iter()
153 .map(|(agent, id, provider, model)| ThreadSafeAgentState::new(agent, id, provider, model, max_failures))
154 .collect();
155 Self {
156 agents: Arc::new(Mutex::new(agent_states)),
157 on_agent_invalid,
158 }
159 }
160
161 pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, i32, String, String)>, max_failures: u32) -> Self {
163 Self::with_max_failures_and_callback(agents, max_failures, None)
164 }
165
166 pub fn set_on_agent_invalid<F>(&mut self, callback: F)
168 where
169 F: Fn(i32) + Send + Sync + 'static,
170 {
171 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
172 }
173
174 pub async fn add_agent(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String) {
176 let mut agents = self.agents.lock().await;
177 agents.push(ThreadSafeAgentState::new(agent, id, provider, model, 3));
178 }
179
180 pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String, max_failures: u32) {
182 let mut agents = self.agents.lock().await;
183 agents.push(ThreadSafeAgentState::new(agent, id, provider, model, max_failures));
184 }
185
186 pub async fn len(&self) -> usize {
188 let agents = self.agents.lock().await;
189 agents.iter().filter(|state| state.is_valid()).count()
190 }
191
192 pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
194 let agents = self.agents.lock().await;
195 let valid_indices: Vec<usize> = agents
196 .iter()
197 .enumerate()
198 .filter(|(_, state)| state.is_valid())
199 .map(|(i, _)| i)
200 .collect();
201
202 if valid_indices.is_empty() {
203 return None;
204 }
205
206 let mut rng = rand::rng();
207 let random_index = rng.random_range(0..valid_indices.len());
208 Some(valid_indices[random_index])
209 }
210
211 pub async fn get_random_valid_agent_state(&self) -> Option<ThreadSafeAgentState> {
214 let mut agents = self.agents.lock().await;
215
216 let valid_indices: Vec<usize> = agents
217 .iter()
218 .enumerate()
219 .filter(|(_, state)| state.is_valid())
220 .map(|(i, _)| i)
221 .collect();
222
223 if valid_indices.is_empty() {
224 return None;
225 }
226
227 let mut rng = rand::rng();
228 let random_index = rng.random_range(0..valid_indices.len());
229 let agent_index = valid_indices[random_index];
230 agents.get_mut(agent_index).cloned()
231 }
232
233
234 pub async fn total_len(&self) -> usize {
236 let agents = self.agents.lock().await;
237 agents.len()
238 }
239
240 pub async fn is_empty(&self) -> bool {
242 self.len().await == 0
243 }
244
245 pub async fn agents(&self) -> Vec<(String, String, u32, u32)> {
247 let agents = self.agents.lock().await;
248 agents
249 .iter()
250 .map(|state| (
251 state.provider.clone(),
252 state.model.clone(),
253 state.failure_count,
254 state.max_failures
255 ))
256 .collect()
257 }
258
259 pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
261 let agents = self.agents.lock().await;
262 agents
263 .iter()
264 .enumerate()
265 .map(|(i, state)| (i, state.failure_count, state.max_failures))
266 .collect()
267 }
268
269 pub async fn reset_failures(&self) {
271 let mut agents = self.agents.lock().await;
272 for state in agents.iter_mut() {
273 state.failure_count = 0;
274 }
275 }
276}
277
278pub struct ThreadSafeRandAgentBuilder {
280 agents: Vec<(BoxAgent<'static>, i32, String, String)>,
281 max_failures: u32,
282 on_agent_invalid: OnAgentInvalidCallback,
283}
284
285impl ThreadSafeRandAgentBuilder {
286 pub fn new() -> Self {
288 Self {
289 agents: Vec::new(),
290 max_failures: 3, on_agent_invalid: None,
292 }
293 }
294
295 pub fn max_failures(mut self, max_failures: u32) -> Self {
297 self.max_failures = max_failures;
298 self
299 }
300
301 pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
303 where
304 F: Fn(i32) + Send + Sync + 'static,
305 {
306 self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
307 self
308 }
309
310 pub fn add_agent(mut self, agent: BoxAgent<'static>, id: i32, provider_name: String, model_name: String) -> Self {
317 self.agents.push((agent, id, provider_name, model_name));
318 self
319 }
320
321 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, id: i32, provider_name: &str, model_name: &str) -> Self {
328 self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
329 self
330 }
331
332 pub fn build(self) -> ThreadSafeRandAgent {
334 ThreadSafeRandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
335 }
336}
337
338impl Default for ThreadSafeRandAgentBuilder {
339 fn default() -> Self {
340 Self::new()
341 }
342}