Skip to main content

car_multi/patterns/
supervisor.rs

1//! Supervisor — one agent monitors and controls workers.
2//!
3//! Each round: workers execute in parallel → supervisor reviews → feedback or approve.
4//! Workers get the supervisor's feedback in the next round.
5//! Stops when supervisor says "APPROVED" or max rounds reached.
6
7use crate::error::MultiError;
8use crate::mailbox::Mailbox;
9use crate::runner::AgentRunner;
10use crate::shared::SharedInfra;
11use crate::types::{AgentOutput, AgentSpec};
12use crate::patterns::swarm::{Swarm, SwarmMode};
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SupervisorResult {
18    pub task: String,
19    pub rounds: Vec<Vec<AgentOutput>>,
20    pub supervisor_feedback: Vec<String>,
21    pub final_answer: String,
22    pub approved: bool,
23}
24
25impl SupervisorResult {
26    pub fn total_rounds(&self) -> usize {
27        self.rounds.len()
28    }
29}
30
31pub struct Supervisor {
32    pub workers: Vec<AgentSpec>,
33    pub supervisor: AgentSpec,
34    pub max_rounds: u32,
35}
36
37impl Supervisor {
38    pub fn new(workers: Vec<AgentSpec>, supervisor: AgentSpec) -> Self {
39        Self {
40            workers,
41            supervisor,
42            max_rounds: 3,
43        }
44    }
45
46    pub fn with_max_rounds(mut self, max_rounds: u32) -> Self {
47        self.max_rounds = max_rounds;
48        self
49    }
50
51    pub async fn run(
52        &self,
53        task: &str,
54        runner: &Arc<dyn AgentRunner>,
55        infra: &SharedInfra,
56    ) -> Result<SupervisorResult, MultiError> {
57        let mut result = SupervisorResult {
58            task: task.to_string(),
59            rounds: Vec::new(),
60            supervisor_feedback: Vec::new(),
61            final_answer: String::new(),
62            approved: false,
63        };
64
65        let mut current_task = task.to_string();
66
67        for round_num in 0..self.max_rounds {
68            // Workers execute in parallel
69            let swarm = Swarm::new(self.workers.clone(), SwarmMode::Parallel);
70            let swarm_result = swarm.run(&current_task, runner, infra).await?;
71            result.rounds.push(swarm_result.outputs.clone());
72
73            // Supervisor reviews
74            let worker_summaries: Vec<String> = swarm_result
75                .outputs
76                .iter()
77                .filter(|o| o.succeeded())
78                .map(|o| format!("- {}: {}", o.name, truncate(&o.answer, 300)))
79                .collect();
80
81            let review_task = format!(
82                "Original task: {}\n\nRound {} results:\n{}\n\n\
83                 Review these results. If they are satisfactory, respond with \
84                 APPROVED followed by a final summary. Otherwise, provide specific \
85                 feedback for improvement.",
86                task,
87                round_num + 1,
88                worker_summaries.join("\n")
89            );
90
91            let mailbox = Mailbox::default();
92            let rt = infra.make_runtime();
93            let feedback_output = runner.run(&self.supervisor, &review_task, &rt, &mailbox).await?;
94            let feedback = feedback_output.answer;
95            result.supervisor_feedback.push(feedback.clone());
96
97            if feedback.to_uppercase().contains("APPROVED") {
98                // Extract answer after APPROVED marker
99                let answer = strip_approved_prefix(&feedback);
100                result.final_answer = answer;
101                result.approved = true;
102                return Ok(result);
103            }
104
105            // Feed supervisor's feedback back to workers
106            current_task = format!(
107                "{}\n\nSupervisor feedback from round {}:\n{}",
108                task,
109                round_num + 1,
110                feedback
111            );
112        }
113
114        // Max rounds reached
115        result.final_answer = format!(
116            "[max supervision rounds reached] {}",
117            result.supervisor_feedback.last().unwrap_or(&String::new())
118        );
119        Ok(result)
120    }
121}
122
123fn strip_approved_prefix(s: &str) -> String {
124    let upper = s.to_uppercase();
125    for prefix in &["APPROVED:", "APPROVED.", "APPROVED\n", "APPROVED "] {
126        if upper.starts_with(prefix) {
127            return s[prefix.len()..].trim().to_string();
128        }
129    }
130    s.to_string()
131}
132
133fn truncate(s: &str, max_len: usize) -> &str {
134    if s.len() <= max_len {
135        return s;
136    }
137    let mut end = max_len;
138    while end > 0 && !s.is_char_boundary(end) {
139        end -= 1;
140    }
141    &s[..end]
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::types::{AgentOutput, AgentSpec};
148    use car_engine::Runtime;
149    use std::sync::atomic::{AtomicU32, Ordering};
150
151    struct ApprovingRunner {
152        call_count: AtomicU32,
153    }
154
155    #[async_trait::async_trait]
156    impl AgentRunner for ApprovingRunner {
157        async fn run(
158            &self,
159            spec: &AgentSpec,
160            _task: &str,
161            _runtime: &Runtime,
162            _mailbox: &Mailbox,
163        ) -> Result<AgentOutput, MultiError> {
164            let _n = self.call_count.fetch_add(1, Ordering::SeqCst);
165            // Supervisor (runs after workers) approves on first review
166            let answer = if spec.name == "supervisor" {
167                "APPROVED: Everything looks good.".to_string()
168            } else {
169                format!("work from {}", spec.name)
170            };
171            Ok(AgentOutput {
172                name: spec.name.clone(),
173                answer,
174                turns: 1,
175                tool_calls: 0,
176                duration_ms: 5.0,
177                error: None,
178            })
179        }
180    }
181
182    #[tokio::test]
183    async fn test_supervisor_approves_round_1() {
184        let workers = vec![
185            AgentSpec::new("coder", "Write code"),
186            AgentSpec::new("tester", "Write tests"),
187        ];
188        let supervisor_spec = AgentSpec::new("supervisor", "Review and coordinate");
189
190        let runner: Arc<dyn AgentRunner> = Arc::new(ApprovingRunner {
191            call_count: AtomicU32::new(0),
192        });
193        let infra = SharedInfra::new();
194
195        let result = Supervisor::new(workers, supervisor_spec)
196            .run("build fibonacci", &runner, &infra)
197            .await
198            .unwrap();
199
200        assert!(result.approved);
201        assert_eq!(result.total_rounds(), 1);
202        assert_eq!(result.final_answer, "Everything looks good.");
203    }
204}