Skip to main content

car_multi/patterns/
map_reduce.rs

1//! Map-Reduce — split a task into N items, run agents in parallel, reduce results.
2//!
3//! The mapper spec is cloned for each item. Each mapper processes one item.
4//! The reducer combines all mapper outputs into a single result.
5
6use crate::error::MultiError;
7use crate::mailbox::Mailbox;
8use crate::runner::AgentRunner;
9use crate::shared::SharedInfra;
10use crate::types::{AgentOutput, AgentSpec};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct MapReduceResult {
16    pub task: String,
17    pub map_outputs: Vec<AgentOutput>,
18    pub reduced_answer: String,
19}
20
21impl MapReduceResult {
22    pub fn all_succeeded(&self) -> bool {
23        self.map_outputs.iter().all(|o| o.succeeded())
24    }
25}
26
27pub struct MapReduce {
28    pub mapper: AgentSpec,
29    pub reducer: AgentSpec,
30    pub max_concurrent: usize,
31}
32
33impl MapReduce {
34    pub fn new(mapper: AgentSpec, reducer: AgentSpec) -> Self {
35        Self {
36            mapper,
37            reducer,
38            max_concurrent: 5,
39        }
40    }
41
42    pub fn with_max_concurrent(mut self, n: usize) -> Self {
43        self.max_concurrent = n;
44        self
45    }
46
47    pub async fn run(
48        &self,
49        task: &str,
50        items: &[String],
51        runner: &Arc<dyn AgentRunner>,
52        infra: &SharedInfra,
53    ) -> Result<MapReduceResult, MultiError> {
54        // Map phase: one agent per item, bounded concurrency
55        let semaphore = Arc::new(tokio::sync::Semaphore::new(self.max_concurrent));
56        let mut handles = Vec::new();
57
58        for (i, item) in items.iter().enumerate() {
59            let sem = Arc::clone(&semaphore);
60            let runner = Arc::clone(runner);
61            let rt = infra.make_runtime();
62            let mailbox = Mailbox::default();
63
64            let mut spec = self.mapper.clone();
65            spec.name = format!("{}_{}", self.mapper.name, i);
66
67            for tool in &spec.tools {
68                rt.register_tool(tool).await;
69            }
70
71            let subtask = format!("{}\n\nProcess this item: {}", task, item);
72
73            handles.push(tokio::spawn(async move {
74                let _permit = sem.acquire().await.unwrap();
75                (i, runner.run(&spec, &subtask, &rt, &mailbox).await)
76            }));
77        }
78
79        let results = futures::future::join_all(handles).await;
80        let mut indexed: Vec<(usize, AgentOutput)> = Vec::new();
81
82        for result in results {
83            match result {
84                Ok((i, Ok(output))) => indexed.push((i, output)),
85                Ok((i, Err(e))) => {
86                    indexed.push((
87                        i,
88                        AgentOutput {
89                            name: format!("{}_{}", self.mapper.name, i),
90                            answer: String::new(),
91                            turns: 0,
92                            tool_calls: 0,
93                            duration_ms: 0.0,
94                            error: Some(e.to_string()),
95                        },
96                    ));
97                }
98                Err(e) => {
99                    indexed.push((
100                        indexed.len(),
101                        AgentOutput {
102                            name: "unknown".to_string(),
103                            answer: String::new(),
104                            turns: 0,
105                            tool_calls: 0,
106                            duration_ms: 0.0,
107                            error: Some(format!("join error: {}", e)),
108                        },
109                    ));
110                }
111            }
112        }
113
114        indexed.sort_by_key(|(i, _)| *i);
115        let map_outputs: Vec<AgentOutput> = indexed.into_iter().map(|(_, o)| o).collect();
116
117        // Reduce phase
118        let summaries: Vec<String> = map_outputs
119            .iter()
120            .filter(|o| o.succeeded())
121            .map(|o| format!("- [{}] {}", o.name, truncate(&o.answer, 300)))
122            .collect();
123
124        let reduce_task = format!(
125            "Original task: {}\n\nResults from {} sub-agents:\n{}\n\n\
126             Combine these into a single coherent result.",
127            task,
128            map_outputs.len(),
129            summaries.join("\n")
130        );
131
132        let rt = infra.make_runtime();
133        let mailbox = Mailbox::default();
134        let reduced = runner
135            .run(&self.reducer, &reduce_task, &rt, &mailbox)
136            .await
137            .map(|o| o.answer)
138            .unwrap_or_default();
139
140        Ok(MapReduceResult {
141            task: task.to_string(),
142            map_outputs,
143            reduced_answer: reduced,
144        })
145    }
146}
147
148fn truncate(s: &str, max_len: usize) -> &str {
149    if s.len() <= max_len {
150        return s;
151    }
152    let mut end = max_len;
153    while end > 0 && !s.is_char_boundary(end) {
154        end -= 1;
155    }
156    &s[..end]
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::types::{AgentOutput, AgentSpec};
163    use car_engine::Runtime;
164
165    struct CountRunner;
166
167    #[async_trait::async_trait]
168    impl crate::runner::AgentRunner for CountRunner {
169        async fn run(
170            &self,
171            spec: &AgentSpec,
172            _task: &str,
173            _runtime: &Runtime,
174            _mailbox: &Mailbox,
175        ) -> Result<AgentOutput, MultiError> {
176            Ok(AgentOutput {
177                name: spec.name.clone(),
178                answer: format!("{} processed", spec.name),
179                turns: 1,
180                tool_calls: 0,
181                duration_ms: 5.0,
182                error: None,
183            })
184        }
185    }
186
187    #[tokio::test]
188    async fn test_map_reduce() {
189        let mapper = AgentSpec::new("summarizer", "Summarize the file");
190        let reducer = AgentSpec::new("combiner", "Combine summaries");
191        let items: Vec<String> = vec!["file_a.rs", "file_b.rs", "file_c.rs"]
192            .into_iter()
193            .map(String::from)
194            .collect();
195
196        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(CountRunner);
197        let infra = SharedInfra::new();
198
199        let result = MapReduce::new(mapper, reducer)
200            .run("summarize codebase", &items, &runner, &infra)
201            .await
202            .unwrap();
203
204        assert_eq!(result.map_outputs.len(), 3);
205        assert!(!result.reduced_answer.is_empty());
206    }
207}