1use crate::agent::context::AgentContext;
6use crate::agent::error::AgentResult;
7use crate::agent::types::AgentOutput;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[async_trait]
41pub trait Coordinator: Send + Sync {
42 async fn dispatch(&self, task: Task, ctx: &AgentContext) -> AgentResult<Vec<DispatchResult>>;
44
45 async fn aggregate(&self, results: Vec<AgentOutput>) -> AgentResult<AgentOutput>;
47
48 fn pattern(&self) -> CoordinationPattern;
50
51 fn name(&self) -> &str {
53 "coordinator"
54 }
55
56 async fn select_agents(&self, task: &Task, ctx: &AgentContext) -> AgentResult<Vec<String>> {
58 let _ = (task, ctx);
59 Ok(vec![])
60 }
61
62 fn requires_all(&self) -> bool {
64 matches!(
65 self.pattern(),
66 CoordinationPattern::Parallel | CoordinationPattern::Consensus { .. }
67 )
68 }
69}
70
71#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
73pub enum CoordinationPattern {
74 #[default]
76 Sequential,
77 Parallel,
79 Hierarchical {
81 supervisor_id: String,
83 },
84 Consensus {
86 threshold: f32,
88 },
89 Debate {
91 max_rounds: usize,
93 },
94 MapReduce,
96 Voting,
98 Custom(String),
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Task {
105 pub id: String,
107 pub task_type: TaskType,
109 pub content: String,
111 pub priority: TaskPriority,
113 pub target_agent: Option<String>,
115 pub params: HashMap<String, serde_json::Value>,
117 pub metadata: HashMap<String, String>,
119 pub created_at: u64,
121 pub timeout_ms: Option<u64>,
123}
124
125impl Task {
126 pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
128 let now = std::time::SystemTime::now()
129 .duration_since(std::time::UNIX_EPOCH)
130 .unwrap_or_default()
131 .as_millis() as u64;
132
133 Self {
134 id: id.into(),
135 task_type: TaskType::General,
136 content: content.into(),
137 priority: TaskPriority::Normal,
138 target_agent: None,
139 params: HashMap::new(),
140 metadata: HashMap::new(),
141 created_at: now,
142 timeout_ms: None,
143 }
144 }
145
146 pub fn with_type(mut self, task_type: TaskType) -> Self {
148 self.task_type = task_type;
149 self
150 }
151
152 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
154 self.priority = priority;
155 self
156 }
157
158 pub fn for_agent(mut self, agent_id: impl Into<String>) -> Self {
160 self.target_agent = Some(agent_id.into());
161 self
162 }
163
164 pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
166 self.params.insert(key.into(), value);
167 self
168 }
169
170 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
172 self.timeout_ms = Some(timeout_ms);
173 self
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
179pub enum TaskType {
180 General,
182 Analysis,
184 Generation,
186 Review,
188 Decision,
190 Search,
192 Custom(String),
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
198pub enum TaskPriority {
199 Low = 0,
200 #[default]
201 Normal = 1,
202 High = 2,
203 Urgent = 3,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct DispatchResult {
209 pub task_id: String,
211 pub agent_id: String,
213 pub status: DispatchStatus,
215 pub output: Option<AgentOutput>,
217 pub error: Option<String>,
219 pub duration_ms: u64,
221}
222
223impl DispatchResult {
224 pub fn success(
226 task_id: impl Into<String>,
227 agent_id: impl Into<String>,
228 output: AgentOutput,
229 duration_ms: u64,
230 ) -> Self {
231 Self {
232 task_id: task_id.into(),
233 agent_id: agent_id.into(),
234 status: DispatchStatus::Completed,
235 output: Some(output),
236 error: None,
237 duration_ms,
238 }
239 }
240
241 pub fn failure(
243 task_id: impl Into<String>,
244 agent_id: impl Into<String>,
245 error: impl Into<String>,
246 duration_ms: u64,
247 ) -> Self {
248 Self {
249 task_id: task_id.into(),
250 agent_id: agent_id.into(),
251 status: DispatchStatus::Failed,
252 output: None,
253 error: Some(error.into()),
254 duration_ms,
255 }
256 }
257
258 pub fn pending(task_id: impl Into<String>, agent_id: impl Into<String>) -> Self {
260 Self {
261 task_id: task_id.into(),
262 agent_id: agent_id.into(),
263 status: DispatchStatus::Pending,
264 output: None,
265 error: None,
266 duration_ms: 0,
267 }
268 }
269}
270
271#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
273pub enum DispatchStatus {
274 Pending,
276 Running,
278 Completed,
280 Failed,
282 Timeout,
284 Cancelled,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize, Default)]
294pub enum AggregationStrategy {
295 Concatenate { separator: String },
297 FirstSuccess,
299 #[default]
301 CollectAll,
302 Vote,
304 LLMSummarize { prompt_template: String },
306 Custom(String),
308}
309
310pub fn aggregate_outputs(
312 outputs: Vec<AgentOutput>,
313 strategy: &AggregationStrategy,
314) -> AgentResult<AgentOutput> {
315 match strategy {
316 AggregationStrategy::Concatenate { separator } => {
317 let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
318 Ok(AgentOutput::text(texts.join(separator)))
319 }
320 AggregationStrategy::FirstSuccess => {
321 outputs.into_iter().find(|o| !o.is_error()).ok_or_else(|| {
322 crate::agent::error::AgentError::CoordinationError(
323 "No successful output".to_string(),
324 )
325 })
326 }
327 AggregationStrategy::CollectAll => {
328 let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
329 Ok(AgentOutput::json(serde_json::json!({
330 "results": texts,
331 "count": texts.len(),
332 })))
333 }
334 AggregationStrategy::Vote => {
335 let mut votes: HashMap<String, usize> = HashMap::new();
337 for output in &outputs {
338 let text = output.to_text();
339 *votes.entry(text).or_insert(0) += 1;
340 }
341 let winner = votes
342 .into_iter()
343 .max_by_key(|(_, count)| *count)
344 .map(|(text, _)| text)
345 .unwrap_or_default();
346 Ok(AgentOutput::text(winner))
347 }
348 AggregationStrategy::LLMSummarize { .. } => {
349 let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
351 Ok(AgentOutput::text(texts.join("\n\n---\n\n")))
352 }
353 AggregationStrategy::Custom(_) => {
354 let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
356 Ok(AgentOutput::text(texts.join("\n")))
357 }
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_task_creation() {
367 let task = Task::new("task-1", "Do something")
368 .with_type(TaskType::Analysis)
369 .with_priority(TaskPriority::High)
370 .for_agent("agent-1")
371 .with_timeout(5000);
372
373 assert_eq!(task.id, "task-1");
374 assert_eq!(task.task_type, TaskType::Analysis);
375 assert_eq!(task.priority, TaskPriority::High);
376 assert_eq!(task.target_agent, Some("agent-1".to_string()));
377 assert_eq!(task.timeout_ms, Some(5000));
378 }
379
380 #[test]
381 fn test_dispatch_result() {
382 let success =
383 DispatchResult::success("task-1", "agent-1", AgentOutput::text("Result"), 100);
384 assert_eq!(success.status, DispatchStatus::Completed);
385 assert!(success.output.is_some());
386
387 let failure = DispatchResult::failure("task-1", "agent-1", "Error occurred", 50);
388 assert_eq!(failure.status, DispatchStatus::Failed);
389 assert!(failure.error.is_some());
390 }
391
392 #[test]
393 fn test_aggregate_concatenate() {
394 let outputs = vec![
395 AgentOutput::text("Part 1"),
396 AgentOutput::text("Part 2"),
397 AgentOutput::text("Part 3"),
398 ];
399
400 let strategy = AggregationStrategy::Concatenate {
401 separator: " | ".to_string(),
402 };
403
404 let result = aggregate_outputs(outputs, &strategy).unwrap();
405 assert_eq!(result.to_text(), "Part 1 | Part 2 | Part 3");
406 }
407
408 #[test]
409 fn test_aggregate_first_success() {
410 let outputs = vec![
411 AgentOutput::error("Error 1"),
412 AgentOutput::text("Success"),
413 AgentOutput::text("Another success"),
414 ];
415
416 let strategy = AggregationStrategy::FirstSuccess;
417 let result = aggregate_outputs(outputs, &strategy).unwrap();
418 assert_eq!(result.to_text(), "Success");
419 }
420
421 #[test]
422 fn test_aggregate_vote() {
423 let outputs = vec![
424 AgentOutput::text("A"),
425 AgentOutput::text("B"),
426 AgentOutput::text("A"),
427 AgentOutput::text("A"),
428 AgentOutput::text("B"),
429 ];
430
431 let strategy = AggregationStrategy::Vote;
432 let result = aggregate_outputs(outputs, &strategy).unwrap();
433 assert_eq!(result.to_text(), "A"); }
435}