car_multi/patterns/
supervisor.rs1use crate::error::MultiError;
8use crate::mailbox::Mailbox;
9use crate::patterns::swarm::{Swarm, SwarmMode};
10use crate::runner::AgentRunner;
11use crate::shared::SharedInfra;
12use crate::types::{AgentOutput, AgentSpec};
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15use tracing::instrument;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SupervisorResult {
19 pub task: String,
20 pub rounds: Vec<Vec<AgentOutput>>,
21 pub supervisor_feedback: Vec<String>,
22 pub final_answer: String,
23 pub approved: bool,
24}
25
26impl SupervisorResult {
27 pub fn total_rounds(&self) -> usize {
28 self.rounds.len()
29 }
30}
31
32pub struct Supervisor {
33 pub workers: Vec<AgentSpec>,
34 pub supervisor: AgentSpec,
35 pub max_rounds: u32,
36}
37
38impl Supervisor {
39 pub fn new(workers: Vec<AgentSpec>, supervisor: AgentSpec) -> Self {
40 Self {
41 workers,
42 supervisor,
43 max_rounds: 3,
44 }
45 }
46
47 pub fn with_max_rounds(mut self, max_rounds: u32) -> Self {
48 self.max_rounds = max_rounds;
49 self
50 }
51
52 #[instrument(name = "multi.supervisor", skip_all)]
53 pub async fn run(
54 &self,
55 task: &str,
56 runner: &Arc<dyn AgentRunner>,
57 infra: &SharedInfra,
58 ) -> Result<SupervisorResult, MultiError> {
59 let mut result = SupervisorResult {
60 task: task.to_string(),
61 rounds: Vec::new(),
62 supervisor_feedback: Vec::new(),
63 final_answer: String::new(),
64 approved: false,
65 };
66
67 let mut current_task = task.to_string();
68
69 for round_num in 0..self.max_rounds {
70 let swarm = Swarm::new(self.workers.clone(), SwarmMode::Parallel);
72 let swarm_result = swarm.run(¤t_task, runner, infra).await?;
73 result.rounds.push(swarm_result.outputs.clone());
74
75 let worker_summaries: Vec<String> = swarm_result
77 .outputs
78 .iter()
79 .filter(|o| o.succeeded())
80 .map(|o| format!("- {}: {}", o.name, truncate(&o.answer, 300)))
81 .collect();
82
83 let review_task = format!(
84 "Original task: {}\n\nRound {} results:\n{}\n\n\
85 Review these results. If they are satisfactory, respond with \
86 APPROVED followed by a final summary. Otherwise, provide specific \
87 feedback for improvement.",
88 task,
89 round_num + 1,
90 worker_summaries.join("\n")
91 );
92
93 if let Err(e) = infra.begin_agent() {
96 let note = e.to_string();
97 result.supervisor_feedback.push(note.clone());
98 result.final_answer = format!("[coordination budget exhausted] {}", note);
99 return Ok(result);
100 }
101
102 let mailbox = Mailbox::default();
103 let rt = infra.make_runtime();
104 let feedback_output = runner
105 .run(&self.supervisor, &review_task, &rt, &mailbox)
106 .await?;
107 infra.record_output(&feedback_output);
108 let feedback = feedback_output.answer;
109 result.supervisor_feedback.push(feedback.clone());
110
111 if feedback.to_uppercase().contains("APPROVED") {
112 let answer = strip_approved_prefix(&feedback);
114 result.final_answer = answer;
115 result.approved = true;
116 return Ok(result);
117 }
118
119 current_task = format!(
121 "{}\n\nSupervisor feedback from round {}:\n{}",
122 task,
123 round_num + 1,
124 feedback
125 );
126 }
127
128 result.final_answer = format!(
130 "[max supervision rounds reached] {}",
131 result.supervisor_feedback.last().unwrap_or(&String::new())
132 );
133 Ok(result)
134 }
135}
136
137fn strip_approved_prefix(s: &str) -> String {
138 let upper = s.to_uppercase();
139 for prefix in &["APPROVED:", "APPROVED.", "APPROVED\n", "APPROVED "] {
140 if upper.starts_with(prefix) {
141 return s[prefix.len()..].trim().to_string();
142 }
143 }
144 s.to_string()
145}
146
147fn truncate(s: &str, max_len: usize) -> &str {
148 if s.len() <= max_len {
149 return s;
150 }
151 let mut end = max_len;
152 while end > 0 && !s.is_char_boundary(end) {
153 end -= 1;
154 }
155 &s[..end]
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::types::{AgentOutput, AgentSpec};
162 use car_engine::Runtime;
163 use std::sync::atomic::{AtomicU32, Ordering};
164
165 struct ApprovingRunner {
166 call_count: AtomicU32,
167 }
168
169 #[async_trait::async_trait]
170 impl AgentRunner for ApprovingRunner {
171 async fn run(
172 &self,
173 spec: &AgentSpec,
174 _task: &str,
175 _runtime: &Runtime,
176 _mailbox: &Mailbox,
177 ) -> Result<AgentOutput, MultiError> {
178 let _n = self.call_count.fetch_add(1, Ordering::SeqCst);
179 let answer = if spec.name == "supervisor" {
181 "APPROVED: Everything looks good.".to_string()
182 } else {
183 format!("work from {}", spec.name)
184 };
185 Ok(AgentOutput {
186 name: spec.name.clone(),
187 answer,
188 turns: 1,
189 tool_calls: 0,
190 duration_ms: 5.0,
191 error: None,
192 outcome: None,
193 tokens: None,
194 })
195 }
196 }
197
198 #[tokio::test]
199 async fn test_supervisor_approves_round_1() {
200 let workers = vec![
201 AgentSpec::new("coder", "Write code"),
202 AgentSpec::new("tester", "Write tests"),
203 ];
204 let supervisor_spec = AgentSpec::new("supervisor", "Review and coordinate");
205
206 let runner: Arc<dyn AgentRunner> = Arc::new(ApprovingRunner {
207 call_count: AtomicU32::new(0),
208 });
209 let infra = SharedInfra::new();
210
211 let result = Supervisor::new(workers, supervisor_spec)
212 .run("build fibonacci", &runner, &infra)
213 .await
214 .unwrap();
215
216 assert!(result.approved);
217 assert_eq!(result.total_rounds(), 1);
218 assert_eq!(result.final_answer, "Everything looks good.");
219 }
220}