car-multi 0.14.0

Multi-agent coordination patterns for Common Agent Runtime
Documentation
//! Map-Reduce — split a task into N items, run agents in parallel, reduce results.
//!
//! The mapper spec is cloned for each item. Each mapper processes one item.
//! The reducer combines all mapper outputs into a single result.

use crate::error::MultiError;
use crate::mailbox::Mailbox;
use crate::runner::AgentRunner;
use crate::shared::SharedInfra;
use crate::types::{AgentOutput, AgentSpec};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::task::JoinSet;
use tracing::instrument;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MapReduceResult {
    pub task: String,
    pub map_outputs: Vec<AgentOutput>,
    pub reduced_answer: String,
}

impl MapReduceResult {
    pub fn all_succeeded(&self) -> bool {
        self.map_outputs.iter().all(|o| o.succeeded())
    }
}

pub struct MapReduce {
    pub mapper: AgentSpec,
    pub reducer: AgentSpec,
    pub max_concurrent: usize,
}

impl MapReduce {
    pub fn new(mapper: AgentSpec, reducer: AgentSpec) -> Self {
        Self {
            mapper,
            reducer,
            max_concurrent: 5,
        }
    }

    pub fn with_max_concurrent(mut self, n: usize) -> Self {
        self.max_concurrent = n;
        self
    }

    #[instrument(name = "multi.map_reduce", skip_all)]
    pub async fn run(
        &self,
        task: &str,
        items: &[String],
        runner: &Arc<dyn AgentRunner>,
        infra: &SharedInfra,
    ) -> Result<MapReduceResult, MultiError> {
        // Map phase: one agent per item, bounded concurrency
        let semaphore = Arc::new(tokio::sync::Semaphore::new(self.max_concurrent));
        let mut handles = JoinSet::new();
        let mut task_indices = HashMap::new();

        for (i, item) in items.iter().enumerate() {
            let sem = Arc::clone(&semaphore);
            let runner = Arc::clone(runner);
            let rt = infra.make_runtime();
            let mailbox = Mailbox::default();

            let mut spec = self.mapper.clone();
            spec.name = format!("{}_{}", self.mapper.name, i);

            for tool in &spec.tools {
                rt.register_tool(tool).await;
            }

            let subtask = format!("{}\n\nProcess this item: {}", task, item);

            let handle = handles.spawn(async move {
                let _permit = sem.acquire().await.unwrap();
                (i, runner.run(&spec, &subtask, &rt, &mailbox).await)
            });
            task_indices.insert(handle.id(), i);
        }

        let mut indexed: Vec<(usize, AgentOutput)> = Vec::new();

        while let Some(result) = handles.join_next().await {
            match result {
                Ok((i, Ok(output))) => indexed.push((i, output)),
                Ok((i, Err(e))) => {
                    indexed.push((
                        i,
                        AgentOutput {
                            name: format!("{}_{}", self.mapper.name, i),
                            answer: String::new(),
                            turns: 0,
                            tool_calls: 0,
                            duration_ms: 0.0,
                            error: Some(e.to_string()),
                            outcome: None,
                            tokens: None,
                        },
                    ));
                }
                Err(e) => {
                    let i = task_indices
                        .get(&e.id())
                        .copied()
                        .expect("mapper task id should be tracked");
                    indexed.push((
                        i,
                        AgentOutput {
                            name: format!("{}_{}", self.mapper.name, i),
                            answer: String::new(),
                            turns: 0,
                            tool_calls: 0,
                            duration_ms: 0.0,
                            error: Some(format!("join error: {}", e)),
                            outcome: None,
                            tokens: None,
                        },
                    ));
                }
            }
        }

        indexed.sort_by_key(|(i, _)| *i);
        let map_outputs: Vec<AgentOutput> = indexed.into_iter().map(|(_, o)| o).collect();

        // Reduce phase
        let summaries: Vec<String> = map_outputs
            .iter()
            .filter(|o| o.succeeded())
            .map(|o| format!("- [{}] {}", o.name, truncate(&o.answer, 300)))
            .collect();

        let reduce_task = format!(
            "Original task: {}\n\nResults from {} sub-agents:\n{}\n\n\
             Combine these into a single coherent result.",
            task,
            map_outputs.len(),
            summaries.join("\n")
        );

        let rt = infra.make_runtime();
        let mailbox = Mailbox::default();
        let reduced = runner
            .run(&self.reducer, &reduce_task, &rt, &mailbox)
            .await
            .map(|o| o.answer)
            .unwrap_or_default();

        Ok(MapReduceResult {
            task: task.to_string(),
            map_outputs,
            reduced_answer: reduced,
        })
    }
}

fn truncate(s: &str, max_len: usize) -> &str {
    if s.len() <= max_len {
        return s;
    }
    let mut end = max_len;
    while end > 0 && !s.is_char_boundary(end) {
        end -= 1;
    }
    &s[..end]
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::{AgentOutput, AgentSpec};
    use car_engine::Runtime;
    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
    use tokio::sync::Notify;

    struct CountRunner;

    #[async_trait::async_trait]
    impl crate::runner::AgentRunner for CountRunner {
        async fn run(
            &self,
            spec: &AgentSpec,
            _task: &str,
            _runtime: &Runtime,
            _mailbox: &Mailbox,
        ) -> Result<AgentOutput, MultiError> {
            Ok(AgentOutput {
                name: spec.name.clone(),
                answer: format!("{} processed", spec.name),
                turns: 1,
                tool_calls: 0,
                duration_ms: 5.0,
                error: None,
                outcome: None,
                tokens: None,
            })
        }
    }

    #[tokio::test]
    async fn test_map_reduce() {
        let mapper = AgentSpec::new("summarizer", "Summarize the file");
        let reducer = AgentSpec::new("combiner", "Combine summaries");
        let items: Vec<String> = vec!["file_a.rs", "file_b.rs", "file_c.rs"]
            .into_iter()
            .map(String::from)
            .collect();

        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(CountRunner);
        let infra = SharedInfra::new();

        let result = MapReduce::new(mapper, reducer)
            .run("summarize codebase", &items, &runner, &infra)
            .await
            .unwrap();

        assert_eq!(result.map_outputs.len(), 3);
        assert!(!result.reduced_answer.is_empty());
    }

    struct LaterPanicRunner {
        later_panicked: Arc<AtomicBool>,
        notify: Arc<Notify>,
    }

    #[async_trait::async_trait]
    impl crate::runner::AgentRunner for LaterPanicRunner {
        async fn run(
            &self,
            spec: &AgentSpec,
            _task: &str,
            _runtime: &Runtime,
            _mailbox: &Mailbox,
        ) -> Result<AgentOutput, MultiError> {
            match spec.name.as_str() {
                "mapper_0" => {
                    while !self.later_panicked.load(Ordering::SeqCst) {
                        self.notify.notified().await;
                    }
                    Ok(AgentOutput {
                        name: spec.name.clone(),
                        answer: "mapper 0 completed".to_string(),
                        turns: 1,
                        tool_calls: 0,
                        duration_ms: 5.0,
                        error: None,
                        outcome: None,
                        tokens: None,
                    })
                }
                "mapper_1" => {
                    self.later_panicked.store(true, Ordering::SeqCst);
                    self.notify.notify_one();
                    panic!("mapper 1 panicked first");
                }
                _ => Ok(AgentOutput {
                    name: spec.name.clone(),
                    answer: "reduced".to_string(),
                    turns: 1,
                    tool_calls: 0,
                    duration_ms: 5.0,
                    error: None,
                    outcome: None,
                    tokens: None,
                }),
            }
        }
    }

    #[tokio::test]
    async fn panicking_mapper_keeps_original_item_index() {
        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(LaterPanicRunner {
            later_panicked: Arc::new(AtomicBool::new(false)),
            notify: Arc::new(Notify::new()),
        });
        let infra = SharedInfra::new();
        let items = vec!["slow first".to_string(), "fast panic".to_string()];

        let result = MapReduce::new(
            AgentSpec::new("mapper", "map item"),
            AgentSpec::new("reducer", "reduce items"),
        )
        .with_max_concurrent(2)
        .run("preserve mapper order", &items, &runner, &infra)
        .await
        .unwrap();

        assert_eq!(result.map_outputs.len(), 2);
        assert_eq!(result.map_outputs[0].name, "mapper_0");
        assert!(result.map_outputs[0].succeeded());
        assert_eq!(result.map_outputs[1].name, "mapper_1");
        assert!(
            result.map_outputs[1]
                .error
                .as_deref()
                .is_some_and(|error| error.contains("panicked")),
            "expected mapper_1 to carry the panic error, got {:?}",
            result.map_outputs[1].error
        );
    }

    struct DropCountingRunner {
        started: Arc<AtomicUsize>,
        dropped: Arc<AtomicUsize>,
        notify: Arc<Notify>,
    }

    struct DropGuard(Arc<AtomicUsize>);

    impl Drop for DropGuard {
        fn drop(&mut self) {
            self.0.fetch_add(1, Ordering::SeqCst);
        }
    }

    #[async_trait::async_trait]
    impl crate::runner::AgentRunner for DropCountingRunner {
        async fn run(
            &self,
            _spec: &AgentSpec,
            _task: &str,
            _runtime: &Runtime,
            _mailbox: &Mailbox,
        ) -> Result<AgentOutput, MultiError> {
            let _guard = DropGuard(self.dropped.clone());
            self.started.fetch_add(1, Ordering::SeqCst);
            self.notify.notify_one();
            std::future::pending::<Result<AgentOutput, MultiError>>().await
        }
    }

    #[tokio::test]
    async fn dropping_map_reduce_run_aborts_mapper_tasks() {
        let started = Arc::new(AtomicUsize::new(0));
        let dropped = Arc::new(AtomicUsize::new(0));
        let notify = Arc::new(Notify::new());
        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(DropCountingRunner {
            started: started.clone(),
            dropped: dropped.clone(),
            notify: notify.clone(),
        });
        let infra = SharedInfra::new();
        let items = vec!["one".to_string(), "two".to_string()];

        let handle = tokio::spawn(async move {
            MapReduce::new(
                AgentSpec::new("worker", "run work"),
                AgentSpec::new("reducer", "reduce work"),
            )
            .with_max_concurrent(2)
            .run("parallel goal", &items, &runner, &infra)
            .await
        });

        while started.load(Ordering::SeqCst) < 2 {
            notify.notified().await;
        }

        handle.abort();
        assert!(handle.await.unwrap_err().is_cancelled());

        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
        while std::time::Instant::now() < deadline {
            if dropped.load(Ordering::SeqCst) >= 2 {
                return;
            }
            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
        }
        panic!(
            "mapper futures were detached after MapReduce cancellation; dropped={}",
            dropped.load(Ordering::SeqCst)
        );
    }
}