Skip to main content

rig_extra/
rand_agent.rs

1//! ## 多线程使用示例
2//!
3//! ```rust
4//! use rig_extra::agent_variant::AgentVariant;
5//! use rig_extra::extra_providers::{bigmodel_old::Client};
6//! use rig_extra::rand_agent::RandAgentBuilder;
7//! use std::sync::Arc;
8//! use tokio::task;
9//! use rig::client::ProviderClient;
10//! use rig_extra::error::RandAgentError;
11//! #[tokio::main]
12//! async fn main() -> Result<(), RandAgentError> {
13//!     // 创建线程安全的 RandAgent
14//!
15//!     //创建多个客户端
16//!     let client1 = Client::from_env();
17//!     let client2 = Client::from_env();
18//!     use rig::completion::Prompt;
19//!
20//!
21//!     let thread_safe_agent = RandAgentBuilder::new()
22//!         .max_failures(3)
23//!         .add_agent(AgentVariant::Bigmodel(client1.agent("glm-4-flash").build()),1, "bigmodel".to_string(), "glm-4-flash".to_string())
24//!         .add_agent(AgentVariant::Bigmodel(client2.agent("glm-4-flash").build()),2, "bigmodel".to_string(), "glm-4-flash".to_string())
25//!         .build();
26//!
27//!     let agent_arc = Arc::new(thread_safe_agent);
28//!
29//!     // 创建多个并发任务
30//!     let mut handles = vec![];
31//!     for i in 0..5 {
32//!         let agent_clone = Arc::clone(&agent_arc);
33//!         let handle = task::spawn(async move {
34//!             let response = agent_clone.prompt(&format!("Hello from task {}", i)).await?;
35//!             println!("Task {} response: {}", i, response);
36//!             Ok::<(), RandAgentError>(())
37//!         });
38//!         handles.push(handle);
39//!     }
40//!
41//!     // 等待所有任务完成
42//!     for handle in handles {
43//!         handle.await??;
44//!     }
45//!
46//!     Ok(())
47//! }
48//! ```
49
50use crate::AgentInfo;
51use crate::agent_variant::AgentVariant;
52use crate::error::RandAgentError;
53use backon::{ExponentialBuilder, Retryable};
54use rand::Rng;
55use rig::completion::{CompletionError, Message, Prompt, PromptError};
56use std::sync::Arc;
57use std::time::Duration;
58use tokio::sync::Mutex;
59
60/// 代理失效回调类型,减少类型复杂度
61pub type OnAgentInvalidCallback = Option<Arc<Box<dyn Fn(i32) + Send + Sync + 'static>>>;
62
63/// 推荐使用 RandAgent,不推荐使用 RandAgent。
64/// RandAgent 已不再维护,RandAgent 支持多线程并发访问且更安全。
65/// 线程安全的 RandAgent,支持多线程并发访问
66#[derive(Clone)]
67pub struct RandAgent {
68    agents: Arc<Mutex<Vec<AgentState>>>,
69    on_agent_invalid: OnAgentInvalidCallback,
70}
71
72/// 线程安全的 Agent 状态
73#[derive(Clone)]
74pub struct AgentState {
75    pub id: i32,
76    pub agent: Arc<AgentVariant>,
77    pub info: AgentInfo,
78}
79
80impl Prompt for RandAgent {
81    #[allow(refining_impl_trait)]
82    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
83        // 第一步:选择代理并获取其索引
84        let agent_index = self
85            .get_random_valid_agent_index()
86            .await
87            .ok_or(CompletionError::ProviderError("没有有效agent".to_string()))?;
88
89        // 第二步:加锁并获取可变引用
90        let mut agents = self.agents.lock().await;
91        let agent_state = &mut agents[agent_index];
92
93        tracing::info!(
94            "Using provider: {}, model: {},id: {}",
95            agent_state.info.provider,
96            agent_state.info.model,
97            agent_state.info.id
98        );
99        match agent_state.agent.prompt(prompt).await {
100            Ok(content) => {
101                agent_state.record_success();
102                Ok(content)
103            }
104            Err(e) => {
105                agent_state.record_failure();
106                if !agent_state.is_valid()
107                    && let Some(cb) = &self.on_agent_invalid
108                {
109                    cb(agent_state.id);
110                }
111                Err(e)
112            }
113        }
114    }
115}
116
117impl AgentState {
118    fn new(
119        agent: AgentVariant,
120        id: i32,
121        provider: String,
122        model: String,
123        max_failures: u32,
124    ) -> Self {
125        Self {
126            id,
127            agent: Arc::new(agent),
128            info: AgentInfo {
129                id,
130                provider,
131                model,
132                failure_count: 0,
133                max_failures,
134            },
135        }
136    }
137
138    fn is_valid(&self) -> bool {
139        self.info.failure_count < self.info.max_failures
140    }
141
142    fn record_failure(&mut self) {
143        self.info.failure_count += 1;
144    }
145
146    fn record_success(&mut self) {
147        self.info.failure_count = 0;
148    }
149}
150
151impl RandAgent {
152    /// 创建新的线程安全 RandAgent
153    pub fn new(agents: Vec<(AgentVariant, i32, String, String)>) -> Self {
154        Self::with_max_failures_and_callback(agents, 3, None)
155    }
156
157    /// 使用自定义最大失败次数和回调创建线程安全 RandAgent
158    pub fn with_max_failures_and_callback(
159        agents: Vec<(AgentVariant, i32, String, String)>,
160        max_failures: u32,
161        on_agent_invalid: OnAgentInvalidCallback,
162    ) -> Self {
163        let agent_states = agents
164            .into_iter()
165            .map(|(agent, id, provider, model)| {
166                AgentState::new(agent, id, provider, model, max_failures)
167            })
168            .collect();
169        Self {
170            agents: Arc::new(Mutex::new(agent_states)),
171            on_agent_invalid,
172        }
173    }
174
175    /// 使用自定义最大失败次数创建线程安全 RandAgent
176    pub fn with_max_failures(
177        agents: Vec<(AgentVariant, i32, String, String)>,
178        max_failures: u32,
179    ) -> Self {
180        Self::with_max_failures_and_callback(agents, max_failures, None)
181    }
182
183    /// 设置 agent 失效时的回调
184    pub fn set_on_agent_invalid<F>(&mut self, callback: F)
185    where
186        F: Fn(i32) + Send + Sync + 'static,
187    {
188        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
189    }
190
191    /// 添加代理到集合中
192    pub async fn add_agent(&self, agent: AgentVariant, id: i32, provider: String, model: String) {
193        let mut agents = self.agents.lock().await;
194        agents.push(AgentState::new(agent, id, provider, model, 3));
195    }
196
197    /// 使用自定义最大失败次数添加代理
198    pub async fn add_agent_with_max_failures(
199        &self,
200        agent: AgentVariant,
201        id: i32,
202        provider: String,
203        model: String,
204        max_failures: u32,
205    ) {
206        let mut agents = self.agents.lock().await;
207        agents.push(AgentState::new(agent, id, provider, model, max_failures));
208    }
209
210    /// 获取有效代理数量
211    pub async fn len(&self) -> usize {
212        let agents = self.agents.lock().await;
213        agents.iter().filter(|state| state.is_valid()).count()
214    }
215
216    /// 从集合中获取一个随机有效代理的索引
217    pub async fn get_random_valid_agent_index(&self) -> Option<usize> {
218        let agents = self.agents.lock().await;
219        let valid_indices: Vec<usize> = agents
220            .iter()
221            .enumerate()
222            .filter(|(_, state)| state.is_valid())
223            .map(|(i, _)| i)
224            .collect();
225
226        if valid_indices.is_empty() {
227            return None;
228        }
229
230        let mut rng = rand::rng();
231        let random_index = rng.random_range(0..valid_indices.len());
232        Some(valid_indices[random_index])
233    }
234
235    /// 从集合中获取一个随机有效代理
236    /// 注意: 并不会增加失败计数
237    pub async fn get_random_valid_agent_state(&self) -> Option<AgentState> {
238        let mut agents = self.agents.lock().await;
239
240        let valid_indices: Vec<usize> = agents
241            .iter()
242            .enumerate()
243            .filter(|(_, state)| state.is_valid())
244            .map(|(i, _)| i)
245            .collect();
246
247        if valid_indices.is_empty() {
248            return None;
249        }
250
251        let mut rng = rand::rng();
252        let random_index = rng.random_range(0..valid_indices.len());
253        let agent_index = valid_indices[random_index];
254        agents.get_mut(agent_index).cloned()
255    }
256
257    /// 获取总代理数量(包括无效的)
258    pub async fn total_len(&self) -> usize {
259        let agents = self.agents.lock().await;
260        agents.len()
261    }
262
263    /// 检查是否有有效代理
264    pub async fn is_empty(&self) -> bool {
265        self.len().await == 0
266    }
267
268    /// 获取agent info
269    pub async fn get_agents_info(&self) -> Vec<AgentInfo> {
270        let agents = self.agents.lock().await;
271        let agent_infos = agents.iter().map(|agent| agent.info.clone()).collect::<_>();
272        tracing::info!("agents info: {:?}", agent_infos);
273        agent_infos
274    }
275
276    /// 获取失败统计
277    pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
278        let agents = self.agents.lock().await;
279        agents
280            .iter()
281            .enumerate()
282            .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
283            .collect()
284    }
285
286    /// 重置所有代理的失败计数
287    pub async fn reset_failures(&self) {
288        let mut agents = self.agents.lock().await;
289        for state in agents.iter_mut() {
290            state.info.failure_count = 0;
291        }
292    }
293
294    /// 通过名称获取 agent
295    pub async fn get_agent_by_name(
296        &self,
297        provider_name: &str,
298        model_name: &str,
299    ) -> Option<AgentState> {
300        let mut agents = self.agents.lock().await;
301
302        for agent in agents.iter_mut() {
303            if agent.info.provider == provider_name && agent.info.model == model_name {
304                return Some(agent.clone());
305            }
306        }
307
308        None
309    }
310
311    /// 通过id获取 agent
312    pub async fn get_agent_by_id(&self, id: i32) -> Option<AgentState> {
313        let mut agents = self.agents.lock().await;
314
315        for agent in agents.iter_mut() {
316            if agent.info.id == id {
317                return Some(agent.clone());
318            }
319        }
320
321        None
322    }
323
324    /// 添加失败重试
325    pub async fn try_invoke_with_retry(
326        &self,
327        info: Message,
328        retry_num: Option<usize>,
329    ) -> Result<String, RandAgentError> {
330        let mut config = ExponentialBuilder::default();
331        if let Some(retry_num) = retry_num {
332            config = config.with_max_times(retry_num)
333        }
334
335        let info = Arc::new(info);
336
337        let content = (|| {
338            let agent = self.clone();
339            let prompt = info.clone();
340            async move { agent.prompt((*prompt).clone()).await }
341        })
342        .retry(config)
343        .sleep(tokio::time::sleep)
344        .notify(|err: &PromptError, dur: Duration| {
345            println!("retrying {err:?} after {dur:?}");
346        })
347        .await?;
348        Ok(content)
349    }
350
351    #[allow(refining_impl_trait)]
352    pub async fn prompt_with_info(
353        &self,
354        prompt: impl Into<Message> + Send,
355    ) -> Result<(String, AgentInfo), PromptError> {
356        // 第一步:选择代理并获取其索引
357        let agent_index = self
358            .get_random_valid_agent_index()
359            .await
360            .ok_or(CompletionError::ProviderError("没有有效agent".to_string()))?;
361
362        // 第二步:加锁并获取可变引用
363        let mut agents = self.agents.lock().await;
364        let agent_state = &mut agents[agent_index];
365
366        let agent_info = agent_state.info.clone();
367
368        tracing::info!(
369            "prompt_with_info Using provider: {}, model: {},id: {}",
370            agent_state.info.provider,
371            agent_state.info.model,
372            agent_state.info.id
373        );
374        match agent_state.agent.prompt(prompt).await {
375            Ok(content) => {
376                agent_state.record_success();
377                Ok((content, agent_info))
378            }
379            Err(e) => {
380                agent_state.record_failure();
381                if !agent_state.is_valid()
382                    && let Some(cb) = &self.on_agent_invalid
383                {
384                    cb(agent_state.id);
385                }
386                Err(e)
387            }
388        }
389    }
390
391    /// 添加失败重试
392    pub async fn try_invoke_with_info_retry(
393        &self,
394        info: Message,
395        retry_num: Option<usize>,
396    ) -> Result<(String, AgentInfo), RandAgentError> {
397        let mut config = ExponentialBuilder::default();
398        if let Some(retry_num) = retry_num {
399            config = config.with_max_times(retry_num)
400        }
401
402        let info = Arc::new(info);
403
404        let content = (|| {
405            let agent = self.clone();
406            let prompt = info.clone();
407            async move { agent.prompt_with_info((*prompt).clone()).await }
408        })
409        .retry(config)
410        .sleep(tokio::time::sleep)
411        .notify(|err: &PromptError, dur: Duration| {
412            println!("retrying {err:?} after {dur:?}");
413        })
414        .await?;
415        Ok(content)
416    }
417}
418
419/// 线程安全 RandAgent 的构建器
420pub struct RandAgentBuilder {
421    pub(crate) agents: Vec<(AgentVariant, i32, String, String)>,
422    max_failures: u32,
423    on_agent_invalid: OnAgentInvalidCallback,
424}
425
426impl RandAgentBuilder {
427    /// 创建新的 RandAgentBuilder
428    pub fn new() -> Self {
429        Self {
430            agents: Vec::new(),
431            max_failures: 3, // 默认最大失败次数
432            on_agent_invalid: None,
433        }
434    }
435
436    /// 设置连续失败的最大次数,超过后标记代理为无效
437    pub fn max_failures(mut self, max_failures: u32) -> Self {
438        self.max_failures = max_failures;
439        self
440    }
441
442    /// 设置 agent 失效时的回调
443    pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
444    where
445        F: Fn(i32) + Send + Sync + 'static,
446    {
447        self.on_agent_invalid = Some(Arc::new(Box::new(callback)));
448        self
449    }
450
451    /// 添加代理到构建器
452    ///
453    /// # 参数
454    /// - agent: 代理实例(AgentVariant 枚举)
455    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
456    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
457    pub fn add_agent(
458        mut self,
459        agent: AgentVariant,
460        id: i32,
461        provider_name: String,
462        model_name: String,
463    ) -> Self {
464        self.agents.push((agent, id, provider_name, model_name));
465        self
466    }
467
468    /// 构建 RandAgent
469    pub fn build(self) -> RandAgent {
470        RandAgent::with_max_failures_and_callback(
471            self.agents,
472            self.max_failures,
473            self.on_agent_invalid,
474        )
475    }
476}
477
478impl Default for RandAgentBuilder {
479    fn default() -> Self {
480        Self::new()
481    }
482}