use crate::agent::context::AgentContext;
use crate::agent::error::AgentResult;
use crate::agent::types::AgentOutput;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[async_trait]
pub trait Coordinator: Send + Sync {
async fn dispatch(&self, task: Task, ctx: &AgentContext) -> AgentResult<Vec<DispatchResult>>;
async fn aggregate(&self, results: Vec<AgentOutput>) -> AgentResult<AgentOutput>;
fn pattern(&self) -> CoordinationPattern;
fn name(&self) -> &str {
"coordinator"
}
async fn select_agents(&self, task: &Task, ctx: &AgentContext) -> AgentResult<Vec<String>> {
let _ = (task, ctx);
Ok(vec![])
}
fn requires_all(&self) -> bool {
matches!(
self.pattern(),
CoordinationPattern::Parallel | CoordinationPattern::Consensus { .. }
)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub enum CoordinationPattern {
#[default]
Sequential,
Parallel,
Hierarchical {
supervisor_id: String,
},
Consensus {
threshold: f32,
},
Debate {
max_rounds: usize,
},
MapReduce,
Voting,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
pub id: String,
pub task_type: TaskType,
pub content: String,
pub priority: TaskPriority,
pub target_agent: Option<String>,
pub params: HashMap<String, serde_json::Value>,
pub metadata: HashMap<String, String>,
pub created_at: u64,
pub timeout_ms: Option<u64>,
}
impl Task {
pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
Self {
id: id.into(),
task_type: TaskType::General,
content: content.into(),
priority: TaskPriority::Normal,
target_agent: None,
params: HashMap::new(),
metadata: HashMap::new(),
created_at: now,
timeout_ms: None,
}
}
pub fn with_type(mut self, task_type: TaskType) -> Self {
self.task_type = task_type;
self
}
pub fn with_priority(mut self, priority: TaskPriority) -> Self {
self.priority = priority;
self
}
pub fn for_agent(mut self, agent_id: impl Into<String>) -> Self {
self.target_agent = Some(agent_id.into());
self
}
pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.params.insert(key.into(), value);
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = Some(timeout_ms);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskType {
General,
Analysis,
Generation,
Review,
Decision,
Search,
Custom(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
pub enum TaskPriority {
Low = 0,
#[default]
Normal = 1,
High = 2,
Urgent = 3,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DispatchResult {
pub task_id: String,
pub agent_id: String,
pub status: DispatchStatus,
pub output: Option<AgentOutput>,
pub error: Option<String>,
pub duration_ms: u64,
}
impl DispatchResult {
pub fn success(
task_id: impl Into<String>,
agent_id: impl Into<String>,
output: AgentOutput,
duration_ms: u64,
) -> Self {
Self {
task_id: task_id.into(),
agent_id: agent_id.into(),
status: DispatchStatus::Completed,
output: Some(output),
error: None,
duration_ms,
}
}
pub fn failure(
task_id: impl Into<String>,
agent_id: impl Into<String>,
error: impl Into<String>,
duration_ms: u64,
) -> Self {
Self {
task_id: task_id.into(),
agent_id: agent_id.into(),
status: DispatchStatus::Failed,
output: None,
error: Some(error.into()),
duration_ms,
}
}
pub fn pending(task_id: impl Into<String>, agent_id: impl Into<String>) -> Self {
Self {
task_id: task_id.into(),
agent_id: agent_id.into(),
status: DispatchStatus::Pending,
output: None,
error: None,
duration_ms: 0,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum DispatchStatus {
Pending,
Running,
Completed,
Failed,
Timeout,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum AggregationStrategy {
Concatenate { separator: String },
FirstSuccess,
#[default]
CollectAll,
Vote,
LLMSummarize { prompt_template: String },
Custom(String),
}
pub fn aggregate_outputs(
outputs: Vec<AgentOutput>,
strategy: &AggregationStrategy,
) -> AgentResult<AgentOutput> {
match strategy {
AggregationStrategy::Concatenate { separator } => {
let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
Ok(AgentOutput::text(texts.join(separator)))
}
AggregationStrategy::FirstSuccess => {
outputs.into_iter().find(|o| !o.is_error()).ok_or_else(|| {
crate::agent::error::AgentError::CoordinationError(
"No successful output".to_string(),
)
})
}
AggregationStrategy::CollectAll => {
let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
Ok(AgentOutput::json(serde_json::json!({
"results": texts,
"count": texts.len(),
})))
}
AggregationStrategy::Vote => {
let mut votes: HashMap<String, usize> = HashMap::new();
for output in &outputs {
let text = output.to_text();
*votes.entry(text).or_insert(0) += 1;
}
let winner = votes
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(text, _)| text)
.unwrap_or_default();
Ok(AgentOutput::text(winner))
}
AggregationStrategy::LLMSummarize { .. } => {
let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
Ok(AgentOutput::text(texts.join("\n\n---\n\n")))
}
AggregationStrategy::Custom(_) => {
let texts: Vec<String> = outputs.iter().map(|o| o.to_text()).collect();
Ok(AgentOutput::text(texts.join("\n")))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_creation() {
let task = Task::new("task-1", "Do something")
.with_type(TaskType::Analysis)
.with_priority(TaskPriority::High)
.for_agent("agent-1")
.with_timeout(5000);
assert_eq!(task.id, "task-1");
assert_eq!(task.task_type, TaskType::Analysis);
assert_eq!(task.priority, TaskPriority::High);
assert_eq!(task.target_agent, Some("agent-1".to_string()));
assert_eq!(task.timeout_ms, Some(5000));
}
#[test]
fn test_dispatch_result() {
let success =
DispatchResult::success("task-1", "agent-1", AgentOutput::text("Result"), 100);
assert_eq!(success.status, DispatchStatus::Completed);
assert!(success.output.is_some());
let failure = DispatchResult::failure("task-1", "agent-1", "Error occurred", 50);
assert_eq!(failure.status, DispatchStatus::Failed);
assert!(failure.error.is_some());
}
#[test]
fn test_aggregate_concatenate() {
let outputs = vec![
AgentOutput::text("Part 1"),
AgentOutput::text("Part 2"),
AgentOutput::text("Part 3"),
];
let strategy = AggregationStrategy::Concatenate {
separator: " | ".to_string(),
};
let result = aggregate_outputs(outputs, &strategy).unwrap();
assert_eq!(result.to_text(), "Part 1 | Part 2 | Part 3");
}
#[test]
fn test_aggregate_first_success() {
let outputs = vec![
AgentOutput::error("Error 1"),
AgentOutput::text("Success"),
AgentOutput::text("Another success"),
];
let strategy = AggregationStrategy::FirstSuccess;
let result = aggregate_outputs(outputs, &strategy).unwrap();
assert_eq!(result.to_text(), "Success");
}
#[test]
fn test_aggregate_vote() {
let outputs = vec![
AgentOutput::text("A"),
AgentOutput::text("B"),
AgentOutput::text("A"),
AgentOutput::text("A"),
AgentOutput::text("B"),
];
let strategy = AggregationStrategy::Vote;
let result = aggregate_outputs(outputs, &strategy).unwrap();
assert_eq!(result.to_text(), "A"); }
}