cognis_core/wrappers/
assign.rs1use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::future::join_all;
12use serde::Serialize;
13
14use crate::runnable::{Runnable, RunnableConfig};
15use crate::Result;
16
17#[derive(Debug, Clone, Serialize)]
19pub struct AssignOutput<I, O> {
20 pub input: I,
22 pub assigned: HashMap<String, O>,
24}
25
26pub struct Assign<I, O> {
28 branches: Vec<(String, Arc<dyn Runnable<I, O>>)>,
29}
30
31impl<I, O> Default for Assign<I, O>
32where
33 I: Send + Sync + Clone + 'static,
34 O: Send + 'static,
35{
36 fn default() -> Self {
37 Self {
38 branches: Vec::new(),
39 }
40 }
41}
42
43impl<I, O> Assign<I, O>
44where
45 I: Send + Sync + Clone + 'static,
46 O: Send + 'static,
47{
48 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn branch(mut self, name: impl Into<String>, runnable: Arc<dyn Runnable<I, O>>) -> Self {
59 let name = name.into();
60 if self.branches.iter().any(|(n, _)| n == &name) {
61 panic!("Assign::branch: duplicate branch name `{name}`");
62 }
63 self.branches.push((name, runnable));
64 self
65 }
66}
67
68#[async_trait]
69impl<I, O> Runnable<I, AssignOutput<I, O>> for Assign<I, O>
70where
71 I: Send + Sync + Clone + 'static,
72 O: Send + 'static,
73{
74 async fn invoke(&self, input: I, config: RunnableConfig) -> Result<AssignOutput<I, O>> {
75 let futs = self.branches.iter().map(|(name, r)| {
76 let r = r.clone();
77 let cfg = config.clone();
78 let n = name.clone();
79 let i = input.clone();
80 async move {
81 let out = r.invoke(i, cfg).await?;
82 Ok::<(String, O), crate::CognisError>((n, out))
83 }
84 });
85 let mut assigned = HashMap::with_capacity(self.branches.len());
86 for r in join_all(futs).await {
87 let (k, v) = r?;
88 assigned.insert(k, v);
89 }
90 Ok(AssignOutput { input, assigned })
91 }
92 fn name(&self) -> &str {
93 "Assign"
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 struct AddN(u32);
102
103 #[async_trait]
104 impl Runnable<u32, u32> for AddN {
105 async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
106 Ok(input + self.0)
107 }
108 }
109
110 #[tokio::test]
111 async fn assign_passes_input_through_and_collects_branches() {
112 let a: Assign<u32, u32> = Assign::new()
113 .branch("plus_one", Arc::new(AddN(1)))
114 .branch("plus_ten", Arc::new(AddN(10)));
115 let out = a.invoke(5, RunnableConfig::default()).await.unwrap();
116 assert_eq!(out.input, 5);
117 assert_eq!(out.assigned["plus_one"], 6);
118 assert_eq!(out.assigned["plus_ten"], 15);
119 }
120
121 #[test]
122 #[should_panic(expected = "duplicate branch name")]
123 fn duplicate_branch_name_panics() {
124 let _: Assign<u32, u32> = Assign::new()
125 .branch("dup", Arc::new(AddN(1)))
126 .branch("dup", Arc::new(AddN(2)));
127 }
128}