1use rand::Rng;
55use rig::agent::{Agent};
56use rig::client::builder::BoxAgent;
57use rig::completion::Prompt;
58use rig::client::completion::CompletionModelHandle;
59
60
61pub struct AgentState<'a> {
63 agent: BoxAgent<'a>,
64 provider: String,
65 model: String,
66 failure_count: u32,
67 max_failures: u32,
68}
69
70impl<'a> AgentState<'a> {
71 fn new(agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) -> Self {
72 Self {
73 agent,
74 provider,
75 model,
76 failure_count: 0,
77 max_failures,
78 }
79 }
80
81 fn is_valid(&self) -> bool {
82 self.failure_count < self.max_failures
83 }
84
85 fn record_failure(&mut self) {
86 self.failure_count += 1;
87 }
88
89 fn record_success(&mut self) {
90 self.failure_count = 0;
91 }
92}
93
94pub struct RandAgent<'a> {
96 agents: Vec<AgentState<'a>>,
97}
98
99impl<'a> RandAgent<'a> {
100 pub fn new(agents: Vec<(BoxAgent<'a>, String, String)>) -> Self {
102 Self::with_max_failures(agents, 3) }
104
105 pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, String, String)>, max_failures: u32) -> Self {
107 let agent_states = agents
108 .into_iter()
109 .map(|(agent, provider, model)| AgentState::new(agent, provider, model, max_failures))
110 .collect();
111 Self {
112 agents: agent_states,
113 }
114 }
115
116
117 pub fn add_agent(&mut self, agent: BoxAgent<'a>, provider: String, model: String) {
119 self.agents.push(AgentState::new(agent, provider, model, 3)); }
121
122 pub fn add_agent_with_max_failures(&mut self, agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) {
124 self.agents.push(AgentState::new(agent, provider, model, max_failures));
125 }
126
127 pub fn len(&self) -> usize {
129 self.agents.iter().filter(|state| state.is_valid()).count()
130 }
131
132 pub fn total_len(&self) -> usize {
134 self.agents.len()
135 }
136
137 pub fn is_empty(&self) -> bool {
139 self.len() == 0
140 }
141
142 async fn get_random_valid_agent(&mut self) -> Option<&mut AgentState<'a>> {
144 let valid_indices: Vec<usize> = self
145 .agents
146 .iter()
147 .enumerate()
148 .filter(|(_, state)| state.is_valid())
149 .map(|(i, _)| i)
150 .collect();
151
152 if valid_indices.is_empty() {
153 return None;
154 }
155
156 let mut rng = rand::rng();
157 let random_index = rng.random_range(0..valid_indices.len());
158 let agent_index = valid_indices[random_index];
159 self.agents.get_mut(agent_index)
160 }
161
162 pub async fn prompt(
164 &mut self,
165 message: &str,
166 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
167 let agent_state = self
168 .get_random_valid_agent()
169 .await
170 .ok_or("No valid agents available")?;
171
172 tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
174 match agent_state.agent.prompt(message).await {
175 Ok(response) => {
176 agent_state.record_success();
177 Ok(response)
178 }
179 Err(e) => {
180 agent_state.record_failure();
181 Err(e.into())
182 }
183 }
184 }
185
186
187 pub fn agents(&self) -> &[AgentState<'a>] {
189 &self.agents
190 }
191
192 pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
194 self.agents
195 .iter()
196 .enumerate()
197 .map(|(i, state)| (i, state.failure_count, state.max_failures))
198 .collect()
199 }
200
201 pub fn reset_failures(&mut self) {
203 for state in &mut self.agents {
204 state.failure_count = 0;
205 }
206 }
207}
208
209
210
211pub struct RandAgentBuilder<'a> {
213 agents: Vec<(BoxAgent<'a>, String, String)>,
214 max_failures: u32,
215}
216
217impl<'a> RandAgentBuilder<'a> {
218 pub fn new() -> Self {
220 Self {
221 agents: Vec::new(),
222 max_failures: 3, }
224 }
225
226 pub fn max_failures(mut self, max_failures: u32) -> Self {
228 self.max_failures = max_failures;
229 self
230 }
231
232 pub fn add_agent(mut self, agent: BoxAgent<'a>, provider_name: String, model_name: String) -> Self {
239 self.agents.push((agent, provider_name, model_name));
240 self
241 }
242
243 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, provider_name: &str, model_name: &str) -> Self {
252 self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
253 self
254 }
255
256 pub fn build(self) -> RandAgent<'a> {
258 RandAgent::with_max_failures(self.agents, self.max_failures)
259 }
260}
261
262impl<'a> Default for RandAgentBuilder<'a> {
263 fn default() -> Self {
264 Self::new()
265 }
266}