rig_extra/
thread_safe_rand_agent.rs

1//! ## 多线程使用示例
2//!
3//! ```rust
4//! use rig_extra::extra_providers::{bigmodel::Client};
5//! use rig_extra::thread_safe_rand_agent::ThreadSafeRandAgentBuilder;
6//! use std::sync::Arc;
7//! use tokio::task;
8//! use rig::client::ProviderClient;
9//! use rig_extra::error::RandAgentError;
10//! #[tokio::main]
11//! async fn main() -> Result<(), RandAgentError> {
12//!     // 创建线程安全的 RandAgent
13//!
14//!     //创建多个客户端 
15//!     let client1 = Client::from_env();
16//!     let client2 = Client::from_env();
17//!     use rig::client::completion::CompletionClientDyn;
18//!     use rig::completion::Prompt;
19//!
20//!
21//!     let thread_safe_agent = ThreadSafeRandAgentBuilder::new()
22//!         .max_failures(3)
23//!         .add_agent(client1.agent("glm-4-flash").build(),1, "bigmodel".to_string(), "glm-4-flash".to_string())
24//!         .add_agent(client2.agent("glm-4-flash").build(),2, "bigmodel".to_string(), "glm-4-flash".to_string())
25//!         .build();
26//!
27//!     let agent_arc = Arc::new(thread_safe_agent);
28//!
29//!     // 创建多个并发任务
30//!     let mut handles = vec![];
31//!     for i in 0..5 {
32//!         let agent_clone = Arc::clone(&agent_arc);
33//!         let handle = task::spawn(async move {
34//!             let response = agent_clone.prompt(&format!("Hello from task {}", i)).await?;
35//!             println!("Task {} response: {}", i, response);
36//!             Ok::<(), RandAgentError>(())
37//!         });
38//!         handles.push(handle);
39//!     }
40//!
41//!     // 等待所有任务完成
42//!     for handle in handles {
43//!         handle.await??;
44//!     }
45//!
46//!     Ok(())
47//! }
48//! ```
49
50use std::sync::Arc;
51use rand::Rng;
52use rig::agent::Agent;
53use rig::client::builder::BoxAgent;
54use rig::client::completion::CompletionModelHandle;
55use rig::completion::{Message, Prompt, PromptError};
56use tokio::sync::Mutex;
57
58/// 代理失效回调类型,减少类型复杂度
59pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
60
61/// 推荐使用 ThreadSafeRandAgent,不推荐使用 RandAgent。
62/// RandAgent 已不再维护,ThreadSafeRandAgent 支持多线程并发访问且更安全。
63/// 线程安全的 RandAgent,支持多线程并发访问
64pub struct ThreadSafeRandAgent {
65    agents: Arc<Mutex<Vec<ThreadSafeAgentState>>>,
66    on_agent_invalid: OnAgentInvalidCallback,
67}
68
69/// 线程安全的 Agent 状态
70#[derive(Clone)]
71pub struct ThreadSafeAgentState {
72    pub id: i32,
73    pub agent: Arc<BoxAgent<'static>>,
74    pub provider: String,
75    pub model: String,
76    pub failure_count: u32,
77    pub max_failures: u32,
78}
79
80impl Prompt for ThreadSafeRandAgent {
81    #[allow(refining_impl_trait)]
82    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
83        // 第一步:选择代理并获取其索引
84        let agent_index = self.get_random_valid_agent_index().await
85            .ok_or(PromptError::MaxDepthError {
86                max_depth: 0,
87                chat_history: vec![],
88                prompt: "没有有效agent".into(),
89            })?;
90
91        // 第二步:加锁并获取可变引用
92        let mut agents = self.agents.lock().await;
93        let agent_state = &mut agents[agent_index];
94
95        tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
96        match agent_state.agent.prompt(prompt).await {
97            Ok(content) => {
98                agent_state.record_success();
99                Ok(content)
100            }
101            Err(e) => {
102                agent_state.record_failure();
103                if !agent_state.is_valid() {
104                    if let Some(cb) = &self.on_agent_invalid {
105                        cb(agent_state.id);
106                    }
107                }
108                Err(e)
109            }
110        }
111    }
112}
113
114impl ThreadSafeAgentState {
115    fn new(agent: BoxAgent<'static>,id: i32, provider: String, model: String, max_failures: u32) -> Self {
116        Self {
117            id,
118            agent: Arc::new(agent),
119            provider,
120            model,
121            failure_count: 0,
122            max_failures,
123        }
124    }
125
126    fn is_valid(&self) -> bool {
127        self.failure_count < self.max_failures
128    }
129
130    fn record_failure(&mut self) {
131        self.failure_count += 1;
132    }
133
134    fn record_success(&mut self) {
135        self.failure_count = 0;
136    }
137}
138
139impl ThreadSafeRandAgent {
140    /// 创建新的线程安全 RandAgent
141    pub fn new(agents: Vec<(BoxAgent<'static>, i32, String, String)>) -> Self {
142        Self::with_max_failures_and_callback(agents, 3, None)
143    }
144
145    /// 使用自定义最大失败次数和回调创建线程安全 RandAgent
146    pub fn with_max_failures_and_callback(
147        agents: Vec<(BoxAgent<'static>, i32, String, String)>,
148        max_failures: u32,
149        on_agent_invalid: OnAgentInvalidCallback,
150    ) -> Self {
151        let agent_states = agents
152            .into_iter()
153            .map(|(agent, id, provider, model)| ThreadSafeAgentState::new(agent, id, provider, model, max_failures))
154            .collect();
155        Self {
156            agents: Arc::new(Mutex::new(agent_states)),
157            on_agent_invalid,
158        }
159    }
160
161    /// 使用自定义最大失败次数创建线程安全 RandAgent
162    pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, i32, String, String)>, max_failures: u32) -> Self {
163        Self::with_max_failures_and_callback(agents, max_failures, None)
164    }
165
166    /// 设置 agent 失效时的回调
167    pub fn set_on_agent_invalid<F>(&mut self, callback: F)
168    where
169        F: Fn(i32) + Send + Sync + 'static,
170    {
171        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
172    }
173
174    /// 添加代理到集合中
175    pub async fn add_agent(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String) {
176        let mut agents = self.agents.lock().await;
177        agents.push(ThreadSafeAgentState::new(agent, id, provider, model, 3));
178    }
179
180    /// 使用自定义最大失败次数添加代理
181    pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String, max_failures: u32) {
182        let mut agents = self.agents.lock().await;
183        agents.push(ThreadSafeAgentState::new(agent, id, provider, model, max_failures));
184    }
185
186    /// 获取有效代理数量
187    pub async fn len(&self) -> usize {
188        let agents = self.agents.lock().await;
189        agents.iter().filter(|state| state.is_valid()).count()
190    }
191    
192    /// 从集合中获取一个随机有效代理的索引
193    pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
194        let agents = self.agents.lock().await;
195        let valid_indices: Vec<usize> = agents
196            .iter()
197            .enumerate()
198            .filter(|(_, state)| state.is_valid())
199            .map(|(i, _)| i)
200            .collect();
201
202        if valid_indices.is_empty() {
203            return None;
204        }
205
206        let mut rng = rand::rng();
207        let random_index = rng.random_range(0..valid_indices.len());
208        Some(valid_indices[random_index])
209    }
210
211    /// 从集合中获取一个随机有效代理
212    /// 注意: 并不会增加失败计数
213    pub async fn get_random_valid_agent_state(&self) -> Option<ThreadSafeAgentState> {
214        let mut agents = self.agents.lock().await;
215
216        let valid_indices: Vec<usize> = agents
217            .iter()
218            .enumerate()
219            .filter(|(_, state)| state.is_valid())
220            .map(|(i, _)| i)
221            .collect();
222
223        if valid_indices.is_empty() {
224            return None;
225        }
226
227        let mut rng = rand::rng();
228        let random_index = rng.random_range(0..valid_indices.len());
229        let agent_index = valid_indices[random_index];
230        agents.get_mut(agent_index).cloned()
231    }
232    
233
234    /// 获取总代理数量(包括无效的)
235    pub async fn total_len(&self) -> usize {
236        let agents = self.agents.lock().await;
237        agents.len()
238    }
239
240    /// 检查是否有有效代理
241    pub async fn is_empty(&self) -> bool {
242        self.len().await == 0
243    }
244    
245    /// 获取所有代理(用于调试或检查)
246    pub async fn agents(&self) -> Vec<(String, String, u32, u32)> {
247        let agents = self.agents.lock().await;
248        agents
249            .iter()
250            .map(|state| (
251                state.provider.clone(),
252                state.model.clone(),
253                state.failure_count,
254                state.max_failures
255            ))
256            .collect()
257    }
258
259    /// 获取失败统计
260    pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
261        let agents = self.agents.lock().await;
262        agents
263            .iter()
264            .enumerate()
265            .map(|(i, state)| (i, state.failure_count, state.max_failures))
266            .collect()
267    }
268
269    /// 重置所有代理的失败计数
270    pub async fn reset_failures(&self) {
271        let mut agents = self.agents.lock().await;
272        for state in agents.iter_mut() {
273            state.failure_count = 0;
274        }
275    }
276}
277
278/// 线程安全 RandAgent 的构建器
279pub struct ThreadSafeRandAgentBuilder {
280    agents: Vec<(BoxAgent<'static>, i32, String, String)>,
281    max_failures: u32,
282    on_agent_invalid: OnAgentInvalidCallback,
283}
284
285impl ThreadSafeRandAgentBuilder {
286    /// 创建新的 ThreadSafeRandAgentBuilder
287    pub fn new() -> Self {
288        Self {
289            agents: Vec::new(),
290            max_failures: 3, // 默认最大失败次数
291            on_agent_invalid: None,
292        }
293    }
294
295    /// 设置连续失败的最大次数,超过后标记代理为无效
296    pub fn max_failures(mut self, max_failures: u32) -> Self {
297        self.max_failures = max_failures;
298        self
299    }
300
301    /// 设置 agent 失效时的回调
302    pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
303    where
304        F: Fn(i32) + Send + Sync + 'static,
305    {
306        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
307        self
308    }
309
310    /// 添加代理到构建器
311    ///
312    /// # 参数
313    /// - agent: 代理实例(需要是 'static 生命周期)
314    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
315    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
316    pub fn add_agent(mut self, agent: BoxAgent<'static>, id: i32, provider_name: String, model_name: String) -> Self {
317        self.agents.push((agent, id, provider_name, model_name));
318        self
319    }
320
321    /// 从 AgentBuilder 添加代理
322    ///
323    /// # 参数
324    /// - builder: AgentBuilder 实例(需要是 'static 生命周期)
325    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
326    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
327    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, id: i32, provider_name: &str, model_name: &str) -> Self {
328        self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
329        self
330    }
331
332    /// 构建 ThreadSafeRandAgent
333    pub fn build(self) -> ThreadSafeRandAgent {
334        ThreadSafeRandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
335    }
336}
337
338impl Default for ThreadSafeRandAgentBuilder {
339    fn default() -> Self {
340        Self::new()
341    }
342}