Skip to main content

cognis_graph/
node.rs

1//! Node trait + per-superstep context + helper closure adapter.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use cognis_core::stream::Observer;
7use cognis_core::{Event, Result, RunnableConfig};
8use uuid::Uuid;
9
10use crate::goto::Goto;
11use crate::state::GraphState;
12
13/// Output of a node's `execute`: a typed state delta + where to go next.
14pub struct NodeOut<S: GraphState> {
15    /// State update — applied via per-field reducers in this superstep's atomic merge.
16    pub update: S::Update,
17    /// Routing decision.
18    pub goto: Goto,
19}
20
21impl<S: GraphState> NodeOut<S> {
22    /// Convenience: terminal node with a state update.
23    pub fn end_with(update: S::Update) -> Self {
24        Self {
25            update,
26            goto: Goto::End,
27        }
28    }
29
30    /// Convenience: route somewhere with no state delta (Default::default()).
31    pub fn goto_only(goto: Goto) -> Self {
32        Self {
33            update: S::Update::default(),
34            goto,
35        }
36    }
37}
38
39/// Per-superstep context handed to every `Node::execute` call. Carries
40/// run-correlation metadata and the active `RunnableConfig`. The lifetime
41/// is the superstep — don't hold across awaits beyond `execute` returning.
42pub struct NodeCtx<'a> {
43    /// Correlation ID for this run.
44    pub run_id: Uuid,
45    /// Superstep counter (0-indexed).
46    pub step: u64,
47    /// The active runnable config (recursion_limit, observers, cancel_token, …).
48    pub config: &'a RunnableConfig,
49    /// Per-target payload when this node is invoked as a `Goto::Send` target.
50    /// `None` for all other dispatch types.
51    payload: Option<&'a serde_json::Value>,
52    /// Engine-supplied: how many supersteps remain before the recursion
53    /// limit fires. `None` when running outside the engine (e.g. unit
54    /// tests). `is_last_step()` derives from this.
55    remaining_steps: Option<u32>,
56}
57
58impl<'a> NodeCtx<'a> {
59    /// Create a new `NodeCtx`. Primarily used by the engine; exposed publicly
60    /// so node implementations in external crates can construct test contexts.
61    pub fn new(run_id: Uuid, step: u64, config: &'a RunnableConfig) -> Self {
62        Self {
63            run_id,
64            step,
65            config,
66            payload: None,
67            remaining_steps: None,
68        }
69    }
70
71    /// Engine-internal: attach a Send payload.
72    pub(crate) fn with_payload(mut self, payload: &'a serde_json::Value) -> Self {
73        self.payload = Some(payload);
74        self
75    }
76
77    /// Engine-internal: set the remaining-step budget.
78    pub(crate) fn with_remaining_steps(mut self, remaining: u32) -> Self {
79        self.remaining_steps = Some(remaining);
80        self
81    }
82
83    /// The Send payload accompanying this dispatch, if any. Returns `None`
84    /// when the node is invoked via `Goto::Node` or `Goto::Multiple`.
85    pub fn payload(&self) -> Option<&serde_json::Value> {
86        self.payload
87    }
88
89    /// Number of supersteps remaining before the recursion limit fires.
90    /// `None` when running outside the engine (unit tests).
91    pub fn remaining_steps(&self) -> Option<u32> {
92        self.remaining_steps
93    }
94
95    /// True if this is the final superstep — i.e. the engine will not run
96    /// another step after this one returns. Mirrors V1 `IsLastStep`.
97    pub fn is_last_step(&self) -> bool {
98        matches!(self.remaining_steps, Some(0) | Some(1))
99    }
100
101    /// Notify every observer in `config.observers` of an event.
102    pub fn emit(&self, event: &Event) {
103        self.config.emit(event);
104    }
105
106    /// Emit a `Custom` event on the run's observer stream. Used by
107    /// `StreamMode::Custom` consumers to surface node-authored progress
108    /// signals (mirrors V1 `StreamWriter`).
109    pub fn write_custom(&self, kind: impl Into<String>, payload: serde_json::Value) {
110        self.config.emit(&Event::Custom {
111            kind: kind.into(),
112            payload,
113            run_id: self.run_id,
114        });
115    }
116
117    /// True if the run was cancelled.
118    pub fn is_cancelled(&self) -> bool {
119        self.config.is_cancelled()
120    }
121
122    /// Convenience accessor for observers.
123    pub fn observers(&self) -> &[Arc<dyn Observer>] {
124        &self.config.observers
125    }
126}
127
128/// Per-task retry policy. The engine wraps `Node::execute` calls in a
129/// retry loop when the node returns a `Some` policy and the call fails
130/// with a [`cognis_core::CognisError`] whose `is_retryable()` is true.
131#[derive(Debug, Clone, Copy)]
132pub struct NodeRetryPolicy {
133    /// Maximum total attempts (including the first).
134    pub max_attempts: u32,
135    /// Initial backoff before the first retry (milliseconds).
136    pub initial_delay_ms: u64,
137    /// Multiplier applied to the delay after each failed attempt.
138    pub backoff_multiplier: f64,
139    /// Cap on per-attempt delay (milliseconds).
140    pub max_delay_ms: u64,
141}
142
143impl Default for NodeRetryPolicy {
144    fn default() -> Self {
145        Self {
146            max_attempts: 3,
147            initial_delay_ms: 100,
148            backoff_multiplier: 2.0,
149            max_delay_ms: 30_000,
150        }
151    }
152}
153
154/// The unit of computation in a graph. Async, takes a `&S` snapshot of state
155/// + per-step context, returns a delta + a routing decision.
156#[async_trait]
157pub trait Node<S: GraphState>: Send + Sync {
158    /// Execute one superstep of this node.
159    async fn execute(&self, state: &S, ctx: &NodeCtx<'_>) -> Result<NodeOut<S>>;
160
161    /// Friendly name for telemetry / logging. Default uses the type name.
162    fn name(&self) -> &str {
163        std::any::type_name::<Self>()
164    }
165
166    /// Per-task retry policy. Default `None` means "no retry — propagate
167    /// the error". When `Some`, the engine retries `execute` on retryable
168    /// errors with exponential backoff.
169    fn retry_policy(&self) -> Option<NodeRetryPolicy> {
170        None
171    }
172}
173
174/// Closure adapter — wrap any `Fn(&S, &NodeCtx) -> Future` as a `Node`.
175pub struct NodeFn<S, F> {
176    name: String,
177    f: F,
178    _state: std::marker::PhantomData<fn() -> S>,
179}
180
181/// Build a `NodeFn` from a closure. The closure receives `(&S, &NodeCtx)`
182/// and returns `Future<Output = Result<NodeOut<S>>>`.
183pub fn node_fn<S, F, Fut>(name: impl Into<String>, f: F) -> NodeFn<S, F>
184where
185    S: GraphState,
186    F: Fn(&S, &NodeCtx<'_>) -> Fut + Send + Sync + 'static,
187    Fut: std::future::Future<Output = Result<NodeOut<S>>> + Send,
188{
189    NodeFn {
190        name: name.into(),
191        f,
192        _state: std::marker::PhantomData,
193    }
194}
195
196#[async_trait]
197impl<S, F, Fut> Node<S> for NodeFn<S, F>
198where
199    S: GraphState,
200    F: Fn(&S, &NodeCtx<'_>) -> Fut + Send + Sync + 'static,
201    Fut: std::future::Future<Output = Result<NodeOut<S>>> + Send,
202{
203    async fn execute(&self, state: &S, ctx: &NodeCtx<'_>) -> Result<NodeOut<S>> {
204        (self.f)(state, ctx).await
205    }
206
207    fn name(&self) -> &str {
208        &self.name
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::goto::Goto;
216    use crate::state::GraphState;
217
218    #[derive(Default, Clone, Debug, PartialEq)]
219    struct S {
220        n: u32,
221    }
222
223    #[derive(Default)]
224    struct SU {
225        n: u32,
226    }
227
228    impl GraphState for S {
229        type Update = SU;
230        fn apply(&mut self, update: Self::Update) {
231            self.n += update.n;
232        }
233    }
234
235    #[tokio::test]
236    async fn node_fn_executes() {
237        let n = node_fn::<S, _, _>("incr", |state, _ctx| {
238            let cur = state.n;
239            async move {
240                Ok(NodeOut {
241                    update: SU { n: cur + 1 },
242                    goto: Goto::end(),
243                })
244            }
245        });
246        let cfg = RunnableConfig::default();
247        let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg);
248        let s = S { n: 5 };
249        let out = n.execute(&s, &ctx).await.unwrap();
250        assert_eq!(out.update.n, 6);
251        assert!(out.goto.is_end());
252        assert_eq!(n.name(), "incr");
253    }
254
255    #[test]
256    fn node_ctx_payload_default_none() {
257        let cfg = RunnableConfig::default();
258        let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg);
259        assert!(ctx.payload().is_none());
260    }
261
262    #[test]
263    fn node_ctx_with_payload() {
264        let cfg = RunnableConfig::default();
265        let payload = serde_json::json!({"x": 42});
266        let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg).with_payload(&payload);
267        assert_eq!(ctx.payload().unwrap()["x"], 42);
268    }
269
270    #[test]
271    fn nodeout_constructors() {
272        let upd = SU { n: 10 };
273        let no: NodeOut<S> = NodeOut::end_with(upd);
274        assert!(no.goto.is_end());
275
276        let no2: NodeOut<S> = NodeOut::goto_only(Goto::node("next"));
277        assert_eq!(no2.update.n, 0);
278        assert_eq!(no2.goto, Goto::node("next"));
279    }
280}