cognis_core/compose/
parallel.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::future::join_all;
8
9use crate::runnable::{Runnable, RunnableConfig};
10use crate::Result;
11
12pub struct Parallel<I, O> {
19 branches: Vec<(String, Arc<dyn Runnable<I, O>>)>,
20}
21
22impl<I, O> Default for Parallel<I, O>
23where
24 I: Send + Sync + Clone + 'static,
25 O: Send + 'static,
26{
27 fn default() -> Self {
28 Self {
29 branches: Vec::new(),
30 }
31 }
32}
33
34impl<I, O> Parallel<I, O>
35where
36 I: Send + Sync + Clone + 'static,
37 O: Send + 'static,
38{
39 pub fn new() -> Self {
41 Self::default()
42 }
43
44 pub fn branch(mut self, name: impl Into<String>, runnable: Arc<dyn Runnable<I, O>>) -> Self {
51 let name = name.into();
52 if self.branches.iter().any(|(n, _)| n == &name) {
53 panic!("Parallel::branch: duplicate branch name `{name}`");
54 }
55 self.branches.push((name, runnable));
56 self
57 }
58}
59
60#[async_trait]
61impl<I, O> Runnable<I, HashMap<String, O>> for Parallel<I, O>
62where
63 I: Send + Sync + Clone + 'static,
64 O: Send + 'static,
65{
66 async fn invoke(&self, input: I, config: RunnableConfig) -> Result<HashMap<String, O>> {
67 let futs = self.branches.iter().map(|(name, r)| {
68 let r = r.clone();
69 let cfg = config.clone();
70 let i = input.clone();
71 let n = name.clone();
72 async move {
73 let out = r.invoke(i, cfg).await?;
74 Ok::<(String, O), crate::CognisError>((n, out))
75 }
76 });
77 let mut out = HashMap::with_capacity(self.branches.len());
78 for r in join_all(futs).await {
79 let (k, v) = r?;
80 out.insert(k, v);
81 }
82 Ok(out)
83 }
84
85 fn name(&self) -> &str {
86 "Parallel"
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 struct AddN(u32);
95
96 #[async_trait]
97 impl Runnable<u32, u32> for AddN {
98 async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
99 Ok(input + self.0)
100 }
101 }
102
103 #[tokio::test]
104 async fn fans_out_and_collects() {
105 let p: Parallel<u32, u32> = Parallel::new()
106 .branch("plus_one", Arc::new(AddN(1)))
107 .branch("plus_ten", Arc::new(AddN(10)));
108 let out = p.invoke(5, RunnableConfig::default()).await.unwrap();
109 assert_eq!(out["plus_one"], 6);
110 assert_eq!(out["plus_ten"], 15);
111 assert_eq!(out.len(), 2);
112 }
113
114 #[test]
115 #[should_panic(expected = "duplicate branch name")]
116 fn duplicate_branch_name_panics() {
117 let _: Parallel<u32, u32> = Parallel::new()
118 .branch("dup", Arc::new(AddN(1)))
119 .branch("dup", Arc::new(AddN(2)));
120 }
121}