Skip to main content

cognis_core/wrappers/
assign.rs

1//! `Assign` — fan out an input to N runnables in parallel and merge their
2//! outputs into a single map alongside the input itself.
3//!
4//! This is the LCEL `RunnablePassthrough.assign(...)` pattern: keep the
5//! original input, but enrich it with values computed by branches.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::future::join_all;
12use serde::Serialize;
13
14use crate::runnable::{Runnable, RunnableConfig};
15use crate::Result;
16
17/// Output type of [`Assign::invoke`].
18#[derive(Debug, Clone, Serialize)]
19pub struct AssignOutput<I, O> {
20    /// The original input, passed through unchanged.
21    pub input: I,
22    /// Per-branch outputs keyed by branch name.
23    pub assigned: HashMap<String, O>,
24}
25
26/// Runs each branch on a clone of the input and bundles the results.
27pub struct Assign<I, O> {
28    branches: Vec<(String, Arc<dyn Runnable<I, O>>)>,
29}
30
31impl<I, O> Default for Assign<I, O>
32where
33    I: Send + Sync + Clone + 'static,
34    O: Send + 'static,
35{
36    fn default() -> Self {
37        Self {
38            branches: Vec::new(),
39        }
40    }
41}
42
43impl<I, O> Assign<I, O>
44where
45    I: Send + Sync + Clone + 'static,
46    O: Send + 'static,
47{
48    /// Empty assign builder.
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Add a named branch.
54    ///
55    /// **Panics** on duplicate names. Two branches under the same key
56    /// would race into the result map and the loser would be silently
57    /// dropped — almost always a programming error.
58    pub fn branch(mut self, name: impl Into<String>, runnable: Arc<dyn Runnable<I, O>>) -> Self {
59        let name = name.into();
60        if self.branches.iter().any(|(n, _)| n == &name) {
61            panic!("Assign::branch: duplicate branch name `{name}`");
62        }
63        self.branches.push((name, runnable));
64        self
65    }
66}
67
68#[async_trait]
69impl<I, O> Runnable<I, AssignOutput<I, O>> for Assign<I, O>
70where
71    I: Send + Sync + Clone + 'static,
72    O: Send + 'static,
73{
74    async fn invoke(&self, input: I, config: RunnableConfig) -> Result<AssignOutput<I, O>> {
75        let futs = self.branches.iter().map(|(name, r)| {
76            let r = r.clone();
77            let cfg = config.clone();
78            let n = name.clone();
79            let i = input.clone();
80            async move {
81                let out = r.invoke(i, cfg).await?;
82                Ok::<(String, O), crate::CognisError>((n, out))
83            }
84        });
85        let mut assigned = HashMap::with_capacity(self.branches.len());
86        for r in join_all(futs).await {
87            let (k, v) = r?;
88            assigned.insert(k, v);
89        }
90        Ok(AssignOutput { input, assigned })
91    }
92    fn name(&self) -> &str {
93        "Assign"
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    struct AddN(u32);
102
103    #[async_trait]
104    impl Runnable<u32, u32> for AddN {
105        async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
106            Ok(input + self.0)
107        }
108    }
109
110    #[tokio::test]
111    async fn assign_passes_input_through_and_collects_branches() {
112        let a: Assign<u32, u32> = Assign::new()
113            .branch("plus_one", Arc::new(AddN(1)))
114            .branch("plus_ten", Arc::new(AddN(10)));
115        let out = a.invoke(5, RunnableConfig::default()).await.unwrap();
116        assert_eq!(out.input, 5);
117        assert_eq!(out.assigned["plus_one"], 6);
118        assert_eq!(out.assigned["plus_ten"], 15);
119    }
120
121    #[test]
122    #[should_panic(expected = "duplicate branch name")]
123    fn duplicate_branch_name_panics() {
124        let _: Assign<u32, u32> = Assign::new()
125            .branch("dup", Arc::new(AddN(1)))
126            .branch("dup", Arc::new(AddN(2)));
127    }
128}