rig_extra/
thread_safe_rand_agent.rs1use std::sync::{Arc, Mutex};
51use rand::Rng;
52use rig::agent::Agent;
53use rig::client::builder::BoxAgent;
54use rig::client::completion::CompletionModelHandle;
55use rig::completion::Prompt;
56
57pub struct ThreadSafeRandAgent {
59 agents: Arc<Mutex<Vec<ThreadSafeAgentState>>>,
60}
61
62pub struct ThreadSafeAgentState {
64 agent: Arc<BoxAgent<'static>>,
65 provider: String,
66 model: String,
67 failure_count: u32,
68 max_failures: u32,
69}
70
71impl ThreadSafeAgentState {
72 fn new(agent: BoxAgent<'static>, provider: String, model: String, max_failures: u32) -> Self {
73 Self {
74 agent: Arc::new(agent),
75 provider,
76 model,
77 failure_count: 0,
78 max_failures,
79 }
80 }
81
82 fn is_valid(&self) -> bool {
83 self.failure_count < self.max_failures
84 }
85
86 fn record_failure(&mut self) {
87 self.failure_count += 1;
88 }
89
90 fn record_success(&mut self) {
91 self.failure_count = 0;
92 }
93}
94
95impl ThreadSafeRandAgent {
96 pub fn new(agents: Vec<(BoxAgent<'static>, String, String)>) -> Self {
98 Self::with_max_failures(agents, 3)
99 }
100
101 pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, String, String)>, max_failures: u32) -> Self {
103 let agent_states = agents
104 .into_iter()
105 .map(|(agent, provider, model)| ThreadSafeAgentState::new(agent, provider, model, max_failures))
106 .collect();
107 Self {
108 agents: Arc::new(Mutex::new(agent_states)),
109 }
110 }
111
112 pub fn add_agent(&self, agent: BoxAgent<'static>, provider: String, model: String) {
114 let mut agents = self.agents.lock().unwrap();
115 agents.push(ThreadSafeAgentState::new(agent, provider, model, 3));
116 }
117
118 pub fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, provider: String, model: String, max_failures: u32) {
120 let mut agents = self.agents.lock().unwrap();
121 agents.push(ThreadSafeAgentState::new(agent, provider, model, max_failures));
122 }
123
124 pub fn len(&self) -> usize {
126 let agents = self.agents.lock().unwrap();
127 agents.iter().filter(|state| state.is_valid()).count()
128 }
129
130 pub fn total_len(&self) -> usize {
132 let agents = self.agents.lock().unwrap();
133 agents.len()
134 }
135
136 pub fn is_empty(&self) -> bool {
138 self.len() == 0
139 }
140
141
142
143 pub async fn prompt(
145 &self,
146 message: &str,
147 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
148 let (agent_index, provider, model) = {
150 let agents = self.agents.lock().unwrap();
151
152 let valid_indices: Vec<usize> = agents
154 .iter()
155 .enumerate()
156 .filter(|(_, state)| state.is_valid())
157 .map(|(i, _)| i)
158 .collect();
159
160 if valid_indices.is_empty() {
161 return Err("No valid agents available".into());
162 }
163
164 let mut rng = rand::rng();
166 let random_index = rng.random_range(0..valid_indices.len());
167 let agent_index = valid_indices[random_index];
168
169 let agent_state = &agents[agent_index];
171 let provider = agent_state.provider.clone();
172 let model = agent_state.model.clone();
173
174 (agent_index, provider, model)
175 };
176
177 tracing::info!("Using provider: {}, model: {}", provider, model);
179
180 let result = {
182 let agent = {
184 let agents = self.agents.lock().unwrap();
185 Arc::clone(&agents[agent_index].agent)
186 };
187
188 agent.prompt(message).await.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
190 };
191
192 match &result {
194 Ok(_) => {
195 let mut agents = self.agents.lock().unwrap();
196 agents[agent_index].record_success();
197 }
198 Err(_) => {
199 let mut agents = self.agents.lock().unwrap();
200 agents[agent_index].record_failure();
201 }
202 }
203
204 result
205 }
206
207 pub fn agents(&self) -> Vec<(String, String, u32, u32)> {
209 let agents = self.agents.lock().unwrap();
210 agents
211 .iter()
212 .map(|state| (
213 state.provider.clone(),
214 state.model.clone(),
215 state.failure_count,
216 state.max_failures
217 ))
218 .collect()
219 }
220
221 pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
223 let agents = self.agents.lock().unwrap();
224 agents
225 .iter()
226 .enumerate()
227 .map(|(i, state)| (i, state.failure_count, state.max_failures))
228 .collect()
229 }
230
231 pub fn reset_failures(&self) {
233 let mut agents = self.agents.lock().unwrap();
234 for state in agents.iter_mut() {
235 state.failure_count = 0;
236 }
237 }
238}
239
240unsafe impl Send for ThreadSafeRandAgent {}
242unsafe impl Sync for ThreadSafeRandAgent {}
243
244
245pub struct ThreadSafeRandAgentBuilder {
247 agents: Vec<(BoxAgent<'static>, String, String)>,
248 max_failures: u32,
249}
250
251impl ThreadSafeRandAgentBuilder {
252 pub fn new() -> Self {
254 Self {
255 agents: Vec::new(),
256 max_failures: 3, }
258 }
259
260 pub fn max_failures(mut self, max_failures: u32) -> Self {
262 self.max_failures = max_failures;
263 self
264 }
265
266 pub fn add_agent(mut self, agent: BoxAgent<'static>, provider_name: String, model_name: String) -> Self {
273 self.agents.push((agent, provider_name, model_name));
274 self
275 }
276
277 pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, provider_name: &str, model_name: &str) -> Self {
284 self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
285 self
286 }
287
288 pub fn build(self) -> ThreadSafeRandAgent {
290 ThreadSafeRandAgent::with_max_failures(self.agents, self.max_failures)
291 }
292}
293
294impl Default for ThreadSafeRandAgentBuilder {
295 fn default() -> Self {
296 Self::new()
297 }
298}