Skip to main content

cognis_graph/
barrier.rs

1//! Named barrier — wait for N parties before dispatching a target node.
2//!
3//! Use case: nodes `a`, `b`, `c` all need to complete before node `d`
4//! runs once with all three outputs in hand. V2's existing
5//! [`Goto::Multiple`] runs N nodes in parallel + atomically merges, but
6//! doesn't give downstream nodes a single "all parties present" handoff.
7//!
8//! [`BarrierNode`] is a node implementation users add to their graph
9//! under a name (e.g. `"join"`). Parties dispatch to it via
10//! [`Goto::Send`] with their payload; the barrier accumulates payloads
11//! until `expected` parties have arrived, then forwards everything to
12//! the target node in the next superstep.
13//!
14//! Pending parties get [`Goto::Halt`] — their branch stops without
15//! terminating the graph.
16
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19
20use async_trait::async_trait;
21
22use cognis_core::Result;
23use uuid::Uuid;
24
25use crate::goto::Goto;
26use crate::node::{Node, NodeCtx, NodeOut};
27use crate::state::GraphState;
28
29/// Per-run accumulator: maps `run_id → ordered Vec<payload>`.
30type Acc = Arc<Mutex<HashMap<Uuid, Vec<serde_json::Value>>>>;
31
32/// Wait for N writes before dispatching `target`.
33///
34/// Pending parties return [`Goto::Halt`]; the last party's invocation
35/// dispatches `target` via [`Goto::Send`] with the full payload list as
36/// a single JSON array.
37pub struct BarrierNode<S: GraphState> {
38    name: String,
39    expected: usize,
40    target: String,
41    acc: Acc,
42    _phantom: std::marker::PhantomData<fn() -> S>,
43}
44
45impl<S: GraphState> BarrierNode<S> {
46    /// Build a barrier that waits for `expected` parties before
47    /// dispatching `target`.
48    pub fn new(name: impl Into<String>, expected: usize, target: impl Into<String>) -> Self {
49        Self {
50            name: name.into(),
51            expected,
52            target: target.into(),
53            acc: Arc::new(Mutex::new(HashMap::new())),
54            _phantom: std::marker::PhantomData,
55        }
56    }
57}
58
59#[async_trait]
60impl<S: GraphState> Node<S> for BarrierNode<S>
61where
62    S::Update: Default,
63{
64    async fn execute(&self, _state: &S, ctx: &NodeCtx<'_>) -> Result<NodeOut<S>> {
65        let payload = ctx.payload().cloned().unwrap_or(serde_json::Value::Null);
66        let (count, payloads) = {
67            let mut acc = self
68                .acc
69                .lock()
70                .map_err(|e| cognis_core::CognisError::Internal(format!("barrier mutex: {e}")))?;
71            let entry = acc.entry(ctx.run_id).or_default();
72            entry.push(payload);
73            (entry.len(), entry.clone())
74        };
75
76        if count < self.expected {
77            // Still waiting on more parties. Halt this branch.
78            return Ok(NodeOut {
79                update: S::Update::default(),
80                goto: Goto::Halt,
81            });
82        }
83
84        // All parties present — dispatch target with the full payload list,
85        // then clear our per-run accumulator so the barrier is reusable.
86        {
87            let mut acc = self
88                .acc
89                .lock()
90                .map_err(|e| cognis_core::CognisError::Internal(format!("barrier mutex: {e}")))?;
91            acc.remove(&ctx.run_id);
92        }
93        Ok(NodeOut {
94            update: S::Update::default(),
95            goto: Goto::Send(vec![(self.target.clone(), serde_json::json!(payloads))]),
96        })
97    }
98
99    fn name(&self) -> &str {
100        &self.name
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::goto::Goto;
108    use cognis_core::RunnableConfig;
109
110    #[derive(Default, Clone)]
111    struct S;
112    #[derive(Default, Clone)]
113    struct SU;
114    impl GraphState for S {
115        type Update = SU;
116        fn apply(&mut self, _: SU) {}
117    }
118
119    #[tokio::test]
120    async fn pending_parties_halt() {
121        let b: BarrierNode<S> = BarrierNode::new("join", 3, "next");
122        let cfg = RunnableConfig::default();
123        let payload_a = serde_json::json!("a");
124        let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg).with_payload(&payload_a);
125        let out = b.execute(&S, &ctx).await.unwrap();
126        assert!(matches!(out.goto, Goto::Halt));
127    }
128
129    #[tokio::test]
130    async fn last_party_dispatches_target_with_all_payloads() {
131        let b: BarrierNode<S> = BarrierNode::new("join", 2, "next");
132        let cfg = RunnableConfig::default();
133        let p1 = serde_json::json!("first");
134        let ctx1 = NodeCtx::new(Uuid::nil(), 0, &cfg).with_payload(&p1);
135        let _ = b.execute(&S, &ctx1).await.unwrap();
136        let p2 = serde_json::json!("second");
137        let ctx2 = NodeCtx::new(Uuid::nil(), 0, &cfg).with_payload(&p2);
138        let out = b.execute(&S, &ctx2).await.unwrap();
139        match out.goto {
140            Goto::Send(targets) => {
141                assert_eq!(targets.len(), 1);
142                assert_eq!(targets[0].0, "next");
143                let arr = targets[0].1.as_array().unwrap();
144                assert_eq!(arr.len(), 2);
145                assert_eq!(arr[0], "first");
146                assert_eq!(arr[1], "second");
147            }
148            _ => panic!("expected Goto::Send"),
149        }
150    }
151
152    #[tokio::test]
153    async fn barrier_resets_per_run() {
154        // Same barrier used twice (different run_ids) — second run should
155        // start fresh, not accumulate from the first.
156        let b: BarrierNode<S> = BarrierNode::new("join", 1, "next");
157        let cfg = RunnableConfig::default();
158        let p = serde_json::json!("only");
159        let run_a = Uuid::new_v4();
160        let ctx_a = NodeCtx::new(run_a, 0, &cfg).with_payload(&p);
161        let _ = b.execute(&S, &ctx_a).await.unwrap();
162        // Same barrier, different run.
163        let run_b = Uuid::new_v4();
164        let ctx_b = NodeCtx::new(run_b, 0, &cfg).with_payload(&p);
165        let out = b.execute(&S, &ctx_b).await.unwrap();
166        match out.goto {
167            Goto::Send(t) => assert_eq!(t[0].1.as_array().unwrap().len(), 1),
168            _ => panic!(),
169        }
170    }
171}