rig_extra/
thread_safe_rand_agent.rs1use std::sync::{Arc, Mutex};
51use rand::Rng;
52use rig::agent::Agent;
53use rig::client::builder::BoxAgent;
54use rig::client::completion::CompletionModelHandle;
55use rig::completion::Prompt;
56
57use crate::error::RandAgentError;
58
59pub struct ThreadSafeRandAgent {
61 agents: Arc<Mutex<Vec<ThreadSafeAgentState>>>,
62}
63
64pub struct ThreadSafeAgentState {
66 agent: Arc<BoxAgent<'static>>,
67 provider: String,
68 model: String,
69 failure_count: u32,
70 max_failures: u32,
71}
72
73impl ThreadSafeAgentState {
74 fn new(agent: BoxAgent<'static>, provider: String, model: String, max_failures: u32) -> Self {
75 Self {
76 agent: Arc::new(agent),
77 provider,
78 model,
79 failure_count: 0,
80 max_failures,
81 }
82 }
83
84 fn is_valid(&self) -> bool {
85 self.failure_count < self.max_failures
86 }
87
88 fn record_failure(&mut self) {
89 self.failure_count += 1;
90 }
91
92 fn record_success(&mut self) {
93 self.failure_count = 0;
94 }
95}
96
97impl ThreadSafeRandAgent {
98 pub fn new(agents: Vec<(BoxAgent<'static>, String, String)>) -> Self {
100 Self::with_max_failures(agents, 3)
101 }
102
103 pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, String, String)>, max_failures: u32) -> Self {
105 let agent_states = agents
106 .into_iter()
107 .map(|(agent, provider, model)| ThreadSafeAgentState::new(agent, provider, model, max_failures))
108 .collect();
109 Self {
110 agents: Arc::new(Mutex::new(agent_states)),
111 }
112 }
113
114 pub fn add_agent(&self, agent: BoxAgent<'static>, provider: String, model: String) {
116 let mut agents = self.agents.lock().unwrap();
117 agents.push(ThreadSafeAgentState::new(agent, provider, model, 3));
118 }
119
120 pub fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, provider: String, model: String, max_failures: u32) {
122 let mut agents = self.agents.lock().unwrap();
123 agents.push(ThreadSafeAgentState::new(agent, provider, model, max_failures));
124 }
125
126 pub fn len(&self) -> usize {
128 let agents = self.agents.lock().unwrap();
129 agents.iter().filter(|state| state.is_valid()).count()
130 }
131
132 pub fn total_len(&self) -> usize {
134 let agents = self.agents.lock().unwrap();
135 agents.len()
136 }
137
138 pub fn is_empty(&self) -> bool {
140 self.len() == 0
141 }
142
143
144
145 pub async fn prompt(
147 &self,
148 message: &str,
149 ) -> Result<String, RandAgentError> {
150 let (agent_index, provider, model) = {
152 let agents = self.agents.lock().unwrap();
153
154 let valid_indices: Vec<usize> = agents
156 .iter()
157 .enumerate()
158 .filter(|(_, state)| state.is_valid())
159 .map(|(i, _)| i)
160 .collect();
161
162 if valid_indices.is_empty() {
163 return Err(RandAgentError::NoValidAgents);
164 }
165
166 let mut rng = rand::rng();
168 let random_index = rng.random_range(0..valid_indices.len());
169 let agent_index = valid_indices[random_index];
170
171 let agent_state = &agents[agent_index];
173 let provider = agent_state.provider.clone();
174 let model = agent_state.model.clone();
175
176 (agent_index, provider, model)
177 };
178
179 tracing::info!("Using provider: {}, model: {}", provider, model);
181
182 let result = {
184 let agent = {
186 let agents = self.agents.lock().unwrap();
187 Arc::clone(&agents[agent_index].agent)
188 };
189
190 agent.prompt(message).await.map_err(|e| RandAgentError::AgentError(Box::new(e)))
192 };
193
194 match &result {
196 Ok(_) => {
197 let mut agents = self.agents.lock().unwrap();
198 agents[agent_index].record_success();
199 }
200 Err(_) => {
201 let mut agents = self.agents.lock().unwrap();
202 agents[agent_index].record_failure();
203 }
204 }
205
206 result
207 }
208
209 pub fn agents(&self) -> Vec<(String, String, u32, u32)> {
211 let agents = self.agents.lock().unwrap();
212 agents
213 .iter()
214 .map(|state| (
215 state.provider.clone(),
216 state.model.clone(),
217 state.failure_count,
218 state.max_failures
219 ))
220 .collect()
221 }
222
223 pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
225 let agents = self.agents.lock().unwrap();
226 agents
227 .iter()
228 .enumerate()
229 .map(|(i, state)| (i, state.failure_count, state.max_failures))
230 .collect()
231 }
232
233 pub fn reset_failures(&self) {
235 let mut agents = self.agents.lock().unwrap();
236 for state in agents.iter_mut() {
237 state.failure_count = 0;
238 }
239 }
240}
241
242unsafe impl Send for ThreadSafeRandAgent {}
244unsafe impl Sync for ThreadSafeRandAgent {}
245
246
247pub struct ThreadSafeRandAgentBuilder {
249 agents: Vec<(BoxAgent<'static>, String, String)>,
250 max_failures: u32,
251}
252
253impl ThreadSafeRandAgentBuilder {
254 pub fn new() -> Self {
256 Self {
257 agents: Vec::new(),
258 max_failures: 3, }
260 }
261
262 pub fn max_failures(mut self, max_failures: u32) -> Self {
264 self.max_failures = max_failures;
265 self
266 }
267
268 pub fn add_agent(mut self, agent: BoxAgent<'static>, provider_name: String, model_name: String) -> Self {
275 self.agents.push((agent, provider_name, model_name));
276 self
277 }
278
279 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, provider_name: &str, model_name: &str) -> Self {
286 self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
287 self
288 }
289
290 pub fn build(self) -> ThreadSafeRandAgent {
292 ThreadSafeRandAgent::with_max_failures(self.agents, self.max_failures)
293 }
294}
295
296impl Default for ThreadSafeRandAgentBuilder {
297 fn default() -> Self {
298 Self::new()
299 }
300}