use crate::error::MultiError;
use crate::mailbox::Mailbox;
use crate::runner::AgentRunner;
use crate::shared::SharedInfra;
use crate::types::{AgentOutput, AgentSpec};
use crate::patterns::swarm::{Swarm, SwarmMode};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::instrument;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SupervisorResult {
pub task: String,
pub rounds: Vec<Vec<AgentOutput>>,
pub supervisor_feedback: Vec<String>,
pub final_answer: String,
pub approved: bool,
}
impl SupervisorResult {
pub fn total_rounds(&self) -> usize {
self.rounds.len()
}
}
pub struct Supervisor {
pub workers: Vec<AgentSpec>,
pub supervisor: AgentSpec,
pub max_rounds: u32,
}
impl Supervisor {
pub fn new(workers: Vec<AgentSpec>, supervisor: AgentSpec) -> Self {
Self {
workers,
supervisor,
max_rounds: 3,
}
}
pub fn with_max_rounds(mut self, max_rounds: u32) -> Self {
self.max_rounds = max_rounds;
self
}
#[instrument(name = "multi.supervisor", skip_all)]
pub async fn run(
&self,
task: &str,
runner: &Arc<dyn AgentRunner>,
infra: &SharedInfra,
) -> Result<SupervisorResult, MultiError> {
let mut result = SupervisorResult {
task: task.to_string(),
rounds: Vec::new(),
supervisor_feedback: Vec::new(),
final_answer: String::new(),
approved: false,
};
let mut current_task = task.to_string();
for round_num in 0..self.max_rounds {
let swarm = Swarm::new(self.workers.clone(), SwarmMode::Parallel);
let swarm_result = swarm.run(¤t_task, runner, infra).await?;
result.rounds.push(swarm_result.outputs.clone());
let worker_summaries: Vec<String> = swarm_result
.outputs
.iter()
.filter(|o| o.succeeded())
.map(|o| format!("- {}: {}", o.name, truncate(&o.answer, 300)))
.collect();
let review_task = format!(
"Original task: {}\n\nRound {} results:\n{}\n\n\
Review these results. If they are satisfactory, respond with \
APPROVED followed by a final summary. Otherwise, provide specific \
feedback for improvement.",
task,
round_num + 1,
worker_summaries.join("\n")
);
let mailbox = Mailbox::default();
let rt = infra.make_runtime();
let feedback_output = runner.run(&self.supervisor, &review_task, &rt, &mailbox).await?;
let feedback = feedback_output.answer;
result.supervisor_feedback.push(feedback.clone());
if feedback.to_uppercase().contains("APPROVED") {
let answer = strip_approved_prefix(&feedback);
result.final_answer = answer;
result.approved = true;
return Ok(result);
}
current_task = format!(
"{}\n\nSupervisor feedback from round {}:\n{}",
task,
round_num + 1,
feedback
);
}
result.final_answer = format!(
"[max supervision rounds reached] {}",
result.supervisor_feedback.last().unwrap_or(&String::new())
);
Ok(result)
}
}
fn strip_approved_prefix(s: &str) -> String {
let upper = s.to_uppercase();
for prefix in &["APPROVED:", "APPROVED.", "APPROVED\n", "APPROVED "] {
if upper.starts_with(prefix) {
return s[prefix.len()..].trim().to_string();
}
}
s.to_string()
}
fn truncate(s: &str, max_len: usize) -> &str {
if s.len() <= max_len {
return s;
}
let mut end = max_len;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{AgentOutput, AgentSpec};
use car_engine::Runtime;
use std::sync::atomic::{AtomicU32, Ordering};
struct ApprovingRunner {
call_count: AtomicU32,
}
#[async_trait::async_trait]
impl AgentRunner for ApprovingRunner {
async fn run(
&self,
spec: &AgentSpec,
_task: &str,
_runtime: &Runtime,
_mailbox: &Mailbox,
) -> Result<AgentOutput, MultiError> {
let _n = self.call_count.fetch_add(1, Ordering::SeqCst);
let answer = if spec.name == "supervisor" {
"APPROVED: Everything looks good.".to_string()
} else {
format!("work from {}", spec.name)
};
Ok(AgentOutput {
name: spec.name.clone(),
answer,
turns: 1,
tool_calls: 0,
duration_ms: 5.0,
error: None,
outcome: None,
tokens: None,
})
}
}
#[tokio::test]
async fn test_supervisor_approves_round_1() {
let workers = vec![
AgentSpec::new("coder", "Write code"),
AgentSpec::new("tester", "Write tests"),
];
let supervisor_spec = AgentSpec::new("supervisor", "Review and coordinate");
let runner: Arc<dyn AgentRunner> = Arc::new(ApprovingRunner {
call_count: AtomicU32::new(0),
});
let infra = SharedInfra::new();
let result = Supervisor::new(workers, supervisor_spec)
.run("build fibonacci", &runner, &infra)
.await
.unwrap();
assert!(result.approved);
assert_eq!(result.total_rounds(), 1);
assert_eq!(result.final_answer, "Everything looks good.");
}
}