Skip to main content

mofa_foundation/llm/
multi_agent.rs

1//! 多 Agent 协作模式
2//!
3//! 提供高级的多 Agent 协作模式,包括:
4//!
5//! - **链式协作**: Agent 串行执行,前一个输出是后一个输入
6//! - **并行协作**: 多个 Agent 同时处理,结果聚合
7//! - **辩论模式**: 多个 Agent 交替辩论,达成共识
8//! - **监督模式**: 一个监督 Agent 评估其他 Agent 的输出
9//! - **MapReduce**: 并行处理后归约
10//!
11//! # 示例
12//!
13//! ```rust,ignore
14//! use mofa_foundation::llm::multi_agent::{AgentTeam, TeamPattern};
15//!
16//! // 创建 Agent 团队
17//! let team = AgentTeam::new()
18//!     .add_agent("analyst", analyst_agent)
19//!     .add_agent("writer", writer_agent)
20//!     .add_agent("editor", editor_agent)
21//!     .with_pattern(TeamPattern::Chain)
22//!     .build();
23//!
24//! let result = team.run("Analyze and write about Rust").await?;
25//! ```
26
27use super::agent::LLMAgent;
28use super::types::{LLMError, LLMResult};
29use std::collections::HashMap;
30use std::sync::Arc;
31
32/// Agent 团队协作模式
33#[derive(Debug, Clone)]
34pub enum TeamPattern {
35    /// 链式:按顺序执行
36    Chain,
37    /// 并行:同时执行,结果聚合
38    Parallel,
39    /// 辩论:多个 Agent 交替发言
40    Debate {
41        /// 最大轮数
42        max_rounds: usize,
43    },
44    /// 监督:一个监督者评估结果
45    Supervised,
46    /// MapReduce:并行处理后归约
47    MapReduce,
48    /// 自定义
49    Custom,
50}
51
52/// Agent 角色
53#[derive(Debug, Clone)]
54pub struct AgentRole {
55    /// 角色 ID
56    pub id: String,
57    /// 角色名称
58    pub name: String,
59    /// 角色描述(会添加到系统提示中)
60    pub description: String,
61    /// 提示词模板
62    pub prompt_template: Option<String>,
63}
64
65impl AgentRole {
66    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
67        Self {
68            id: id.into(),
69            name: name.into(),
70            description: String::new(),
71            prompt_template: None,
72        }
73    }
74
75    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
76        self.description = desc.into();
77        self
78    }
79
80    pub fn with_template(mut self, template: impl Into<String>) -> Self {
81        self.prompt_template = Some(template.into());
82        self
83    }
84}
85
86/// Agent 成员
87pub struct AgentMember {
88    /// 角色信息
89    pub role: AgentRole,
90    /// Agent 实例
91    pub agent: Arc<LLMAgent>,
92}
93
94impl AgentMember {
95    pub fn new(id: impl Into<String>, agent: Arc<LLMAgent>) -> Self {
96        let id = id.into();
97        Self {
98            role: AgentRole::new(&id, &id),
99            agent,
100        }
101    }
102
103    pub fn with_role(mut self, role: AgentRole) -> Self {
104        self.role = role;
105        self
106    }
107
108    /// 执行任务
109    pub async fn execute(&self, input: &str, context: Option<&str>) -> LLMResult<String> {
110        let prompt = if let Some(ref template) = self.role.prompt_template {
111            let mut p = template.replace("{input}", input);
112            if let Some(ctx) = context {
113                p = p.replace("{context}", ctx);
114            }
115            p
116        } else if let Some(ctx) = context {
117            format!("Context:\n{}\n\nTask:\n{}", ctx, input)
118        } else {
119            input.to_string()
120        };
121
122        self.agent.ask(&prompt).await
123    }
124}
125
126/// Agent 团队
127pub struct AgentTeam {
128    /// 团队 ID
129    pub id: String,
130    /// 团队名称
131    pub name: String,
132    /// 成员列表
133    members: Vec<AgentMember>,
134    /// 成员映射(按 ID)
135    member_map: HashMap<String, usize>,
136    /// 协作模式
137    pattern: TeamPattern,
138    /// 监督者 ID(用于 Supervised 模式)
139    supervisor_id: Option<String>,
140    /// 聚合提示词(用于并行和 MapReduce 模式)
141    aggregate_prompt: Option<String>,
142}
143
144impl AgentTeam {
145    /// 创建新的 Agent 团队
146    pub fn new(id: impl Into<String>) -> AgentTeamBuilder {
147        AgentTeamBuilder::new(id)
148    }
149
150    /// 链式执行
151    async fn run_chain(&self, input: &str) -> LLMResult<String> {
152        let mut current_output = input.to_string();
153
154        for member in &self.members {
155            current_output = member.execute(&current_output, None).await?;
156        }
157
158        Ok(current_output)
159    }
160
161    /// 并行执行
162    async fn run_parallel(&self, input: &str) -> LLMResult<String> {
163        let mut results = Vec::new();
164
165        // 由于 Agent 包含不可跨线程的闭包,这里顺序执行
166        // 未来可以通过重构 Agent 来实现真正的并行
167        for member in &self.members {
168            let result = member.execute(input, None).await?;
169            results.push((member.role.id.clone(), result));
170        }
171
172        // 聚合结果
173        let aggregated = results
174            .iter()
175            .map(|(id, result)| format!("=== {} ===\n{}", id, result))
176            .collect::<Vec<_>>()
177            .join("\n\n");
178
179        // 如果有聚合提示词,使用第一个 Agent 进行聚合
180        if let Some(ref aggregate_prompt) = self.aggregate_prompt
181            && let Some(first_member) = self.members.first()
182        {
183            let prompt = aggregate_prompt
184                .replace("{results}", &aggregated)
185                .replace("{input}", input);
186            return first_member.agent.ask(&prompt).await;
187        }
188
189        Ok(aggregated)
190    }
191
192    /// 辩论执行
193    async fn run_debate(&self, input: &str, max_rounds: usize) -> LLMResult<String> {
194        if self.members.len() < 2 {
195            return Err(LLMError::Other(
196                "Debate requires at least 2 agents".to_string(),
197            ));
198        }
199
200        let mut context = format!("Initial topic: {}\n\n", input);
201        let mut last_response = String::new();
202
203        for round in 0..max_rounds {
204            for (i, member) in self.members.iter().enumerate() {
205                let prompt = format!(
206                    "Round {}, Speaker {}: {}\n\n\
207                    Previous discussion:\n{}\n\n\
208                    Please provide your perspective. Be constructive and build on previous points.",
209                    round + 1,
210                    i + 1,
211                    member.role.name,
212                    context
213                );
214
215                let response = member.execute(&prompt, None).await?;
216                context.push_str(&format!(
217                    "[{} - Round {}]:\n{}\n\n",
218                    member.role.name,
219                    round + 1,
220                    response
221                ));
222                last_response = response;
223            }
224        }
225
226        // 最后总结
227        if let Some(first_member) = self.members.first() {
228            let summary_prompt = format!(
229                "Based on the following debate, provide a concise summary of the key points \
230                and conclusions:\n\n{}",
231                context
232            );
233            first_member.agent.ask(&summary_prompt).await
234        } else {
235            Ok(last_response)
236        }
237    }
238
239    /// 监督执行
240    async fn run_supervised(&self, input: &str) -> LLMResult<String> {
241        let supervisor_id = self.supervisor_id.as_ref().ok_or_else(|| {
242            LLMError::Other("Supervisor not specified for Supervised pattern".to_string())
243        })?;
244
245        let supervisor_idx = self
246            .member_map
247            .get(supervisor_id)
248            .ok_or_else(|| LLMError::Other(format!("Supervisor '{}' not found", supervisor_id)))?;
249
250        // 收集工作者结果
251        let mut worker_results = Vec::new();
252        for (i, member) in self.members.iter().enumerate() {
253            if i != *supervisor_idx {
254                let result = member.execute(input, None).await?;
255                worker_results.push((member.role.id.clone(), member.role.name.clone(), result));
256            }
257        }
258
259        // 让监督者评估
260        let results_text = worker_results
261            .iter()
262            .map(|(id, name, result)| format!("=== {} ({}) ===\n{}", name, id, result))
263            .collect::<Vec<_>>()
264            .join("\n\n");
265
266        let supervisor = &self.members[*supervisor_idx];
267        let eval_prompt = format!(
268            "You are the supervisor. Evaluate the following responses to the task: \"{}\"\n\n\
269            Responses:\n{}\n\n\
270            Please provide:\n\
271            1. An evaluation of each response\n\
272            2. The best response or a synthesized improved response\n\
273            3. Suggestions for improvement",
274            input, results_text
275        );
276
277        supervisor.agent.ask(&eval_prompt).await
278    }
279
280    /// MapReduce 执行
281    async fn run_map_reduce(&self, input: &str) -> LLMResult<String> {
282        // Map 阶段:每个 Agent 处理输入
283        let mut mapped_results = Vec::new();
284        for member in &self.members {
285            let result = member.execute(input, None).await?;
286            mapped_results.push((member.role.id.clone(), result));
287        }
288
289        // Reduce 阶段:聚合结果
290        let reduce_input = mapped_results
291            .iter()
292            .map(|(id, result)| format!("[{}]: {}", id, result))
293            .collect::<Vec<_>>()
294            .join("\n\n");
295
296        let reduce_prompt = if let Some(ref aggregate_prompt) = self.aggregate_prompt {
297            aggregate_prompt
298                .replace("{results}", &reduce_input)
299                .replace("{input}", input)
300        } else {
301            format!(
302                "Synthesize the following results into a coherent response:\n\n{}\n\n\
303                Original task: {}",
304                reduce_input, input
305            )
306        };
307
308        // 使用第一个 Agent 进行 reduce
309        if let Some(first_member) = self.members.first() {
310            first_member.agent.ask(&reduce_prompt).await
311        } else {
312            Ok(reduce_input)
313        }
314    }
315
316    /// 执行团队任务
317    pub async fn run(&self, input: impl Into<String>) -> LLMResult<String> {
318        let input = input.into();
319
320        match &self.pattern {
321            TeamPattern::Chain => self.run_chain(&input).await,
322            TeamPattern::Parallel => self.run_parallel(&input).await,
323            TeamPattern::Debate { max_rounds } => self.run_debate(&input, *max_rounds).await,
324            TeamPattern::Supervised => self.run_supervised(&input).await,
325            TeamPattern::MapReduce => self.run_map_reduce(&input).await,
326            TeamPattern::Custom => {
327                // 自定义模式默认使用链式
328                self.run_chain(&input).await
329            }
330        }
331    }
332
333    /// 获取成员
334    pub fn get_member(&self, id: &str) -> Option<&AgentMember> {
335        self.member_map.get(id).map(|idx| &self.members[*idx])
336    }
337
338    /// 获取所有成员 ID
339    pub fn member_ids(&self) -> Vec<&str> {
340        self.members.iter().map(|m| m.role.id.as_str()).collect()
341    }
342}
343
344/// Agent 团队构建器
345pub struct AgentTeamBuilder {
346    id: String,
347    name: String,
348    members: Vec<AgentMember>,
349    pattern: TeamPattern,
350    supervisor_id: Option<String>,
351    aggregate_prompt: Option<String>,
352}
353
354impl AgentTeamBuilder {
355    pub fn new(id: impl Into<String>) -> Self {
356        let id = id.into();
357        Self {
358            name: id.clone(),
359            id,
360            members: Vec::new(),
361            pattern: TeamPattern::Chain,
362            supervisor_id: None,
363            aggregate_prompt: None,
364        }
365    }
366
367    /// 设置名称
368    pub fn with_name(mut self, name: impl Into<String>) -> Self {
369        self.name = name.into();
370        self
371    }
372
373    /// 添加成员
374    pub fn add_member(mut self, id: impl Into<String>, agent: Arc<LLMAgent>) -> Self {
375        self.members.push(AgentMember::new(id, agent));
376        self
377    }
378
379    /// 添加带角色的成员
380    pub fn add_member_with_role(mut self, agent: Arc<LLMAgent>, role: AgentRole) -> Self {
381        let member = AgentMember::new(&role.id, agent).with_role(role);
382        self.members.push(member);
383        self
384    }
385
386    /// 设置协作模式
387    pub fn with_pattern(mut self, pattern: TeamPattern) -> Self {
388        self.pattern = pattern;
389        self
390    }
391
392    /// 设置监督者
393    pub fn with_supervisor(mut self, supervisor_id: impl Into<String>) -> Self {
394        self.supervisor_id = Some(supervisor_id.into());
395        self.pattern = TeamPattern::Supervised;
396        self
397    }
398
399    /// 设置聚合提示词
400    pub fn with_aggregate_prompt(mut self, prompt: impl Into<String>) -> Self {
401        self.aggregate_prompt = Some(prompt.into());
402        self
403    }
404
405    /// 构建团队
406    pub fn build(self) -> AgentTeam {
407        let member_map: HashMap<String, usize> = self
408            .members
409            .iter()
410            .enumerate()
411            .map(|(i, m)| (m.role.id.clone(), i))
412            .collect();
413
414        AgentTeam {
415            id: self.id,
416            name: self.name,
417            members: self.members,
418            member_map,
419            pattern: self.pattern,
420            supervisor_id: self.supervisor_id,
421            aggregate_prompt: self.aggregate_prompt,
422        }
423    }
424}
425
426// ============================================================================
427// 预定义团队模式
428// ============================================================================
429
430/// 创建内容创作团队
431///
432/// 包含:研究员、写手、编辑
433pub fn content_creation_team(
434    researcher: Arc<LLMAgent>,
435    writer: Arc<LLMAgent>,
436    editor: Arc<LLMAgent>,
437) -> AgentTeam {
438    AgentTeamBuilder::new("content-creation")
439        .with_name("Content Creation Team")
440        .add_member_with_role(
441            researcher,
442            AgentRole::new("researcher", "Researcher")
443                .with_description("Research and gather information on the topic")
444                .with_template(
445                    "Research the following topic thoroughly and provide key findings:\n\n{input}",
446                ),
447        )
448        .add_member_with_role(
449            writer,
450            AgentRole::new("writer", "Writer")
451                .with_description("Write engaging content based on research")
452                .with_template(
453                    "Based on the following research, write an engaging article:\n\n{input}",
454                ),
455        )
456        .add_member_with_role(
457            editor,
458            AgentRole::new("editor", "Editor")
459                .with_description("Edit and polish the content")
460                .with_template(
461                    "Edit and improve the following article for clarity and engagement:\n\n{input}",
462                ),
463        )
464        .with_pattern(TeamPattern::Chain)
465        .build()
466}
467
468/// 创建代码审查团队
469///
470/// 包含:安全审查员、性能审查员、风格审查员、监督者
471pub fn code_review_team(
472    security_reviewer: Arc<LLMAgent>,
473    performance_reviewer: Arc<LLMAgent>,
474    style_reviewer: Arc<LLMAgent>,
475    supervisor: Arc<LLMAgent>,
476) -> AgentTeam {
477    AgentTeamBuilder::new("code-review")
478        .with_name("Code Review Team")
479        .add_member_with_role(
480            security_reviewer,
481            AgentRole::new("security", "Security Reviewer")
482                .with_description("Review code for security vulnerabilities"),
483        )
484        .add_member_with_role(
485            performance_reviewer,
486            AgentRole::new("performance", "Performance Reviewer")
487                .with_description("Review code for performance issues"),
488        )
489        .add_member_with_role(
490            style_reviewer,
491            AgentRole::new("style", "Style Reviewer")
492                .with_description("Review code for style and best practices"),
493        )
494        .add_member_with_role(
495            supervisor,
496            AgentRole::new("supervisor", "Lead Reviewer")
497                .with_description("Synthesize reviews and provide final feedback"),
498        )
499        .with_supervisor("supervisor")
500        .build()
501}
502
503/// 创建辩论团队
504///
505/// 两个 Agent 进行辩论
506pub fn debate_team(agent1: Arc<LLMAgent>, agent2: Arc<LLMAgent>, max_rounds: usize) -> AgentTeam {
507    AgentTeamBuilder::new("debate")
508        .with_name("Debate Team")
509        .add_member_with_role(
510            agent1,
511            AgentRole::new("debater1", "Debater 1")
512                .with_description("Present and defend your position"),
513        )
514        .add_member_with_role(
515            agent2,
516            AgentRole::new("debater2", "Debater 2")
517                .with_description("Present an alternative perspective"),
518        )
519        .with_pattern(TeamPattern::Debate { max_rounds })
520        .build()
521}
522
523/// 创建分析团队
524///
525/// 多个 Agent 并行分析,然后聚合结果
526pub fn analysis_team(analysts: Vec<(impl Into<String>, Arc<LLMAgent>)>) -> AgentTeam {
527    let mut builder = AgentTeamBuilder::new("analysis")
528        .with_name("Analysis Team")
529        .with_pattern(TeamPattern::MapReduce)
530        .with_aggregate_prompt(
531            "Synthesize the following analyses into a comprehensive report:\n\n{results}\n\n\
532            Original question: {input}",
533        );
534
535    for (id, agent) in analysts {
536        builder = builder.add_member(id, agent);
537    }
538
539    builder.build()
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[test]
547    fn test_team_builder() {
548        // 创建一个没有实际 Agent 的团队(仅测试构建器)
549        let builder = AgentTeamBuilder::new("test-team")
550            .with_name("Test Team")
551            .with_pattern(TeamPattern::Chain);
552
553        // 只测试构建器的配置,不测试实际执行
554        assert_eq!(builder.id, "test-team");
555        assert_eq!(builder.name, "Test Team");
556    }
557
558    #[test]
559    fn test_agent_role() {
560        let role = AgentRole::new("researcher", "Researcher")
561            .with_description("Research topics")
562            .with_template("{input}");
563
564        assert_eq!(role.id, "researcher");
565        assert_eq!(role.name, "Researcher");
566        assert_eq!(role.description, "Research topics");
567        assert!(role.prompt_template.is_some());
568    }
569}