Skip to main content

mofa_foundation/agent/components/
coordinator.rs

1//! 协调组件
2//!
3//! 从 kernel 层导入 Coordinator trait,提供具体实现
4
5use async_trait::async_trait;
6use mofa_kernel::agent::AgentResult;
7use mofa_kernel::agent::components::coordinator::{
8    AggregationStrategy, CoordinationPattern, Coordinator, DispatchResult, Task, aggregate_outputs,
9};
10use mofa_kernel::agent::context::AgentContext;
11use mofa_kernel::agent::types::AgentOutput;
12
13// ============================================================================
14// 具体协调器实现
15// ============================================================================
16
17/// 顺序协调器
18///
19/// 按顺序将任务分发给多个 Agent
20pub struct SequentialCoordinator {
21    agent_ids: Vec<String>,
22}
23
24impl SequentialCoordinator {
25    /// 创建新的顺序协调器
26    pub fn new(agent_ids: Vec<String>) -> Self {
27        Self { agent_ids }
28    }
29}
30
31#[async_trait]
32impl Coordinator for SequentialCoordinator {
33    async fn dispatch(&self, task: Task, _ctx: &AgentContext) -> AgentResult<Vec<DispatchResult>> {
34        // 简化实现:为每个 agent 创建待处理结果
35        let mut results = Vec::new();
36        for agent_id in &self.agent_ids {
37            results.push(DispatchResult::pending(&task.id, agent_id));
38        }
39        Ok(results)
40    }
41
42    async fn aggregate(&self, results: Vec<AgentOutput>) -> AgentResult<AgentOutput> {
43        let texts: Vec<String> = results.iter().map(|o| o.to_text()).collect();
44        Ok(AgentOutput::text(texts.join("\n\n---\n\n")))
45    }
46
47    fn pattern(&self) -> CoordinationPattern {
48        CoordinationPattern::Sequential
49    }
50
51    fn name(&self) -> &str {
52        "sequential"
53    }
54
55    async fn select_agents(&self, _task: &Task, _ctx: &AgentContext) -> AgentResult<Vec<String>> {
56        Ok(self.agent_ids.clone())
57    }
58}
59
60/// 并行协调器
61///
62/// 并行将任务分发给多个 Agent
63pub struct ParallelCoordinator {
64    agent_ids: Vec<String>,
65}
66
67impl ParallelCoordinator {
68    /// 创建新的并行协调器
69    pub fn new(agent_ids: Vec<String>) -> Self {
70        Self { agent_ids }
71    }
72}
73
74#[async_trait]
75impl Coordinator for ParallelCoordinator {
76    async fn dispatch(&self, task: Task, _ctx: &AgentContext) -> AgentResult<Vec<DispatchResult>> {
77        let mut results = Vec::new();
78        for agent_id in &self.agent_ids {
79            results.push(DispatchResult::pending(&task.id, agent_id));
80        }
81        Ok(results)
82    }
83
84    async fn aggregate(&self, results: Vec<AgentOutput>) -> AgentResult<AgentOutput> {
85        aggregate_outputs(results, &AggregationStrategy::CollectAll)
86    }
87
88    fn pattern(&self) -> CoordinationPattern {
89        CoordinationPattern::Parallel
90    }
91
92    fn name(&self) -> &str {
93        "parallel"
94    }
95
96    async fn select_agents(&self, _task: &Task, _ctx: &AgentContext) -> AgentResult<Vec<String>> {
97        Ok(self.agent_ids.clone())
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn test_sequential_coordinator() {
107        let coordinator =
108            SequentialCoordinator::new(vec!["agent-1".to_string(), "agent-2".to_string()]);
109        assert_eq!(coordinator.name(), "sequential");
110        assert_eq!(coordinator.pattern(), CoordinationPattern::Sequential);
111    }
112
113    #[test]
114    fn test_parallel_coordinator() {
115        let coordinator =
116            ParallelCoordinator::new(vec!["agent-1".to_string(), "agent-2".to_string()]);
117        assert_eq!(coordinator.name(), "parallel");
118        assert_eq!(coordinator.pattern(), CoordinationPattern::Parallel);
119    }
120
121    #[tokio::test]
122    async fn test_sequential_dispatch() {
123        let coordinator =
124            SequentialCoordinator::new(vec!["agent-1".to_string(), "agent-2".to_string()]);
125        let ctx = AgentContext::new("test");
126        let task = Task::new("task-1", "Do something");
127
128        let results = coordinator.dispatch(task, &ctx).await.unwrap();
129        assert_eq!(results.len(), 2);
130    }
131
132    #[tokio::test]
133    async fn test_sequential_aggregate() {
134        let coordinator = SequentialCoordinator::new(vec!["agent-1".to_string()]);
135        let results = vec![AgentOutput::text("Result 1"), AgentOutput::text("Result 2")];
136
137        let aggregated = coordinator.aggregate(results).await.unwrap();
138        assert!(aggregated.to_text().contains("Result 1"));
139        assert!(aggregated.to_text().contains("Result 2"));
140    }
141}