rig_extra/
rand_agent.rs

1//! ## 多线程使用示例
2//!
3//! ```rust
4//! use rig_extra::agent_variant::AgentVariant;
5//! use rig_extra::extra_providers::{bigmodel::Client};
6//! use rig_extra::rand_agent::RandAgentBuilder;
7//! use std::sync::Arc;
8//! use tokio::task;
9//! use rig::client::ProviderClient;
10//! use rig_extra::error::RandAgentError;
11//! #[tokio::main]
12//! async fn main() -> Result<(), RandAgentError> {
13//!     // 创建线程安全的 RandAgent
14//!
15//!     //创建多个客户端
16//!     let client1 = Client::from_env();
17//!     let client2 = Client::from_env();
18//!     use rig::completion::Prompt;
19//!
20//!
21//!     let thread_safe_agent = RandAgentBuilder::new()
22//!         .max_failures(3)
23//!         .add_agent(AgentVariant::Bigmodel(client1.agent("glm-4-flash").build()),1, "bigmodel".to_string(), "glm-4-flash".to_string())
24//!         .add_agent(AgentVariant::Bigmodel(client2.agent("glm-4-flash").build()),2, "bigmodel".to_string(), "glm-4-flash".to_string())
25//!         .build();
26//!
27//!     let agent_arc = Arc::new(thread_safe_agent);
28//!
29//!     // 创建多个并发任务
30//!     let mut handles = vec![];
31//!     for i in 0..5 {
32//!         let agent_clone = Arc::clone(&agent_arc);
33//!         let handle = task::spawn(async move {
34//!             let response = agent_clone.prompt(&format!("Hello from task {}", i)).await?;
35//!             println!("Task {} response: {}", i, response);
36//!             Ok::<(), RandAgentError>(())
37//!         });
38//!         handles.push(handle);
39//!     }
40//!
41//!     // 等待所有任务完成
42//!     for handle in handles {
43//!         handle.await??;
44//!     }
45//!
46//!     Ok(())
47//! }
48//! ```
49
50use crate::AgentInfo;
51use crate::agent_variant::AgentVariant;
52use crate::error::RandAgentError;
53use backon::{ExponentialBuilder, Retryable};
54use rand::Rng;
55use rig::completion::{Message, Prompt, PromptError};
56use std::sync::Arc;
57use std::time::Duration;
58use tokio::sync::Mutex;
59
60/// 代理失效回调类型,减少类型复杂度
61pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
62
63/// 推荐使用 RandAgent,不推荐使用 RandAgent。
64/// RandAgent 已不再维护,RandAgent 支持多线程并发访问且更安全。
65/// 线程安全的 RandAgent,支持多线程并发访问
66#[derive(Clone)]
67pub struct RandAgent {
68    agents: Arc<Mutex<Vec<AgentState>>>,
69    on_agent_invalid: OnAgentInvalidCallback,
70}
71
72/// 线程安全的 Agent 状态
73#[derive(Clone)]
74pub struct AgentState {
75    pub id: i32,
76    pub agent: Arc<AgentVariant>,
77    pub info: AgentInfo,
78}
79
80impl Prompt for RandAgent {
81    #[allow(refining_impl_trait)]
82    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
83        // 第一步:选择代理并获取其索引
84        let agent_index =
85            self.get_random_valid_agent_index()
86                .await
87                .ok_or(PromptError::MaxDepthError {
88                    max_depth: 0,
89                    chat_history: Box::new(vec![]),
90                    prompt: "没有有效agent".into(),
91                })?;
92
93        // 第二步:加锁并获取可变引用
94        let mut agents = self.agents.lock().await;
95        let agent_state = &mut agents[agent_index];
96
97        tracing::info!(
98            "Using provider: {}, model: {},id: {}",
99            agent_state.info.provider,
100            agent_state.info.model,
101            agent_state.info.id
102        );
103        match agent_state.agent.prompt(prompt).await {
104            Ok(content) => {
105                agent_state.record_success();
106                Ok(content)
107            }
108            Err(e) => {
109                agent_state.record_failure();
110                if !agent_state.is_valid()
111                    && let Some(cb) = &self.on_agent_invalid
112                {
113                    cb(agent_state.id);
114                }
115                Err(e)
116            }
117        }
118    }
119}
120
121impl AgentState {
122    fn new(
123        agent: AgentVariant,
124        id: i32,
125        provider: String,
126        model: String,
127        max_failures: u32,
128    ) -> Self {
129        Self {
130            id,
131            agent: Arc::new(agent),
132            info: AgentInfo {
133                id,
134                provider,
135                model,
136                failure_count: 0,
137                max_failures,
138            },
139        }
140    }
141
142    fn is_valid(&self) -> bool {
143        self.info.failure_count < self.info.max_failures
144    }
145
146    fn record_failure(&mut self) {
147        self.info.failure_count += 1;
148    }
149
150    fn record_success(&mut self) {
151        self.info.failure_count = 0;
152    }
153}
154
155impl RandAgent {
156    /// 创建新的线程安全 RandAgent
157    pub fn new(agents: Vec<(AgentVariant, i32, String, String)>) -> Self {
158        Self::with_max_failures_and_callback(agents, 3, None)
159    }
160
161    /// 使用自定义最大失败次数和回调创建线程安全 RandAgent
162    pub fn with_max_failures_and_callback(
163        agents: Vec<(AgentVariant, i32, String, String)>,
164        max_failures: u32,
165        on_agent_invalid: OnAgentInvalidCallback,
166    ) -> Self {
167        let agent_states = agents
168            .into_iter()
169            .map(|(agent, id, provider, model)| {
170                AgentState::new(agent, id, provider, model, max_failures)
171            })
172            .collect();
173        Self {
174            agents: Arc::new(Mutex::new(agent_states)),
175            on_agent_invalid,
176        }
177    }
178
179    /// 使用自定义最大失败次数创建线程安全 RandAgent
180    pub fn with_max_failures(
181        agents: Vec<(AgentVariant, i32, String, String)>,
182        max_failures: u32,
183    ) -> Self {
184        Self::with_max_failures_and_callback(agents, max_failures, None)
185    }
186
187    /// 设置 agent 失效时的回调
188    pub fn set_on_agent_invalid<F>(&mut self, callback: F)
189    where
190        F: Fn(i32) + Send + Sync + 'static,
191    {
192        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
193    }
194
195    /// 添加代理到集合中
196    pub async fn add_agent(&self, agent: AgentVariant, id: i32, provider: String, model: String) {
197        let mut agents = self.agents.lock().await;
198        agents.push(AgentState::new(agent, id, provider, model, 3));
199    }
200
201    /// 使用自定义最大失败次数添加代理
202    pub async fn add_agent_with_max_failures(
203        &self,
204        agent: AgentVariant,
205        id: i32,
206        provider: String,
207        model: String,
208        max_failures: u32,
209    ) {
210        let mut agents = self.agents.lock().await;
211        agents.push(AgentState::new(agent, id, provider, model, max_failures));
212    }
213
214    /// 获取有效代理数量
215    pub async fn len(&self) -> usize {
216        let agents = self.agents.lock().await;
217        agents.iter().filter(|state| state.is_valid()).count()
218    }
219
220    /// 从集合中获取一个随机有效代理的索引
221    pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
222        let agents = self.agents.lock().await;
223        let valid_indices: Vec<usize> = agents
224            .iter()
225            .enumerate()
226            .filter(|(_, state)| state.is_valid())
227            .map(|(i, _)| i)
228            .collect();
229
230        if valid_indices.is_empty() {
231            return None;
232        }
233
234        let mut rng = rand::rng();
235        let random_index = rng.random_range(0..valid_indices.len());
236        Some(valid_indices[random_index])
237    }
238
239    /// 从集合中获取一个随机有效代理
240    /// 注意: 并不会增加失败计数
241    pub async fn get_random_valid_agent_state(&self) -> Option<AgentState> {
242        let mut agents = self.agents.lock().await;
243
244        let valid_indices: Vec<usize> = agents
245            .iter()
246            .enumerate()
247            .filter(|(_, state)| state.is_valid())
248            .map(|(i, _)| i)
249            .collect();
250
251        if valid_indices.is_empty() {
252            return None;
253        }
254
255        let mut rng = rand::rng();
256        let random_index = rng.random_range(0..valid_indices.len());
257        let agent_index = valid_indices[random_index];
258        agents.get_mut(agent_index).cloned()
259    }
260
261    /// 获取总代理数量(包括无效的)
262    pub async fn total_len(&self) -> usize {
263        let agents = self.agents.lock().await;
264        agents.len()
265    }
266
267    /// 检查是否有有效代理
268    pub async fn is_empty(&self) -> bool {
269        self.len().await == 0
270    }
271
272    /// 获取agent info
273    pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
274        let agents = self.agents.lock().await;
275        let agent_infos = agents.iter().map(|agent| agent.info.clone()).collect::<_>();
276        tracing::info!("agents info: {:?}", agent_infos);
277        agent_infos
278    }
279
280    /// 获取失败统计
281    pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
282        let agents = self.agents.lock().await;
283        agents
284            .iter()
285            .enumerate()
286            .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
287            .collect()
288    }
289
290    /// 重置所有代理的失败计数
291    pub async fn reset_failures(&self) {
292        let mut agents = self.agents.lock().await;
293        for state in agents.iter_mut() {
294            state.info.failure_count = 0;
295        }
296    }
297
298    /// 通过名称获取 agent
299    pub async fn get_agent_by_name(
300        &self,
301        provider_name: &str,
302        model_name: &str,
303    ) -> Option<AgentState> {
304        let mut agents = self.agents.lock().await;
305
306        for agent in agents.iter_mut() {
307            if agent.info.provider == provider_name && agent.info.model == model_name {
308                return Some(agent.clone());
309            }
310        }
311
312        None
313    }
314
315    /// 通过id获取 agent
316    pub async fn get_agent_by_id(&self, id: i32) -> Option<AgentState> {
317        let mut agents = self.agents.lock().await;
318
319        for agent in agents.iter_mut() {
320            if agent.info.id == id {
321                return Some(agent.clone());
322            }
323        }
324
325        None
326    }
327
328    /// 添加失败重试
329    pub async fn try_invoke_with_retry(
330        &self,
331        info: Message,
332        retry_num: Option<usize>,
333    ) -> Result<String, RandAgentError> {
334        let mut config = ExponentialBuilder::default();
335        if let Some(retry_num) = retry_num {
336            config = config.with_max_times(retry_num)
337        }
338
339        let info = Arc::new(info);
340
341        let content = (|| {
342            let agent = self.clone();
343            let prompt = info.clone();
344            async move { agent.prompt((*prompt).clone()).await }
345        })
346        .retry(config)
347        .sleep(tokio::time::sleep)
348        .notify(|err: &PromptError, dur: Duration| {
349            println!("retrying {err:?} after {dur:?}");
350        })
351        .await?;
352        Ok(content)
353    }
354
355    #[allow(refining_impl_trait)]
356    pub async fn prompt_with_info(
357        &self,
358        prompt: impl Into<Message> + Send,
359    ) -> Result<(String, AgentInfo), PromptError> {
360        // 第一步:选择代理并获取其索引
361        let agent_index =
362            self.get_random_valid_agent_index()
363                .await
364                .ok_or(PromptError::MaxDepthError {
365                    max_depth: 0,
366                    chat_history: Box::new(vec![]),
367                    prompt: "没有有效agent".into(),
368                })?;
369
370        // 第二步:加锁并获取可变引用
371        let mut agents = self.agents.lock().await;
372        let agent_state = &mut agents[agent_index];
373
374        let agent_info = agent_state.info.clone();
375
376        tracing::info!(
377            "prompt_with_info Using provider: {}, model: {},id: {}",
378            agent_state.info.provider,
379            agent_state.info.model,
380            agent_state.info.id
381        );
382        match agent_state.agent.prompt(prompt).await {
383            Ok(content) => {
384                agent_state.record_success();
385                Ok((content, agent_info))
386            }
387            Err(e) => {
388                agent_state.record_failure();
389                if !agent_state.is_valid()
390                    && let Some(cb) = &self.on_agent_invalid
391                {
392                    cb(agent_state.id);
393                }
394                Err(e)
395            }
396        }
397    }
398
399    /// 添加失败重试
400    pub async fn try_invoke_with_info_retry(
401        &self,
402        info: Message,
403        retry_num: Option<usize>,
404    ) -> Result<(String, AgentInfo), RandAgentError> {
405        let mut config = ExponentialBuilder::default();
406        if let Some(retry_num) = retry_num {
407            config = config.with_max_times(retry_num)
408        }
409
410        let info = Arc::new(info);
411
412        let content = (|| {
413            let agent = self.clone();
414            let prompt = info.clone();
415            async move { agent.prompt_with_info((*prompt).clone()).await }
416        })
417        .retry(config)
418        .sleep(tokio::time::sleep)
419        .notify(|err: &PromptError, dur: Duration| {
420            println!("retrying {err:?} after {dur:?}");
421        })
422        .await?;
423        Ok(content)
424    }
425}
426
427/// 线程安全 RandAgent 的构建器
428pub struct RandAgentBuilder {
429    pub(crate) agents: Vec<(AgentVariant, i32, String, String)>,
430    max_failures: u32,
431    on_agent_invalid: OnAgentInvalidCallback,
432}
433
434impl RandAgentBuilder {
435    /// 创建新的 RandAgentBuilder
436    pub fn new() -> Self {
437        Self {
438            agents: Vec::new(),
439            max_failures: 3, // 默认最大失败次数
440            on_agent_invalid: None,
441        }
442    }
443
444    /// 设置连续失败的最大次数,超过后标记代理为无效
445    pub fn max_failures(mut self, max_failures: u32) -> Self {
446        self.max_failures = max_failures;
447        self
448    }
449
450    /// 设置 agent 失效时的回调
451    pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
452    where
453        F: Fn(i32) + Send + Sync + 'static,
454    {
455        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
456        self
457    }
458
459    /// 添加代理到构建器
460    ///
461    /// # 参数
462    /// - agent: 代理实例(AgentVariant 枚举)
463    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
464    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
465    pub fn add_agent(
466        mut self,
467        agent: AgentVariant,
468        id: i32,
469        provider_name: String,
470        model_name: String,
471    ) -> Self {
472        self.agents.push((agent, id, provider_name, model_name));
473        self
474    }
475
476    /// 构建 RandAgent
477    pub fn build(self) -> RandAgent {
478        RandAgent::with_max_failures_and_callback(
479            self.agents,
480            self.max_failures,
481            self.on_agent_invalid,
482        )
483    }
484}
485
486impl Default for RandAgentBuilder {
487    fn default() -> Self {
488        Self::new()
489    }
490}