rig_extra/
rand_agent.rs

1//!
2//! RandAgent - 多代理随机选择器
3//!
4//! 该模块提供了一个 `RandAgent` 结构体,可以包装多个 AI 代理,
5//! 每次调用时随机选择一个代理来执行任务。
6//!
7//! ## 特性
8//!
9//! - 支持任意数量的 AI 代理
10//! - 每次调用时随机选择一个有效代理
11//! - 自动记录代理失败次数,连续失败达到阈值后标记为无效
12//! - 成功响应时自动重置失败计数
13//! - 线程安全的随机数生成
14//! - 提供构建器模式
15//! - 支持失败统计和重置功能
16//!
17//! ## 使用示例
18//!
19//! ```rust
20//! use rig_extra::extra_providers::{bigmodel::Client};
21//! use rig::client::ProviderClient;
22//! use rig::client::completion::CompletionClientDyn;
23//! use rig_extra::rand_agent::RandAgentBuilder;
24//! use rig_extra::error::RandAgentError;
25//!
26//! #[tokio::main]
27//! async fn main() -> Result<(), RandAgentError> {
28//!     use rig::completion::Prompt;
29//! // 创建多个客户端
30//!     
31//! use rig_extra::error::RandAgentError;
32//! let client1 = Client::from_env();
33//!     let client2 = Client::from_env();
34//!
35//!     // 创建 agent
36//!     let agent1 = client1.agent("glm-4-flash").build();
37//!     let agent2 = client2.agent("glm-4-flash").build();
38//!
39//!     // 使用构建器创建 RandAgent,设置最大失败次数
40//!     let mut rand_agent = RandAgentBuilder::new()
41//!         .max_failures(3) // 连续失败3次后标记为无效
42//!         .add_agent(agent1,1, "bigmodel".to_string(), "glm-4-flash".to_string())
43//!         .add_agent(agent2, 2,"bigmodel".to_string(), "glm-4-flash".to_string())
44//!         .build();
45//!
46//!     // 发送消息,会随机选择一个有效代理
47//!     let response = rand_agent.prompt("Hello!").await?;
48//!     println!("Response: {}", response);
49//!
50//!     // 查看失败统计
51//!     let stats = rand_agent.failure_stats().await;
52//!     println!("Failure stats: {:?}", stats);
53//!
54//!     Ok(())
55//! }
56//! ```
57
58use rand::Rng;
59use rig::agent::{Agent};
60use rig::client::builder::BoxAgent;
61use rig::completion::{Message, Prompt, PromptError};
62use rig::client::completion::CompletionModelHandle;
63use tokio::sync::Mutex;
64
65
66/// Agent状态,包含agent实例和失败计数
67pub struct AgentState<'a> {
68    id: i32,
69    agent: BoxAgent<'a>,
70    provider: String,
71    model: String,
72    failure_count: u32,
73    max_failures: u32,
74}
75
76impl<'a> AgentState<'a> {
77    fn new(agent: BoxAgent<'a>,id:i32, provider: String, model: String, max_failures: u32) -> Self {
78        Self {
79            id,
80            agent,
81            provider,
82            model,
83            failure_count: 0,
84            max_failures,
85        }
86    }
87
88    fn is_valid(&self) -> bool {
89        self.failure_count < self.max_failures
90    }
91
92    fn record_failure(&mut self) {
93        self.failure_count += 1;
94    }
95
96    fn record_success(&mut self) {
97        self.failure_count = 0;
98    }
99}
100
101/// 代理失效回调类型,减少类型复杂度
102pub type OnRandAgentInvalidCallback = Option<Box<dyn Fn(i32) + Send + Sync + 'static>>;
103
104/// 包装多个代理的结构体,每次调用时随机选择一个代理
105pub struct RandAgent<'a> {
106    agents: Mutex<Vec<AgentState<'a>>>,
107    on_agent_invalid: OnRandAgentInvalidCallback,
108}
109
110impl Prompt for RandAgent<'_> {
111    #[allow(refining_impl_trait)]
112    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
113        let mut agents = self.agents.lock().await;
114        let agent_state = Self::get_random_valid_agent(&mut agents)
115            .await
116            .ok_or(PromptError::MaxDepthError {
117                max_depth: 0,
118                chat_history: vec![],
119                prompt: "没有有效agent".into(),
120            })?;
121
122        tracing::info!("Using provider: {}, model: {}", agent_state.provider, agent_state.model);
123        match agent_state.agent.prompt(prompt).await {
124            Ok(content) => {
125                agent_state.record_success();
126                Ok(content)
127            }
128            Err(e) => {
129                agent_state.record_failure();
130                if !agent_state.is_valid() {
131                    if let Some(cb) = &self.on_agent_invalid {
132                        cb(agent_state.id);
133                    }
134                }
135                Err(e)
136            }
137        }
138    }
139}
140
141impl<'a> RandAgent<'a> {
142    /// 使用给定的代理创建新的 RandAgent
143    pub fn new(agents: Vec<(BoxAgent<'a>, i32, String, String)>) -> Self {
144        Self::with_max_failures_and_callback(agents, 3, None)
145    }
146
147    /// 使用自定义最大失败次数和回调创建新的 RandAgent
148    pub fn with_max_failures_and_callback(
149        agents: Vec<(BoxAgent<'a>, i32, String, String)>,
150        max_failures: u32,
151        on_agent_invalid: OnRandAgentInvalidCallback,
152    ) -> Self {
153        let agent_states = agents
154            .into_iter()
155            .map(|(agent, id, provider, model)| AgentState::new(agent, id, provider, model, max_failures))
156            .collect();
157        Self {
158            agents: Mutex::new(agent_states),
159            on_agent_invalid,
160        }
161    }
162
163    /// 使用自定义最大失败次数创建新的 RandAgent
164    pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, i32,String, String)>, max_failures: u32) -> Self {
165        Self::with_max_failures_and_callback(agents, max_failures, None)
166    }
167
168    /// 设置 agent 失效时的回调
169    pub fn set_on_agent_invalid<F>(&mut self, callback: F)
170    where
171        F: Fn(i32) + Send + Sync + 'static,
172    {
173        self.on_agent_invalid = Some(Box::new(callback));
174    }
175
176    
177    /// 向集合中添加代理
178    pub async fn add_agent(&self, agent: BoxAgent<'a>, id: i32, provider: String, model: String) {
179        self.agents.lock().await.push(AgentState::new(agent, id, provider, model, 3)); // 使用默认最大失败次数
180    }
181
182    /// 使用自定义最大失败次数向集合中添加代理
183    pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'a>, id: i32, provider: String, model: String, max_failures: u32) {
184        self.agents.lock().await.push(AgentState::new(agent, id, provider, model, max_failures));
185    }
186
187    /// 获取有效代理的数量
188    pub async fn len(&self) -> usize {
189        self.agents.lock().await.iter().filter(|state| state.is_valid()).count()
190    }
191
192    /// 获取代理总数(包括无效的)
193    pub async fn total_len(&self) -> usize {
194        self.agents.lock().await.len()
195    }
196
197    /// 检查是否有有效代理
198    pub async fn is_empty(&self) -> bool {
199        self.len().await == 0
200    }
201
202    /// 从集合中获取一个随机有效代理
203    async fn get_random_valid_agent<'b>(agents: &'b mut Vec<AgentState<'a>>) -> Option<&'b mut AgentState<'a>> {
204        let valid_indices: Vec<usize> = agents
205            .iter()
206            .enumerate()
207            .filter(|(_, state)| state.is_valid())
208            .map(|(i, _)| i)
209            .collect();
210
211        if valid_indices.is_empty() {
212            return None;
213        }
214
215        let mut rng = rand::rng();
216        let random_index = rng.random_range(0..valid_indices.len());
217        let agent_index = valid_indices[random_index];
218        agents.get_mut(agent_index)
219    }
220    
221
222    /// 获取失败统计信息
223    pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
224        self.agents
225            .lock()
226            .await
227            .iter()
228            .enumerate()
229            .map(|(i, state)| (i, state.failure_count, state.max_failures))
230            .collect()
231    }
232
233    /// 重置所有代理的失败计数
234    pub async fn reset_failures(&self) {
235        for state in self.agents.lock().await.iter_mut() {
236            state.failure_count = 0;
237        }
238    }
239}
240
241
242
243/// 用于创建 RandAgent 实例的构建器
244pub struct RandAgentBuilder<'a> {
245    agents: Vec<(BoxAgent<'a>, i32, String, String)>,
246    max_failures: u32,
247    on_agent_invalid: Option<Box<dyn Fn(i32) + Send + Sync + 'static>>,
248}
249
250impl<'a> RandAgentBuilder<'a> {
251    /// 创建新的 RandAgentBuilder
252    pub fn new() -> Self {
253        Self {
254            agents: Vec::new(),
255            max_failures: 3, // 默认最大失败次数
256            on_agent_invalid: None,
257        }
258    }
259
260    /// 设置标记代理为无效前的最大连续失败次数
261    pub fn max_failures(mut self, max_failures: u32) -> Self {
262        self.max_failures = max_failures;
263        self
264    }
265
266    /// 设置 agent 失效时的回调
267    pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
268    where
269        F: Fn(i32) + Send + Sync + 'static,
270    {
271        self.on_agent_invalid = Some(Box::new(callback));
272        self
273    }
274
275    /// 向构建器添加代理
276    ///
277    /// # 参数
278    /// - agent: 代理实例
279    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
280    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
281    pub fn add_agent(mut self, agent: BoxAgent<'a>, id: i32, provider_name: String, model_name: String) -> Self {
282        self.agents.push((agent, id, provider_name, model_name));
283        self
284    }
285
286    /// 从 AgentBuilder 添加代理
287    ///
288    /// # 参数
289    /// - builder: AgentBuilder 实例
290    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
291    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
292    ///
293    /// 推荐优先使用 add_agent,add_builder 适用于直接传 AgentBuilder 的场景。
294    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, id: i32, provider_name: &str, model_name: &str) -> Self {
295        self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
296        self
297    }
298
299    /// 构建 RandAgent
300    pub fn build(self) -> RandAgent<'a> {
301        RandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
302    }
303}
304
305impl<'a> Default for RandAgentBuilder<'a> {
306    fn default() -> Self {
307        Self::new()
308    }
309}