1use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19
20use async_trait::async_trait;
21
22use cognis_core::Result;
23use uuid::Uuid;
24
25use crate::goto::Goto;
26use crate::node::{Node, NodeCtx, NodeOut};
27use crate::state::GraphState;
28
29type Acc = Arc<Mutex<HashMap<Uuid, Vec<serde_json::Value>>>>;
31
32pub struct BarrierNode<S: GraphState> {
38 name: String,
39 expected: usize,
40 target: String,
41 acc: Acc,
42 _phantom: std::marker::PhantomData<fn() -> S>,
43}
44
45impl<S: GraphState> BarrierNode<S> {
46 pub fn new(name: impl Into<String>, expected: usize, target: impl Into<String>) -> Self {
49 Self {
50 name: name.into(),
51 expected,
52 target: target.into(),
53 acc: Arc::new(Mutex::new(HashMap::new())),
54 _phantom: std::marker::PhantomData,
55 }
56 }
57}
58
59#[async_trait]
60impl<S: GraphState> Node<S> for BarrierNode<S>
61where
62 S::Update: Default,
63{
64 async fn execute(&self, _state: &S, ctx: &NodeCtx<'_>) -> Result<NodeOut<S>> {
65 let payload = ctx.payload().cloned().unwrap_or(serde_json::Value::Null);
66 let (count, payloads) = {
67 let mut acc = self
68 .acc
69 .lock()
70 .map_err(|e| cognis_core::CognisError::Internal(format!("barrier mutex: {e}")))?;
71 let entry = acc.entry(ctx.run_id).or_default();
72 entry.push(payload);
73 (entry.len(), entry.clone())
74 };
75
76 if count < self.expected {
77 return Ok(NodeOut {
79 update: S::Update::default(),
80 goto: Goto::Halt,
81 });
82 }
83
84 {
87 let mut acc = self
88 .acc
89 .lock()
90 .map_err(|e| cognis_core::CognisError::Internal(format!("barrier mutex: {e}")))?;
91 acc.remove(&ctx.run_id);
92 }
93 Ok(NodeOut {
94 update: S::Update::default(),
95 goto: Goto::Send(vec![(self.target.clone(), serde_json::json!(payloads))]),
96 })
97 }
98
99 fn name(&self) -> &str {
100 &self.name
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::goto::Goto;
108 use cognis_core::RunnableConfig;
109
110 #[derive(Default, Clone)]
111 struct S;
112 #[derive(Default, Clone)]
113 struct SU;
114 impl GraphState for S {
115 type Update = SU;
116 fn apply(&mut self, _: SU) {}
117 }
118
119 #[tokio::test]
120 async fn pending_parties_halt() {
121 let b: BarrierNode<S> = BarrierNode::new("join", 3, "next");
122 let cfg = RunnableConfig::default();
123 let payload_a = serde_json::json!("a");
124 let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg).with_payload(&payload_a);
125 let out = b.execute(&S, &ctx).await.unwrap();
126 assert!(matches!(out.goto, Goto::Halt));
127 }
128
129 #[tokio::test]
130 async fn last_party_dispatches_target_with_all_payloads() {
131 let b: BarrierNode<S> = BarrierNode::new("join", 2, "next");
132 let cfg = RunnableConfig::default();
133 let p1 = serde_json::json!("first");
134 let ctx1 = NodeCtx::new(Uuid::nil(), 0, &cfg).with_payload(&p1);
135 let _ = b.execute(&S, &ctx1).await.unwrap();
136 let p2 = serde_json::json!("second");
137 let ctx2 = NodeCtx::new(Uuid::nil(), 0, &cfg).with_payload(&p2);
138 let out = b.execute(&S, &ctx2).await.unwrap();
139 match out.goto {
140 Goto::Send(targets) => {
141 assert_eq!(targets.len(), 1);
142 assert_eq!(targets[0].0, "next");
143 let arr = targets[0].1.as_array().unwrap();
144 assert_eq!(arr.len(), 2);
145 assert_eq!(arr[0], "first");
146 assert_eq!(arr[1], "second");
147 }
148 _ => panic!("expected Goto::Send"),
149 }
150 }
151
152 #[tokio::test]
153 async fn barrier_resets_per_run() {
154 let b: BarrierNode<S> = BarrierNode::new("join", 1, "next");
157 let cfg = RunnableConfig::default();
158 let p = serde_json::json!("only");
159 let run_a = Uuid::new_v4();
160 let ctx_a = NodeCtx::new(run_a, 0, &cfg).with_payload(&p);
161 let _ = b.execute(&S, &ctx_a).await.unwrap();
162 let run_b = Uuid::new_v4();
164 let ctx_b = NodeCtx::new(run_b, 0, &cfg).with_payload(&p);
165 let out = b.execute(&S, &ctx_b).await.unwrap();
166 match out.goto {
167 Goto::Send(t) => assert_eq!(t[0].1.as_array().unwrap().len(), 1),
168 _ => panic!(),
169 }
170 }
171}