rig_extra/
rand_agent.rs

1//!
2//! RandAgent - 多代理随机选择器
3//!
4//! 该模块提供了一个 `RandAgent` 结构体,可以包装多个 AI 代理,
5//! 每次调用时随机选择一个代理来执行任务。
6//!
7//! ## 特性
8//!
9//! - 支持任意数量的 AI 代理
10//! - 每次调用时随机选择一个有效代理
11//! - 自动记录代理失败次数,连续失败达到阈值后标记为无效
12//! - 成功响应时自动重置失败计数
13//! - 线程安全的随机数生成
14//! - 提供构建器模式
15//! - 支持失败统计和重置功能
16//!
17//! ## 使用示例
18//!
19//! ```rust
20//! use rig_extra::extra_providers::{bigmodel::Client};
21//! use rig::client::ProviderClient;
22//! use rig::client::completion::CompletionClientDyn;
23//! use rig_extra::rand_agent::RandAgentBuilder;
24//! #[tokio::main]
25//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
26//!     // 创建多个客户端
27//!     
28//! let client1 = Client::from_env();
29//!     let client2 = Client::from_env();
30//!
31//!     // 创建 agent
32//!     let agent1 = client1.agent("glm-4-flash").build();
33//!     let agent2 = client2.agent("glm-4-flash").build();
34//!
35//!     // 使用构建器创建 RandAgent,设置最大失败次数
36//!     let mut rand_agent = RandAgentBuilder::new()
37//!         .max_failures(3) // 连续失败3次后标记为无效
38//!         .add_agent(agent1, "bigmodel".to_string(), "glm-4-flash".to_string())
39//!         .add_agent(agent2, "bigmodel".to_string(), "glm-4-flash".to_string())
40//!         .build();
41//!
42//!     // 发送消息,会随机选择一个有效代理
43//!     let response = rand_agent.prompt("Hello!").await?;
44//!     println!("Response: {}", response);
45//!
46//!     // 查看失败统计
47//!     let stats = rand_agent.failure_stats();
48//!     println!("Failure stats: {:?}", stats);
49//!
50//!     Ok(())
51//! }
52//! ```
53
54use rand::Rng;
55use rig::agent::{Agent};
56use rig::client::builder::BoxAgent;
57use rig::completion::Prompt;
58use rig::client::completion::CompletionModelHandle;
59
60
61/// Agent状态,包含agent实例和失败计数
62pub struct AgentState<'a> {
63    agent: BoxAgent<'a>,
64    provider: String,
65    model: String,
66    failure_count: u32,
67    max_failures: u32,
68}
69
70impl<'a> AgentState<'a> {
71    fn new(agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) -> Self {
72        Self {
73            agent,
74            provider,
75            model,
76            failure_count: 0,
77            max_failures,
78        }
79    }
80
81    fn is_valid(&self) -> bool {
82        self.failure_count < self.max_failures
83    }
84
85    fn record_failure(&mut self) {
86        self.failure_count += 1;
87    }
88
89    fn record_success(&mut self) {
90        self.failure_count = 0;
91    }
92}
93
94/// 包装多个代理的结构体,每次调用时随机选择一个代理
95pub struct RandAgent<'a> {
96    agents: Vec<AgentState<'a>>,
97}
98
99impl<'a> RandAgent<'a> {
100    /// 使用给定的代理创建新的 RandAgent
101    pub fn new(agents: Vec<(BoxAgent<'a>, String, String)>) -> Self {
102        Self::with_max_failures(agents, 3) // 默认最大失败次数为3
103    }
104
105    /// 使用自定义最大失败次数创建新的 RandAgent
106    pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, String, String)>, max_failures: u32) -> Self {
107        let agent_states = agents
108            .into_iter()
109            .map(|(agent, provider, model)| AgentState::new(agent, provider, model, max_failures))
110            .collect();
111        Self {
112            agents: agent_states,
113        }
114    }
115
116    
117    /// 向集合中添加代理
118    pub fn add_agent(&mut self, agent: BoxAgent<'a>, provider: String, model: String) {
119        self.agents.push(AgentState::new(agent, provider, model, 3)); // 使用默认最大失败次数
120    }
121
122    /// 使用自定义最大失败次数向集合中添加代理
123    pub fn add_agent_with_max_failures(&mut self, agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) {
124        self.agents.push(AgentState::new(agent, provider, model, max_failures));
125    }
126
127    /// 获取有效代理的数量
128    pub fn len(&self) -> usize {
129        self.agents.iter().filter(|state| state.is_valid()).count()
130    }
131
132    /// 获取代理总数(包括无效的)
133    pub fn total_len(&self) -> usize {
134        self.agents.len()
135    }
136
137    /// 检查是否有有效代理
138    pub fn is_empty(&self) -> bool {
139        self.len() == 0
140    }
141
142    /// 从集合中获取一个随机有效代理
143    async fn get_random_valid_agent(&mut self) -> Option<&mut AgentState<'a>> {
144        let valid_indices: Vec<usize> = self
145            .agents
146            .iter()
147            .enumerate()
148            .filter(|(_, state)| state.is_valid())
149            .map(|(i, _)| i)
150            .collect();
151
152        if valid_indices.is_empty() {
153            return None;
154        }
155
156        let mut rng = rand::rng();
157        let random_index = rng.random_range(0..valid_indices.len());
158        let agent_index = valid_indices[random_index];
159        self.agents.get_mut(agent_index)
160    }
161
162    /// 使用随机有效代理发送消息
163    pub async fn prompt(
164        &mut self,
165        message: &str,
166    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
167        let agent_state = self
168            .get_random_valid_agent()
169            .await
170            .ok_or("No valid agents available")?;
171
172        // 打印使用的provider和model
173        tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
174        match agent_state.agent.prompt(message).await {
175            Ok(response) => {
176                agent_state.record_success();
177                Ok(response)
178            }
179            Err(e) => {
180                agent_state.record_failure();
181                Err(e.into())
182            }
183        }
184    }
185
186    
187    /// 获取所有代理(用于调试或检查)
188    pub fn agents(&self) -> &[AgentState<'a>] {
189        &self.agents
190    }
191
192    /// 获取失败统计信息
193    pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
194        self.agents
195            .iter()
196            .enumerate()
197            .map(|(i, state)| (i, state.failure_count, state.max_failures))
198            .collect()
199    }
200
201    /// 重置所有代理的失败计数
202    pub fn reset_failures(&mut self) {
203        for state in &mut self.agents {
204            state.failure_count = 0;
205        }
206    }
207}
208
209
210
211/// 用于创建 RandAgent 实例的构建器
212pub struct RandAgentBuilder<'a> {
213    agents: Vec<(BoxAgent<'a>, String, String)>,
214    max_failures: u32,
215}
216
217impl<'a> RandAgentBuilder<'a> {
218    /// 创建新的 RandAgentBuilder
219    pub fn new() -> Self {
220        Self {
221            agents: Vec::new(),
222            max_failures: 3, // 默认最大失败次数
223        }
224    }
225
226    /// 设置标记代理为无效前的最大连续失败次数
227    pub fn max_failures(mut self, max_failures: u32) -> Self {
228        self.max_failures = max_failures;
229        self
230    }
231
232    /// 向构建器添加代理
233    ///
234    /// # 参数
235    /// - agent: 代理实例
236    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
237    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
238    pub fn add_agent(mut self, agent: BoxAgent<'a>, provider_name: String, model_name: String) -> Self {
239        self.agents.push((agent, provider_name, model_name));
240        self
241    }
242
243    /// 从 AgentBuilder 添加代理
244    ///
245    /// # 参数
246    /// - builder: AgentBuilder 实例
247    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
248    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
249    ///
250    /// 推荐优先使用 add_agent,add_builder 适用于直接传 AgentBuilder 的场景。
251    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, provider_name: &str, model_name: &str) -> Self {
252        self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
253        self
254    }
255
256    /// 构建 RandAgent
257    pub fn build(self) -> RandAgent<'a> {
258        RandAgent::with_max_failures(self.agents, self.max_failures)
259    }
260}
261
262impl<'a> Default for RandAgentBuilder<'a> {
263    fn default() -> Self {
264        Self::new()
265    }
266}