rig_extra/
rand_agent.rs

1//! ## 多线程使用示例
2//!
3//! ```rust
4//! use rig_extra::extra_providers::{bigmodel::Client};
5//! use rig_extra::rand_agent::RandAgentBuilder;
6//! use std::sync::Arc;
7//! use tokio::task;
8//! use rig::client::ProviderClient;
9//! use rig_extra::error::RandAgentError;
10//! #[tokio::main]
11//! async fn main() -> Result<(), RandAgentError> {
12//!     // 创建线程安全的 RandAgent
13//!
14//!     //创建多个客户端
15//!     let client1 = Client::from_env();
16//!     let client2 = Client::from_env();
17//!     use rig::client::completion::CompletionClientDyn;
18//!     use rig::completion::Prompt;
19//!
20//!
21//!     let thread_safe_agent = RandAgentBuilder::new()
22//!         .max_failures(3)
23//!         .add_agent(client1.agent("glm-4-flash").build(),1, "bigmodel".to_string(), "glm-4-flash".to_string())
24//!         .add_agent(client2.agent("glm-4-flash").build(),2, "bigmodel".to_string(), "glm-4-flash".to_string())
25//!         .build();
26//!
27//!     let agent_arc = Arc::new(thread_safe_agent);
28//!
29//!     // 创建多个并发任务
30//!     let mut handles = vec![];
31//!     for i in 0..5 {
32//!         let agent_clone = Arc::clone(&agent_arc);
33//!         let handle = task::spawn(async move {
34//!             let response = agent_clone.prompt(&format!("Hello from task {}", i)).await?;
35//!             println!("Task {} response: {}", i, response);
36//!             Ok::<(), RandAgentError>(())
37//!         });
38//!         handles.push(handle);
39//!     }
40//!
41//!     // 等待所有任务完成
42//!     for handle in handles {
43//!         handle.await??;
44//!     }
45//!
46//!     Ok(())
47//! }
48//! ```
49
50use crate::AgentInfo;
51use crate::error::RandAgentError;
52use backon::{ExponentialBuilder, Retryable};
53use rand::Rng;
54use rig::agent::Agent;
55use rig::client::builder::BoxAgent;
56use rig::client::completion::CompletionModelHandle;
57use rig::completion::{Message, Prompt, PromptError};
58use std::sync::Arc;
59use std::time::Duration;
60use tokio::sync::Mutex;
61
62/// 代理失效回调类型,减少类型复杂度
63pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
64
65/// 推荐使用 RandAgent,不推荐使用 RandAgent。
66/// RandAgent 已不再维护,RandAgent 支持多线程并发访问且更安全。
67/// 线程安全的 RandAgent,支持多线程并发访问
68#[derive(Clone)]
69pub struct RandAgent {
70    agents: Arc<Mutex<Vec<AgentState>>>,
71    on_agent_invalid: OnAgentInvalidCallback,
72}
73
74/// 线程安全的 Agent 状态
75#[derive(Clone)]
76pub struct AgentState {
77    pub id: i32,
78    pub agent: Arc<BoxAgent<'static>>,
79    pub info: AgentInfo,
80}
81
82impl Prompt for RandAgent {
83    #[allow(refining_impl_trait)]
84    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
85        // 第一步:选择代理并获取其索引
86        let agent_index =
87            self.get_random_valid_agent_index()
88                .await
89                .ok_or(PromptError::MaxDepthError {
90                    max_depth: 0,
91                    chat_history: Box::new(vec![]),
92                    prompt: "没有有效agent".into(),
93                })?;
94
95        // 第二步:加锁并获取可变引用
96        let mut agents = self.agents.lock().await;
97        let agent_state = &mut agents[agent_index];
98
99        tracing::info!(
100            "Using provider: {}, model: {},id: {}",
101            agent_state.info.provider,
102            agent_state.info.model,
103            agent_state.info.id
104        );
105        match agent_state.agent.prompt(prompt).await {
106            Ok(content) => {
107                agent_state.record_success();
108                Ok(content)
109            }
110            Err(e) => {
111                agent_state.record_failure();
112                if !agent_state.is_valid() {
113                    if let Some(cb) = &self.on_agent_invalid {
114                        cb(agent_state.id);
115                    }
116                }
117                Err(e)
118            }
119        }
120    }
121}
122
123impl AgentState {
124    fn new(
125        agent: BoxAgent<'static>,
126        id: i32,
127        provider: String,
128        model: String,
129        max_failures: u32,
130    ) -> Self {
131        Self {
132            id,
133            agent: Arc::new(agent),
134            info: AgentInfo {
135                id,
136                provider,
137                model,
138                failure_count: 0,
139                max_failures,
140            },
141        }
142    }
143
144    fn is_valid(&self) -> bool {
145        self.info.failure_count < self.info.max_failures
146    }
147
148    fn record_failure(&mut self) {
149        self.info.failure_count += 1;
150    }
151
152    fn record_success(&mut self) {
153        self.info.failure_count = 0;
154    }
155}
156
157impl RandAgent {
158    /// 创建新的线程安全 RandAgent
159    pub fn new(agents: Vec<(BoxAgent<'static>, i32, String, String)>) -> Self {
160        Self::with_max_failures_and_callback(agents, 3, None)
161    }
162
163    /// 使用自定义最大失败次数和回调创建线程安全 RandAgent
164    pub fn with_max_failures_and_callback(
165        agents: Vec<(BoxAgent<'static>, i32, String, String)>,
166        max_failures: u32,
167        on_agent_invalid: OnAgentInvalidCallback,
168    ) -> Self {
169        let agent_states = agents
170            .into_iter()
171            .map(|(agent, id, provider, model)| {
172                AgentState::new(agent, id, provider, model, max_failures)
173            })
174            .collect();
175        Self {
176            agents: Arc::new(Mutex::new(agent_states)),
177            on_agent_invalid,
178        }
179    }
180
181    /// 使用自定义最大失败次数创建线程安全 RandAgent
182    pub fn with_max_failures(
183        agents: Vec<(BoxAgent<'static>, i32, String, String)>,
184        max_failures: u32,
185    ) -> Self {
186        Self::with_max_failures_and_callback(agents, max_failures, None)
187    }
188
189    /// 设置 agent 失效时的回调
190    pub fn set_on_agent_invalid<F>(&mut self, callback: F)
191    where
192        F: Fn(i32) + Send + Sync + 'static,
193    {
194        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
195    }
196
197    /// 添加代理到集合中
198    pub async fn add_agent(
199        &self,
200        agent: BoxAgent<'static>,
201        id: i32,
202        provider: String,
203        model: String,
204    ) {
205        let mut agents = self.agents.lock().await;
206        agents.push(AgentState::new(agent, id, provider, model, 3));
207    }
208
209    /// 使用自定义最大失败次数添加代理
210    pub async fn add_agent_with_max_failures(
211        &self,
212        agent: BoxAgent<'static>,
213        id: i32,
214        provider: String,
215        model: String,
216        max_failures: u32,
217    ) {
218        let mut agents = self.agents.lock().await;
219        agents.push(AgentState::new(agent, id, provider, model, max_failures));
220    }
221
222    /// 获取有效代理数量
223    pub async fn len(&self) -> usize {
224        let agents = self.agents.lock().await;
225        agents.iter().filter(|state| state.is_valid()).count()
226    }
227
228    /// 从集合中获取一个随机有效代理的索引
229    pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
230        let agents = self.agents.lock().await;
231        let valid_indices: Vec<usize> = agents
232            .iter()
233            .enumerate()
234            .filter(|(_, state)| state.is_valid())
235            .map(|(i, _)| i)
236            .collect();
237
238        if valid_indices.is_empty() {
239            return None;
240        }
241
242        let mut rng = rand::rng();
243        let random_index = rng.random_range(0..valid_indices.len());
244        Some(valid_indices[random_index])
245    }
246
247    /// 从集合中获取一个随机有效代理
248    /// 注意: 并不会增加失败计数
249    pub async fn get_random_valid_agent_state(&self) -> Option<AgentState> {
250        let mut agents = self.agents.lock().await;
251
252        let valid_indices: Vec<usize> = agents
253            .iter()
254            .enumerate()
255            .filter(|(_, state)| state.is_valid())
256            .map(|(i, _)| i)
257            .collect();
258
259        if valid_indices.is_empty() {
260            return None;
261        }
262
263        let mut rng = rand::rng();
264        let random_index = rng.random_range(0..valid_indices.len());
265        let agent_index = valid_indices[random_index];
266        agents.get_mut(agent_index).cloned()
267    }
268
269    /// 获取总代理数量(包括无效的)
270    pub async fn total_len(&self) -> usize {
271        let agents = self.agents.lock().await;
272        agents.len()
273    }
274
275    /// 检查是否有有效代理
276    pub async fn is_empty(&self) -> bool {
277        self.len().await == 0
278    }
279
280    /// 获取agent info
281    pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
282        let agents = self.agents.lock().await;
283        let agent_infos = agents.iter().map(|agent| agent.info.clone()).collect::<_>();
284        tracing::info!("agents info: {:?}", agent_infos);
285        agent_infos
286    }
287
288    /// 获取失败统计
289    pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
290        let agents = self.agents.lock().await;
291        agents
292            .iter()
293            .enumerate()
294            .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
295            .collect()
296    }
297
298    /// 重置所有代理的失败计数
299    pub async fn reset_failures(&self) {
300        let mut agents = self.agents.lock().await;
301        for state in agents.iter_mut() {
302            state.info.failure_count = 0;
303        }
304    }
305
306    /// 通过名称获取 agent
307    pub async fn get_agent_by_name(
308        &self,
309        provider_name: &str,
310        model_name: &str,
311    ) -> Option<AgentState> {
312        let mut agents = self.agents.lock().await;
313
314        for agent in agents.iter_mut() {
315            if agent.info.provider == provider_name && agent.info.model == model_name {
316                return Some(agent.clone());
317            }
318        }
319
320        None
321    }
322
323    /// 通过id获取 agent
324    pub async fn get_agent_by_id(&self, id: i32) -> Option<AgentState> {
325        let mut agents = self.agents.lock().await;
326
327        for agent in agents.iter_mut() {
328            if agent.info.id == id {
329                return Some(agent.clone());
330            }
331        }
332
333        None
334    }
335
336    /// 添加失败重试
337    pub async fn try_invoke_with_retry(
338        &self,
339        info: Message,
340        retry_num: Option<usize>,
341    ) -> Result<String, RandAgentError> {
342        let mut config = ExponentialBuilder::default();
343        if let Some(retry_num) = retry_num {
344            config = config.with_max_times(retry_num)
345        }
346
347        let info = Arc::new(info);
348
349        let content = (|| {
350            let agent = self.clone();
351            let prompt = info.clone();
352            async move { agent.prompt((*prompt).clone()).await }
353        })
354        .retry(config)
355        .sleep(tokio::time::sleep)
356        .notify(|err: &PromptError, dur: Duration| {
357            println!("retrying {err:?} after {dur:?}");
358        })
359        .await?;
360        Ok(content)
361    }
362
363    #[allow(refining_impl_trait)]
364    pub async fn prompt_with_info(
365        &self,
366        prompt: impl Into<Message> + Send,
367    ) -> Result<(String, AgentInfo), PromptError> {
368        // 第一步:选择代理并获取其索引
369        let agent_index =
370            self.get_random_valid_agent_index()
371                .await
372                .ok_or(PromptError::MaxDepthError {
373                    max_depth: 0,
374                    chat_history: Box::new(vec![]),
375                    prompt: "没有有效agent".into(),
376                })?;
377
378        // 第二步:加锁并获取可变引用
379        let mut agents = self.agents.lock().await;
380        let agent_state = &mut agents[agent_index];
381
382        let agent_info = agent_state.info.clone();
383
384        tracing::info!(
385            "prompt_with_info Using provider: {}, model: {},id: {}",
386            agent_state.info.provider,
387            agent_state.info.model,
388            agent_state.info.id
389        );
390        match agent_state.agent.prompt(prompt).await {
391            Ok(content) => {
392                agent_state.record_success();
393                Ok((content, agent_info))
394            }
395            Err(e) => {
396                agent_state.record_failure();
397                if !agent_state.is_valid() {
398                    if let Some(cb) = &self.on_agent_invalid {
399                        cb(agent_state.id);
400                    }
401                }
402                Err(e)
403            }
404        }
405    }
406
407    /// 添加失败重试
408    pub async fn try_invoke_with_info_retry(
409        &self,
410        info: Message,
411        retry_num: Option<usize>,
412    ) -> Result<(String, AgentInfo), RandAgentError> {
413        let mut config = ExponentialBuilder::default();
414        if let Some(retry_num) = retry_num {
415            config = config.with_max_times(retry_num)
416        }
417
418        let info = Arc::new(info);
419
420        let content = (|| {
421            let agent = self.clone();
422            let prompt = info.clone();
423            async move { agent.prompt_with_info((*prompt).clone()).await }
424        })
425        .retry(config)
426        .sleep(tokio::time::sleep)
427        .notify(|err: &PromptError, dur: Duration| {
428            println!("retrying {err:?} after {dur:?}");
429        })
430        .await?;
431        Ok(content)
432    }
433}
434
435/// 线程安全 RandAgent 的构建器
436pub struct RandAgentBuilder {
437    pub(crate) agents: Vec<(BoxAgent<'static>, i32, String, String)>,
438    max_failures: u32,
439    on_agent_invalid: OnAgentInvalidCallback,
440}
441
442impl RandAgentBuilder {
443    /// 创建新的 RandAgentBuilder
444    pub fn new() -> Self {
445        Self {
446            agents: Vec::new(),
447            max_failures: 3, // 默认最大失败次数
448            on_agent_invalid: None,
449        }
450    }
451
452    /// 设置连续失败的最大次数,超过后标记代理为无效
453    pub fn max_failures(mut self, max_failures: u32) -> Self {
454        self.max_failures = max_failures;
455        self
456    }
457
458    /// 设置 agent 失效时的回调
459    pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
460    where
461        F: Fn(i32) + Send + Sync + 'static,
462    {
463        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
464        self
465    }
466
467    /// 添加代理到构建器
468    ///
469    /// # 参数
470    /// - agent: 代理实例(需要是 'static 生命周期)
471    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
472    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
473    pub fn add_agent(
474        mut self,
475        agent: BoxAgent<'static>,
476        id: i32,
477        provider_name: String,
478        model_name: String,
479    ) -> Self {
480        self.agents.push((agent, id, provider_name, model_name));
481        self
482    }
483
484    /// 从 AgentBuilder 添加代理
485    ///
486    /// # 参数
487    /// - builder: AgentBuilder 实例(需要是 'static 生命周期)
488    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
489    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
490    pub fn add_builder(
491        mut self,
492        builder: Agent<CompletionModelHandle<'static>>,
493        id: i32,
494        provider_name: &str,
495        model_name: &str,
496    ) -> Self {
497        self.agents.push((
498            builder,
499            id,
500            provider_name.to_string(),
501            model_name.to_string(),
502        ));
503        self
504    }
505
506    /// 构建 RandAgent
507    pub fn build(self) -> RandAgent {
508        RandAgent::with_max_failures_and_callback(
509            self.agents,
510            self.max_failures,
511            self.on_agent_invalid,
512        )
513    }
514}
515
516impl Default for RandAgentBuilder {
517    fn default() -> Self {
518        Self::new()
519    }
520}