rig_extra/
thread_safe_rand_agent.rs

1//! ## 多线程使用示例
2//!
3//! ```rust
4//! use rig_extra::extra_providers::{bigmodel::Client};
5//! use rig_extra::thread_safe_rand_agent::ThreadSafeRandAgentBuilder;
6//! use std::sync::Arc;
7//! use tokio::task;
8//! use rig::client::ProviderClient;
9//! use rig_extra::error::RandAgentError;
10//! #[tokio::main]
11//! async fn main() -> Result<(), RandAgentError> {
12//!     // 创建线程安全的 RandAgent
13//!
14//!     //创建多个客户端 
15//!     let client1 = Client::from_env();
16//!     let client2 = Client::from_env();
17//!     use rig::client::completion::CompletionClientDyn;
18//! 
19//!
20//!     let thread_safe_agent = ThreadSafeRandAgentBuilder::new()
21//!         .max_failures(3)
22//!         .add_agent(client1.agent("glm-4-flash").build(), "bigmodel".to_string(), "glm-4-flash".to_string())
23//!         .add_agent(client2.agent("glm-4-flash").build(), "bigmodel".to_string(), "glm-4-flash".to_string())
24//!         .build();
25//!
26//!     let agent_arc = Arc::new(thread_safe_agent);
27//!
28//!     // 创建多个并发任务
29//!     let mut handles = vec![];
30//!     for i in 0..5 {
31//!         let agent_clone = Arc::clone(&agent_arc);
32//!         let handle = task::spawn(async move {
33//!             let response = agent_clone.prompt(&format!("Hello from task {}", i)).await?;
34//!             println!("Task {} response: {}", i, response);
35//!             Ok::<(), RandAgentError>(())
36//!         });
37//!         handles.push(handle);
38//!     }
39//!
40//!     // 等待所有任务完成
41//!     for handle in handles {
42//!         handle.await??;
43//!     }
44//!
45//!     Ok(())
46//! }
47//! ```
48
49
50use std::sync::{Arc, Mutex};
51use rand::Rng;
52use rig::agent::Agent;
53use rig::client::builder::BoxAgent;
54use rig::client::completion::CompletionModelHandle;
55use rig::completion::Prompt;
56
57use crate::error::RandAgentError;
58
59/// 线程安全的 RandAgent,支持多线程并发访问
60pub struct ThreadSafeRandAgent {
61    agents: Arc<Mutex<Vec<ThreadSafeAgentState>>>,
62}
63
64/// 线程安全的 Agent 状态
65pub struct ThreadSafeAgentState {
66    agent: Arc<BoxAgent<'static>>,
67    provider: String,
68    model: String,
69    failure_count: u32,
70    max_failures: u32,
71}
72
73impl ThreadSafeAgentState {
74    fn new(agent: BoxAgent<'static>, provider: String, model: String, max_failures: u32) -> Self {
75        Self {
76            agent: Arc::new(agent),
77            provider,
78            model,
79            failure_count: 0,
80            max_failures,
81        }
82    }
83
84    fn is_valid(&self) -> bool {
85        self.failure_count < self.max_failures
86    }
87
88    fn record_failure(&mut self) {
89        self.failure_count += 1;
90    }
91
92    fn record_success(&mut self) {
93        self.failure_count = 0;
94    }
95}
96
97impl ThreadSafeRandAgent {
98    /// 创建新的线程安全 RandAgent
99    pub fn new(agents: Vec<(BoxAgent<'static>, String, String)>) -> Self {
100        Self::with_max_failures(agents, 3)
101    }
102
103    /// 使用自定义最大失败次数创建线程安全 RandAgent
104    pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, String, String)>, max_failures: u32) -> Self {
105        let agent_states = agents
106            .into_iter()
107            .map(|(agent, provider, model)| ThreadSafeAgentState::new(agent, provider, model, max_failures))
108            .collect();
109        Self {
110            agents: Arc::new(Mutex::new(agent_states)),
111        }
112    }
113
114    /// 添加代理到集合中
115    pub fn add_agent(&self, agent: BoxAgent<'static>, provider: String, model: String) {
116        let mut agents = self.agents.lock().unwrap();
117        agents.push(ThreadSafeAgentState::new(agent, provider, model, 3));
118    }
119
120    /// 使用自定义最大失败次数添加代理
121    pub fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, provider: String, model: String, max_failures: u32) {
122        let mut agents = self.agents.lock().unwrap();
123        agents.push(ThreadSafeAgentState::new(agent, provider, model, max_failures));
124    }
125
126    /// 获取有效代理数量
127    pub fn len(&self) -> usize {
128        let agents = self.agents.lock().unwrap();
129        agents.iter().filter(|state| state.is_valid()).count()
130    }
131
132    /// 获取总代理数量(包括无效的)
133    pub fn total_len(&self) -> usize {
134        let agents = self.agents.lock().unwrap();
135        agents.len()
136    }
137
138    /// 检查是否有有效代理
139    pub fn is_empty(&self) -> bool {
140        self.len() == 0
141    }
142
143
144
145    /// 向随机有效代理发送消息
146    pub async fn prompt(
147        &self,
148        message: &str,
149    ) -> Result<String, RandAgentError> {
150        // 第一步:选择代理并获取其信息
151        let (agent_index, provider, model) = {
152            let agents = self.agents.lock().unwrap();
153
154            // 找到所有有效代理的索引
155            let valid_indices: Vec<usize> = agents
156                .iter()
157                .enumerate()
158                .filter(|(_, state)| state.is_valid())
159                .map(|(i, _)| i)
160                .collect();
161
162            if valid_indices.is_empty() {
163                return Err(RandAgentError::NoValidAgents);
164            }
165
166            // 随机选择一个有效代理
167            let mut rng = rand::rng();
168            let random_index = rng.random_range(0..valid_indices.len());
169            let agent_index = valid_indices[random_index];
170
171            // 获取代理信息
172            let agent_state = &agents[agent_index];
173            let provider = agent_state.provider.clone();
174            let model = agent_state.model.clone();
175
176            (agent_index, provider, model)
177        };
178
179        // 打印使用的 provider 和 model
180        tracing::info!("Using provider: {}, model: {}", provider, model);
181
182        // 第二步:执行异步操作(在锁外执行)
183        let result = {
184            // 获取代理的 Arc 克隆以避免在异步操作中持有锁
185            let agent = {
186                let agents = self.agents.lock().unwrap();
187                Arc::clone(&agents[agent_index].agent)
188            };
189
190            // 在锁外执行异步操作
191            agent.prompt(message).await.map_err(|e| RandAgentError::AgentError(Box::new(e)))
192        };
193
194        // 第三步:根据结果更新失败计数
195        match &result {
196            Ok(_) => {
197                let mut agents = self.agents.lock().unwrap();
198                agents[agent_index].record_success();
199            }
200            Err(_) => {
201                let mut agents = self.agents.lock().unwrap();
202                agents[agent_index].record_failure();
203            }
204        }
205
206        result
207    }
208
209    /// 获取所有代理(用于调试或检查)
210    pub fn agents(&self) -> Vec<(String, String, u32, u32)> {
211        let agents = self.agents.lock().unwrap();
212        agents
213            .iter()
214            .map(|state| (
215                state.provider.clone(),
216                state.model.clone(),
217                state.failure_count,
218                state.max_failures
219            ))
220            .collect()
221    }
222
223    /// 获取失败统计
224    pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
225        let agents = self.agents.lock().unwrap();
226        agents
227            .iter()
228            .enumerate()
229            .map(|(i, state)| (i, state.failure_count, state.max_failures))
230            .collect()
231    }
232
233    /// 重置所有代理的失败计数
234    pub fn reset_failures(&self) {
235        let mut agents = self.agents.lock().unwrap();
236        for state in agents.iter_mut() {
237            state.failure_count = 0;
238        }
239    }
240}
241
242// 实现 Send + Sync trait
243unsafe impl Send for ThreadSafeRandAgent {}
244unsafe impl Sync for ThreadSafeRandAgent {}
245
246
247/// 线程安全 RandAgent 的构建器
248pub struct ThreadSafeRandAgentBuilder {
249    agents: Vec<(BoxAgent<'static>, String, String)>,
250    max_failures: u32,
251}
252
253impl ThreadSafeRandAgentBuilder {
254    /// 创建新的 ThreadSafeRandAgentBuilder
255    pub fn new() -> Self {
256        Self {
257            agents: Vec::new(),
258            max_failures: 3, // 默认最大失败次数
259        }
260    }
261
262    /// 设置连续失败的最大次数,超过后标记代理为无效
263    pub fn max_failures(mut self, max_failures: u32) -> Self {
264        self.max_failures = max_failures;
265        self
266    }
267
268    /// 添加代理到构建器
269    ///
270    /// # 参数
271    /// - agent: 代理实例(需要是 'static 生命周期)
272    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
273    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
274    pub fn add_agent(mut self, agent: BoxAgent<'static>, provider_name: String, model_name: String) -> Self {
275        self.agents.push((agent, provider_name, model_name));
276        self
277    }
278
279    /// 从 AgentBuilder 添加代理
280    ///
281    /// # 参数
282    /// - builder: AgentBuilder 实例(需要是 'static 生命周期)
283    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
284    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
285    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, provider_name: &str, model_name: &str) -> Self {
286        self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
287        self
288    }
289
290    /// 构建 ThreadSafeRandAgent
291    pub fn build(self) -> ThreadSafeRandAgent {
292        ThreadSafeRandAgent::with_max_failures(self.agents, self.max_failures)
293    }
294}
295
296impl Default for ThreadSafeRandAgentBuilder {
297    fn default() -> Self {
298        Self::new()
299    }
300}