car_multi/patterns/
supervisor.rs1use crate::error::MultiError;
8use crate::mailbox::Mailbox;
9use crate::runner::AgentRunner;
10use crate::shared::SharedInfra;
11use crate::types::{AgentOutput, AgentSpec};
12use crate::patterns::swarm::{Swarm, SwarmMode};
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SupervisorResult {
18 pub task: String,
19 pub rounds: Vec<Vec<AgentOutput>>,
20 pub supervisor_feedback: Vec<String>,
21 pub final_answer: String,
22 pub approved: bool,
23}
24
25impl SupervisorResult {
26 pub fn total_rounds(&self) -> usize {
27 self.rounds.len()
28 }
29}
30
31pub struct Supervisor {
32 pub workers: Vec<AgentSpec>,
33 pub supervisor: AgentSpec,
34 pub max_rounds: u32,
35}
36
37impl Supervisor {
38 pub fn new(workers: Vec<AgentSpec>, supervisor: AgentSpec) -> Self {
39 Self {
40 workers,
41 supervisor,
42 max_rounds: 3,
43 }
44 }
45
46 pub fn with_max_rounds(mut self, max_rounds: u32) -> Self {
47 self.max_rounds = max_rounds;
48 self
49 }
50
51 pub async fn run(
52 &self,
53 task: &str,
54 runner: &Arc<dyn AgentRunner>,
55 infra: &SharedInfra,
56 ) -> Result<SupervisorResult, MultiError> {
57 let mut result = SupervisorResult {
58 task: task.to_string(),
59 rounds: Vec::new(),
60 supervisor_feedback: Vec::new(),
61 final_answer: String::new(),
62 approved: false,
63 };
64
65 let mut current_task = task.to_string();
66
67 for round_num in 0..self.max_rounds {
68 let swarm = Swarm::new(self.workers.clone(), SwarmMode::Parallel);
70 let swarm_result = swarm.run(¤t_task, runner, infra).await?;
71 result.rounds.push(swarm_result.outputs.clone());
72
73 let worker_summaries: Vec<String> = swarm_result
75 .outputs
76 .iter()
77 .filter(|o| o.succeeded())
78 .map(|o| format!("- {}: {}", o.name, truncate(&o.answer, 300)))
79 .collect();
80
81 let review_task = format!(
82 "Original task: {}\n\nRound {} results:\n{}\n\n\
83 Review these results. If they are satisfactory, respond with \
84 APPROVED followed by a final summary. Otherwise, provide specific \
85 feedback for improvement.",
86 task,
87 round_num + 1,
88 worker_summaries.join("\n")
89 );
90
91 let mailbox = Mailbox::default();
92 let rt = infra.make_runtime();
93 let feedback_output = runner.run(&self.supervisor, &review_task, &rt, &mailbox).await?;
94 let feedback = feedback_output.answer;
95 result.supervisor_feedback.push(feedback.clone());
96
97 if feedback.to_uppercase().contains("APPROVED") {
98 let answer = strip_approved_prefix(&feedback);
100 result.final_answer = answer;
101 result.approved = true;
102 return Ok(result);
103 }
104
105 current_task = format!(
107 "{}\n\nSupervisor feedback from round {}:\n{}",
108 task,
109 round_num + 1,
110 feedback
111 );
112 }
113
114 result.final_answer = format!(
116 "[max supervision rounds reached] {}",
117 result.supervisor_feedback.last().unwrap_or(&String::new())
118 );
119 Ok(result)
120 }
121}
122
123fn strip_approved_prefix(s: &str) -> String {
124 let upper = s.to_uppercase();
125 for prefix in &["APPROVED:", "APPROVED.", "APPROVED\n", "APPROVED "] {
126 if upper.starts_with(prefix) {
127 return s[prefix.len()..].trim().to_string();
128 }
129 }
130 s.to_string()
131}
132
133fn truncate(s: &str, max_len: usize) -> &str {
134 if s.len() <= max_len {
135 return s;
136 }
137 let mut end = max_len;
138 while end > 0 && !s.is_char_boundary(end) {
139 end -= 1;
140 }
141 &s[..end]
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::types::{AgentOutput, AgentSpec};
148 use car_engine::Runtime;
149 use std::sync::atomic::{AtomicU32, Ordering};
150
151 struct ApprovingRunner {
152 call_count: AtomicU32,
153 }
154
155 #[async_trait::async_trait]
156 impl AgentRunner for ApprovingRunner {
157 async fn run(
158 &self,
159 spec: &AgentSpec,
160 _task: &str,
161 _runtime: &Runtime,
162 _mailbox: &Mailbox,
163 ) -> Result<AgentOutput, MultiError> {
164 let _n = self.call_count.fetch_add(1, Ordering::SeqCst);
165 let answer = if spec.name == "supervisor" {
167 "APPROVED: Everything looks good.".to_string()
168 } else {
169 format!("work from {}", spec.name)
170 };
171 Ok(AgentOutput {
172 name: spec.name.clone(),
173 answer,
174 turns: 1,
175 tool_calls: 0,
176 duration_ms: 5.0,
177 error: None,
178 })
179 }
180 }
181
182 #[tokio::test]
183 async fn test_supervisor_approves_round_1() {
184 let workers = vec![
185 AgentSpec::new("coder", "Write code"),
186 AgentSpec::new("tester", "Write tests"),
187 ];
188 let supervisor_spec = AgentSpec::new("supervisor", "Review and coordinate");
189
190 let runner: Arc<dyn AgentRunner> = Arc::new(ApprovingRunner {
191 call_count: AtomicU32::new(0),
192 });
193 let infra = SharedInfra::new();
194
195 let result = Supervisor::new(workers, supervisor_spec)
196 .run("build fibonacci", &runner, &infra)
197 .await
198 .unwrap();
199
200 assert!(result.approved);
201 assert_eq!(result.total_rounds(), 1);
202 assert_eq!(result.final_answer, "Everything looks good.");
203 }
204}