Skip to main content

car_multi/patterns/
map_reduce.rs

1//! Map-Reduce — split a task into N items, run agents in parallel, reduce results.
2//!
3//! The mapper spec is cloned for each item. Each mapper processes one item.
4//! The reducer combines all mapper outputs into a single result.
5
6use crate::error::MultiError;
7use crate::mailbox::Mailbox;
8use crate::runner::AgentRunner;
9use crate::shared::SharedInfra;
10use crate::types::{AgentOutput, AgentSpec};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::task::JoinSet;
15use tracing::instrument;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MapReduceResult {
19    pub task: String,
20    pub map_outputs: Vec<AgentOutput>,
21    pub reduced_answer: String,
22}
23
24impl MapReduceResult {
25    pub fn all_succeeded(&self) -> bool {
26        self.map_outputs.iter().all(|o| o.succeeded())
27    }
28}
29
30pub struct MapReduce {
31    pub mapper: AgentSpec,
32    pub reducer: AgentSpec,
33    pub max_concurrent: usize,
34}
35
36impl MapReduce {
37    pub fn new(mapper: AgentSpec, reducer: AgentSpec) -> Self {
38        Self {
39            mapper,
40            reducer,
41            max_concurrent: 5,
42        }
43    }
44
45    pub fn with_max_concurrent(mut self, n: usize) -> Self {
46        self.max_concurrent = n;
47        self
48    }
49
50    #[instrument(name = "multi.map_reduce", skip_all)]
51    pub async fn run(
52        &self,
53        task: &str,
54        items: &[String],
55        runner: &Arc<dyn AgentRunner>,
56        infra: &SharedInfra,
57    ) -> Result<MapReduceResult, MultiError> {
58        // Map phase: one agent per item, bounded concurrency
59        let semaphore = Arc::new(tokio::sync::Semaphore::new(self.max_concurrent));
60        let mut handles = JoinSet::new();
61        let mut task_indices = HashMap::new();
62        let mut indexed: Vec<(usize, AgentOutput)> = Vec::new();
63
64        for (i, item) in items.iter().enumerate() {
65            // Budget gate per mapper. Denied items are recorded as skipped and
66            // never spawned.
67            if let Err(e) = infra.begin_agent() {
68                let name = format!("{}_{}", self.mapper.name, i);
69                indexed.push((i, crate::budget::budget_skipped_output(&name, &e)));
70                continue;
71            }
72
73            let sem = Arc::clone(&semaphore);
74            let runner = Arc::clone(runner);
75            let rt = infra.make_runtime();
76            let mailbox = Mailbox::default();
77
78            let mut spec = self.mapper.clone();
79            spec.name = format!("{}_{}", self.mapper.name, i);
80
81            for tool in &spec.tools {
82                rt.register_tool(tool).await;
83            }
84
85            let subtask = format!("{}\n\nProcess this item: {}", task, item);
86
87            let handle = handles.spawn(async move {
88                let _permit = sem.acquire().await.unwrap();
89                (i, runner.run(&spec, &subtask, &rt, &mailbox).await)
90            });
91            task_indices.insert(handle.id(), i);
92        }
93
94        while let Some(result) = handles.join_next().await {
95            match result {
96                Ok((i, Ok(output))) => {
97                    infra.record_output(&output);
98                    indexed.push((i, output));
99                }
100                Ok((i, Err(e))) => {
101                    // Spend before a runner Err is not metered (no token payload
102                    // on the error path) — the budget can under-count failures.
103                    indexed.push((
104                        i,
105                        AgentOutput {
106                            name: format!("{}_{}", self.mapper.name, i),
107                            answer: String::new(),
108                            turns: 0,
109                            tool_calls: 0,
110                            duration_ms: 0.0,
111                            error: Some(e.to_string()),
112                            outcome: None,
113                            tokens: None,
114                        },
115                    ));
116                }
117                Err(e) => {
118                    let i = task_indices
119                        .get(&e.id())
120                        .copied()
121                        .expect("mapper task id should be tracked");
122                    indexed.push((
123                        i,
124                        AgentOutput {
125                            name: format!("{}_{}", self.mapper.name, i),
126                            answer: String::new(),
127                            turns: 0,
128                            tool_calls: 0,
129                            duration_ms: 0.0,
130                            error: Some(format!("join error: {}", e)),
131                            outcome: None,
132                            tokens: None,
133                        },
134                    ));
135                }
136            }
137        }
138
139        indexed.sort_by_key(|(i, _)| *i);
140        let map_outputs: Vec<AgentOutput> = indexed.into_iter().map(|(_, o)| o).collect();
141
142        // Reduce phase
143        let summaries: Vec<String> = map_outputs
144            .iter()
145            .filter(|o| o.succeeded())
146            .map(|o| format!("- [{}] {}", o.name, truncate(&o.answer, 300)))
147            .collect();
148
149        let reduce_task = format!(
150            "Original task: {}\n\nResults from {} sub-agents:\n{}\n\n\
151             Combine these into a single coherent result.",
152            task,
153            map_outputs.len(),
154            summaries.join("\n")
155        );
156
157        // Gate the reducer; on denial return the unreduced map summaries so the
158        // caller still gets the mapper work it paid for.
159        let reduced = if infra.begin_agent().is_ok() {
160            let rt = infra.make_runtime();
161            let mailbox = Mailbox::default();
162            match runner.run(&self.reducer, &reduce_task, &rt, &mailbox).await {
163                Ok(o) => {
164                    infra.record_output(&o);
165                    o.answer
166                }
167                Err(_) => String::new(),
168            }
169        } else {
170            summaries.join("\n")
171        };
172
173        Ok(MapReduceResult {
174            task: task.to_string(),
175            map_outputs,
176            reduced_answer: reduced,
177        })
178    }
179}
180
181fn truncate(s: &str, max_len: usize) -> &str {
182    if s.len() <= max_len {
183        return s;
184    }
185    let mut end = max_len;
186    while end > 0 && !s.is_char_boundary(end) {
187        end -= 1;
188    }
189    &s[..end]
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::types::{AgentOutput, AgentSpec};
196    use car_engine::Runtime;
197    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
198    use tokio::sync::Notify;
199
200    struct CountRunner;
201
202    #[async_trait::async_trait]
203    impl crate::runner::AgentRunner for CountRunner {
204        async fn run(
205            &self,
206            spec: &AgentSpec,
207            _task: &str,
208            _runtime: &Runtime,
209            _mailbox: &Mailbox,
210        ) -> Result<AgentOutput, MultiError> {
211            Ok(AgentOutput {
212                name: spec.name.clone(),
213                answer: format!("{} processed", spec.name),
214                turns: 1,
215                tool_calls: 0,
216                duration_ms: 5.0,
217                error: None,
218                outcome: None,
219                tokens: None,
220            })
221        }
222    }
223
224    #[tokio::test]
225    async fn test_map_reduce() {
226        let mapper = AgentSpec::new("summarizer", "Summarize the file");
227        let reducer = AgentSpec::new("combiner", "Combine summaries");
228        let items: Vec<String> = vec!["file_a.rs", "file_b.rs", "file_c.rs"]
229            .into_iter()
230            .map(String::from)
231            .collect();
232
233        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(CountRunner);
234        let infra = SharedInfra::new();
235
236        let result = MapReduce::new(mapper, reducer)
237            .run("summarize codebase", &items, &runner, &infra)
238            .await
239            .unwrap();
240
241        assert_eq!(result.map_outputs.len(), 3);
242        assert!(!result.reduced_answer.is_empty());
243    }
244
245    struct LaterPanicRunner {
246        later_panicked: Arc<AtomicBool>,
247        notify: Arc<Notify>,
248    }
249
250    #[async_trait::async_trait]
251    impl crate::runner::AgentRunner for LaterPanicRunner {
252        async fn run(
253            &self,
254            spec: &AgentSpec,
255            _task: &str,
256            _runtime: &Runtime,
257            _mailbox: &Mailbox,
258        ) -> Result<AgentOutput, MultiError> {
259            match spec.name.as_str() {
260                "mapper_0" => {
261                    while !self.later_panicked.load(Ordering::SeqCst) {
262                        self.notify.notified().await;
263                    }
264                    Ok(AgentOutput {
265                        name: spec.name.clone(),
266                        answer: "mapper 0 completed".to_string(),
267                        turns: 1,
268                        tool_calls: 0,
269                        duration_ms: 5.0,
270                        error: None,
271                        outcome: None,
272                        tokens: None,
273                    })
274                }
275                "mapper_1" => {
276                    self.later_panicked.store(true, Ordering::SeqCst);
277                    self.notify.notify_one();
278                    panic!("mapper 1 panicked first");
279                }
280                _ => Ok(AgentOutput {
281                    name: spec.name.clone(),
282                    answer: "reduced".to_string(),
283                    turns: 1,
284                    tool_calls: 0,
285                    duration_ms: 5.0,
286                    error: None,
287                    outcome: None,
288                    tokens: None,
289                }),
290            }
291        }
292    }
293
294    #[tokio::test]
295    async fn panicking_mapper_keeps_original_item_index() {
296        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(LaterPanicRunner {
297            later_panicked: Arc::new(AtomicBool::new(false)),
298            notify: Arc::new(Notify::new()),
299        });
300        let infra = SharedInfra::new();
301        let items = vec!["slow first".to_string(), "fast panic".to_string()];
302
303        let result = MapReduce::new(
304            AgentSpec::new("mapper", "map item"),
305            AgentSpec::new("reducer", "reduce items"),
306        )
307        .with_max_concurrent(2)
308        .run("preserve mapper order", &items, &runner, &infra)
309        .await
310        .unwrap();
311
312        assert_eq!(result.map_outputs.len(), 2);
313        assert_eq!(result.map_outputs[0].name, "mapper_0");
314        assert!(result.map_outputs[0].succeeded());
315        assert_eq!(result.map_outputs[1].name, "mapper_1");
316        assert!(
317            result.map_outputs[1]
318                .error
319                .as_deref()
320                .is_some_and(|error| error.contains("panicked")),
321            "expected mapper_1 to carry the panic error, got {:?}",
322            result.map_outputs[1].error
323        );
324    }
325
326    struct DropCountingRunner {
327        started: Arc<AtomicUsize>,
328        dropped: Arc<AtomicUsize>,
329        notify: Arc<Notify>,
330    }
331
332    struct DropGuard(Arc<AtomicUsize>);
333
334    impl Drop for DropGuard {
335        fn drop(&mut self) {
336            self.0.fetch_add(1, Ordering::SeqCst);
337        }
338    }
339
340    #[async_trait::async_trait]
341    impl crate::runner::AgentRunner for DropCountingRunner {
342        async fn run(
343            &self,
344            _spec: &AgentSpec,
345            _task: &str,
346            _runtime: &Runtime,
347            _mailbox: &Mailbox,
348        ) -> Result<AgentOutput, MultiError> {
349            let _guard = DropGuard(self.dropped.clone());
350            self.started.fetch_add(1, Ordering::SeqCst);
351            self.notify.notify_one();
352            std::future::pending::<Result<AgentOutput, MultiError>>().await
353        }
354    }
355
356    #[tokio::test]
357    async fn dropping_map_reduce_run_aborts_mapper_tasks() {
358        let started = Arc::new(AtomicUsize::new(0));
359        let dropped = Arc::new(AtomicUsize::new(0));
360        let notify = Arc::new(Notify::new());
361        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(DropCountingRunner {
362            started: started.clone(),
363            dropped: dropped.clone(),
364            notify: notify.clone(),
365        });
366        let infra = SharedInfra::new();
367        let items = vec!["one".to_string(), "two".to_string()];
368
369        let handle = tokio::spawn(async move {
370            MapReduce::new(
371                AgentSpec::new("worker", "run work"),
372                AgentSpec::new("reducer", "reduce work"),
373            )
374            .with_max_concurrent(2)
375            .run("parallel goal", &items, &runner, &infra)
376            .await
377        });
378
379        while started.load(Ordering::SeqCst) < 2 {
380            notify.notified().await;
381        }
382
383        handle.abort();
384        assert!(handle.await.unwrap_err().is_cancelled());
385
386        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
387        while std::time::Instant::now() < deadline {
388            if dropped.load(Ordering::SeqCst) >= 2 {
389                return;
390            }
391            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
392        }
393        panic!(
394            "mapper futures were detached after MapReduce cancellation; dropped={}",
395            dropped.load(Ordering::SeqCst)
396        );
397    }
398}