mofa_foundation/agent/components/
coordinator.rs1use async_trait::async_trait;
6use mofa_kernel::agent::AgentResult;
7use mofa_kernel::agent::components::coordinator::{
8 AggregationStrategy, CoordinationPattern, Coordinator, DispatchResult, Task, aggregate_outputs,
9};
10use mofa_kernel::agent::context::AgentContext;
11use mofa_kernel::agent::types::AgentOutput;
12
13pub struct SequentialCoordinator {
21 agent_ids: Vec<String>,
22}
23
24impl SequentialCoordinator {
25 pub fn new(agent_ids: Vec<String>) -> Self {
27 Self { agent_ids }
28 }
29}
30
31#[async_trait]
32impl Coordinator for SequentialCoordinator {
33 async fn dispatch(&self, task: Task, _ctx: &AgentContext) -> AgentResult<Vec<DispatchResult>> {
34 let mut results = Vec::new();
36 for agent_id in &self.agent_ids {
37 results.push(DispatchResult::pending(&task.id, agent_id));
38 }
39 Ok(results)
40 }
41
42 async fn aggregate(&self, results: Vec<AgentOutput>) -> AgentResult<AgentOutput> {
43 let texts: Vec<String> = results.iter().map(|o| o.to_text()).collect();
44 Ok(AgentOutput::text(texts.join("\n\n---\n\n")))
45 }
46
47 fn pattern(&self) -> CoordinationPattern {
48 CoordinationPattern::Sequential
49 }
50
51 fn name(&self) -> &str {
52 "sequential"
53 }
54
55 async fn select_agents(&self, _task: &Task, _ctx: &AgentContext) -> AgentResult<Vec<String>> {
56 Ok(self.agent_ids.clone())
57 }
58}
59
60pub struct ParallelCoordinator {
64 agent_ids: Vec<String>,
65}
66
67impl ParallelCoordinator {
68 pub fn new(agent_ids: Vec<String>) -> Self {
70 Self { agent_ids }
71 }
72}
73
74#[async_trait]
75impl Coordinator for ParallelCoordinator {
76 async fn dispatch(&self, task: Task, _ctx: &AgentContext) -> AgentResult<Vec<DispatchResult>> {
77 let mut results = Vec::new();
78 for agent_id in &self.agent_ids {
79 results.push(DispatchResult::pending(&task.id, agent_id));
80 }
81 Ok(results)
82 }
83
84 async fn aggregate(&self, results: Vec<AgentOutput>) -> AgentResult<AgentOutput> {
85 aggregate_outputs(results, &AggregationStrategy::CollectAll)
86 }
87
88 fn pattern(&self) -> CoordinationPattern {
89 CoordinationPattern::Parallel
90 }
91
92 fn name(&self) -> &str {
93 "parallel"
94 }
95
96 async fn select_agents(&self, _task: &Task, _ctx: &AgentContext) -> AgentResult<Vec<String>> {
97 Ok(self.agent_ids.clone())
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_sequential_coordinator() {
107 let coordinator =
108 SequentialCoordinator::new(vec!["agent-1".to_string(), "agent-2".to_string()]);
109 assert_eq!(coordinator.name(), "sequential");
110 assert_eq!(coordinator.pattern(), CoordinationPattern::Sequential);
111 }
112
113 #[test]
114 fn test_parallel_coordinator() {
115 let coordinator =
116 ParallelCoordinator::new(vec!["agent-1".to_string(), "agent-2".to_string()]);
117 assert_eq!(coordinator.name(), "parallel");
118 assert_eq!(coordinator.pattern(), CoordinationPattern::Parallel);
119 }
120
121 #[tokio::test]
122 async fn test_sequential_dispatch() {
123 let coordinator =
124 SequentialCoordinator::new(vec!["agent-1".to_string(), "agent-2".to_string()]);
125 let ctx = AgentContext::new("test");
126 let task = Task::new("task-1", "Do something");
127
128 let results = coordinator.dispatch(task, &ctx).await.unwrap();
129 assert_eq!(results.len(), 2);
130 }
131
132 #[tokio::test]
133 async fn test_sequential_aggregate() {
134 let coordinator = SequentialCoordinator::new(vec!["agent-1".to_string()]);
135 let results = vec![AgentOutput::text("Result 1"), AgentOutput::text("Result 2")];
136
137 let aggregated = coordinator.aggregate(results).await.unwrap();
138 assert!(aggregated.to_text().contains("Result 1"));
139 assert!(aggregated.to_text().contains("Result 2"));
140 }
141}