rig_extra/
rand_agent.rs

1//!
2//! RandAgent - 多代理随机选择器
3//!
4//! 该模块提供了一个 `RandAgent` 结构体,可以包装多个 AI 代理,
5//! 每次调用时随机选择一个代理来执行任务。
6//!
7//! ## 特性
8//!
9//! - 支持任意数量的 AI 代理
10//! - 每次调用时随机选择一个有效代理
11//! - 自动记录代理失败次数,连续失败达到阈值后标记为无效
12//! - 成功响应时自动重置失败计数
13//! - 线程安全的随机数生成
14//! - 提供构建器模式
15//! - 支持失败统计和重置功能
16//!
17//! ## 使用示例
18//!
19//! ```rust
20//! use rig_extra::extra_providers::{bigmodel::Client};
21//! use rig::client::ProviderClient;
22//! use rig::client::completion::CompletionClientDyn;
23//! use rig_extra::rand_agent::RandAgentBuilder;
24//! use rig_extra::error::RandAgentError;
25//! 
26//! #[tokio::main]
27//! async fn main() -> Result<(), RandAgentError> {
28//!     // 创建多个客户端
29//!     
30//! use rig_extra::error::RandAgentError;
31//! let client1 = Client::from_env();
32//!     let client2 = Client::from_env();
33//!
34//!     // 创建 agent
35//!     let agent1 = client1.agent("glm-4-flash").build();
36//!     let agent2 = client2.agent("glm-4-flash").build();
37//!
38//!     // 使用构建器创建 RandAgent,设置最大失败次数
39//!     let mut rand_agent = RandAgentBuilder::new()
40//!         .max_failures(3) // 连续失败3次后标记为无效
41//!         .add_agent(agent1, "bigmodel".to_string(), "glm-4-flash".to_string())
42//!         .add_agent(agent2, "bigmodel".to_string(), "glm-4-flash".to_string())
43//!         .build();
44//!
45//!     // 发送消息,会随机选择一个有效代理
46//!     let response = rand_agent.prompt("Hello!").await?;
47//!     println!("Response: {}", response);
48//!
49//!     // 查看失败统计
50//!     let stats = rand_agent.failure_stats();
51//!     println!("Failure stats: {:?}", stats);
52//!
53//!     Ok(())
54//! }
55//! ```
56
57use rand::Rng;
58use rig::agent::{Agent};
59use rig::client::builder::BoxAgent;
60use rig::completion::Prompt;
61use rig::client::completion::CompletionModelHandle;
62use crate::error::RandAgentError;
63
64
65/// Agent状态,包含agent实例和失败计数
66pub struct AgentState<'a> {
67    agent: BoxAgent<'a>,
68    provider: String,
69    model: String,
70    failure_count: u32,
71    max_failures: u32,
72}
73
74impl<'a> AgentState<'a> {
75    fn new(agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) -> Self {
76        Self {
77            agent,
78            provider,
79            model,
80            failure_count: 0,
81            max_failures,
82        }
83    }
84
85    fn is_valid(&self) -> bool {
86        self.failure_count < self.max_failures
87    }
88
89    fn record_failure(&mut self) {
90        self.failure_count += 1;
91    }
92
93    fn record_success(&mut self) {
94        self.failure_count = 0;
95    }
96}
97
98/// 包装多个代理的结构体,每次调用时随机选择一个代理
99pub struct RandAgent<'a> {
100    agents: Vec<AgentState<'a>>,
101}
102
103impl<'a> RandAgent<'a> {
104    /// 使用给定的代理创建新的 RandAgent
105    pub fn new(agents: Vec<(BoxAgent<'a>, String, String)>) -> Self {
106        Self::with_max_failures(agents, 3) // 默认最大失败次数为3
107    }
108
109    /// 使用自定义最大失败次数创建新的 RandAgent
110    pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, String, String)>, max_failures: u32) -> Self {
111        let agent_states = agents
112            .into_iter()
113            .map(|(agent, provider, model)| AgentState::new(agent, provider, model, max_failures))
114            .collect();
115        Self {
116            agents: agent_states,
117        }
118    }
119
120    
121    /// 向集合中添加代理
122    pub fn add_agent(&mut self, agent: BoxAgent<'a>, provider: String, model: String) {
123        self.agents.push(AgentState::new(agent, provider, model, 3)); // 使用默认最大失败次数
124    }
125
126    /// 使用自定义最大失败次数向集合中添加代理
127    pub fn add_agent_with_max_failures(&mut self, agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) {
128        self.agents.push(AgentState::new(agent, provider, model, max_failures));
129    }
130
131    /// 获取有效代理的数量
132    pub fn len(&self) -> usize {
133        self.agents.iter().filter(|state| state.is_valid()).count()
134    }
135
136    /// 获取代理总数(包括无效的)
137    pub fn total_len(&self) -> usize {
138        self.agents.len()
139    }
140
141    /// 检查是否有有效代理
142    pub fn is_empty(&self) -> bool {
143        self.len() == 0
144    }
145
146    /// 从集合中获取一个随机有效代理
147    async fn get_random_valid_agent(&mut self) -> Option<&mut AgentState<'a>> {
148        let valid_indices: Vec<usize> = self
149            .agents
150            .iter()
151            .enumerate()
152            .filter(|(_, state)| state.is_valid())
153            .map(|(i, _)| i)
154            .collect();
155
156        if valid_indices.is_empty() {
157            return None;
158        }
159
160        let mut rng = rand::rng();
161        let random_index = rng.random_range(0..valid_indices.len());
162        let agent_index = valid_indices[random_index];
163        self.agents.get_mut(agent_index)
164    }
165
166    /// 使用随机有效代理发送消息
167    pub async fn prompt(
168        &mut self,
169        message: &str,
170    ) -> Result<String, RandAgentError> {
171        let agent_state = self
172            .get_random_valid_agent()
173            .await
174            .ok_or(RandAgentError::NoValidAgents)?;
175
176        // 打印使用的provider和model
177        tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
178        match agent_state.agent.prompt(message).await {
179            Ok(response) => {
180                agent_state.record_success();
181                Ok(response)
182            }
183            Err(e) => {
184                agent_state.record_failure();
185                Err(RandAgentError::AgentError(Box::new(e)))
186            }
187        }
188    }
189
190    
191    /// 获取所有代理(用于调试或检查)
192    pub fn agents(&self) -> &[AgentState<'a>] {
193        &self.agents
194    }
195
196    /// 获取失败统计信息
197    pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
198        self.agents
199            .iter()
200            .enumerate()
201            .map(|(i, state)| (i, state.failure_count, state.max_failures))
202            .collect()
203    }
204
205    /// 重置所有代理的失败计数
206    pub fn reset_failures(&mut self) {
207        for state in &mut self.agents {
208            state.failure_count = 0;
209        }
210    }
211}
212
213
214
215/// 用于创建 RandAgent 实例的构建器
216pub struct RandAgentBuilder<'a> {
217    agents: Vec<(BoxAgent<'a>, String, String)>,
218    max_failures: u32,
219}
220
221impl<'a> RandAgentBuilder<'a> {
222    /// 创建新的 RandAgentBuilder
223    pub fn new() -> Self {
224        Self {
225            agents: Vec::new(),
226            max_failures: 3, // 默认最大失败次数
227        }
228    }
229
230    /// 设置标记代理为无效前的最大连续失败次数
231    pub fn max_failures(mut self, max_failures: u32) -> Self {
232        self.max_failures = max_failures;
233        self
234    }
235
236    /// 向构建器添加代理
237    ///
238    /// # 参数
239    /// - agent: 代理实例
240    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
241    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
242    pub fn add_agent(mut self, agent: BoxAgent<'a>, provider_name: String, model_name: String) -> Self {
243        self.agents.push((agent, provider_name, model_name));
244        self
245    }
246
247    /// 从 AgentBuilder 添加代理
248    ///
249    /// # 参数
250    /// - builder: AgentBuilder 实例
251    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
252    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
253    ///
254    /// 推荐优先使用 add_agent,add_builder 适用于直接传 AgentBuilder 的场景。
255    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, provider_name: &str, model_name: &str) -> Self {
256        self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
257        self
258    }
259
260    /// 构建 RandAgent
261    pub fn build(self) -> RandAgent<'a> {
262        RandAgent::with_max_failures(self.agents, self.max_failures)
263    }
264}
265
266impl<'a> Default for RandAgentBuilder<'a> {
267    fn default() -> Self {
268        Self::new()
269    }
270}