Skip to main content

car_multi/patterns/
map_reduce.rs

1//! Map-Reduce — split a task into N items, run agents in parallel, reduce results.
2//!
3//! The mapper spec is cloned for each item. Each mapper processes one item.
4//! The reducer combines all mapper outputs into a single result.
5
6use crate::error::MultiError;
7use crate::mailbox::Mailbox;
8use crate::runner::AgentRunner;
9use crate::shared::SharedInfra;
10use crate::types::{AgentOutput, AgentSpec};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::task::JoinSet;
15use tracing::instrument;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MapReduceResult {
19    pub task: String,
20    pub map_outputs: Vec<AgentOutput>,
21    pub reduced_answer: String,
22}
23
24impl MapReduceResult {
25    pub fn all_succeeded(&self) -> bool {
26        self.map_outputs.iter().all(|o| o.succeeded())
27    }
28}
29
30pub struct MapReduce {
31    pub mapper: AgentSpec,
32    pub reducer: AgentSpec,
33    pub max_concurrent: usize,
34}
35
36impl MapReduce {
37    pub fn new(mapper: AgentSpec, reducer: AgentSpec) -> Self {
38        Self {
39            mapper,
40            reducer,
41            max_concurrent: 5,
42        }
43    }
44
45    pub fn with_max_concurrent(mut self, n: usize) -> Self {
46        self.max_concurrent = n;
47        self
48    }
49
50    #[instrument(name = "multi.map_reduce", skip_all)]
51    pub async fn run(
52        &self,
53        task: &str,
54        items: &[String],
55        runner: &Arc<dyn AgentRunner>,
56        infra: &SharedInfra,
57    ) -> Result<MapReduceResult, MultiError> {
58        // Map phase: one agent per item, bounded concurrency
59        let semaphore = Arc::new(tokio::sync::Semaphore::new(self.max_concurrent));
60        let mut handles = JoinSet::new();
61        let mut task_indices = HashMap::new();
62
63        for (i, item) in items.iter().enumerate() {
64            let sem = Arc::clone(&semaphore);
65            let runner = Arc::clone(runner);
66            let rt = infra.make_runtime();
67            let mailbox = Mailbox::default();
68
69            let mut spec = self.mapper.clone();
70            spec.name = format!("{}_{}", self.mapper.name, i);
71
72            for tool in &spec.tools {
73                rt.register_tool(tool).await;
74            }
75
76            let subtask = format!("{}\n\nProcess this item: {}", task, item);
77
78            let handle = handles.spawn(async move {
79                let _permit = sem.acquire().await.unwrap();
80                (i, runner.run(&spec, &subtask, &rt, &mailbox).await)
81            });
82            task_indices.insert(handle.id(), i);
83        }
84
85        let mut indexed: Vec<(usize, AgentOutput)> = Vec::new();
86
87        while let Some(result) = handles.join_next().await {
88            match result {
89                Ok((i, Ok(output))) => indexed.push((i, output)),
90                Ok((i, Err(e))) => {
91                    indexed.push((
92                        i,
93                        AgentOutput {
94                            name: format!("{}_{}", self.mapper.name, i),
95                            answer: String::new(),
96                            turns: 0,
97                            tool_calls: 0,
98                            duration_ms: 0.0,
99                            error: Some(e.to_string()),
100                            outcome: None,
101                            tokens: None,
102                        },
103                    ));
104                }
105                Err(e) => {
106                    let i = task_indices
107                        .get(&e.id())
108                        .copied()
109                        .expect("mapper task id should be tracked");
110                    indexed.push((
111                        i,
112                        AgentOutput {
113                            name: format!("{}_{}", self.mapper.name, i),
114                            answer: String::new(),
115                            turns: 0,
116                            tool_calls: 0,
117                            duration_ms: 0.0,
118                            error: Some(format!("join error: {}", e)),
119                            outcome: None,
120                            tokens: None,
121                        },
122                    ));
123                }
124            }
125        }
126
127        indexed.sort_by_key(|(i, _)| *i);
128        let map_outputs: Vec<AgentOutput> = indexed.into_iter().map(|(_, o)| o).collect();
129
130        // Reduce phase
131        let summaries: Vec<String> = map_outputs
132            .iter()
133            .filter(|o| o.succeeded())
134            .map(|o| format!("- [{}] {}", o.name, truncate(&o.answer, 300)))
135            .collect();
136
137        let reduce_task = format!(
138            "Original task: {}\n\nResults from {} sub-agents:\n{}\n\n\
139             Combine these into a single coherent result.",
140            task,
141            map_outputs.len(),
142            summaries.join("\n")
143        );
144
145        let rt = infra.make_runtime();
146        let mailbox = Mailbox::default();
147        let reduced = runner
148            .run(&self.reducer, &reduce_task, &rt, &mailbox)
149            .await
150            .map(|o| o.answer)
151            .unwrap_or_default();
152
153        Ok(MapReduceResult {
154            task: task.to_string(),
155            map_outputs,
156            reduced_answer: reduced,
157        })
158    }
159}
160
161fn truncate(s: &str, max_len: usize) -> &str {
162    if s.len() <= max_len {
163        return s;
164    }
165    let mut end = max_len;
166    while end > 0 && !s.is_char_boundary(end) {
167        end -= 1;
168    }
169    &s[..end]
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::types::{AgentOutput, AgentSpec};
176    use car_engine::Runtime;
177    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
178    use tokio::sync::Notify;
179
180    struct CountRunner;
181
182    #[async_trait::async_trait]
183    impl crate::runner::AgentRunner for CountRunner {
184        async fn run(
185            &self,
186            spec: &AgentSpec,
187            _task: &str,
188            _runtime: &Runtime,
189            _mailbox: &Mailbox,
190        ) -> Result<AgentOutput, MultiError> {
191            Ok(AgentOutput {
192                name: spec.name.clone(),
193                answer: format!("{} processed", spec.name),
194                turns: 1,
195                tool_calls: 0,
196                duration_ms: 5.0,
197                error: None,
198                outcome: None,
199                tokens: None,
200            })
201        }
202    }
203
204    #[tokio::test]
205    async fn test_map_reduce() {
206        let mapper = AgentSpec::new("summarizer", "Summarize the file");
207        let reducer = AgentSpec::new("combiner", "Combine summaries");
208        let items: Vec<String> = vec!["file_a.rs", "file_b.rs", "file_c.rs"]
209            .into_iter()
210            .map(String::from)
211            .collect();
212
213        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(CountRunner);
214        let infra = SharedInfra::new();
215
216        let result = MapReduce::new(mapper, reducer)
217            .run("summarize codebase", &items, &runner, &infra)
218            .await
219            .unwrap();
220
221        assert_eq!(result.map_outputs.len(), 3);
222        assert!(!result.reduced_answer.is_empty());
223    }
224
225    struct LaterPanicRunner {
226        later_panicked: Arc<AtomicBool>,
227        notify: Arc<Notify>,
228    }
229
230    #[async_trait::async_trait]
231    impl crate::runner::AgentRunner for LaterPanicRunner {
232        async fn run(
233            &self,
234            spec: &AgentSpec,
235            _task: &str,
236            _runtime: &Runtime,
237            _mailbox: &Mailbox,
238        ) -> Result<AgentOutput, MultiError> {
239            match spec.name.as_str() {
240                "mapper_0" => {
241                    while !self.later_panicked.load(Ordering::SeqCst) {
242                        self.notify.notified().await;
243                    }
244                    Ok(AgentOutput {
245                        name: spec.name.clone(),
246                        answer: "mapper 0 completed".to_string(),
247                        turns: 1,
248                        tool_calls: 0,
249                        duration_ms: 5.0,
250                        error: None,
251                        outcome: None,
252                        tokens: None,
253                    })
254                }
255                "mapper_1" => {
256                    self.later_panicked.store(true, Ordering::SeqCst);
257                    self.notify.notify_one();
258                    panic!("mapper 1 panicked first");
259                }
260                _ => Ok(AgentOutput {
261                    name: spec.name.clone(),
262                    answer: "reduced".to_string(),
263                    turns: 1,
264                    tool_calls: 0,
265                    duration_ms: 5.0,
266                    error: None,
267                    outcome: None,
268                    tokens: None,
269                }),
270            }
271        }
272    }
273
274    #[tokio::test]
275    async fn panicking_mapper_keeps_original_item_index() {
276        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(LaterPanicRunner {
277            later_panicked: Arc::new(AtomicBool::new(false)),
278            notify: Arc::new(Notify::new()),
279        });
280        let infra = SharedInfra::new();
281        let items = vec!["slow first".to_string(), "fast panic".to_string()];
282
283        let result = MapReduce::new(
284            AgentSpec::new("mapper", "map item"),
285            AgentSpec::new("reducer", "reduce items"),
286        )
287        .with_max_concurrent(2)
288        .run("preserve mapper order", &items, &runner, &infra)
289        .await
290        .unwrap();
291
292        assert_eq!(result.map_outputs.len(), 2);
293        assert_eq!(result.map_outputs[0].name, "mapper_0");
294        assert!(result.map_outputs[0].succeeded());
295        assert_eq!(result.map_outputs[1].name, "mapper_1");
296        assert!(
297            result.map_outputs[1]
298                .error
299                .as_deref()
300                .is_some_and(|error| error.contains("panicked")),
301            "expected mapper_1 to carry the panic error, got {:?}",
302            result.map_outputs[1].error
303        );
304    }
305
306    struct DropCountingRunner {
307        started: Arc<AtomicUsize>,
308        dropped: Arc<AtomicUsize>,
309        notify: Arc<Notify>,
310    }
311
312    struct DropGuard(Arc<AtomicUsize>);
313
314    impl Drop for DropGuard {
315        fn drop(&mut self) {
316            self.0.fetch_add(1, Ordering::SeqCst);
317        }
318    }
319
320    #[async_trait::async_trait]
321    impl crate::runner::AgentRunner for DropCountingRunner {
322        async fn run(
323            &self,
324            _spec: &AgentSpec,
325            _task: &str,
326            _runtime: &Runtime,
327            _mailbox: &Mailbox,
328        ) -> Result<AgentOutput, MultiError> {
329            let _guard = DropGuard(self.dropped.clone());
330            self.started.fetch_add(1, Ordering::SeqCst);
331            self.notify.notify_one();
332            std::future::pending::<Result<AgentOutput, MultiError>>().await
333        }
334    }
335
336    #[tokio::test]
337    async fn dropping_map_reduce_run_aborts_mapper_tasks() {
338        let started = Arc::new(AtomicUsize::new(0));
339        let dropped = Arc::new(AtomicUsize::new(0));
340        let notify = Arc::new(Notify::new());
341        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(DropCountingRunner {
342            started: started.clone(),
343            dropped: dropped.clone(),
344            notify: notify.clone(),
345        });
346        let infra = SharedInfra::new();
347        let items = vec!["one".to_string(), "two".to_string()];
348
349        let handle = tokio::spawn(async move {
350            MapReduce::new(
351                AgentSpec::new("worker", "run work"),
352                AgentSpec::new("reducer", "reduce work"),
353            )
354            .with_max_concurrent(2)
355            .run("parallel goal", &items, &runner, &infra)
356            .await
357        });
358
359        while started.load(Ordering::SeqCst) < 2 {
360            notify.notified().await;
361        }
362
363        handle.abort();
364        assert!(handle.await.unwrap_err().is_cancelled());
365
366        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
367        while std::time::Instant::now() < deadline {
368            if dropped.load(Ordering::SeqCst) >= 2 {
369                return;
370            }
371            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
372        }
373        panic!(
374            "mapper futures were detached after MapReduce cancellation; dropped={}",
375            dropped.load(Ordering::SeqCst)
376        );
377    }
378}