1use rand::Rng;
55use rig::agent::{Agent, AgentBuilder};
56use rig::client::builder::BoxAgent;
57use rig::completion::Prompt;
58use rig::client::completion::CompletionModelHandle;
59
60
61pub struct AgentState<'a> {
63 agent: BoxAgent<'a>,
64 provider: String,
65 model: String,
66 failure_count: u32,
67 max_failures: u32,
68}
69
70impl<'a> AgentState<'a> {
71 fn new(agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) -> Self {
72 Self {
73 agent,
74 provider,
75 model,
76 failure_count: 0,
77 max_failures,
78 }
79 }
80
81 fn is_valid(&self) -> bool {
82 self.failure_count < self.max_failures
83 }
84
85 fn record_failure(&mut self) {
86 self.failure_count += 1;
87 }
88
89 fn record_success(&mut self) {
90 self.failure_count = 0;
91 }
92}
93
94pub struct RandAgent<'a> {
96 agents: Vec<AgentState<'a>>,
97}
98
99impl<'a> RandAgent<'a> {
100 pub fn new(agents: Vec<(BoxAgent<'a>, String, String)>) -> Self {
102 Self::with_max_failures(agents, 3) }
104
105 pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, String, String)>, max_failures: u32) -> Self {
107 let agent_states = agents
108 .into_iter()
109 .map(|(agent, provider, model)| AgentState::new(agent, provider, model, max_failures))
110 .collect();
111 Self {
112 agents: agent_states,
113 }
114 }
115
116 #[deprecated(note = "请使用 RandAgentBuilder::add_agent/add_builder 方式构建并传递 provider/model")]
118 pub fn from_builders(_builders: Vec<AgentBuilder<CompletionModelHandle<'a>>>) -> Self {
119 unimplemented!("from_builders 已废弃,请使用 RandAgentBuilder::add_agent/add_builder 方式");
120 }
121
122 #[deprecated(note = "请使用 RandAgentBuilder::add_agent/add_builder 方式构建并传递 provider/model")]
124 pub fn from_builders_with_max_failures(
125 _builders: Vec<AgentBuilder<CompletionModelHandle<'a>>>,
126 _max_failures: u32,
127 ) -> Self {
128 unimplemented!("from_builders_with_max_failures 已废弃,请使用 RandAgentBuilder::add_agent/add_builder 方式");
129 }
130
131 pub fn add_agent(&mut self, agent: BoxAgent<'a>, provider: String, model: String) {
133 self.agents.push(AgentState::new(agent, provider, model, 3)); }
135
136 pub fn add_agent_with_max_failures(&mut self, agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) {
138 self.agents.push(AgentState::new(agent, provider, model, max_failures));
139 }
140
141 pub fn len(&self) -> usize {
143 self.agents.iter().filter(|state| state.is_valid()).count()
144 }
145
146 pub fn total_len(&self) -> usize {
148 self.agents.len()
149 }
150
151 pub fn is_empty(&self) -> bool {
153 self.len() == 0
154 }
155
156 async fn get_random_valid_agent(&mut self) -> Option<&mut AgentState<'a>> {
158 let valid_indices: Vec<usize> = self
159 .agents
160 .iter()
161 .enumerate()
162 .filter(|(_, state)| state.is_valid())
163 .map(|(i, _)| i)
164 .collect();
165
166 if valid_indices.is_empty() {
167 return None;
168 }
169
170 let mut rng = rand::rng();
171 let random_index = rng.random_range(0..valid_indices.len());
172 let agent_index = valid_indices[random_index];
173 self.agents.get_mut(agent_index)
174 }
175
176 pub async fn prompt(
178 &mut self,
179 message: &str,
180 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
181 let agent_state = self
182 .get_random_valid_agent()
183 .await
184 .ok_or("No valid agents available")?;
185
186 tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
188 match agent_state.agent.prompt(message).await {
189 Ok(response) => {
190 agent_state.record_success();
191 Ok(response)
192 }
193 Err(e) => {
194 agent_state.record_failure();
195 Err(e.into())
196 }
197 }
198 }
199
200 pub async fn stream_prompt(
202 &self,
203 _message: &str,
204 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
205 Err("Streaming not implemented for RandAgent".into())
206 }
207
208 pub fn agents(&self) -> &[AgentState<'a>] {
210 &self.agents
211 }
212
213 pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
215 self.agents
216 .iter()
217 .enumerate()
218 .map(|(i, state)| (i, state.failure_count, state.max_failures))
219 .collect()
220 }
221
222 pub fn reset_failures(&mut self) {
224 for state in &mut self.agents {
225 state.failure_count = 0;
226 }
227 }
228}
229
230pub struct RandAgentBuilder<'a> {
235 agents: Vec<(BoxAgent<'a>, String, String)>,
236 max_failures: u32,
237}
238
239impl<'a> RandAgentBuilder<'a> {
240 pub fn new() -> Self {
242 Self {
243 agents: Vec::new(),
244 max_failures: 3, }
246 }
247
248 pub fn max_failures(mut self, max_failures: u32) -> Self {
250 self.max_failures = max_failures;
251 self
252 }
253
254 pub fn add_agent(mut self, agent: BoxAgent<'a>, provider_name: String, model_name: String) -> Self {
261 self.agents.push((agent, provider_name, model_name));
262 self
263 }
264
265 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, provider_name: &str, model_name: &str) -> Self {
274 self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
275 self
276 }
277
278 pub fn build(self) -> RandAgent<'a> {
280 RandAgent::with_max_failures(self.agents, self.max_failures)
281 }
282}
283
284impl<'a> Default for RandAgentBuilder<'a> {
285 fn default() -> Self {
286 Self::new()
287 }
288}