1use crate::error::MultiError;
7use crate::mailbox::Mailbox;
8use crate::runner::AgentRunner;
9use crate::shared::SharedInfra;
10use crate::types::{AgentOutput, AgentSpec};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::task::JoinSet;
15use tracing::instrument;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MapReduceResult {
19 pub task: String,
20 pub map_outputs: Vec<AgentOutput>,
21 pub reduced_answer: String,
22}
23
24impl MapReduceResult {
25 pub fn all_succeeded(&self) -> bool {
26 self.map_outputs.iter().all(|o| o.succeeded())
27 }
28}
29
30pub struct MapReduce {
31 pub mapper: AgentSpec,
32 pub reducer: AgentSpec,
33 pub max_concurrent: usize,
34}
35
36impl MapReduce {
37 pub fn new(mapper: AgentSpec, reducer: AgentSpec) -> Self {
38 Self {
39 mapper,
40 reducer,
41 max_concurrent: 5,
42 }
43 }
44
45 pub fn with_max_concurrent(mut self, n: usize) -> Self {
46 self.max_concurrent = n;
47 self
48 }
49
50 #[instrument(name = "multi.map_reduce", skip_all)]
51 pub async fn run(
52 &self,
53 task: &str,
54 items: &[String],
55 runner: &Arc<dyn AgentRunner>,
56 infra: &SharedInfra,
57 ) -> Result<MapReduceResult, MultiError> {
58 let semaphore = Arc::new(tokio::sync::Semaphore::new(self.max_concurrent));
60 let mut handles = JoinSet::new();
61 let mut task_indices = HashMap::new();
62 let mut indexed: Vec<(usize, AgentOutput)> = Vec::new();
63
64 for (i, item) in items.iter().enumerate() {
65 if let Err(e) = infra.begin_agent() {
68 let name = format!("{}_{}", self.mapper.name, i);
69 indexed.push((i, crate::budget::budget_skipped_output(&name, &e)));
70 continue;
71 }
72
73 let sem = Arc::clone(&semaphore);
74 let runner = Arc::clone(runner);
75 let rt = infra.make_runtime();
76 let mailbox = Mailbox::default();
77
78 let mut spec = self.mapper.clone();
79 spec.name = format!("{}_{}", self.mapper.name, i);
80
81 for tool in &spec.tools {
82 rt.register_tool(tool).await;
83 }
84
85 let subtask = format!("{}\n\nProcess this item: {}", task, item);
86
87 let handle = handles.spawn(async move {
88 let _permit = sem.acquire().await.unwrap();
89 (i, runner.run(&spec, &subtask, &rt, &mailbox).await)
90 });
91 task_indices.insert(handle.id(), i);
92 }
93
94 while let Some(result) = handles.join_next().await {
95 match result {
96 Ok((i, Ok(output))) => {
97 infra.record_output(&output);
98 indexed.push((i, output));
99 }
100 Ok((i, Err(e))) => {
101 indexed.push((
104 i,
105 AgentOutput {
106 name: format!("{}_{}", self.mapper.name, i),
107 answer: String::new(),
108 turns: 0,
109 tool_calls: 0,
110 duration_ms: 0.0,
111 error: Some(e.to_string()),
112 outcome: None,
113 tokens: None,
114 },
115 ));
116 }
117 Err(e) => {
118 let i = task_indices
119 .get(&e.id())
120 .copied()
121 .expect("mapper task id should be tracked");
122 indexed.push((
123 i,
124 AgentOutput {
125 name: format!("{}_{}", self.mapper.name, i),
126 answer: String::new(),
127 turns: 0,
128 tool_calls: 0,
129 duration_ms: 0.0,
130 error: Some(format!("join error: {}", e)),
131 outcome: None,
132 tokens: None,
133 },
134 ));
135 }
136 }
137 }
138
139 indexed.sort_by_key(|(i, _)| *i);
140 let map_outputs: Vec<AgentOutput> = indexed.into_iter().map(|(_, o)| o).collect();
141
142 let summaries: Vec<String> = map_outputs
144 .iter()
145 .filter(|o| o.succeeded())
146 .map(|o| format!("- [{}] {}", o.name, truncate(&o.answer, 300)))
147 .collect();
148
149 let reduce_task = format!(
150 "Original task: {}\n\nResults from {} sub-agents:\n{}\n\n\
151 Combine these into a single coherent result.",
152 task,
153 map_outputs.len(),
154 summaries.join("\n")
155 );
156
157 let reduced = if infra.begin_agent().is_ok() {
160 let rt = infra.make_runtime();
161 let mailbox = Mailbox::default();
162 match runner.run(&self.reducer, &reduce_task, &rt, &mailbox).await {
163 Ok(o) => {
164 infra.record_output(&o);
165 o.answer
166 }
167 Err(_) => String::new(),
168 }
169 } else {
170 summaries.join("\n")
171 };
172
173 Ok(MapReduceResult {
174 task: task.to_string(),
175 map_outputs,
176 reduced_answer: reduced,
177 })
178 }
179}
180
181fn truncate(s: &str, max_len: usize) -> &str {
182 if s.len() <= max_len {
183 return s;
184 }
185 let mut end = max_len;
186 while end > 0 && !s.is_char_boundary(end) {
187 end -= 1;
188 }
189 &s[..end]
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::types::{AgentOutput, AgentSpec};
196 use car_engine::Runtime;
197 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
198 use tokio::sync::Notify;
199
200 struct CountRunner;
201
202 #[async_trait::async_trait]
203 impl crate::runner::AgentRunner for CountRunner {
204 async fn run(
205 &self,
206 spec: &AgentSpec,
207 _task: &str,
208 _runtime: &Runtime,
209 _mailbox: &Mailbox,
210 ) -> Result<AgentOutput, MultiError> {
211 Ok(AgentOutput {
212 name: spec.name.clone(),
213 answer: format!("{} processed", spec.name),
214 turns: 1,
215 tool_calls: 0,
216 duration_ms: 5.0,
217 error: None,
218 outcome: None,
219 tokens: None,
220 })
221 }
222 }
223
224 #[tokio::test]
225 async fn test_map_reduce() {
226 let mapper = AgentSpec::new("summarizer", "Summarize the file");
227 let reducer = AgentSpec::new("combiner", "Combine summaries");
228 let items: Vec<String> = vec!["file_a.rs", "file_b.rs", "file_c.rs"]
229 .into_iter()
230 .map(String::from)
231 .collect();
232
233 let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(CountRunner);
234 let infra = SharedInfra::new();
235
236 let result = MapReduce::new(mapper, reducer)
237 .run("summarize codebase", &items, &runner, &infra)
238 .await
239 .unwrap();
240
241 assert_eq!(result.map_outputs.len(), 3);
242 assert!(!result.reduced_answer.is_empty());
243 }
244
245 struct LaterPanicRunner {
246 later_panicked: Arc<AtomicBool>,
247 notify: Arc<Notify>,
248 }
249
250 #[async_trait::async_trait]
251 impl crate::runner::AgentRunner for LaterPanicRunner {
252 async fn run(
253 &self,
254 spec: &AgentSpec,
255 _task: &str,
256 _runtime: &Runtime,
257 _mailbox: &Mailbox,
258 ) -> Result<AgentOutput, MultiError> {
259 match spec.name.as_str() {
260 "mapper_0" => {
261 while !self.later_panicked.load(Ordering::SeqCst) {
262 self.notify.notified().await;
263 }
264 Ok(AgentOutput {
265 name: spec.name.clone(),
266 answer: "mapper 0 completed".to_string(),
267 turns: 1,
268 tool_calls: 0,
269 duration_ms: 5.0,
270 error: None,
271 outcome: None,
272 tokens: None,
273 })
274 }
275 "mapper_1" => {
276 self.later_panicked.store(true, Ordering::SeqCst);
277 self.notify.notify_one();
278 panic!("mapper 1 panicked first");
279 }
280 _ => Ok(AgentOutput {
281 name: spec.name.clone(),
282 answer: "reduced".to_string(),
283 turns: 1,
284 tool_calls: 0,
285 duration_ms: 5.0,
286 error: None,
287 outcome: None,
288 tokens: None,
289 }),
290 }
291 }
292 }
293
294 #[tokio::test]
295 async fn panicking_mapper_keeps_original_item_index() {
296 let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(LaterPanicRunner {
297 later_panicked: Arc::new(AtomicBool::new(false)),
298 notify: Arc::new(Notify::new()),
299 });
300 let infra = SharedInfra::new();
301 let items = vec!["slow first".to_string(), "fast panic".to_string()];
302
303 let result = MapReduce::new(
304 AgentSpec::new("mapper", "map item"),
305 AgentSpec::new("reducer", "reduce items"),
306 )
307 .with_max_concurrent(2)
308 .run("preserve mapper order", &items, &runner, &infra)
309 .await
310 .unwrap();
311
312 assert_eq!(result.map_outputs.len(), 2);
313 assert_eq!(result.map_outputs[0].name, "mapper_0");
314 assert!(result.map_outputs[0].succeeded());
315 assert_eq!(result.map_outputs[1].name, "mapper_1");
316 assert!(
317 result.map_outputs[1]
318 .error
319 .as_deref()
320 .is_some_and(|error| error.contains("panicked")),
321 "expected mapper_1 to carry the panic error, got {:?}",
322 result.map_outputs[1].error
323 );
324 }
325
326 struct DropCountingRunner {
327 started: Arc<AtomicUsize>,
328 dropped: Arc<AtomicUsize>,
329 notify: Arc<Notify>,
330 }
331
332 struct DropGuard(Arc<AtomicUsize>);
333
334 impl Drop for DropGuard {
335 fn drop(&mut self) {
336 self.0.fetch_add(1, Ordering::SeqCst);
337 }
338 }
339
340 #[async_trait::async_trait]
341 impl crate::runner::AgentRunner for DropCountingRunner {
342 async fn run(
343 &self,
344 _spec: &AgentSpec,
345 _task: &str,
346 _runtime: &Runtime,
347 _mailbox: &Mailbox,
348 ) -> Result<AgentOutput, MultiError> {
349 let _guard = DropGuard(self.dropped.clone());
350 self.started.fetch_add(1, Ordering::SeqCst);
351 self.notify.notify_one();
352 std::future::pending::<Result<AgentOutput, MultiError>>().await
353 }
354 }
355
356 #[tokio::test]
357 async fn dropping_map_reduce_run_aborts_mapper_tasks() {
358 let started = Arc::new(AtomicUsize::new(0));
359 let dropped = Arc::new(AtomicUsize::new(0));
360 let notify = Arc::new(Notify::new());
361 let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(DropCountingRunner {
362 started: started.clone(),
363 dropped: dropped.clone(),
364 notify: notify.clone(),
365 });
366 let infra = SharedInfra::new();
367 let items = vec!["one".to_string(), "two".to_string()];
368
369 let handle = tokio::spawn(async move {
370 MapReduce::new(
371 AgentSpec::new("worker", "run work"),
372 AgentSpec::new("reducer", "reduce work"),
373 )
374 .with_max_concurrent(2)
375 .run("parallel goal", &items, &runner, &infra)
376 .await
377 });
378
379 while started.load(Ordering::SeqCst) < 2 {
380 notify.notified().await;
381 }
382
383 handle.abort();
384 assert!(handle.await.unwrap_err().is_cancelled());
385
386 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
387 while std::time::Instant::now() < deadline {
388 if dropped.load(Ordering::SeqCst) >= 2 {
389 return;
390 }
391 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
392 }
393 panic!(
394 "mapper futures were detached after MapReduce cancellation; dropped={}",
395 dropped.load(Ordering::SeqCst)
396 );
397 }
398}