Skip to main content

cognis_graph/
compiled.rs

1//! Compiled, executable graph. Implements `Runnable<S, S>` so a graph
2//! composes anywhere a `Runnable` is expected (including as a node
3//! inside another graph).
4
5use std::collections::HashSet;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9
10use cognis_core::{Result, Runnable, RunnableConfig};
11
12use crate::builder::Graph;
13use crate::checkpoint::Checkpointer;
14use crate::durability::Durability;
15use crate::engine;
16use crate::state::GraphState;
17use crate::stream_mode::StreamModes;
18
19/// A validated, ready-to-run graph. Cheap to clone (the underlying nodes
20/// are `Arc<dyn Node<S>>`).
21#[derive(Clone)]
22pub struct CompiledGraph<S: GraphState> {
23    pub(crate) graph: Graph<S>,
24    pub(crate) checkpointer: Option<Arc<dyn Checkpointer<S>>>,
25    pub(crate) interrupt_before: HashSet<String>,
26    pub(crate) interrupt_after: HashSet<String>,
27    pub(crate) durability: Durability,
28}
29
30impl<S: GraphState> std::fmt::Debug for CompiledGraph<S> {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("CompiledGraph")
33            .field("node_count", &self.graph.nodes.len())
34            .field("has_checkpointer", &self.checkpointer.is_some())
35            .field("interrupt_before", &self.interrupt_before)
36            .field("interrupt_after", &self.interrupt_after)
37            .finish()
38    }
39}
40
41impl<S: GraphState> CompiledGraph<S> {
42    pub(crate) fn new(graph: Graph<S>) -> Self {
43        Self {
44            graph,
45            checkpointer: None,
46            interrupt_before: HashSet::new(),
47            interrupt_after: HashSet::new(),
48            durability: Durability::default(),
49        }
50    }
51
52    /// Override checkpoint timing relative to step execution. Default is
53    /// [`Durability::Sync`].
54    pub fn with_durability(mut self, d: Durability) -> Self {
55        self.durability = d;
56        self
57    }
58
59    /// Current durability mode.
60    pub fn durability(&self) -> &Durability {
61        &self.durability
62    }
63
64    /// Number of registered nodes — useful for testing / introspection.
65    pub fn node_count(&self) -> usize {
66        self.graph.nodes.len()
67    }
68
69    /// Names of all registered nodes.
70    pub fn node_names(&self) -> Vec<&str> {
71        self.graph.nodes.keys().map(|s| s.as_str()).collect()
72    }
73
74    /// Optional graph version tag (set via [`crate::builder::Graph::with_version`]).
75    pub fn version(&self) -> Option<&str> {
76        self.graph.version.as_deref()
77    }
78
79    /// All annotations attached to `node_name`, or an empty map if the
80    /// node has no annotations / isn't registered.
81    pub fn annotations(
82        &self,
83        node_name: &str,
84    ) -> &std::collections::HashMap<String, serde_json::Value> {
85        static EMPTY: std::sync::OnceLock<std::collections::HashMap<String, serde_json::Value>> =
86            std::sync::OnceLock::new();
87        self.graph
88            .annotations
89            .get(node_name)
90            .unwrap_or_else(|| EMPTY.get_or_init(std::collections::HashMap::new))
91    }
92
93    /// Look up a single annotation value by `(node_name, key)`.
94    pub fn annotation(&self, node_name: &str, key: &str) -> Option<&serde_json::Value> {
95        self.graph
96            .annotations
97            .get(node_name)
98            .and_then(|m| m.get(key))
99    }
100}
101
102impl<S: GraphState + Clone + Send + 'static> CompiledGraph<S> {
103    /// Attach a checkpointer; the engine will save state after each superstep.
104    pub fn with_checkpointer(mut self, cp: Arc<dyn Checkpointer<S>>) -> Self {
105        self.checkpointer = Some(cp);
106        self
107    }
108
109    /// Pause the graph BEFORE each named node executes. Requires a checkpointer
110    /// (errors at invoke time if not configured). Interrupt names are validated
111    /// at invoke time, not compile time, because `with_interrupt_before` runs
112    /// after `compile()`. Resume via `CompiledGraph::resume`.
113    pub fn with_interrupt_before<I, N>(mut self, names: I) -> Self
114    where
115        I: IntoIterator<Item = N>,
116        N: Into<String>,
117    {
118        self.interrupt_before
119            .extend(names.into_iter().map(Into::into));
120        self
121    }
122
123    /// Pause the graph AFTER each named node completes (state already updated).
124    /// Requires a checkpointer. Resume via `CompiledGraph::resume`.
125    pub fn with_interrupt_after<I, N>(mut self, names: I) -> Self
126    where
127        I: IntoIterator<Item = N>,
128        N: Into<String>,
129    {
130        self.interrupt_after
131            .extend(names.into_iter().map(Into::into));
132        self
133    }
134
135    /// Continue execution from a previously-interrupted run.
136    ///
137    /// `state` is the (possibly user-edited) state to seed the next superstep
138    /// with. `run_id` and `step` come from the original `GraphInterrupted` error.
139    /// The resume's `RunnableConfig::run_id` is set to `run_id` so observers
140    /// can correlate with the original run.
141    pub async fn resume(
142        &self,
143        run_id: uuid::Uuid,
144        step: u64,
145        state: S,
146        config: RunnableConfig,
147    ) -> Result<S>
148    where
149        S::Update: Clone,
150    {
151        let mut cfg = config;
152        cfg.run_id = run_id;
153        engine::resume(self, state, cfg, step).await
154    }
155
156    // ---------------------------------------------------------------
157    // State inspection (HITL / time-travel / debugging)
158    // ---------------------------------------------------------------
159
160    /// Latest checkpointed state for `run_id`. Returns `None` if there is
161    /// no checkpointer attached or no state recorded for that run.
162    pub async fn get_state(&self, run_id: uuid::Uuid) -> Result<Option<S>> {
163        match &self.checkpointer {
164            Some(cp) => cp.load(run_id, None).await,
165            None => Ok(None),
166        }
167    }
168
169    /// State at a specific superstep — for time-travel.
170    pub async fn get_state_at(&self, run_id: uuid::Uuid, step: u64) -> Result<Option<S>> {
171        match &self.checkpointer {
172            Some(cp) => cp.load(run_id, Some(step)).await,
173            None => Ok(None),
174        }
175    }
176
177    /// Full step history for `run_id`. Each `(step, state)` pair is one
178    /// superstep boundary — earliest first.
179    pub async fn get_state_history(&self, run_id: uuid::Uuid) -> Result<Vec<(u64, S)>> {
180        let cp = match &self.checkpointer {
181            Some(cp) => cp,
182            None => return Ok(Vec::new()),
183        };
184        let steps = cp.list(run_id).await?;
185        let mut out = Vec::with_capacity(steps.len());
186        for s in steps {
187            if let Some(state) = cp.load(run_id, Some(s)).await? {
188                out.push((s, state));
189            }
190        }
191        Ok(out)
192    }
193
194    /// Save a (possibly user-edited) state at `step` for `run_id`. Used to
195    /// patch state before resuming an interrupted run.
196    ///
197    /// Errors if no checkpointer is attached.
198    pub async fn update_state(&self, run_id: uuid::Uuid, step: u64, state: &S) -> Result<()> {
199        match &self.checkpointer {
200            Some(cp) => cp.save(run_id, step, state).await,
201            None => Err(cognis_core::CognisError::Configuration(
202                "update_state requires a checkpointer; attach via .with_checkpointer(...)".into(),
203            )),
204        }
205    }
206}
207
208impl<S> CompiledGraph<S>
209where
210    S: GraphState + Clone + Send + 'static,
211    <S as GraphState>::Update: Clone,
212{
213    /// Stream events filtered by [`StreamModes`] — see the `stream_mode`
214    /// module for what each mode captures.
215    pub async fn stream_mode(
216        &self,
217        input: S,
218        modes: StreamModes,
219        config: RunnableConfig,
220    ) -> Result<cognis_core::EventStream> {
221        use cognis_core::Observer;
222        use futures::StreamExt;
223        use tokio::sync::mpsc;
224        use tokio_stream::wrappers::UnboundedReceiverStream;
225
226        struct ChannelObserver(mpsc::UnboundedSender<cognis_core::Event>);
227        impl Observer for ChannelObserver {
228            fn on_event(&self, event: &cognis_core::Event) {
229                let _ = self.0.send(event.clone());
230            }
231        }
232
233        let (tx, rx) = mpsc::unbounded_channel::<cognis_core::Event>();
234        let observer: Arc<dyn Observer> = Arc::new(ChannelObserver(tx));
235        let mut cfg = config;
236        cfg.observers.push(observer);
237
238        let this = self.clone();
239        tokio::spawn(async move {
240            let _ = engine::run(&this, input, cfg).await;
241        });
242
243        let filtered = UnboundedReceiverStream::new(rx).filter(move |e| {
244            let keep = modes.matches(e);
245            async move { keep }
246        });
247
248        Ok(cognis_core::EventStream::new(filtered))
249    }
250}
251
252#[async_trait]
253impl<S> Runnable<S, S> for CompiledGraph<S>
254where
255    S: GraphState + Clone + Send + 'static,
256    <S as GraphState>::Update: Clone,
257{
258    async fn invoke(&self, input: S, config: RunnableConfig) -> Result<S> {
259        engine::run(self, input, config).await
260    }
261
262    fn name(&self) -> &str {
263        "CompiledGraph"
264    }
265
266    /// Override the default `stream_events` to emit real per-node events as
267    /// the engine runs (real-time, not synthetic OnEnd-only). Engine events
268    /// embed `serde_json::Value::Null` so `S: Serialize` is not required.
269    async fn stream_events(
270        &self,
271        input: S,
272        config: RunnableConfig,
273    ) -> Result<cognis_core::EventStream> {
274        use cognis_core::Observer;
275        use tokio::sync::mpsc;
276        use tokio_stream::wrappers::UnboundedReceiverStream;
277
278        struct ChannelObserver(mpsc::UnboundedSender<cognis_core::Event>);
279        impl Observer for ChannelObserver {
280            fn on_event(&self, event: &cognis_core::Event) {
281                let _ = self.0.send(event.clone());
282            }
283        }
284
285        let (tx, rx) = mpsc::unbounded_channel::<cognis_core::Event>();
286        let observer: Arc<dyn Observer> = Arc::new(ChannelObserver(tx));
287        let mut cfg = config;
288        cfg.observers.push(observer);
289
290        let this = self.clone();
291        tokio::spawn(async move {
292            let _ = engine::run(&this, input, cfg).await;
293        });
294
295        Ok(cognis_core::EventStream::new(UnboundedReceiverStream::new(
296            rx,
297        )))
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use crate::goto::Goto;
305    use crate::node::{node_fn, NodeOut};
306
307    #[derive(Default, Clone, Debug, PartialEq, serde::Serialize)]
308    struct Counter {
309        n: u32,
310    }
311
312    #[derive(Default, Clone)]
313    struct CounterUpdate {
314        n: u32,
315    }
316
317    impl GraphState for Counter {
318        type Update = CounterUpdate;
319        fn apply(&mut self, u: Self::Update) {
320            self.n += u.n;
321        }
322    }
323
324    #[tokio::test]
325    async fn linear_two_nodes_runs_to_end() {
326        let g = Graph::<Counter>::new()
327            .node(
328                "a",
329                node_fn::<Counter, _, _>("a", |_s, _c| async move {
330                    Ok(NodeOut {
331                        update: CounterUpdate { n: 1 },
332                        goto: Goto::node("b"),
333                    })
334                }),
335            )
336            .node(
337                "b",
338                node_fn::<Counter, _, _>("b", |_s, _c| async move {
339                    Ok(NodeOut {
340                        update: CounterUpdate { n: 10 },
341                        goto: Goto::end(),
342                    })
343                }),
344            )
345            .start_at("a")
346            .compile()
347            .unwrap();
348
349        let out = g
350            .invoke(Counter::default(), RunnableConfig::default())
351            .await
352            .unwrap();
353        assert_eq!(out, Counter { n: 11 });
354    }
355
356    #[tokio::test]
357    async fn cycle_terminates_via_state_check() {
358        // Loop until counter reaches 5.
359        let g = Graph::<Counter>::new()
360            .node(
361                "tick",
362                node_fn::<Counter, _, _>("tick", |s, _c| {
363                    let cur = s.n;
364                    async move {
365                        if cur >= 5 {
366                            Ok(NodeOut {
367                                update: CounterUpdate { n: 0 },
368                                goto: Goto::end(),
369                            })
370                        } else {
371                            Ok(NodeOut {
372                                update: CounterUpdate { n: 1 },
373                                goto: Goto::node("tick"),
374                            })
375                        }
376                    }
377                }),
378            )
379            .start_at("tick")
380            .compile()
381            .unwrap();
382
383        let out = g
384            .invoke(Counter::default(), RunnableConfig::default())
385            .await
386            .unwrap();
387        assert_eq!(out, Counter { n: 5 });
388    }
389
390    #[tokio::test]
391    async fn recursion_limit_is_honored() {
392        // Infinite loop → expect RecursionLimit error.
393        let g = Graph::<Counter>::new()
394            .node(
395                "loop",
396                node_fn::<Counter, _, _>("loop", |_s, _c| async move {
397                    Ok(NodeOut {
398                        update: CounterUpdate { n: 1 },
399                        goto: Goto::node("loop"),
400                    })
401                }),
402            )
403            .start_at("loop")
404            .compile()
405            .unwrap();
406
407        let cfg = RunnableConfig::default().with_recursion_limit(3);
408        let err = g.invoke(Counter::default(), cfg).await.unwrap_err();
409        assert!(matches!(
410            err,
411            cognis_core::CognisError::RecursionLimit { limit: 3 }
412        ));
413    }
414
415    #[tokio::test]
416    async fn compiled_graph_clones_and_runs() {
417        let g = Graph::<Counter>::new()
418            .node(
419                "a",
420                node_fn::<Counter, _, _>("a", |_s, _c| async move {
421                    Ok(NodeOut {
422                        update: CounterUpdate { n: 1 },
423                        goto: Goto::end(),
424                    })
425                }),
426            )
427            .start_at("a")
428            .compile()
429            .unwrap();
430        let g2 = g.clone();
431        let r1 = g
432            .invoke(Counter::default(), RunnableConfig::default())
433            .await
434            .unwrap();
435        let r2 = g2
436            .invoke(Counter::default(), RunnableConfig::default())
437            .await
438            .unwrap();
439        assert_eq!(r1.n, 1);
440        assert_eq!(r2.n, 1);
441    }
442
443    #[tokio::test]
444    async fn route_to_unknown_node_errors() {
445        let g = Graph::<Counter>::new()
446            .node(
447                "bad",
448                node_fn::<Counter, _, _>("bad", |_s, _c| async move {
449                    Ok(NodeOut {
450                        update: CounterUpdate { n: 0 },
451                        goto: Goto::node("ghost"),
452                    })
453                }),
454            )
455            .start_at("bad")
456            .compile()
457            .unwrap();
458        let err = g
459            .invoke(Counter::default(), RunnableConfig::default())
460            .await
461            .unwrap_err();
462        assert!(format!("{err}").contains("ghost"));
463    }
464
465    #[tokio::test]
466    async fn stream_events_emits_per_node() {
467        use cognis_core::Event;
468        use futures::StreamExt;
469
470        let g = Graph::<Counter>::new()
471            .node(
472                "a",
473                node_fn::<Counter, _, _>("a", |_, _| async move {
474                    Ok(NodeOut {
475                        update: CounterUpdate { n: 1 },
476                        goto: Goto::node("b"),
477                    })
478                }),
479            )
480            .node(
481                "b",
482                node_fn::<Counter, _, _>("b", |_, _| async move {
483                    Ok(NodeOut {
484                        update: CounterUpdate { n: 1 },
485                        goto: Goto::end(),
486                    })
487                }),
488            )
489            .start_at("a")
490            .compile()
491            .unwrap();
492
493        let mut s = g
494            .stream_events(Counter::default(), RunnableConfig::default())
495            .await
496            .unwrap();
497        let mut events = Vec::new();
498        while let Some(e) = s.next().await {
499            events.push(e);
500        }
501        assert!(events
502            .iter()
503            .any(|e| matches!(e, Event::OnNodeStart { node, .. } if node == "a")));
504        assert!(events
505            .iter()
506            .any(|e| matches!(e, Event::OnNodeStart { node, .. } if node == "b")));
507        assert!(events.iter().any(|e| matches!(e, Event::OnEnd { .. })));
508    }
509}