rig_extra/
thread_safe_rand_agent.rs

1//! ## 多线程使用示例
2//!
3//! ```rust
4//! use rig_extra::extra_providers::{bigmodel::Client};
5//! use rig_extra::thread_safe_rand_agent::ThreadSafeRandAgentBuilder;
6//! use std::sync::Arc;
7//! use tokio::task;
8//! use rig::client::ProviderClient;
9//! 
10//! #[tokio::main]
11//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
12//!     // 创建线程安全的 RandAgent
13//!
14//!     //创建多个客户端 
15//!     let client1 = Client::from_env();
16//!     let client2 = Client::from_env();
17//!     use rig::client::completion::CompletionClientDyn;
18//! 
19//!
20//!     let thread_safe_agent = ThreadSafeRandAgentBuilder::new()
21//!         .max_failures(3)
22//!         .add_agent(client1.agent("glm-4-flash").build(), "bigmodel".to_string(), "glm-4-flash".to_string())
23//!         .add_agent(client2.agent("glm-4-flash").build(), "bigmodel".to_string(), "glm-4-flash".to_string())
24//!         .build();
25//!
26//!     let agent_arc = Arc::new(thread_safe_agent);
27//!
28//!     // 创建多个并发任务
29//!     let mut handles = vec![];
30//!     for i in 0..5 {
31//!         let agent_clone = Arc::clone(&agent_arc);
32//!         let handle = task::spawn(async move {
33//!             let response = agent_clone.prompt(&format!("Hello from task {}", i)).await?;
34//!             println!("Task {} response: {}", i, response);
35//!             Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
36//!         });
37//!         handles.push(handle);
38//!     }
39//!
40//!     // 等待所有任务完成
41//!     for handle in handles {
42//!         handle.await??;
43//!     }
44//!
45//!     Ok(())
46//! }
47//! ```
48
49
50use std::sync::{Arc, Mutex};
51use rand::Rng;
52use rig::agent::Agent;
53use rig::client::builder::BoxAgent;
54use rig::client::completion::CompletionModelHandle;
55use rig::completion::Prompt;
56
57/// 线程安全的 RandAgent,支持多线程并发访问
58pub struct ThreadSafeRandAgent {
59    agents: Arc<Mutex<Vec<ThreadSafeAgentState>>>,
60}
61
62/// 线程安全的 Agent 状态
63pub struct ThreadSafeAgentState {
64    agent: Arc<BoxAgent<'static>>,
65    provider: String,
66    model: String,
67    failure_count: u32,
68    max_failures: u32,
69}
70
71impl ThreadSafeAgentState {
72    fn new(agent: BoxAgent<'static>, provider: String, model: String, max_failures: u32) -> Self {
73        Self {
74            agent: Arc::new(agent),
75            provider,
76            model,
77            failure_count: 0,
78            max_failures,
79        }
80    }
81
82    fn is_valid(&self) -> bool {
83        self.failure_count < self.max_failures
84    }
85
86    fn record_failure(&mut self) {
87        self.failure_count += 1;
88    }
89
90    fn record_success(&mut self) {
91        self.failure_count = 0;
92    }
93}
94
95impl ThreadSafeRandAgent {
96    /// 创建新的线程安全 RandAgent
97    pub fn new(agents: Vec<(BoxAgent<'static>, String, String)>) -> Self {
98        Self::with_max_failures(agents, 3)
99    }
100
101    /// 使用自定义最大失败次数创建线程安全 RandAgent
102    pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, String, String)>, max_failures: u32) -> Self {
103        let agent_states = agents
104            .into_iter()
105            .map(|(agent, provider, model)| ThreadSafeAgentState::new(agent, provider, model, max_failures))
106            .collect();
107        Self {
108            agents: Arc::new(Mutex::new(agent_states)),
109        }
110    }
111
112    /// 添加代理到集合中
113    pub fn add_agent(&self, agent: BoxAgent<'static>, provider: String, model: String) {
114        let mut agents = self.agents.lock().unwrap();
115        agents.push(ThreadSafeAgentState::new(agent, provider, model, 3));
116    }
117
118    /// 使用自定义最大失败次数添加代理
119    pub fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, provider: String, model: String, max_failures: u32) {
120        let mut agents = self.agents.lock().unwrap();
121        agents.push(ThreadSafeAgentState::new(agent, provider, model, max_failures));
122    }
123
124    /// 获取有效代理数量
125    pub fn len(&self) -> usize {
126        let agents = self.agents.lock().unwrap();
127        agents.iter().filter(|state| state.is_valid()).count()
128    }
129
130    /// 获取总代理数量(包括无效的)
131    pub fn total_len(&self) -> usize {
132        let agents = self.agents.lock().unwrap();
133        agents.len()
134    }
135
136    /// 检查是否有有效代理
137    pub fn is_empty(&self) -> bool {
138        self.len() == 0
139    }
140
141
142
143    /// 向随机有效代理发送消息
144    pub async fn prompt(
145        &self,
146        message: &str,
147    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
148        // 第一步:选择代理并获取其信息
149        let (agent_index, provider, model) = {
150            let agents = self.agents.lock().unwrap();
151
152            // 找到所有有效代理的索引
153            let valid_indices: Vec<usize> = agents
154                .iter()
155                .enumerate()
156                .filter(|(_, state)| state.is_valid())
157                .map(|(i, _)| i)
158                .collect();
159
160            if valid_indices.is_empty() {
161                return Err("No valid agents available".into());
162            }
163
164            // 随机选择一个有效代理
165            let mut rng = rand::rng();
166            let random_index = rng.random_range(0..valid_indices.len());
167            let agent_index = valid_indices[random_index];
168
169            // 获取代理信息
170            let agent_state = &agents[agent_index];
171            let provider = agent_state.provider.clone();
172            let model = agent_state.model.clone();
173
174            (agent_index, provider, model)
175        };
176
177        // 打印使用的 provider 和 model
178        tracing::info!("Using provider: {}, model: {}", provider, model);
179
180        // 第二步:执行异步操作(在锁外执行)
181        let result = {
182            // 获取代理的 Arc 克隆以避免在异步操作中持有锁
183            let agent = {
184                let agents = self.agents.lock().unwrap();
185                Arc::clone(&agents[agent_index].agent)
186            };
187
188            // 在锁外执行异步操作
189            agent.prompt(message).await.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
190        };
191
192        // 第三步:根据结果更新失败计数
193        match &result {
194            Ok(_) => {
195                let mut agents = self.agents.lock().unwrap();
196                agents[agent_index].record_success();
197            }
198            Err(_) => {
199                let mut agents = self.agents.lock().unwrap();
200                agents[agent_index].record_failure();
201            }
202        }
203
204        result
205    }
206
207    /// 获取所有代理(用于调试或检查)
208    pub fn agents(&self) -> Vec<(String, String, u32, u32)> {
209        let agents = self.agents.lock().unwrap();
210        agents
211            .iter()
212            .map(|state| (
213                state.provider.clone(),
214                state.model.clone(),
215                state.failure_count,
216                state.max_failures
217            ))
218            .collect()
219    }
220
221    /// 获取失败统计
222    pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
223        let agents = self.agents.lock().unwrap();
224        agents
225            .iter()
226            .enumerate()
227            .map(|(i, state)| (i, state.failure_count, state.max_failures))
228            .collect()
229    }
230
231    /// 重置所有代理的失败计数
232    pub fn reset_failures(&self) {
233        let mut agents = self.agents.lock().unwrap();
234        for state in agents.iter_mut() {
235            state.failure_count = 0;
236        }
237    }
238}
239
240// 实现 Send + Sync trait
241unsafe impl Send for ThreadSafeRandAgent {}
242unsafe impl Sync for ThreadSafeRandAgent {}
243
244
245/// 线程安全 RandAgent 的构建器
246pub struct ThreadSafeRandAgentBuilder {
247    agents: Vec<(BoxAgent<'static>, String, String)>,
248    max_failures: u32,
249}
250
251impl ThreadSafeRandAgentBuilder {
252    /// 创建新的 ThreadSafeRandAgentBuilder
253    pub fn new() -> Self {
254        Self {
255            agents: Vec::new(),
256            max_failures: 3, // 默认最大失败次数
257        }
258    }
259
260    /// 设置连续失败的最大次数,超过后标记代理为无效
261    pub fn max_failures(mut self, max_failures: u32) -> Self {
262        self.max_failures = max_failures;
263        self
264    }
265
266    /// 添加代理到构建器
267    ///
268    /// # 参数
269    /// - agent: 代理实例(需要是 'static 生命周期)
270    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
271    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
272    pub fn add_agent(mut self, agent: BoxAgent<'static>, provider_name: String, model_name: String) -> Self {
273        self.agents.push((agent, provider_name, model_name));
274        self
275    }
276
277    /// 从 AgentBuilder 添加代理
278    ///
279    /// # 参数
280    /// - builder: AgentBuilder 实例(需要是 'static 生命周期)
281    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
282    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
283    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, provider_name: &str, model_name: &str) -> Self {
284        self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
285        self
286    }
287
288    /// 构建 ThreadSafeRandAgent
289    pub fn build(self) -> ThreadSafeRandAgent {
290        ThreadSafeRandAgent::with_max_failures(self.agents, self.max_failures)
291    }
292}
293
294impl Default for ThreadSafeRandAgentBuilder {
295    fn default() -> Self {
296        Self::new()
297    }
298}