1use crate::error::MultiError;
7use crate::mailbox::Mailbox;
8use crate::runner::AgentRunner;
9use crate::shared::SharedInfra;
10use crate::types::{AgentOutput, AgentSpec};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct MapReduceResult {
16 pub task: String,
17 pub map_outputs: Vec<AgentOutput>,
18 pub reduced_answer: String,
19}
20
21impl MapReduceResult {
22 pub fn all_succeeded(&self) -> bool {
23 self.map_outputs.iter().all(|o| o.succeeded())
24 }
25}
26
27pub struct MapReduce {
28 pub mapper: AgentSpec,
29 pub reducer: AgentSpec,
30 pub max_concurrent: usize,
31}
32
33impl MapReduce {
34 pub fn new(mapper: AgentSpec, reducer: AgentSpec) -> Self {
35 Self {
36 mapper,
37 reducer,
38 max_concurrent: 5,
39 }
40 }
41
42 pub fn with_max_concurrent(mut self, n: usize) -> Self {
43 self.max_concurrent = n;
44 self
45 }
46
47 pub async fn run(
48 &self,
49 task: &str,
50 items: &[String],
51 runner: &Arc<dyn AgentRunner>,
52 infra: &SharedInfra,
53 ) -> Result<MapReduceResult, MultiError> {
54 let semaphore = Arc::new(tokio::sync::Semaphore::new(self.max_concurrent));
56 let mut handles = Vec::new();
57
58 for (i, item) in items.iter().enumerate() {
59 let sem = Arc::clone(&semaphore);
60 let runner = Arc::clone(runner);
61 let rt = infra.make_runtime();
62 let mailbox = Mailbox::default();
63
64 let mut spec = self.mapper.clone();
65 spec.name = format!("{}_{}", self.mapper.name, i);
66
67 for tool in &spec.tools {
68 rt.register_tool(tool).await;
69 }
70
71 let subtask = format!("{}\n\nProcess this item: {}", task, item);
72
73 handles.push(tokio::spawn(async move {
74 let _permit = sem.acquire().await.unwrap();
75 (i, runner.run(&spec, &subtask, &rt, &mailbox).await)
76 }));
77 }
78
79 let results = futures::future::join_all(handles).await;
80 let mut indexed: Vec<(usize, AgentOutput)> = Vec::new();
81
82 for result in results {
83 match result {
84 Ok((i, Ok(output))) => indexed.push((i, output)),
85 Ok((i, Err(e))) => {
86 indexed.push((
87 i,
88 AgentOutput {
89 name: format!("{}_{}", self.mapper.name, i),
90 answer: String::new(),
91 turns: 0,
92 tool_calls: 0,
93 duration_ms: 0.0,
94 error: Some(e.to_string()),
95 },
96 ));
97 }
98 Err(e) => {
99 indexed.push((
100 indexed.len(),
101 AgentOutput {
102 name: "unknown".to_string(),
103 answer: String::new(),
104 turns: 0,
105 tool_calls: 0,
106 duration_ms: 0.0,
107 error: Some(format!("join error: {}", e)),
108 },
109 ));
110 }
111 }
112 }
113
114 indexed.sort_by_key(|(i, _)| *i);
115 let map_outputs: Vec<AgentOutput> = indexed.into_iter().map(|(_, o)| o).collect();
116
117 let summaries: Vec<String> = map_outputs
119 .iter()
120 .filter(|o| o.succeeded())
121 .map(|o| format!("- [{}] {}", o.name, truncate(&o.answer, 300)))
122 .collect();
123
124 let reduce_task = format!(
125 "Original task: {}\n\nResults from {} sub-agents:\n{}\n\n\
126 Combine these into a single coherent result.",
127 task,
128 map_outputs.len(),
129 summaries.join("\n")
130 );
131
132 let rt = infra.make_runtime();
133 let mailbox = Mailbox::default();
134 let reduced = runner
135 .run(&self.reducer, &reduce_task, &rt, &mailbox)
136 .await
137 .map(|o| o.answer)
138 .unwrap_or_default();
139
140 Ok(MapReduceResult {
141 task: task.to_string(),
142 map_outputs,
143 reduced_answer: reduced,
144 })
145 }
146}
147
148fn truncate(s: &str, max_len: usize) -> &str {
149 if s.len() <= max_len {
150 return s;
151 }
152 let mut end = max_len;
153 while end > 0 && !s.is_char_boundary(end) {
154 end -= 1;
155 }
156 &s[..end]
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::types::{AgentOutput, AgentSpec};
163 use car_engine::Runtime;
164
165 struct CountRunner;
166
167 #[async_trait::async_trait]
168 impl crate::runner::AgentRunner for CountRunner {
169 async fn run(
170 &self,
171 spec: &AgentSpec,
172 _task: &str,
173 _runtime: &Runtime,
174 _mailbox: &Mailbox,
175 ) -> Result<AgentOutput, MultiError> {
176 Ok(AgentOutput {
177 name: spec.name.clone(),
178 answer: format!("{} processed", spec.name),
179 turns: 1,
180 tool_calls: 0,
181 duration_ms: 5.0,
182 error: None,
183 })
184 }
185 }
186
187 #[tokio::test]
188 async fn test_map_reduce() {
189 let mapper = AgentSpec::new("summarizer", "Summarize the file");
190 let reducer = AgentSpec::new("combiner", "Combine summaries");
191 let items: Vec<String> = vec!["file_a.rs", "file_b.rs", "file_c.rs"]
192 .into_iter()
193 .map(String::from)
194 .collect();
195
196 let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(CountRunner);
197 let infra = SharedInfra::new();
198
199 let result = MapReduce::new(mapper, reducer)
200 .run("summarize codebase", &items, &runner, &infra)
201 .await
202 .unwrap();
203
204 assert_eq!(result.map_outputs.len(), 3);
205 assert!(!result.reduced_answer.is_empty());
206 }
207}