1use crate::error::MultiError;
7use crate::mailbox::Mailbox;
8use crate::runner::AgentRunner;
9use crate::shared::SharedInfra;
10use crate::types::{AgentOutput, AgentSpec};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::task::JoinSet;
15use tracing::instrument;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MapReduceResult {
19 pub task: String,
20 pub map_outputs: Vec<AgentOutput>,
21 pub reduced_answer: String,
22}
23
24impl MapReduceResult {
25 pub fn all_succeeded(&self) -> bool {
26 self.map_outputs.iter().all(|o| o.succeeded())
27 }
28}
29
30pub struct MapReduce {
31 pub mapper: AgentSpec,
32 pub reducer: AgentSpec,
33 pub max_concurrent: usize,
34}
35
36impl MapReduce {
37 pub fn new(mapper: AgentSpec, reducer: AgentSpec) -> Self {
38 Self {
39 mapper,
40 reducer,
41 max_concurrent: 5,
42 }
43 }
44
45 pub fn with_max_concurrent(mut self, n: usize) -> Self {
46 self.max_concurrent = n;
47 self
48 }
49
50 #[instrument(name = "multi.map_reduce", skip_all)]
51 pub async fn run(
52 &self,
53 task: &str,
54 items: &[String],
55 runner: &Arc<dyn AgentRunner>,
56 infra: &SharedInfra,
57 ) -> Result<MapReduceResult, MultiError> {
58 let semaphore = Arc::new(tokio::sync::Semaphore::new(self.max_concurrent));
60 let mut handles = JoinSet::new();
61 let mut task_indices = HashMap::new();
62
63 for (i, item) in items.iter().enumerate() {
64 let sem = Arc::clone(&semaphore);
65 let runner = Arc::clone(runner);
66 let rt = infra.make_runtime();
67 let mailbox = Mailbox::default();
68
69 let mut spec = self.mapper.clone();
70 spec.name = format!("{}_{}", self.mapper.name, i);
71
72 for tool in &spec.tools {
73 rt.register_tool(tool).await;
74 }
75
76 let subtask = format!("{}\n\nProcess this item: {}", task, item);
77
78 let handle = handles.spawn(async move {
79 let _permit = sem.acquire().await.unwrap();
80 (i, runner.run(&spec, &subtask, &rt, &mailbox).await)
81 });
82 task_indices.insert(handle.id(), i);
83 }
84
85 let mut indexed: Vec<(usize, AgentOutput)> = Vec::new();
86
87 while let Some(result) = handles.join_next().await {
88 match result {
89 Ok((i, Ok(output))) => indexed.push((i, output)),
90 Ok((i, Err(e))) => {
91 indexed.push((
92 i,
93 AgentOutput {
94 name: format!("{}_{}", self.mapper.name, i),
95 answer: String::new(),
96 turns: 0,
97 tool_calls: 0,
98 duration_ms: 0.0,
99 error: Some(e.to_string()),
100 outcome: None,
101 tokens: None,
102 },
103 ));
104 }
105 Err(e) => {
106 let i = task_indices
107 .get(&e.id())
108 .copied()
109 .expect("mapper task id should be tracked");
110 indexed.push((
111 i,
112 AgentOutput {
113 name: format!("{}_{}", self.mapper.name, i),
114 answer: String::new(),
115 turns: 0,
116 tool_calls: 0,
117 duration_ms: 0.0,
118 error: Some(format!("join error: {}", e)),
119 outcome: None,
120 tokens: None,
121 },
122 ));
123 }
124 }
125 }
126
127 indexed.sort_by_key(|(i, _)| *i);
128 let map_outputs: Vec<AgentOutput> = indexed.into_iter().map(|(_, o)| o).collect();
129
130 let summaries: Vec<String> = map_outputs
132 .iter()
133 .filter(|o| o.succeeded())
134 .map(|o| format!("- [{}] {}", o.name, truncate(&o.answer, 300)))
135 .collect();
136
137 let reduce_task = format!(
138 "Original task: {}\n\nResults from {} sub-agents:\n{}\n\n\
139 Combine these into a single coherent result.",
140 task,
141 map_outputs.len(),
142 summaries.join("\n")
143 );
144
145 let rt = infra.make_runtime();
146 let mailbox = Mailbox::default();
147 let reduced = runner
148 .run(&self.reducer, &reduce_task, &rt, &mailbox)
149 .await
150 .map(|o| o.answer)
151 .unwrap_or_default();
152
153 Ok(MapReduceResult {
154 task: task.to_string(),
155 map_outputs,
156 reduced_answer: reduced,
157 })
158 }
159}
160
161fn truncate(s: &str, max_len: usize) -> &str {
162 if s.len() <= max_len {
163 return s;
164 }
165 let mut end = max_len;
166 while end > 0 && !s.is_char_boundary(end) {
167 end -= 1;
168 }
169 &s[..end]
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::types::{AgentOutput, AgentSpec};
176 use car_engine::Runtime;
177 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
178 use tokio::sync::Notify;
179
180 struct CountRunner;
181
182 #[async_trait::async_trait]
183 impl crate::runner::AgentRunner for CountRunner {
184 async fn run(
185 &self,
186 spec: &AgentSpec,
187 _task: &str,
188 _runtime: &Runtime,
189 _mailbox: &Mailbox,
190 ) -> Result<AgentOutput, MultiError> {
191 Ok(AgentOutput {
192 name: spec.name.clone(),
193 answer: format!("{} processed", spec.name),
194 turns: 1,
195 tool_calls: 0,
196 duration_ms: 5.0,
197 error: None,
198 outcome: None,
199 tokens: None,
200 })
201 }
202 }
203
204 #[tokio::test]
205 async fn test_map_reduce() {
206 let mapper = AgentSpec::new("summarizer", "Summarize the file");
207 let reducer = AgentSpec::new("combiner", "Combine summaries");
208 let items: Vec<String> = vec!["file_a.rs", "file_b.rs", "file_c.rs"]
209 .into_iter()
210 .map(String::from)
211 .collect();
212
213 let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(CountRunner);
214 let infra = SharedInfra::new();
215
216 let result = MapReduce::new(mapper, reducer)
217 .run("summarize codebase", &items, &runner, &infra)
218 .await
219 .unwrap();
220
221 assert_eq!(result.map_outputs.len(), 3);
222 assert!(!result.reduced_answer.is_empty());
223 }
224
225 struct LaterPanicRunner {
226 later_panicked: Arc<AtomicBool>,
227 notify: Arc<Notify>,
228 }
229
230 #[async_trait::async_trait]
231 impl crate::runner::AgentRunner for LaterPanicRunner {
232 async fn run(
233 &self,
234 spec: &AgentSpec,
235 _task: &str,
236 _runtime: &Runtime,
237 _mailbox: &Mailbox,
238 ) -> Result<AgentOutput, MultiError> {
239 match spec.name.as_str() {
240 "mapper_0" => {
241 while !self.later_panicked.load(Ordering::SeqCst) {
242 self.notify.notified().await;
243 }
244 Ok(AgentOutput {
245 name: spec.name.clone(),
246 answer: "mapper 0 completed".to_string(),
247 turns: 1,
248 tool_calls: 0,
249 duration_ms: 5.0,
250 error: None,
251 outcome: None,
252 tokens: None,
253 })
254 }
255 "mapper_1" => {
256 self.later_panicked.store(true, Ordering::SeqCst);
257 self.notify.notify_one();
258 panic!("mapper 1 panicked first");
259 }
260 _ => Ok(AgentOutput {
261 name: spec.name.clone(),
262 answer: "reduced".to_string(),
263 turns: 1,
264 tool_calls: 0,
265 duration_ms: 5.0,
266 error: None,
267 outcome: None,
268 tokens: None,
269 }),
270 }
271 }
272 }
273
274 #[tokio::test]
275 async fn panicking_mapper_keeps_original_item_index() {
276 let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(LaterPanicRunner {
277 later_panicked: Arc::new(AtomicBool::new(false)),
278 notify: Arc::new(Notify::new()),
279 });
280 let infra = SharedInfra::new();
281 let items = vec!["slow first".to_string(), "fast panic".to_string()];
282
283 let result = MapReduce::new(
284 AgentSpec::new("mapper", "map item"),
285 AgentSpec::new("reducer", "reduce items"),
286 )
287 .with_max_concurrent(2)
288 .run("preserve mapper order", &items, &runner, &infra)
289 .await
290 .unwrap();
291
292 assert_eq!(result.map_outputs.len(), 2);
293 assert_eq!(result.map_outputs[0].name, "mapper_0");
294 assert!(result.map_outputs[0].succeeded());
295 assert_eq!(result.map_outputs[1].name, "mapper_1");
296 assert!(
297 result.map_outputs[1]
298 .error
299 .as_deref()
300 .is_some_and(|error| error.contains("panicked")),
301 "expected mapper_1 to carry the panic error, got {:?}",
302 result.map_outputs[1].error
303 );
304 }
305
306 struct DropCountingRunner {
307 started: Arc<AtomicUsize>,
308 dropped: Arc<AtomicUsize>,
309 notify: Arc<Notify>,
310 }
311
312 struct DropGuard(Arc<AtomicUsize>);
313
314 impl Drop for DropGuard {
315 fn drop(&mut self) {
316 self.0.fetch_add(1, Ordering::SeqCst);
317 }
318 }
319
320 #[async_trait::async_trait]
321 impl crate::runner::AgentRunner for DropCountingRunner {
322 async fn run(
323 &self,
324 _spec: &AgentSpec,
325 _task: &str,
326 _runtime: &Runtime,
327 _mailbox: &Mailbox,
328 ) -> Result<AgentOutput, MultiError> {
329 let _guard = DropGuard(self.dropped.clone());
330 self.started.fetch_add(1, Ordering::SeqCst);
331 self.notify.notify_one();
332 std::future::pending::<Result<AgentOutput, MultiError>>().await
333 }
334 }
335
336 #[tokio::test]
337 async fn dropping_map_reduce_run_aborts_mapper_tasks() {
338 let started = Arc::new(AtomicUsize::new(0));
339 let dropped = Arc::new(AtomicUsize::new(0));
340 let notify = Arc::new(Notify::new());
341 let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(DropCountingRunner {
342 started: started.clone(),
343 dropped: dropped.clone(),
344 notify: notify.clone(),
345 });
346 let infra = SharedInfra::new();
347 let items = vec!["one".to_string(), "two".to_string()];
348
349 let handle = tokio::spawn(async move {
350 MapReduce::new(
351 AgentSpec::new("worker", "run work"),
352 AgentSpec::new("reducer", "reduce work"),
353 )
354 .with_max_concurrent(2)
355 .run("parallel goal", &items, &runner, &infra)
356 .await
357 });
358
359 while started.load(Ordering::SeqCst) < 2 {
360 notify.notified().await;
361 }
362
363 handle.abort();
364 assert!(handle.await.unwrap_err().is_cancelled());
365
366 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
367 while std::time::Instant::now() < deadline {
368 if dropped.load(Ordering::SeqCst) >= 2 {
369 return;
370 }
371 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
372 }
373 panic!(
374 "mapper futures were detached after MapReduce cancellation; dropped={}",
375 dropped.load(Ordering::SeqCst)
376 );
377 }
378}