rig_extra/
rand_agent.rs

1//!
2//! RandAgent - 多代理随机选择器
3//!
4//! 该模块提供了一个 `RandAgent` 结构体,可以包装多个 AI 代理,
5//! 每次调用时随机选择一个代理来执行任务。
6//!
7//! ## 特性
8//!
9//! - 支持任意数量的 AI 代理
10//! - 每次调用时随机选择一个有效代理
11//! - 自动记录代理失败次数,连续失败达到阈值后标记为无效
12//! - 成功响应时自动重置失败计数
13//! - 线程安全的随机数生成
14//! - 提供构建器模式
15//! - 支持失败统计和重置功能
16//!
17//! ## 使用示例
18//!
19//! ```rust
20//! use rig_extra::extra_providers::{bigmodel::Client};
21//! use rig::client::ProviderClient;
22//! use rig::client::completion::CompletionClientDyn;
23//! use rig_extra::rand_agent::RandAgentBuilder;
24//! #[tokio::main]
25//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
26//!     // 创建多个客户端
27//!     
28//! let client1 = Client::from_env();
29//!     let client2 = Client::from_env();
30//!
31//!     // 创建 agent
32//!     let agent1 = client1.agent("glm-4-flash").build();
33//!     let agent2 = client2.agent("glm-4-flash").build();
34//!
35//!     // 使用构建器创建 RandAgent,设置最大失败次数
36//!     let mut rand_agent = RandAgentBuilder::new()
37//!         .max_failures(3) // 连续失败3次后标记为无效
38//!         .add_agent(agent1, "bigmodel".to_string(), "glm-4-flash".to_string())
39//!         .add_agent(agent2, "bigmodel".to_string(), "glm-4-flash".to_string())
40//!         .build();
41//!
42//!     // 发送消息,会随机选择一个有效代理
43//!     let response = rand_agent.prompt("Hello!").await?;
44//!     println!("Response: {}", response);
45//!
46//!     // 查看失败统计
47//!     let stats = rand_agent.failure_stats();
48//!     println!("Failure stats: {:?}", stats);
49//!
50//!     Ok(())
51//! }
52//! ```
53
54use rand::Rng;
55use rig::agent::{Agent, AgentBuilder};
56use rig::client::builder::BoxAgent;
57use rig::completion::Prompt;
58use rig::client::completion::CompletionModelHandle;
59
60
61/// Agent状态,包含agent实例和失败计数
62pub struct AgentState<'a> {
63    agent: BoxAgent<'a>,
64    provider: String,
65    model: String,
66    failure_count: u32,
67    max_failures: u32,
68}
69
70impl<'a> AgentState<'a> {
71    fn new(agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) -> Self {
72        Self {
73            agent,
74            provider,
75            model,
76            failure_count: 0,
77            max_failures,
78        }
79    }
80
81    fn is_valid(&self) -> bool {
82        self.failure_count < self.max_failures
83    }
84
85    fn record_failure(&mut self) {
86        self.failure_count += 1;
87    }
88
89    fn record_success(&mut self) {
90        self.failure_count = 0;
91    }
92}
93
94/// A wrapper struct that holds multiple agents and randomly selects one for each invocation
95pub struct RandAgent<'a> {
96    agents: Vec<AgentState<'a>>,
97}
98
99impl<'a> RandAgent<'a> {
100    /// Create a new RandAgent with the given agents
101    pub fn new(agents: Vec<(BoxAgent<'a>, String, String)>) -> Self {
102        Self::with_max_failures(agents, 3) // 默认最大失败次数为3
103    }
104
105    /// Create a new RandAgent with custom max failure count
106    pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, String, String)>, max_failures: u32) -> Self {
107        let agent_states = agents
108            .into_iter()
109            .map(|(agent, provider, model)| AgentState::new(agent, provider, model, max_failures))
110            .collect();
111        Self {
112            agents: agent_states,
113        }
114    }
115
116    /// Create a RandAgent from multiple AgentBuilders
117    #[deprecated(note = "请使用 RandAgentBuilder::add_agent/add_builder 方式构建并传递 provider/model")]
118    pub fn from_builders(_builders: Vec<AgentBuilder<CompletionModelHandle<'a>>>) -> Self {
119        unimplemented!("from_builders 已废弃,请使用 RandAgentBuilder::add_agent/add_builder 方式");
120    }
121
122    /// Create a RandAgent from multiple AgentBuilders with custom max failure count
123    #[deprecated(note = "请使用 RandAgentBuilder::add_agent/add_builder 方式构建并传递 provider/model")]
124    pub fn from_builders_with_max_failures(
125        _builders: Vec<AgentBuilder<CompletionModelHandle<'a>>>,
126        _max_failures: u32,
127    ) -> Self {
128        unimplemented!("from_builders_with_max_failures 已废弃,请使用 RandAgentBuilder::add_agent/add_builder 方式");
129    }
130
131    /// Add an agent to the collection
132    pub fn add_agent(&mut self, agent: BoxAgent<'a>, provider: String, model: String) {
133        self.agents.push(AgentState::new(agent, provider, model, 3)); // 使用默认最大失败次数
134    }
135
136    /// Add an agent to the collection with custom max failure count
137    pub fn add_agent_with_max_failures(&mut self, agent: BoxAgent<'a>, provider: String, model: String, max_failures: u32) {
138        self.agents.push(AgentState::new(agent, provider, model, max_failures));
139    }
140
141    /// Get the number of valid agents
142    pub fn len(&self) -> usize {
143        self.agents.iter().filter(|state| state.is_valid()).count()
144    }
145
146    /// Get the total number of agents (including invalid ones)
147    pub fn total_len(&self) -> usize {
148        self.agents.len()
149    }
150
151    /// Check if there are any valid agents
152    pub fn is_empty(&self) -> bool {
153        self.len() == 0
154    }
155
156    /// Get a random valid agent from the collection
157    async fn get_random_valid_agent(&mut self) -> Option<&mut AgentState<'a>> {
158        let valid_indices: Vec<usize> = self
159            .agents
160            .iter()
161            .enumerate()
162            .filter(|(_, state)| state.is_valid())
163            .map(|(i, _)| i)
164            .collect();
165
166        if valid_indices.is_empty() {
167            return None;
168        }
169
170        let mut rng = rand::rng();
171        let random_index = rng.random_range(0..valid_indices.len());
172        let agent_index = valid_indices[random_index];
173        self.agents.get_mut(agent_index)
174    }
175
176    /// Prompt a random valid agent with the given message
177    pub async fn prompt(
178        &mut self,
179        message: &str,
180    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
181        let agent_state = self
182            .get_random_valid_agent()
183            .await
184            .ok_or("No valid agents available")?;
185
186        // 打印使用的provider和model
187        tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
188        match agent_state.agent.prompt(message).await {
189            Ok(response) => {
190                agent_state.record_success();
191                Ok(response)
192            }
193            Err(e) => {
194                agent_state.record_failure();
195                Err(e.into())
196            }
197        }
198    }
199
200    /// Stream a response from a random agent (not implemented due to trait compatibility issues)
201    pub async fn stream_prompt(
202        &self,
203        _message: &str,
204    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
205        Err("Streaming not implemented for RandAgent".into())
206    }
207
208    /// Get all agents (for debugging or inspection)
209    pub fn agents(&self) -> &[AgentState<'a>] {
210        &self.agents
211    }
212
213    /// Get failure statistics
214    pub fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
215        self.agents
216            .iter()
217            .enumerate()
218            .map(|(i, state)| (i, state.failure_count, state.max_failures))
219            .collect()
220    }
221
222    /// Reset failure counts for all agents
223    pub fn reset_failures(&mut self) {
224        for state in &mut self.agents {
225            state.failure_count = 0;
226        }
227    }
228}
229
230// Note: RandAgent cannot implement Clone because BoxAgent<'a> may not implement Clone
231// If you need to clone a RandAgent, you'll need to rebuild it from the original agents
232
233/// Builder for creating RandAgent instances
234pub struct RandAgentBuilder<'a> {
235    agents: Vec<(BoxAgent<'a>, String, String)>,
236    max_failures: u32,
237}
238
239impl<'a> RandAgentBuilder<'a> {
240    /// Create a new RandAgentBuilder
241    pub fn new() -> Self {
242        Self {
243            agents: Vec::new(),
244            max_failures: 3, // 默认最大失败次数
245        }
246    }
247
248    /// Set the maximum number of consecutive failures before marking an agent as invalid
249    pub fn max_failures(mut self, max_failures: u32) -> Self {
250        self.max_failures = max_failures;
251        self
252    }
253
254    /// Add an agent to the builder
255    ///
256    /// # 参数
257    /// - agent: 代理实例
258    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
259    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
260    pub fn add_agent(mut self, agent: BoxAgent<'a>, provider_name: String, model_name: String) -> Self {
261        self.agents.push((agent, provider_name, model_name));
262        self
263    }
264
265    /// Add an agent from an AgentBuilder
266    ///
267    /// # 参数
268    /// - builder: AgentBuilder 实例
269    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
270    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
271    ///
272    /// 推荐优先使用 add_agent,add_builder 适用于直接传 AgentBuilder 的场景。
273    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, provider_name: &str, model_name: &str) -> Self {
274        self.agents.push((builder, provider_name.to_string(), model_name.to_string()));
275        self
276    }
277
278    /// Build the RandAgent
279    pub fn build(self) -> RandAgent<'a> {
280        RandAgent::with_max_failures(self.agents, self.max_failures)
281    }
282}
283
284impl<'a> Default for RandAgentBuilder<'a> {
285    fn default() -> Self {
286        Self::new()
287    }
288}