1use rand::Rng;
58use rig::agent::{Agent};
59use rig::client::builder::BoxAgent;
60use rig::completion::Prompt;
61use rig::client::completion::CompletionModelHandle;
62use crate::error::RandAgentError;
63
64
65pub struct AgentState<'a> {
67 agent: BoxAgent<'a>,
68 provider: String,
69 model: String,
70 failure_count: u32,
71 max_failures: u32,
72}
73
74impl<'a> AgentState<'a> {
75 fn new(agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) -> Self {
76 Self {
77 agent,
78 provider,
79 model,
80 failure_count: 0,
81 max_failures,
82 }
83 }
84
85 fn is_valid(&self) -> bool {
86 self.failure_count < self.max_failures
87 }
88
89 fn record_failure(&mut self) {
90 self.failure_count += 1;
91 }
92
93 fn record_success(&mut self) {
94 self.failure_count = 0;
95 }
96}
97
98pub struct RandAgent<'a> {
100 agents: Vec<AgentState<'a>>,
101}
102
103impl<'a> RandAgent<'a> {
104 pub fn new(agents: Vec<(BoxAgent<'a>, String, String)>) -> Self {
106 Self::with_max_failures(agents, 3) }
108
109 pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, String, String)>, max_failures: u32) -> Self {
111 let agent_states = agents
112 .into_iter()
113 .map(|(agent, provider, model)| AgentState::new(agent, provider, model, max_failures))
114 .collect();
115 Self {
116 agents: agent_states,
117 }
118 }
119
120
121 pub fn add_agent(&mut self, agent: BoxAgent<'a>, provider: String, model: String) {
123 self.agents.push(AgentState::new(agent, provider, model, 3)); }
125
126 pub fn add_agent_with_max_failures(&mut self, agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) {
128 self.agents.push(AgentState::new(agent, provider, model, max_failures));
129 }
130
131 pub fn len(&self) -> usize {
133 self.agents.iter().filter(|state| state.is_valid()).count()
134 }
135
136 pub fn total_len(&self) -> usize {
138 self.agents.len()
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.len() == 0
144 }
145
146 async fn get_random_valid_agent(&mut self) -> Option<&mut AgentState<'a>> {
148 let valid_indices: Vec<usize> = self
149 .agents
150 .iter()
151 .enumerate()
152 .filter(|(_, state)| state.is_valid())
153 .map(|(i, _)| i)
154 .collect();
155
156 if valid_indices.is_empty() {
157 return None;
158 }
159
160 let mut rng = rand::rng();
161 let random_index = rng.random_range(0..valid_indices.len());
162 let agent_index = valid_indices[random_index];
163 self.agents.get_mut(agent_index)
164 }
165
166 pub async fn prompt(
168 &mut self,
169 message: &str,
170 ) -> Result<String, RandAgentError> {
171 let agent_state = self
172 .get_random_valid_agent()
173 .await
174 .ok_or(RandAgentError::NoValidAgents)?;
175
176 tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
178 match agent_state.agent.prompt(message).await {
179 Ok(response) => {
180 agent_state.record_success();
181 Ok(response)
182 }
183 Err(e) => {
184 agent_state.record_failure();
185 Err(RandAgentError::AgentError(Box::new(e)))
186 }
187 }
188 }
189
190
191 pub fn agents(&self) -> &[AgentState<'a>] {
193 &self.agents
194 }
195
196 pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
198 self.agents
199 .iter()
200 .enumerate()
201 .map(|(i, state)| (i, state.failure_count, state.max_failures))
202 .collect()
203 }
204
205 pub fn reset_failures(&mut self) {
207 for state in &mut self.agents {
208 state.failure_count = 0;
209 }
210 }
211}
212
213
214
215pub struct RandAgentBuilder<'a> {
217 agents: Vec<(BoxAgent<'a>, String, String)>,
218 max_failures: u32,
219}
220
221impl<'a> RandAgentBuilder<'a> {
222 pub fn new() -> Self {
224 Self {
225 agents: Vec::new(),
226 max_failures: 3, }
228 }
229
230 pub fn max_failures(mut self, max_failures: u32) -> Self {
232 self.max_failures = max_failures;
233 self
234 }
235
236 pub fn add_agent(mut self, agent: BoxAgent<'a>, provider_name: String, model_name: String) -> Self {
243 self.agents.push((agent, provider_name, model_name));
244 self
245 }
246
247 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, provider_name: &str, model_name: &str) -> Self {
256 self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
257 self
258 }
259
260 pub fn build(self) -> RandAgent<'a> {
262 RandAgent::with_max_failures(self.agents, self.max_failures)
263 }
264}
265
266impl<'a> Default for RandAgentBuilder<'a> {
267 fn default() -> Self {
268 Self::new()
269 }
270}