Skip to main content

mofa_kernel/agent/components/
coordinator.rs

1//! 协调组件
2//!
3//! 定义多 Agent 协调能力
4
5use crate::agent::context::AgentContext;
6use crate::agent::error::AgentResult;
7use crate::agent::types::AgentOutput;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// 协调器 Trait
13///
14/// 负责多 Agent 的任务分发和结果聚合
15///
16/// # 示例
17///
18/// ```rust,ignore
19/// use mofa_kernel::agent::components::coordinator::{Coordinator, CoordinationPattern, Task, DispatchResult};
20///
21/// struct SequentialCoordinator {
22///     agent_ids: Vec<String>,
23/// }
24///
25/// #[async_trait]
26/// impl Coordinator for SequentialCoordinator {
27///     async fn dispatch(&self, task: Task, ctx: &CoreAgentContext) -> AgentResult<Vec<DispatchResult>> {
28///         // Sequential dispatch implementation
29///     }
30///
31///     async fn aggregate(&self, results: Vec<AgentOutput>) -> AgentResult<AgentOutput> {
32///         // Combine results
33///     }
34///
35///     fn pattern(&self) -> CoordinationPattern {
36///         CoordinationPattern::Sequential
37///     }
38/// }
39/// ```
40#[async_trait]
41pub trait Coordinator: Send + Sync {
42    /// 分发任务给 Agent(s)
43    async fn dispatch(&self, task: Task, ctx: &AgentContext) -> AgentResult<Vec<DispatchResult>>;
44
45    /// 聚合多个 Agent 的结果
46    async fn aggregate(&self, results: Vec<AgentOutput>) -> AgentResult<AgentOutput>;
47
48    /// 获取协调模式
49    fn pattern(&self) -> CoordinationPattern;
50
51    /// 协调器名称
52    fn name(&self) -> &str {
53        "coordinator"
54    }
55
56    /// 选择执行任务的 Agent
57    async fn select_agents(&self, task: &Task, ctx: &AgentContext) -> AgentResult<Vec<String>> {
58        let _ = (task, ctx);
59        Ok(vec![])
60    }
61
62    /// 是否需要所有 Agent 完成
63    fn requires_all(&self) -> bool {
64        matches!(
65            self.pattern(),
66            CoordinationPattern::Parallel | CoordinationPattern::Consensus { .. }
67        )
68    }
69}
70
71/// 协调模式
72#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
73pub enum CoordinationPattern {
74    /// 顺序执行
75    #[default]
76    Sequential,
77    /// 并行执行
78    Parallel,
79    /// 层级执行 (带监督者)
80    Hierarchical {
81        /// 监督者 Agent ID
82        supervisor_id: String,
83    },
84    /// 共识模式 (需要达成一致)
85    Consensus {
86        /// 共识阈值 (0.0 - 1.0)
87        threshold: f32,
88    },
89    /// 辩论模式
90    Debate {
91        /// 最大轮次
92        max_rounds: usize,
93    },
94    /// MapReduce 模式
95    MapReduce,
96    /// 投票模式
97    Voting,
98    /// 自定义模式
99    Custom(String),
100}
101
102/// 任务定义
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Task {
105    /// 任务 ID
106    pub id: String,
107    /// 任务类型
108    pub task_type: TaskType,
109    /// 任务内容
110    pub content: String,
111    /// 任务优先级
112    pub priority: TaskPriority,
113    /// 目标 Agent ID (可选,如果为空则由协调器选择)
114    pub target_agent: Option<String>,
115    /// 任务参数
116    pub params: HashMap<String, serde_json::Value>,
117    /// 任务元数据
118    pub metadata: HashMap<String, String>,
119    /// 创建时间
120    pub created_at: u64,
121    /// 超时时间 (毫秒)
122    pub timeout_ms: Option<u64>,
123}
124
125impl Task {
126    /// 创建新任务
127    pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
128        let now = std::time::SystemTime::now()
129            .duration_since(std::time::UNIX_EPOCH)
130            .unwrap_or_default()
131            .as_millis() as u64;
132
133        Self {
134            id: id.into(),
135            task_type: TaskType::General,
136            content: content.into(),
137            priority: TaskPriority::Normal,
138            target_agent: None,
139            params: HashMap::new(),
140            metadata: HashMap::new(),
141            created_at: now,
142            timeout_ms: None,
143        }
144    }
145
146    /// 设置任务类型
147    pub fn with_type(mut self, task_type: TaskType) -> Self {
148        self.task_type = task_type;
149        self
150    }
151
152    /// 设置优先级
153    pub fn with_priority(mut self, priority: TaskPriority) -> Self {
154        self.priority = priority;
155        self
156    }
157
158    /// 设置目标 Agent
159    pub fn for_agent(mut self, agent_id: impl Into<String>) -> Self {
160        self.target_agent = Some(agent_id.into());
161        self
162    }
163
164    /// 添加参数
165    pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
166        self.params.insert(key.into(), value);
167        self
168    }
169
170    /// 设置超时
171    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
172        self.timeout_ms = Some(timeout_ms);
173        self
174    }
175}
176
177/// 任务类型
178#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
179pub enum TaskType {
180    /// 通用任务
181    General,
182    /// 分析任务
183    Analysis,
184    /// 生成任务
185    Generation,
186    /// 审查任务
187    Review,
188    /// 决策任务
189    Decision,
190    /// 搜索任务
191    Search,
192    /// 自定义任务
193    Custom(String),
194}
195
196/// 任务优先级
197#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
198pub enum TaskPriority {
199    Low = 0,
200    #[default]
201    Normal = 1,
202    High = 2,
203    Urgent = 3,
204}
205
206/// 分发结果
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct DispatchResult {
209    /// 任务 ID
210    pub task_id: String,
211    /// Agent ID
212    pub agent_id: String,
213    /// 执行状态
214    pub status: DispatchStatus,
215    /// 执行结果 (如果完成)
216    pub output: Option<AgentOutput>,
217    /// 错误信息 (如果失败)
218    pub error: Option<String>,
219    /// 执行时间 (毫秒)
220    pub duration_ms: u64,
221}
222
223impl DispatchResult {
224    /// 创建成功结果
225    pub fn success(
226        task_id: impl Into<String>,
227        agent_id: impl Into<String>,
228        output: AgentOutput,
229        duration_ms: u64,
230    ) -> Self {
231        Self {
232            task_id: task_id.into(),
233            agent_id: agent_id.into(),
234            status: DispatchStatus::Completed,
235            output: Some(output),
236            error: None,
237            duration_ms,
238        }
239    }
240
241    /// 创建失败结果
242    pub fn failure(
243        task_id: impl Into<String>,
244        agent_id: impl Into<String>,
245        error: impl Into<String>,
246        duration_ms: u64,
247    ) -> Self {
248        Self {
249            task_id: task_id.into(),
250            agent_id: agent_id.into(),
251            status: DispatchStatus::Failed,
252            output: None,
253            error: Some(error.into()),
254            duration_ms,
255        }
256    }
257
258    /// 创建待处理结果
259    pub fn pending(task_id: impl Into<String>, agent_id: impl Into<String>) -> Self {
260        Self {
261            task_id: task_id.into(),
262            agent_id: agent_id.into(),
263            status: DispatchStatus::Pending,
264            output: None,
265            error: None,
266            duration_ms: 0,
267        }
268    }
269}
270
271/// 分发状态
272#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
273pub enum DispatchStatus {
274    /// 待处理
275    Pending,
276    /// 运行中
277    Running,
278    /// 已完成
279    Completed,
280    /// 失败
281    Failed,
282    /// 超时
283    Timeout,
284    /// 取消
285    Cancelled,
286}
287
288// ============================================================================
289// 聚合策略
290// ============================================================================
291
292/// 结果聚合策略
293#[derive(Debug, Clone, Serialize, Deserialize, Default)]
294pub enum AggregationStrategy {
295    /// 连接所有结果
296    Concatenate { separator: String },
297    /// 取第一个成功的结果
298    FirstSuccess,
299    /// 收集所有结果
300    #[default]
301    CollectAll,
302    /// 投票选择
303    Vote,
304    /// 使用 LLM 总结
305    LLMSummarize { prompt_template: String },
306    /// 自定义聚合
307    Custom(String),
308}
309
310/// 聚合结果
311pub fn aggregate_outputs(
312    outputs: Vec<AgentOutput>,
313    strategy: &AggregationStrategy,
314) -> AgentResult<AgentOutput> {
315    match strategy {
316        AggregationStrategy::Concatenate { separator } => {
317            let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
318            Ok(AgentOutput::text(texts.join(separator)))
319        }
320        AggregationStrategy::FirstSuccess => {
321            outputs.into_iter().find(|o| !o.is_error()).ok_or_else(|| {
322                crate::agent::error::AgentError::CoordinationError(
323                    "No successful output".to_string(),
324                )
325            })
326        }
327        AggregationStrategy::CollectAll => {
328            let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
329            Ok(AgentOutput::json(serde_json::json!({
330                "results": texts,
331                "count": texts.len(),
332            })))
333        }
334        AggregationStrategy::Vote => {
335            // 简单投票:选择最常见的结果
336            let mut votes: HashMap<String, usize> = HashMap::new();
337            for output in &outputs {
338                let text = output.to_text();
339                *votes.entry(text).or_insert(0) += 1;
340            }
341            let winner = votes
342                .into_iter()
343                .max_by_key(|(_, count)| *count)
344                .map(|(text, _)| text)
345                .unwrap_or_default();
346            Ok(AgentOutput::text(winner))
347        }
348        AggregationStrategy::LLMSummarize { .. } => {
349            // LLM 总结需要外部 LLM 调用,这里只是占位
350            let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
351            Ok(AgentOutput::text(texts.join("\n\n---\n\n")))
352        }
353        AggregationStrategy::Custom(_) => {
354            // 自定义聚合需要外部实现
355            let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
356            Ok(AgentOutput::text(texts.join("\n")))
357        }
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_task_creation() {
367        let task = Task::new("task-1", "Do something")
368            .with_type(TaskType::Analysis)
369            .with_priority(TaskPriority::High)
370            .for_agent("agent-1")
371            .with_timeout(5000);
372
373        assert_eq!(task.id, "task-1");
374        assert_eq!(task.task_type, TaskType::Analysis);
375        assert_eq!(task.priority, TaskPriority::High);
376        assert_eq!(task.target_agent, Some("agent-1".to_string()));
377        assert_eq!(task.timeout_ms, Some(5000));
378    }
379
380    #[test]
381    fn test_dispatch_result() {
382        let success =
383            DispatchResult::success("task-1", "agent-1", AgentOutput::text("Result"), 100);
384        assert_eq!(success.status, DispatchStatus::Completed);
385        assert!(success.output.is_some());
386
387        let failure = DispatchResult::failure("task-1", "agent-1", "Error occurred", 50);
388        assert_eq!(failure.status, DispatchStatus::Failed);
389        assert!(failure.error.is_some());
390    }
391
392    #[test]
393    fn test_aggregate_concatenate() {
394        let outputs = vec![
395            AgentOutput::text("Part 1"),
396            AgentOutput::text("Part 2"),
397            AgentOutput::text("Part 3"),
398        ];
399
400        let strategy = AggregationStrategy::Concatenate {
401            separator: " | ".to_string(),
402        };
403
404        let result = aggregate_outputs(outputs, &strategy).unwrap();
405        assert_eq!(result.to_text(), "Part 1 | Part 2 | Part 3");
406    }
407
408    #[test]
409    fn test_aggregate_first_success() {
410        let outputs = vec![
411            AgentOutput::error("Error 1"),
412            AgentOutput::text("Success"),
413            AgentOutput::text("Another success"),
414        ];
415
416        let strategy = AggregationStrategy::FirstSuccess;
417        let result = aggregate_outputs(outputs, &strategy).unwrap();
418        assert_eq!(result.to_text(), "Success");
419    }
420
421    #[test]
422    fn test_aggregate_vote() {
423        let outputs = vec![
424            AgentOutput::text("A"),
425            AgentOutput::text("B"),
426            AgentOutput::text("A"),
427            AgentOutput::text("A"),
428            AgentOutput::text("B"),
429        ];
430
431        let strategy = AggregationStrategy::Vote;
432        let result = aggregate_outputs(outputs, &strategy).unwrap();
433        assert_eq!(result.to_text(), "A"); // A 有 3 票,B 有 2 票
434    }
435}