Skip to main content

cognis_core/compose/
parallel.rs

1//! Parallel fan-out: invoke multiple runnables on the same input.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::future::join_all;
8
9use crate::runnable::{Runnable, RunnableConfig};
10use crate::Result;
11
12/// Runs multiple runnables concurrently on a shared input and returns
13/// their outputs keyed by name.
14///
15/// Each branch must share the same input type `I` (cloned per branch) and
16/// the same output type `O`. For heterogeneous outputs, wrap each branch's
17/// output in an enum or use `serde_json::Value`.
18pub struct Parallel<I, O> {
19    branches: Vec<(String, Arc<dyn Runnable<I, O>>)>,
20}
21
22impl<I, O> Default for Parallel<I, O>
23where
24    I: Send + Sync + Clone + 'static,
25    O: Send + 'static,
26{
27    fn default() -> Self {
28        Self {
29            branches: Vec::new(),
30        }
31    }
32}
33
34impl<I, O> Parallel<I, O>
35where
36    I: Send + Sync + Clone + 'static,
37    O: Send + 'static,
38{
39    /// Empty builder.
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Add a named branch.
45    ///
46    /// **Panics** if `name` is already registered. Two branches under the
47    /// same key would race into the result map and the loser would be
48    /// silently dropped — almost always a programming error, so we fail
49    /// loudly at construction time rather than at run time.
50    pub fn branch(mut self, name: impl Into<String>, runnable: Arc<dyn Runnable<I, O>>) -> Self {
51        let name = name.into();
52        if self.branches.iter().any(|(n, _)| n == &name) {
53            panic!("Parallel::branch: duplicate branch name `{name}`");
54        }
55        self.branches.push((name, runnable));
56        self
57    }
58}
59
60#[async_trait]
61impl<I, O> Runnable<I, HashMap<String, O>> for Parallel<I, O>
62where
63    I: Send + Sync + Clone + 'static,
64    O: Send + 'static,
65{
66    async fn invoke(&self, input: I, config: RunnableConfig) -> Result<HashMap<String, O>> {
67        let futs = self.branches.iter().map(|(name, r)| {
68            let r = r.clone();
69            let cfg = config.clone();
70            let i = input.clone();
71            let n = name.clone();
72            async move {
73                let out = r.invoke(i, cfg).await?;
74                Ok::<(String, O), crate::CognisError>((n, out))
75            }
76        });
77        let mut out = HashMap::with_capacity(self.branches.len());
78        for r in join_all(futs).await {
79            let (k, v) = r?;
80            out.insert(k, v);
81        }
82        Ok(out)
83    }
84
85    fn name(&self) -> &str {
86        "Parallel"
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    struct AddN(u32);
95
96    #[async_trait]
97    impl Runnable<u32, u32> for AddN {
98        async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
99            Ok(input + self.0)
100        }
101    }
102
103    #[tokio::test]
104    async fn fans_out_and_collects() {
105        let p: Parallel<u32, u32> = Parallel::new()
106            .branch("plus_one", Arc::new(AddN(1)))
107            .branch("plus_ten", Arc::new(AddN(10)));
108        let out = p.invoke(5, RunnableConfig::default()).await.unwrap();
109        assert_eq!(out["plus_one"], 6);
110        assert_eq!(out["plus_ten"], 15);
111        assert_eq!(out.len(), 2);
112    }
113
114    #[test]
115    #[should_panic(expected = "duplicate branch name")]
116    fn duplicate_branch_name_panics() {
117        let _: Parallel<u32, u32> = Parallel::new()
118            .branch("dup", Arc::new(AddN(1)))
119            .branch("dup", Arc::new(AddN(2)));
120    }
121}