1use rand::Rng;
59use rig::agent::{Agent};
60use rig::client::builder::BoxAgent;
61use rig::completion::{Message, Prompt, PromptError};
62use rig::client::completion::CompletionModelHandle;
63use tokio::sync::Mutex;
64use crate::AgentInfo;
65
66pub struct AgentState<'a> {
68 id: i32,
69 agent: BoxAgent<'a>,
70 pub info: AgentInfo,
71}
72
73impl<'a> AgentState<'a> {
74 fn new(agent: BoxAgent<'a>,id:i32, provider: String, model: String, max_failures: u32) -> Self {
75 Self {
76 id,
77 agent,
78 info: AgentInfo{
79 id,
80 provider,
81 model,
82 failure_count: 0,
83 max_failures,
84 }
85 }
86 }
87
88 fn is_valid(&self) -> bool {
89 self.info.failure_count < self.info.max_failures
90 }
91
92 fn record_failure(&mut self) {
93 self.info.failure_count += 1;
94 }
95
96 fn record_success(&mut self) {
97 self.info.failure_count = 0;
98 }
99}
100
101pub type OnRandAgentInvalidCallback = Option<Box<dyn Fn(i32) + Send + Sync + 'static>>;
103
104#[deprecated(since = "0.7.0", note = "使用 `ThreadSafeRandAgent`")]
106pub struct RandAgent<'a> {
107 agents: Mutex<Vec<AgentState<'a>>>,
108 on_agent_invalid: OnRandAgentInvalidCallback,
109}
110
111impl Prompt for RandAgent<'_> {
112 #[allow(refining_impl_trait)]
113 async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
114 let mut agents = self.agents.lock().await;
115 let agent_state = Self::get_random_valid_agent(&mut agents)
116 .await
117 .ok_or(PromptError::MaxDepthError {
118 max_depth: 0,
119 chat_history: vec![],
120 prompt: "没有有效agent".into(),
121 })?;
122
123 tracing::info!("Using provider: {}, model: {}", agent_state.info.provider, agent_state.info.model);
124 match agent_state.agent.prompt(prompt).await {
125 Ok(content) => {
126 agent_state.record_success();
127 Ok(content)
128 }
129 Err(e) => {
130 agent_state.record_failure();
131 if !agent_state.is_valid() {
132 if let Some(cb) = &self.on_agent_invalid {
133 cb(agent_state.id);
134 }
135 }
136 Err(e)
137 }
138 }
139 }
140}
141
142impl<'a> RandAgent<'a> {
143 pub fn new(agents: Vec<(BoxAgent<'a>, i32, String, String)>) -> Self {
145 Self::with_max_failures_and_callback(agents, 3, None)
146 }
147
148 pub fn with_max_failures_and_callback(
150 agents: Vec<(BoxAgent<'a>, i32, String, String)>,
151 max_failures: u32,
152 on_agent_invalid: OnRandAgentInvalidCallback,
153 ) -> Self {
154 let agent_states = agents
155 .into_iter()
156 .map(|(agent, id, provider, model)| AgentState::new(agent, id, provider, model, max_failures))
157 .collect();
158 Self {
159 agents: Mutex::new(agent_states),
160 on_agent_invalid,
161 }
162 }
163
164 pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, i32,String, String)>, max_failures: u32) -> Self {
166 Self::with_max_failures_and_callback(agents, max_failures, None)
167 }
168
169 pub fn set_on_agent_invalid<F>(&mut self, callback: F)
171 where
172 F: Fn(i32) + Send + Sync + 'static,
173 {
174 self.on_agent_invalid = Some(Box::new(callback));
175 }
176
177
178 pub async fn add_agent(&self, agent: BoxAgent<'a>, id: i32, provider: String, model: String) {
180 self.agents.lock().await.push(AgentState::new(agent, id, provider, model, 3)); }
182
183 pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'a>, id: i32, provider: String, model: String, max_failures: u32) {
185 self.agents.lock().await.push(AgentState::new(agent, id, provider, model, max_failures));
186 }
187
188 pub async fn len(&self) -> usize {
190 self.agents.lock().await.iter().filter(|state| state.is_valid()).count()
191 }
192
193 pub async fn total_len(&self) -> usize {
195 self.agents.lock().await.len()
196 }
197
198 pub async fn is_empty(&self) -> bool {
200 self.len().await == 0
201 }
202
203 async fn get_random_valid_agent<'b>(agents: &'b mut Vec<AgentState<'a>>) -> Option<&'b mut AgentState<'a>> {
205 let valid_indices: Vec<usize> = agents
206 .iter()
207 .enumerate()
208 .filter(|(_, state)| state.is_valid())
209 .map(|(i, _)| i)
210 .collect();
211
212 if valid_indices.is_empty() {
213 return None;
214 }
215
216 let mut rng = rand::rng();
217 let random_index = rng.random_range(0..valid_indices.len());
218 let agent_index = valid_indices[random_index];
219 agents.get_mut(agent_index)
220 }
221
222
223 pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
225 self.agents
226 .lock()
227 .await
228 .iter()
229 .enumerate()
230 .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
231 .collect()
232 }
233
234 pub async fn reset_failures(&self) {
236 for state in self.agents.lock().await.iter_mut() {
237 state.info.failure_count = 0;
238 }
239 }
240}
241
242
243
244pub struct RandAgentBuilder<'a> {
246 pub(crate) agents: Vec<(BoxAgent<'a>, i32, String, String)>,
247 max_failures: u32,
248 on_agent_invalid: Option<Box<dyn Fn(i32) + Send + Sync + 'static>>,
249}
250
251impl<'a> RandAgentBuilder<'a> {
252 pub fn new() -> Self {
254 Self {
255 agents: Vec::new(),
256 max_failures: 3, on_agent_invalid: None,
258 }
259 }
260
261 pub fn max_failures(mut self, max_failures: u32) -> Self {
263 self.max_failures = max_failures;
264 self
265 }
266
267 pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
269 where
270 F: Fn(i32) + Send + Sync + 'static,
271 {
272 self.on_agent_invalid = Some(Box::new(callback));
273 self
274 }
275
276 pub fn add_agent(mut self, agent: BoxAgent<'a>, id: i32, provider_name: String, model_name: String) -> Self {
283 self.agents.push((agent, id, provider_name, model_name));
284 self
285 }
286
287 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, id: i32, provider_name: &str, model_name: &str) -> Self {
296 self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
297 self
298 }
299
300 pub fn build(self) -> RandAgent<'a> {
302 RandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
303 }
304}
305
306impl<'a> Default for RandAgentBuilder<'a> {
307 fn default() -> Self {
308 Self::new()
309 }
310}