use crate::error::MultiError;
use crate::mailbox::Mailbox;
use crate::runner::AgentRunner;
use crate::shared::SharedInfra;
use crate::types::{AgentOutput, AgentSpec};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::task::JoinSet;
use tracing::instrument;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MapReduceResult {
pub task: String,
pub map_outputs: Vec<AgentOutput>,
pub reduced_answer: String,
}
impl MapReduceResult {
pub fn all_succeeded(&self) -> bool {
self.map_outputs.iter().all(|o| o.succeeded())
}
}
pub struct MapReduce {
pub mapper: AgentSpec,
pub reducer: AgentSpec,
pub max_concurrent: usize,
}
impl MapReduce {
pub fn new(mapper: AgentSpec, reducer: AgentSpec) -> Self {
Self {
mapper,
reducer,
max_concurrent: 5,
}
}
pub fn with_max_concurrent(mut self, n: usize) -> Self {
self.max_concurrent = n;
self
}
#[instrument(name = "multi.map_reduce", skip_all)]
pub async fn run(
&self,
task: &str,
items: &[String],
runner: &Arc<dyn AgentRunner>,
infra: &SharedInfra,
) -> Result<MapReduceResult, MultiError> {
let semaphore = Arc::new(tokio::sync::Semaphore::new(self.max_concurrent));
let mut handles = JoinSet::new();
let mut task_indices = HashMap::new();
for (i, item) in items.iter().enumerate() {
let sem = Arc::clone(&semaphore);
let runner = Arc::clone(runner);
let rt = infra.make_runtime();
let mailbox = Mailbox::default();
let mut spec = self.mapper.clone();
spec.name = format!("{}_{}", self.mapper.name, i);
for tool in &spec.tools {
rt.register_tool(tool).await;
}
let subtask = format!("{}\n\nProcess this item: {}", task, item);
let handle = handles.spawn(async move {
let _permit = sem.acquire().await.unwrap();
(i, runner.run(&spec, &subtask, &rt, &mailbox).await)
});
task_indices.insert(handle.id(), i);
}
let mut indexed: Vec<(usize, AgentOutput)> = Vec::new();
while let Some(result) = handles.join_next().await {
match result {
Ok((i, Ok(output))) => indexed.push((i, output)),
Ok((i, Err(e))) => {
indexed.push((
i,
AgentOutput {
name: format!("{}_{}", self.mapper.name, i),
answer: String::new(),
turns: 0,
tool_calls: 0,
duration_ms: 0.0,
error: Some(e.to_string()),
outcome: None,
tokens: None,
},
));
}
Err(e) => {
let i = task_indices
.get(&e.id())
.copied()
.expect("mapper task id should be tracked");
indexed.push((
i,
AgentOutput {
name: format!("{}_{}", self.mapper.name, i),
answer: String::new(),
turns: 0,
tool_calls: 0,
duration_ms: 0.0,
error: Some(format!("join error: {}", e)),
outcome: None,
tokens: None,
},
));
}
}
}
indexed.sort_by_key(|(i, _)| *i);
let map_outputs: Vec<AgentOutput> = indexed.into_iter().map(|(_, o)| o).collect();
let summaries: Vec<String> = map_outputs
.iter()
.filter(|o| o.succeeded())
.map(|o| format!("- [{}] {}", o.name, truncate(&o.answer, 300)))
.collect();
let reduce_task = format!(
"Original task: {}\n\nResults from {} sub-agents:\n{}\n\n\
Combine these into a single coherent result.",
task,
map_outputs.len(),
summaries.join("\n")
);
let rt = infra.make_runtime();
let mailbox = Mailbox::default();
let reduced = runner
.run(&self.reducer, &reduce_task, &rt, &mailbox)
.await
.map(|o| o.answer)
.unwrap_or_default();
Ok(MapReduceResult {
task: task.to_string(),
map_outputs,
reduced_answer: reduced,
})
}
}
fn truncate(s: &str, max_len: usize) -> &str {
if s.len() <= max_len {
return s;
}
let mut end = max_len;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{AgentOutput, AgentSpec};
use car_engine::Runtime;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use tokio::sync::Notify;
struct CountRunner;
#[async_trait::async_trait]
impl crate::runner::AgentRunner for CountRunner {
async fn run(
&self,
spec: &AgentSpec,
_task: &str,
_runtime: &Runtime,
_mailbox: &Mailbox,
) -> Result<AgentOutput, MultiError> {
Ok(AgentOutput {
name: spec.name.clone(),
answer: format!("{} processed", spec.name),
turns: 1,
tool_calls: 0,
duration_ms: 5.0,
error: None,
outcome: None,
tokens: None,
})
}
}
#[tokio::test]
async fn test_map_reduce() {
let mapper = AgentSpec::new("summarizer", "Summarize the file");
let reducer = AgentSpec::new("combiner", "Combine summaries");
let items: Vec<String> = vec!["file_a.rs", "file_b.rs", "file_c.rs"]
.into_iter()
.map(String::from)
.collect();
let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(CountRunner);
let infra = SharedInfra::new();
let result = MapReduce::new(mapper, reducer)
.run("summarize codebase", &items, &runner, &infra)
.await
.unwrap();
assert_eq!(result.map_outputs.len(), 3);
assert!(!result.reduced_answer.is_empty());
}
struct LaterPanicRunner {
later_panicked: Arc<AtomicBool>,
notify: Arc<Notify>,
}
#[async_trait::async_trait]
impl crate::runner::AgentRunner for LaterPanicRunner {
async fn run(
&self,
spec: &AgentSpec,
_task: &str,
_runtime: &Runtime,
_mailbox: &Mailbox,
) -> Result<AgentOutput, MultiError> {
match spec.name.as_str() {
"mapper_0" => {
while !self.later_panicked.load(Ordering::SeqCst) {
self.notify.notified().await;
}
Ok(AgentOutput {
name: spec.name.clone(),
answer: "mapper 0 completed".to_string(),
turns: 1,
tool_calls: 0,
duration_ms: 5.0,
error: None,
outcome: None,
tokens: None,
})
}
"mapper_1" => {
self.later_panicked.store(true, Ordering::SeqCst);
self.notify.notify_one();
panic!("mapper 1 panicked first");
}
_ => Ok(AgentOutput {
name: spec.name.clone(),
answer: "reduced".to_string(),
turns: 1,
tool_calls: 0,
duration_ms: 5.0,
error: None,
outcome: None,
tokens: None,
}),
}
}
}
#[tokio::test]
async fn panicking_mapper_keeps_original_item_index() {
let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(LaterPanicRunner {
later_panicked: Arc::new(AtomicBool::new(false)),
notify: Arc::new(Notify::new()),
});
let infra = SharedInfra::new();
let items = vec!["slow first".to_string(), "fast panic".to_string()];
let result = MapReduce::new(
AgentSpec::new("mapper", "map item"),
AgentSpec::new("reducer", "reduce items"),
)
.with_max_concurrent(2)
.run("preserve mapper order", &items, &runner, &infra)
.await
.unwrap();
assert_eq!(result.map_outputs.len(), 2);
assert_eq!(result.map_outputs[0].name, "mapper_0");
assert!(result.map_outputs[0].succeeded());
assert_eq!(result.map_outputs[1].name, "mapper_1");
assert!(
result.map_outputs[1]
.error
.as_deref()
.is_some_and(|error| error.contains("panicked")),
"expected mapper_1 to carry the panic error, got {:?}",
result.map_outputs[1].error
);
}
struct DropCountingRunner {
started: Arc<AtomicUsize>,
dropped: Arc<AtomicUsize>,
notify: Arc<Notify>,
}
struct DropGuard(Arc<AtomicUsize>);
impl Drop for DropGuard {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
#[async_trait::async_trait]
impl crate::runner::AgentRunner for DropCountingRunner {
async fn run(
&self,
_spec: &AgentSpec,
_task: &str,
_runtime: &Runtime,
_mailbox: &Mailbox,
) -> Result<AgentOutput, MultiError> {
let _guard = DropGuard(self.dropped.clone());
self.started.fetch_add(1, Ordering::SeqCst);
self.notify.notify_one();
std::future::pending::<Result<AgentOutput, MultiError>>().await
}
}
#[tokio::test]
async fn dropping_map_reduce_run_aborts_mapper_tasks() {
let started = Arc::new(AtomicUsize::new(0));
let dropped = Arc::new(AtomicUsize::new(0));
let notify = Arc::new(Notify::new());
let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(DropCountingRunner {
started: started.clone(),
dropped: dropped.clone(),
notify: notify.clone(),
});
let infra = SharedInfra::new();
let items = vec!["one".to_string(), "two".to_string()];
let handle = tokio::spawn(async move {
MapReduce::new(
AgentSpec::new("worker", "run work"),
AgentSpec::new("reducer", "reduce work"),
)
.with_max_concurrent(2)
.run("parallel goal", &items, &runner, &infra)
.await
});
while started.load(Ordering::SeqCst) < 2 {
notify.notified().await;
}
handle.abort();
assert!(handle.await.unwrap_err().is_cancelled());
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
while std::time::Instant::now() < deadline {
if dropped.load(Ordering::SeqCst) >= 2 {
return;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
panic!(
"mapper futures were detached after MapReduce cancellation; dropped={}",
dropped.load(Ordering::SeqCst)
);
}
}