rig_extra/
rand_agent.rs

1//! ## 多线程使用示例
2//!
3//! ```rust
4//! use rig_extra::extra_providers::{bigmodel::Client};
5//! use rig_extra::rand_agent::RandAgentBuilder;
6//! use std::sync::Arc;
7//! use tokio::task;
8//! use rig::client::ProviderClient;
9//! use rig_extra::error::RandAgentError;
10//! #[tokio::main]
11//! async fn main() -> Result<(), RandAgentError> {
12//!     // 创建线程安全的 RandAgent
13//!
14//!     //创建多个客户端 
15//!     let client1 = Client::from_env();
16//!     let client2 = Client::from_env();
17//!     use rig::client::completion::CompletionClientDyn;
18//!     use rig::completion::Prompt;
19//!
20//!
21//!     let thread_safe_agent = RandAgentBuilder::new()
22//!         .max_failures(3)
23//!         .add_agent(client1.agent("glm-4-flash").build(),1, "bigmodel".to_string(), "glm-4-flash".to_string())
24//!         .add_agent(client2.agent("glm-4-flash").build(),2, "bigmodel".to_string(), "glm-4-flash".to_string())
25//!         .build();
26//!
27//!     let agent_arc = Arc::new(thread_safe_agent);
28//!
29//!     // 创建多个并发任务
30//!     let mut handles = vec![];
31//!     for i in 0..5 {
32//!         let agent_clone = Arc::clone(&agent_arc);
33//!         let handle = task::spawn(async move {
34//!             let response = agent_clone.prompt(&format!("Hello from task {}", i)).await?;
35//!             println!("Task {} response: {}", i, response);
36//!             Ok::<(), RandAgentError>(())
37//!         });
38//!         handles.push(handle);
39//!     }
40//!
41//!     // 等待所有任务完成
42//!     for handle in handles {
43//!         handle.await??;
44//!     }
45//!
46//!     Ok(())
47//! }
48//! ```
49
50use std::sync::Arc;
51use std::time::Duration;
52use backon::{Retryable, ExponentialBuilder};
53use rand::Rng;
54use rig::agent::Agent;
55use rig::client::builder::BoxAgent;
56use rig::client::completion::CompletionModelHandle;
57use rig::completion::{Message, Prompt, PromptError};
58use tokio::sync::Mutex;
59use crate::AgentInfo;
60use crate::error::RandAgentError;
61
62/// 代理失效回调类型,减少类型复杂度
63pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
64
65/// 推荐使用 RandAgent,不推荐使用 RandAgent。
66/// RandAgent 已不再维护,RandAgent 支持多线程并发访问且更安全。
67/// 线程安全的 RandAgent,支持多线程并发访问
68#[derive(Clone)]
69pub struct RandAgent {
70    agents: Arc<Mutex<Vec<AgentState>>>,
71    on_agent_invalid: OnAgentInvalidCallback,
72}
73
74/// 线程安全的 Agent 状态
75#[derive(Clone)]
76pub struct AgentState {
77    pub id: i32,
78    pub agent: Arc<BoxAgent<'static>>,
79    pub info: AgentInfo,
80}
81
82impl Prompt for RandAgent {
83    #[allow(refining_impl_trait)]
84    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
85        // 第一步:选择代理并获取其索引
86        let agent_index = self.get_random_valid_agent_index().await
87            .ok_or(PromptError::MaxDepthError {
88                max_depth: 0,
89                chat_history: vec![],
90                prompt: "没有有效agent".into(),
91            })?;
92
93        // 第二步:加锁并获取可变引用
94        let mut agents = self.agents.lock().await;
95        let agent_state = &mut agents[agent_index];
96
97        tracing::info!("Using provider: {}, model: {},id: {}", agent_state.info.provider, agent_state.info.model,agent_state.info.id);
98        match agent_state.agent.prompt(prompt).await {
99            Ok(content) => {
100                agent_state.record_success();
101                Ok(content)
102            }
103            Err(e) => {
104                agent_state.record_failure();
105                if !agent_state.is_valid() {
106                    if let Some(cb) = &self.on_agent_invalid {
107                        cb(agent_state.id);
108                    }
109                }
110                Err(e)
111            }
112        }
113    }
114}
115
116impl AgentState {
117    fn new(agent: BoxAgent<'static>,id: i32, provider: String, model: String, max_failures: u32) -> Self {
118        Self {
119            id,
120            agent: Arc::new(agent),
121            info: AgentInfo{
122                id,
123                provider,
124                model,
125                failure_count: 0,
126                max_failures,
127            }
128        }
129    }
130
131    fn is_valid(&self) -> bool {
132        self.info.failure_count < self.info.max_failures
133    }
134
135    fn record_failure(&mut self) {
136        self.info.failure_count += 1;
137    }
138
139    fn record_success(&mut self) {
140        self.info.failure_count = 0;
141    }
142}
143
144impl RandAgent {
145    /// 创建新的线程安全 RandAgent
146    pub fn new(agents: Vec<(BoxAgent<'static>, i32, String, String)>) -> Self {
147        Self::with_max_failures_and_callback(agents, 3, None)
148    }
149
150    /// 使用自定义最大失败次数和回调创建线程安全 RandAgent
151    pub fn with_max_failures_and_callback(
152        agents: Vec<(BoxAgent<'static>, i32, String, String)>,
153        max_failures: u32,
154        on_agent_invalid: OnAgentInvalidCallback,
155    ) -> Self {
156        let agent_states = agents
157            .into_iter()
158            .map(|(agent, id, provider, model)| AgentState::new(agent, id, provider, model, max_failures))
159            .collect();
160        Self {
161            agents: Arc::new(Mutex::new(agent_states)),
162            on_agent_invalid,
163        }
164    }
165
166    /// 使用自定义最大失败次数创建线程安全 RandAgent
167    pub fn with_max_failures(agents: Vec<(BoxAgent<'static>, i32, String, String)>, max_failures: u32) -> Self {
168        Self::with_max_failures_and_callback(agents, max_failures, None)
169    }
170
171    /// 设置 agent 失效时的回调
172    pub fn set_on_agent_invalid<F>(&mut self, callback: F)
173    where
174        F: Fn(i32) + Send + Sync + 'static,
175    {
176        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
177    }
178
179    /// 添加代理到集合中
180    pub async fn add_agent(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String) {
181        let mut agents = self.agents.lock().await;
182        agents.push(AgentState::new(agent, id, provider, model, 3));
183    }
184
185    /// 使用自定义最大失败次数添加代理
186    pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'static>, id: i32, provider: String, model: String, max_failures: u32) {
187        let mut agents = self.agents.lock().await;
188        agents.push(AgentState::new(agent, id, provider, model, max_failures));
189    }
190
191    /// 获取有效代理数量
192    pub async fn len(&self) -> usize {
193        let agents = self.agents.lock().await;
194        agents.iter().filter(|state| state.is_valid()).count()
195    }
196    
197    /// 从集合中获取一个随机有效代理的索引
198    pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
199        let agents = self.agents.lock().await;
200        let valid_indices: Vec<usize> = agents
201            .iter()
202            .enumerate()
203            .filter(|(_, state)| state.is_valid())
204            .map(|(i, _)| i)
205            .collect();
206
207        if valid_indices.is_empty() {
208            return None;
209        }
210
211        let mut rng = rand::rng();
212        let random_index = rng.random_range(0..valid_indices.len());
213        Some(valid_indices[random_index])
214    }
215
216    /// 从集合中获取一个随机有效代理
217    /// 注意: 并不会增加失败计数
218    pub async fn get_random_valid_agent_state(&self) -> Option<AgentState> {
219        let mut agents = self.agents.lock().await;
220
221        let valid_indices: Vec<usize> = agents
222            .iter()
223            .enumerate()
224            .filter(|(_, state)| state.is_valid())
225            .map(|(i, _)| i)
226            .collect();
227
228        if valid_indices.is_empty() {
229            return None;
230        }
231
232        let mut rng = rand::rng();
233        let random_index = rng.random_range(0..valid_indices.len());
234        let agent_index = valid_indices[random_index];
235        agents.get_mut(agent_index).cloned()
236    }
237    
238
239    /// 获取总代理数量(包括无效的)
240    pub async fn total_len(&self) -> usize {
241        let agents = self.agents.lock().await;
242        agents.len()
243    }
244
245    /// 检查是否有有效代理
246    pub async fn is_empty(&self) -> bool {
247        self.len().await == 0
248    }
249    
250    /// 获取agent info
251    pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
252        let  agents = self.agents.lock().await;
253        let agent_infos = agents.iter()
254            .map(|agent|{
255                agent.info.clone()
256            }).collect::<_>();
257        tracing::info!("agents info: {:?}", agent_infos);
258        agent_infos
259    }
260
261    /// 获取失败统计
262    pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
263        let agents = self.agents.lock().await;
264        agents
265            .iter()
266            .enumerate()
267            .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
268            .collect()
269    }
270
271    /// 重置所有代理的失败计数
272    pub async fn reset_failures(&self) {
273        let mut agents = self.agents.lock().await;
274        for state in agents.iter_mut() {
275            state.info.failure_count = 0;
276        }
277    }
278
279    /// 通过名称获取 agent 
280    pub async fn get_agent_by_name(&self,provider_name: &str, model_name: &str) -> Option<AgentState> {
281        let mut agents = self.agents.lock().await;
282
283        for agent in agents.iter_mut() {
284            if agent.info.provider == provider_name &&  agent.info.model == model_name {
285                return Some(agent.clone());
286            }
287        }
288
289        None
290    }   
291    
292    /// 通过id获取 agent 
293    pub async fn get_agent_by_id(&self,id:i32) -> Option<AgentState> {
294        let mut agents = self.agents.lock().await;
295
296        for agent in agents.iter_mut() {
297            if agent.info.id == id {
298                return Some(agent.clone());
299            }
300        }
301
302        None
303    }
304
305    /// 添加失败重试
306    pub async fn try_invoke_with_retry(&self, info: Message, retry_num: Option<usize>) -> Result<String, RandAgentError> {
307        let mut config = ExponentialBuilder::default();
308        if let Some(retry_num) = retry_num {
309            config = config.with_max_times(retry_num)
310        }
311
312        let info = Arc::new(info);
313
314        let content = (|| {
315            let agent = self.clone();
316            let prompt = info.clone();
317            async move {
318                agent.prompt((*prompt).clone()).await
319            }
320        })
321        .retry(config)
322        .sleep(tokio::time::sleep)
323        .notify(|err: &PromptError, dur: Duration| {
324            println!("retrying {err:?} after {dur:?}");
325        })
326        .await?;
327        Ok(content)
328    }
329
330    #[allow(refining_impl_trait)]
331    pub async fn prompt_with_info(&self, prompt: impl Into<Message> + Send) -> Result<(String,AgentInfo), PromptError> {
332        // 第一步:选择代理并获取其索引
333        let agent_index = self.get_random_valid_agent_index().await
334            .ok_or(PromptError::MaxDepthError {
335                max_depth: 0,
336                chat_history: vec![],
337                prompt: "没有有效agent".into(),
338            })?;
339
340        // 第二步:加锁并获取可变引用
341        let mut agents = self.agents.lock().await;
342        let agent_state = &mut agents[agent_index];
343        
344        let agent_info = agent_state.info.clone();
345        
346        tracing::info!("prompt_with_info Using provider: {}, model: {},id: {}", agent_state.info.provider, agent_state.info.model,agent_state.info.id);
347        match agent_state.agent.prompt(prompt).await {
348            Ok(content) => {
349                agent_state.record_success();
350                Ok((content, agent_info))
351            }
352            Err(e) => {
353                agent_state.record_failure();
354                if !agent_state.is_valid() {
355                    if let Some(cb) = &self.on_agent_invalid {
356                        cb(agent_state.id);
357                    }
358                }
359                Err(e)
360            }
361        }
362    }
363
364    /// 添加失败重试
365    pub async fn try_invoke_with_info_retry(&self, info: Message, retry_num: Option<usize>) -> Result<(String,AgentInfo), RandAgentError> {
366        let mut config = ExponentialBuilder::default();
367        if let Some(retry_num) = retry_num {
368            config = config.with_max_times(retry_num)
369        }
370
371        let info = Arc::new(info);
372
373        let content = (|| {
374            let agent = self.clone();
375            let prompt = info.clone();
376            async move {
377                agent.prompt_with_info((*prompt).clone()).await
378            }
379        })
380            .retry(config)
381            .sleep(tokio::time::sleep)
382            .notify(|err: &PromptError, dur: Duration| {
383                println!("retrying {err:?} after {dur:?}");
384            })
385            .await?;
386        Ok(content)
387    }
388}
389
390/// 线程安全 RandAgent 的构建器
391pub struct RandAgentBuilder {
392    pub(crate) agents: Vec<(BoxAgent<'static>, i32, String, String)>,
393    max_failures: u32,
394    on_agent_invalid: OnAgentInvalidCallback,
395}
396
397impl RandAgentBuilder {
398    /// 创建新的 RandAgentBuilder
399    pub fn new() -> Self {
400        Self {
401            agents: Vec::new(),
402            max_failures: 3, // 默认最大失败次数
403            on_agent_invalid: None,
404        }
405    }
406
407    /// 设置连续失败的最大次数,超过后标记代理为无效
408    pub fn max_failures(mut self, max_failures: u32) -> Self {
409        self.max_failures = max_failures;
410        self
411    }
412
413    /// 设置 agent 失效时的回调
414    pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
415    where
416        F: Fn(i32) + Send + Sync + 'static,
417    {
418        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
419        self
420    }
421
422    /// 添加代理到构建器
423    ///
424    /// # 参数
425    /// - agent: 代理实例(需要是 'static 生命周期)
426    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
427    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
428    pub fn add_agent(mut self, agent: BoxAgent<'static>, id: i32, provider_name: String, model_name: String) -> Self {
429        self.agents.push((agent, id, provider_name, model_name));
430        self
431    }
432
433    /// 从 AgentBuilder 添加代理
434    ///
435    /// # 参数
436    /// - builder: AgentBuilder 实例(需要是 'static 生命周期)
437    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
438    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
439    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'static>>, id: i32, provider_name: &str, model_name: &str) -> Self {
440        self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
441        self
442    }
443
444    /// 构建 RandAgent
445    pub fn build(self) -> RandAgent {
446        RandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
447    }
448    
449}
450
451impl Default for RandAgentBuilder {
452    fn default() -> Self {
453        Self::new()
454    }
455}