rig_extra/
rand_agent.rs

1//!
2//! RandAgent - 多代理随机选择器
3//!
4//! 该模块提供了一个 `RandAgent` 结构体,可以包装多个 AI 代理,
5//! 每次调用时随机选择一个代理来执行任务。
6//!
7//! ## 特性
8//!
9//! - 支持任意数量的 AI 代理
10//! - 每次调用时随机选择一个有效代理
11//! - 自动记录代理失败次数,连续失败达到阈值后标记为无效
12//! - 成功响应时自动重置失败计数
13//! - 线程安全的随机数生成
14//! - 提供构建器模式
15//! - 支持失败统计和重置功能
16//!
17//! ## 使用示例
18//!
19//! ```rust
20//! use rig_extra::extra_providers::{bigmodel::Client};
21//! use rig::client::ProviderClient;
22//! use rig::client::completion::CompletionClientDyn;
23//! use rig_extra::rand_agent::RandAgentBuilder;
24//! use rig_extra::error::RandAgentError;
25//!
26//! #[tokio::main]
27//! async fn main() -> Result<(), RandAgentError> {
28//!     use rig::completion::Prompt;
29//! // 创建多个客户端
30//!     
31//! use rig_extra::error::RandAgentError;
32//! let client1 = Client::from_env();
33//!     let client2 = Client::from_env();
34//!
35//!     // 创建 agent
36//!     let agent1 = client1.agent("glm-4-flash").build();
37//!     let agent2 = client2.agent("glm-4-flash").build();
38//!
39//!     // 使用构建器创建 RandAgent,设置最大失败次数
40//!     let mut rand_agent = RandAgentBuilder::new()
41//!         .max_failures(3) // 连续失败3次后标记为无效
42//!         .add_agent(agent1,1, "bigmodel".to_string(), "glm-4-flash".to_string())
43//!         .add_agent(agent2, 2,"bigmodel".to_string(), "glm-4-flash".to_string())
44//!         .build();
45//!
46//!     // 发送消息,会随机选择一个有效代理
47//!     let response = rand_agent.prompt("Hello!").await?;
48//!     println!("Response: {}", response);
49//!
50//!     // 查看失败统计
51//!     let stats = rand_agent.failure_stats().await;
52//!     println!("Failure stats: {:?}", stats);
53//!
54//!     Ok(())
55//! }
56//! ```
57
58use rand::Rng;
59use rig::agent::{Agent};
60use rig::client::builder::BoxAgent;
61use rig::completion::{Message, Prompt, PromptError};
62use rig::client::completion::CompletionModelHandle;
63use tokio::sync::Mutex;
64use crate::AgentInfo;
65
66/// Agent状态,包含agent实例和失败计数
67pub struct AgentState<'a> {
68    id: i32,
69    agent: BoxAgent<'a>,
70    pub info: AgentInfo,
71}
72
73impl<'a> AgentState<'a> {
74    fn new(agent: BoxAgent<'a>,id:i32, provider: String, model: String, max_failures: u32) -> Self {
75        Self {
76            id,
77            agent,
78            info: AgentInfo{
79                id,
80                provider,
81                model,
82                failure_count: 0,
83                max_failures,
84            }
85        }
86    }
87
88    fn is_valid(&self) -> bool {
89        self.info.failure_count < self.info.max_failures
90    }
91
92    fn record_failure(&mut self) {
93        self.info.failure_count += 1;
94    }
95
96    fn record_success(&mut self) {
97        self.info.failure_count = 0;
98    }
99}
100
101/// 代理失效回调类型,减少类型复杂度
102pub type OnRandAgentInvalidCallback = Option<Box<dyn Fn(i32) + Send + Sync + 'static>>;
103
104/// 包装多个代理的结构体,每次调用时随机选择一个代理
105#[deprecated(since = "0.7.0", note = "使用 `ThreadSafeRandAgent`")]
106pub struct RandAgent<'a> {
107    agents: Mutex<Vec<AgentState<'a>>>,
108    on_agent_invalid: OnRandAgentInvalidCallback,
109}
110
111impl Prompt for RandAgent<'_> {
112    #[allow(refining_impl_trait)]
113    async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
114        let mut agents = self.agents.lock().await;
115        let agent_state = Self::get_random_valid_agent(&mut agents)
116            .await
117            .ok_or(PromptError::MaxDepthError {
118                max_depth: 0,
119                chat_history: vec![],
120                prompt: "没有有效agent".into(),
121            })?;
122
123        tracing::info!("Using provider: {}, model: {}", agent_state.info.provider, agent_state.info.model);
124        match agent_state.agent.prompt(prompt).await {
125            Ok(content) => {
126                agent_state.record_success();
127                Ok(content)
128            }
129            Err(e) => {
130                agent_state.record_failure();
131                if !agent_state.is_valid() {
132                    if let Some(cb) = &self.on_agent_invalid {
133                        cb(agent_state.id);
134                    }
135                }
136                Err(e)
137            }
138        }
139    }
140}
141
142impl<'a> RandAgent<'a> {
143    /// 使用给定的代理创建新的 RandAgent
144    pub fn new(agents: Vec<(BoxAgent<'a>, i32, String, String)>) -> Self {
145        Self::with_max_failures_and_callback(agents, 3, None)
146    }
147
148    /// 使用自定义最大失败次数和回调创建新的 RandAgent
149    pub fn with_max_failures_and_callback(
150        agents: Vec<(BoxAgent<'a>, i32, String, String)>,
151        max_failures: u32,
152        on_agent_invalid: OnRandAgentInvalidCallback,
153    ) -> Self {
154        let agent_states = agents
155            .into_iter()
156            .map(|(agent, id, provider, model)| AgentState::new(agent, id, provider, model, max_failures))
157            .collect();
158        Self {
159            agents: Mutex::new(agent_states),
160            on_agent_invalid,
161        }
162    }
163
164    /// 使用自定义最大失败次数创建新的 RandAgent
165    pub fn with_max_failures(agents: Vec<(BoxAgent<'a>, i32,String, String)>, max_failures: u32) -> Self {
166        Self::with_max_failures_and_callback(agents, max_failures, None)
167    }
168
169    /// 设置 agent 失效时的回调
170    pub fn set_on_agent_invalid<F>(&mut self, callback: F)
171    where
172        F: Fn(i32) + Send + Sync + 'static,
173    {
174        self.on_agent_invalid = Some(Box::new(callback));
175    }
176
177    
178    /// 向集合中添加代理
179    pub async fn add_agent(&self, agent: BoxAgent<'a>, id: i32, provider: String, model: String) {
180        self.agents.lock().await.push(AgentState::new(agent, id, provider, model, 3)); // 使用默认最大失败次数
181    }
182
183    /// 使用自定义最大失败次数向集合中添加代理
184    pub async fn add_agent_with_max_failures(&self, agent: BoxAgent<'a>, id: i32, provider: String, model: String, max_failures: u32) {
185        self.agents.lock().await.push(AgentState::new(agent, id, provider, model, max_failures));
186    }
187
188    /// 获取有效代理的数量
189    pub async fn len(&self) -> usize {
190        self.agents.lock().await.iter().filter(|state| state.is_valid()).count()
191    }
192
193    /// 获取代理总数(包括无效的)
194    pub async fn total_len(&self) -> usize {
195        self.agents.lock().await.len()
196    }
197
198    /// 检查是否有有效代理
199    pub async fn is_empty(&self) -> bool {
200        self.len().await == 0
201    }
202
203    /// 从集合中获取一个随机有效代理
204    async fn get_random_valid_agent<'b>(agents: &'b mut Vec<AgentState<'a>>) -> Option<&'b mut AgentState<'a>> {
205        let valid_indices: Vec<usize> = agents
206            .iter()
207            .enumerate()
208            .filter(|(_, state)| state.is_valid())
209            .map(|(i, _)| i)
210            .collect();
211
212        if valid_indices.is_empty() {
213            return None;
214        }
215
216        let mut rng = rand::rng();
217        let random_index = rng.random_range(0..valid_indices.len());
218        let agent_index = valid_indices[random_index];
219        agents.get_mut(agent_index)
220    }
221    
222
223    /// 获取失败统计信息
224    pub async fn failure_stats(&self) -> Vec<(usize, u32, u32)> {
225        self.agents
226            .lock()
227            .await
228            .iter()
229            .enumerate()
230            .map(|(i, state)| (i, state.info.failure_count, state.info.max_failures))
231            .collect()
232    }
233
234    /// 重置所有代理的失败计数
235    pub async fn reset_failures(&self) {
236        for state in self.agents.lock().await.iter_mut() {
237            state.info.failure_count = 0;
238        }
239    }
240}
241
242
243
244/// 用于创建 RandAgent 实例的构建器
245pub struct RandAgentBuilder<'a> {
246    pub(crate) agents: Vec<(BoxAgent<'a>, i32, String, String)>,
247    max_failures: u32,
248    on_agent_invalid: Option<Box<dyn Fn(i32) + Send + Sync + 'static>>,
249}
250
251impl<'a> RandAgentBuilder<'a> {
252    /// 创建新的 RandAgentBuilder
253    pub fn new() -> Self {
254        Self {
255            agents: Vec::new(),
256            max_failures: 3, // 默认最大失败次数
257            on_agent_invalid: None,
258        }
259    }
260
261    /// 设置标记代理为无效前的最大连续失败次数
262    pub fn max_failures(mut self, max_failures: u32) -> Self {
263        self.max_failures = max_failures;
264        self
265    }
266
267    /// 设置 agent 失效时的回调
268    pub fn on_agent_invalid<F>(mut self, callback: F) -> Self
269    where
270        F: Fn(i32) + Send + Sync + 'static,
271    {
272        self.on_agent_invalid = Some(Box::new(callback));
273        self
274    }
275
276    /// 向构建器添加代理
277    ///
278    /// # 参数
279    /// - agent: 代理实例
280    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
281    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
282    pub fn add_agent(mut self, agent: BoxAgent<'a>, id: i32, provider_name: String, model_name: String) -> Self {
283        self.agents.push((agent, id, provider_name, model_name));
284        self
285    }
286
287    /// 从 AgentBuilder 添加代理
288    ///
289    /// # 参数
290    /// - builder: AgentBuilder 实例
291    /// - provider_name: 提供方名称(如 openai、bigmodel 等)
292    /// - model_name: 模型名称(如 gpt-3.5、glm-4-flash 等)
293    ///
294    /// 推荐优先使用 add_agent,add_builder 适用于直接传 AgentBuilder 的场景。
295    pub fn add_builder(mut self, builder: Agent<CompletionModelHandle<'a>>, id: i32, provider_name: &str, model_name: &str) -> Self {
296        self.agents.push((builder, id, provider_name.to_string(), model_name.to_string()));
297        self
298    }
299
300    /// 构建 RandAgent
301    pub fn build(self) -> RandAgent<'a> {
302        RandAgent::with_max_failures_and_callback(self.agents, self.max_failures, self.on_agent_invalid)
303    }
304}
305
306impl<'a> Default for RandAgentBuilder<'a> {
307    fn default() -> Self {
308        Self::new()
309    }
310}