use std::collections::HashMap;
use tokio::sync::Mutex;
use tracing::{debug, info, warn};
use crate::error::Result;
use super::super::types::SubagentResult;
use super::mailbox::{Mailbox, MailboxMessage, MessageKind};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaskState {
Pending,
Assigned(String),
Completed,
Failed(String),
}
#[derive(Debug, Clone)]
pub struct TaskAssignment {
pub task: String,
pub assigned_to: Option<String>,
pub state: TaskState,
pub attempts: u32,
}
pub struct TeamCoordinator {
leader: String,
max_retries: u32,
tasks: Mutex<HashMap<String, TaskAssignment>>,
results: Mutex<HashMap<String, SubagentResult>>,
}
impl TeamCoordinator {
pub fn new(leader: impl Into<String>) -> Self {
Self {
leader: leader.into(),
max_retries: 3,
tasks: Mutex::new(HashMap::new()),
results: Mutex::new(HashMap::new()),
}
}
pub fn with_max_retries(leader: impl Into<String>, max_retries: u32) -> Self {
Self {
leader: leader.into(),
max_retries,
tasks: Mutex::new(HashMap::new()),
results: Mutex::new(HashMap::new()),
}
}
pub async fn add_tasks(&self, tasks: Vec<String>) {
let mut map = self.tasks.lock().await;
for task in tasks {
debug!(task = %task, "Task added to pending queue");
map.entry(task.clone()).or_insert(TaskAssignment {
task,
assigned_to: None,
state: TaskState::Pending,
attempts: 0,
});
}
}
pub async fn assign(&self, task: &str, agent_name: &str, mailbox: &Mailbox) -> Result<()> {
info!(task = %task, agent = %agent_name, "Assigning task to teammate");
{
let mut map = self.tasks.lock().await;
if let Some(assignment) = map.get_mut(task) {
assignment.assigned_to = Some(agent_name.to_string());
assignment.state = TaskState::Assigned(agent_name.to_string());
assignment.attempts += 1;
}
}
let msg = MailboxMessage::new(
&self.leader,
agent_name,
MessageKind::TaskAssigned {
task: task.to_string(),
context: HashMap::new(),
},
);
mailbox.send(msg).await?;
Ok(())
}
pub async fn assign_next(&self, agent_name: &str, mailbox: &Mailbox) -> Result<Option<String>> {
let task = {
let map = self.tasks.lock().await;
map.iter()
.find(|(_, a)| a.state == TaskState::Pending)
.map(|(t, _)| t.clone())
};
if let Some(ref t) = task {
self.assign(t, agent_name, mailbox).await?;
}
Ok(task)
}
pub async fn record_result(&self, task: &str, result: SubagentResult) {
debug!(task = %task, agent = %result.agent_name, "Recording task result");
let mut map = self.tasks.lock().await;
if let Some(assignment) = map.get_mut(task) {
assignment.state = TaskState::Completed;
}
drop(map);
let mut results = self.results.lock().await;
results.insert(task.to_string(), result);
}
pub async fn record_failure(&self, task: &str, error: &str) -> bool {
warn!(task = %task, error = %error, "Task failed");
let mut map = self.tasks.lock().await;
if let Some(assignment) = map.get_mut(task) {
if assignment.attempts < self.max_retries {
assignment.state = TaskState::Pending;
assignment.assigned_to = None;
info!(task = %task, attempt = assignment.attempts, "Task returned to pending for reassignment");
true
} else {
assignment.state = TaskState::Failed(error.to_string());
warn!(task = %task, attempts = assignment.attempts, "Task permanently failed after max retries");
false
}
} else {
false
}
}
pub async fn collect_results(&self) -> Vec<(String, SubagentResult)> {
let results = self.results.lock().await;
results
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
pub async fn get_result(&self, task: &str) -> Option<SubagentResult> {
let results = self.results.lock().await;
results.get(task).cloned()
}
pub async fn is_complete(&self) -> bool {
let map = self.tasks.lock().await;
map.values()
.all(|a| matches!(a.state, TaskState::Completed | TaskState::Failed(_)))
}
pub async fn has_failures(&self) -> bool {
let map = self.tasks.lock().await;
map.values()
.any(|a| matches!(a.state, TaskState::Failed(_)))
}
pub async fn progress(&self) -> (usize, usize) {
let map = self.tasks.lock().await;
let total = map.len();
let done = map
.values()
.filter(|a| matches!(a.state, TaskState::Completed | TaskState::Failed(_)))
.count();
(done, total)
}
pub async fn task_state(&self, task: &str) -> Option<TaskState> {
let map = self.tasks.lock().await;
map.get(task).map(|a| a.state.clone())
}
pub async fn all_tasks(&self) -> Vec<TaskAssignment> {
let map = self.tasks.lock().await;
map.values().cloned().collect()
}
pub fn leader(&self) -> &str {
&self.leader
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_add_and_assign() {
let coord = TeamCoordinator::new("leader");
let mailbox = Mailbox::new();
coord.add_tasks(vec!["task1".into(), "task2".into()]).await;
coord.assign("task1", "worker", &mailbox).await.unwrap();
let msg = mailbox.recv().await.unwrap();
assert_eq!(msg.to, "worker");
match msg.kind {
MessageKind::TaskAssigned { task, .. } => assert_eq!(task, "task1"),
_ => panic!("Wrong message kind"),
}
let state = coord.task_state("task1").await.unwrap();
assert!(matches!(state, TaskState::Assigned(ref a) if a == "worker"));
}
#[tokio::test]
async fn test_assign_next() {
let coord = TeamCoordinator::new("leader");
let mailbox = Mailbox::new();
coord.add_tasks(vec!["t1".into()]).await;
let task = coord.assign_next("w", &mailbox).await.unwrap();
assert_eq!(task, Some("t1".into()));
let next = coord.assign_next("w", &mailbox).await.unwrap();
assert!(next.is_none());
}
#[tokio::test]
async fn test_record_and_complete() {
let coord = TeamCoordinator::new("leader");
let mailbox = Mailbox::new();
coord.add_tasks(vec!["t1".into()]).await;
coord.assign("t1", "w", &mailbox).await.unwrap();
let result = SubagentResult {
agent_name: "w".into(),
output: "done".into(),
duration: std::time::Duration::from_millis(100),
iterations: 1,
tokens_used: None,
was_truncated: false,
mode: crate::agent::subagent::types::ExecutionMode::Teammate,
};
coord.record_result("t1", result).await;
assert!(coord.is_complete().await);
let (done, total) = coord.progress().await;
assert_eq!(done, 1);
assert_eq!(total, 1);
}
#[tokio::test]
async fn test_failure_and_retry() {
let coord = TeamCoordinator::with_max_retries("leader", 2);
let mailbox = Mailbox::new();
coord.add_tasks(vec!["t1".into()]).await;
coord.assign("t1", "w", &mailbox).await.unwrap();
let can_retry = coord.record_failure("t1", "timeout").await;
assert!(can_retry);
let state = coord.task_state("t1").await.unwrap();
assert!(matches!(state, TaskState::Pending));
coord.assign("t1", "w2", &mailbox).await.unwrap();
let can_retry = coord.record_failure("t1", "timeout again").await;
assert!(!can_retry);
assert!(coord.is_complete().await);
assert!(coord.has_failures().await);
}
#[tokio::test]
async fn test_progress() {
let coord = TeamCoordinator::new("leader");
coord
.add_tasks(vec!["a".into(), "b".into(), "c".into()])
.await;
let (done, total) = coord.progress().await;
assert_eq!(done, 0);
assert_eq!(total, 3);
}
}