rig_extra/
thread_safe_rand_agent.rs

1//! ## 多线程使用示例
2//!
3//! ```rust
4//! use rig_extra::extra_providers::{bigmodel::Client};
5//! use rig_extra::thread_safe_rand_agent::ThreadSafeRandAgentBuilder;
6//! use std::sync::Arc;
7//! use tokio::task;
8//! use rig::client::ProviderClient;
9//! use rig_extra::error::RandAgentError;
10//! #[tokio::main]
11//! async fn main() -> Result<(), RandAgentError> {
12//!     // 创建线程安全的 RandAgent
13//!
14//!     //创建多个客户端 
15//!     let client1 = Client::from_env();
16//!     let client2 = Client::from_env();
17//!     use rig::client::completion::CompletionClientDyn;
18//!     use rig::completion::Prompt;
19//!
20//!
21//!     let thread_safe_agent = ThreadSafeRandAgentBuilder::new()
22//!         .max_failures(3)
23//!         .add_agent(client1.agent("glm-4-flash").build(),1, "bigmodel".to_string(), "glm-4-flash".to_string())
24//!         .add_agent(client2.agent("glm-4-flash").build(),2, "bigmodel".to_string(), "glm-4-flash".to_string())
25//!         .build();
26//!
27//!     let agent_arc = Arc::new(thread_safe_agent);
28//!
29//!     // 创建多个并发任务
30//!     let mut handles = vec![];
31//!     for i in 0..5 {
32//!         let agent_clone = Arc::clone(&agent_arc);
33//!         let handle = task::spawn(async move {
34//!             let response = agent_clone.prompt(&format!("Hello from task {}", i)).await?;
35//!             println!("Task {} response: {}", i, response);
36//!             Ok::<(), RandAgentError>(())
37//!         });
38//!         handles.push(handle);
39//!     }
40//!
41//!     // 等待所有任务完成
42//!     for handle in handles {
43//!         handle.await??;
44//!     }
45//!
46//!     Ok(())
47//! }
48//! ```
49
50use std::sync::Arc;
51use rand::Rng;
52use rig::agent::Agent;
53use rig::client::builder::BoxAgent;
54use rig::client::completion::CompletionModelHandle;
55use rig::completion::{Message, Prompt, PromptError};
56use tokio::sync::Mutex;
57use crate::AgentInfo;
58
59/// 代理失效回调类型,减少类型复杂度
60pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
61
62/// 推荐使用 ThreadSafeRandAgent,不推荐使用 RandAgent。
63/// RandAgent 已不再维护,ThreadSafeRandAgent 支持多线程并发访问且更安全。
64/// 线程安全的 RandAgent,支持多线程并发访问
65pub struct ThreadSafeRandAgent {
66    agents: Arc<Mutex<Vec<ThreadSafeAgentState>>>,
67    on_agent_invalid: OnAgentInvalidCallback,
68}
69
70/// 线程安全的 Agent 状态
71#[derive(Clone)]
72pub struct ThreadSafeAgentState {
73    pub id: i32,
74    pub agent: Arc<BoxAgent<'static>>,
75    pub info: AgentInfo,
76}
77
78impl Prompt for ThreadSafeRandAgent {
79    #[allow(refining_impl_trait)]
80    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
81        // 第一步:选择代理并获取其索引
82        let agent_index = self.get_random_valid_agent_index().await
83            .ok_or(PromptError::MaxDepthError {
84                max_depth: 0,
85                chat_history: vec![],
86                prompt: "没有有效agent".into(),
87            })?;
88
89        // 第二步:加锁并获取可变引用
90        let mut agents = self.agents.lock().await;
91        let agent_state = &mut agents[agent_index];
92
93        tracing::info!("Using provider: {}, model: {}", agent_state.info.provider, agent_state.info.model);
94        match agent_state.agent.prompt(prompt).await {
95            Ok(content) => {
96                agent_state.record_success();
97                Ok(content)
98            }
99            Err(e) => {
100                agent_state.record_failure();
101                if !agent_state.is_valid() {
102                    if let Some(cb) = &self.on_agent_invalid {
103                        cb(agent_state.id);
104                    }
105                }
106                Err(e)
107            }
108        }
109    }
110}
111
112impl ThreadSafeAgentState {
113    fn new(agent: BoxAgent<'static>,id: i32, provider: String, model: String, max_failures: u32) -> Self {
114        Self {
115            id,
116            agent: Arc::new(agent),
117            info: AgentInfo{
118                id,
119                provider,
120                model,
121                failure_count: 0,
122                max_failures,
123            }
124        }
125    }
126
127    fn is_valid(&self) -> bool {
128        self.info.failure_count < self.info.max_failures
129    }
130
131    fn record_failure(&mut self) {
132        self.info.failure_count += 1;
133    }
134
135    fn record_success(&mut self) {
136        self.info.failure_count = 0;
137    }
138}
139
140impl ThreadSafeRandAgent {
141    /// 创建新的线程安全 RandAgent
142    pub fn new(agents: Vec<(BoxAgent<'static>, i32, String, String)>) -> Self {
143        Self::with_max_failures_and_callback(agents, 3, None)
144    }
145
146    /// 使用自定义最大失败次数和回调创建线程安全 RandAgent
147    pub fn with_max_failures_and_callback(
148        agents: Vec<(BoxAgent<'static>, i32, String, String)>,
149        max_failures: u32,
150        on_agent_invalid: OnAgentInvalidCallback,
151    ) -> Self {
152        let agent_states = agents
153            .into_iter()
154            .map(|(agent, id, provider, model)| ThreadSafeAgentState::new(agent, id, provider, model, max_failures))
155            .collect();
156        Self {
157            agents: Arc::new(Mutex::new(agent_states)),
158            on_agent_invalid,
159        }
160    }
161
162    /// 使用自定义最大失败次数创建线程安全 RandAgent
163    pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, i32, String, String)>, max_failures: u32) -> Self {
164        Self::with_max_failures_and_callback(agents, max_failures, None)
165    }
166
167    /// 设置 agent 失效时的回调
168    pub fn set_on_agent_invalid<F>(&mut self, callback: F)
169    where
170        F: Fn(i32) + Send + Sync + 'static,
171    {
172        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
173    }
174
175    /// 添加代理到集合中
176    pub async fn add_agent(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String) {
177        let mut agents = self.agents.lock().await;
178        agents.push(ThreadSafeAgentState::new(agent, id, provider, model, 3));
179    }
180
181    /// 使用自定义最大失败次数添加代理
182    pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String, max_failures: u32) {
183        let mut agents = self.agents.lock().await;
184        agents.push(ThreadSafeAgentState::new(agent, id, provider, model, max_failures));
185    }
186
187    /// 获取有效代理数量
188    pub async fn len(&self) -> usize {
189        let agents = self.agents.lock().await;
190        agents.iter().filter(|state| state.is_valid()).count()
191    }
192    
193    /// 从集合中获取一个随机有效代理的索引
194    pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
195        let agents = self.agents.lock().await;
196        let valid_indices: Vec<usize> = agents
197            .iter()
198            .enumerate()
199            .filter(|(_, state)| state.is_valid())
200            .map(|(i, _)| i)
201            .collect();
202
203        if valid_indices.is_empty() {
204            return None;
205        }
206
207        let mut rng = rand::rng();
208        let random_index = rng.random_range(0..valid_indices.len());
209        Some(valid_indices[random_index])
210    }
211
212    /// 从集合中获取一个随机有效代理
213    /// 注意: 并不会增加失败计数
214    pub async fn get_random_valid_agent_state(&self) -> Option<ThreadSafeAgentState> {
215        let mut agents = self.agents.lock().await;
216
217        let valid_indices: Vec<usize> = agents
218            .iter()
219            .enumerate()
220            .filter(|(_, state)| state.is_valid())
221            .map(|(i, _)| i)
222            .collect();
223
224        if valid_indices.is_empty() {
225            return None;
226        }
227
228        let mut rng = rand::rng();
229        let random_index = rng.random_range(0..valid_indices.len());
230        let agent_index = valid_indices[random_index];
231        agents.get_mut(agent_index).cloned()
232    }
233    
234
235    /// 获取总代理数量(包括无效的)
236    pub async fn total_len(&self) -> usize {
237        let agents = self.agents.lock().await;
238        agents.len()
239    }
240
241    /// 检查是否有有效代理
242    pub async fn is_empty(&self) -> bool {
243        self.len().await == 0
244    }
245    
246    /// 获取所有代理(用于调试或检查)
247    #[deprecated(since = "0.6.1", note = "Renamed to `get_agent_info`")]
248    pub async fn agents(&self) -> Vec<(String, String, u32, u32)> {
249        let agents = self.agents.lock().await;
250        agents
251            .iter()
252            .map(|state| (
253                state.info.provider.clone(),
254                state.info.model.clone(),
255                state.info.failure_count,
256                state.info.max_failures
257            ))
258            .collect()
259    }
260
261    /// 获取agent info
262    pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
263        let  agents = self.agents.lock().await;
264        let agent_infos = agents.iter()
265            .map(|agent|{
266                agent.info.clone()
267            }).collect::<_>();
268        tracing::info!("agents info: {:?}", agent_infos);
269        agent_infos
270    }
271
272    /// 获取失败统计
273    pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
274        let agents = self.agents.lock().await;
275        agents
276            .iter()
277            .enumerate()
278            .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
279            .collect()
280    }
281
282    /// 重置所有代理的失败计数
283    pub async fn reset_failures(&self) {
284        let mut agents = self.agents.lock().await;
285        for state in agents.iter_mut() {
286            state.info.failure_count = 0;
287        }
288    }
289
290    /// 通过名称获取 agent 
291    pub async fn get_agent_by_name(&self,provider_name: &str, model_name: &str) -> Option<ThreadSafeAgentState> {
292        let mut agents = self.agents.lock().await;
293
294        for agent in agents.iter_mut() {
295            if agent.info.provider == provider_name &&  agent.info.model == model_name {
296                return Some(agent.clone());
297            }
298        }
299
300        None
301    }   
302    
303    /// 通过id获取 agent 
304    pub async fn get_agent_by_id(&self,id:i32) -> Option<ThreadSafeAgentState> {
305        let mut agents = self.agents.lock().await;
306
307        for agent in agents.iter_mut() {
308            if agent.info.id == id {
309                return Some(agent.clone());
310            }
311        }
312
313        None
314    }
315}
316
317/// 线程安全 RandAgent 的构建器
318pub struct ThreadSafeRandAgentBuilder {
319    pub(crate) agents: Vec<(BoxAgent<'static>, i32, String, String)>,
320    max_failures: u32,
321    on_agent_invalid: OnAgentInvalidCallback,
322}
323
324impl ThreadSafeRandAgentBuilder {
325    /// 创建新的 ThreadSafeRandAgentBuilder
326    pub fn new() -> Self {
327        Self {
328            agents: Vec::new(),
329            max_failures: 3, // 默认最大失败次数
330            on_agent_invalid: None,
331        }
332    }
333
334    /// 设置连续失败的最大次数,超过后标记代理为无效
335    pub fn max_failures(mut self, max_failures: u32) -> Self {
336        self.max_failures = max_failures;
337        self
338    }
339
340    /// 设置 agent 失效时的回调
341    pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
342    where
343        F: Fn(i32) + Send + Sync + 'static,
344    {
345        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
346        self
347    }
348
349    /// 添加代理到构建器
350    ///
351    /// # 参数
352    /// - agent: 代理实例(需要是 'static 生命周期)
353    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
354    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
355    pub fn add_agent(mut self, agent: BoxAgent<'static>, id: i32, provider_name: String, model_name: String) -> Self {
356        self.agents.push((agent, id, provider_name, model_name));
357        self
358    }
359
360    /// 从 AgentBuilder 添加代理
361    ///
362    /// # 参数
363    /// - builder: AgentBuilder 实例(需要是 'static 生命周期)
364    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
365    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
366    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, id: i32, provider_name: &str, model_name: &str) -> Self {
367        self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
368        self
369    }
370
371    /// 构建 ThreadSafeRandAgent
372    pub fn build(self) -> ThreadSafeRandAgent {
373        ThreadSafeRandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
374    }
375    
376}
377
378impl Default for ThreadSafeRandAgentBuilder {
379    fn default() -> Self {
380        Self::new()
381    }
382}