Skip to main content

cognis_graph/
snapshot.rs

1//! Compiled-graph snapshot — the *shape* of a graph (not its node closures).
2//!
3//! Use cases: schema-evolution checks, cross-version comparison, generated
4//! docs, "spec" files in CI. `from_snapshot` rebuilds via a caller-supplied
5//! `node_factory` because the closures themselves aren't serializable.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11
12use cognis_core::{CognisError, Result};
13
14use crate::builder::Graph;
15use crate::compiled::CompiledGraph;
16use crate::node::Node;
17use crate::state::GraphState;
18
19/// Serializable shape of a `CompiledGraph<S>`.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct GraphSnapshot {
22    /// Registered node names (alphabetical).
23    pub nodes: Vec<String>,
24    /// Static `(from, to)` edges declared via `Graph::edge`.
25    pub edges: Vec<(String, String)>,
26    /// Start node, if set.
27    pub start: Option<String>,
28    /// Optional version tag (see [`crate::Graph::with_version`]).
29    pub version: Option<String>,
30}
31
32impl<S: GraphState> CompiledGraph<S> {
33    /// Snapshot the graph's static shape.
34    pub fn snapshot(&self) -> GraphSnapshot {
35        let mut nodes: Vec<String> = self.graph.nodes.keys().cloned().collect();
36        nodes.sort();
37        let mut edges: Vec<(String, String)> = self
38            .graph
39            .edges
40            .iter()
41            .map(|(f, t)| (f.clone(), t.clone()))
42            .collect();
43        edges.sort();
44        GraphSnapshot {
45            nodes,
46            edges,
47            start: self.graph.start.clone(),
48            version: self.graph.version.clone(),
49        }
50    }
51}
52
53/// Factory function used by [`Graph::from_snapshot`] — looks up a node
54/// implementation by name.
55pub type NodeFactory<S> = dyn Fn(&str) -> Option<Arc<dyn Node<S>>>;
56
57impl<S: GraphState> Graph<S> {
58    /// Reconstruct a `Graph<S>` from a [`GraphSnapshot`] using a factory
59    /// to produce node implementations by name. The factory must return
60    /// a node for every name in `snap.nodes`; otherwise this errors.
61    pub fn from_snapshot(snap: &GraphSnapshot, node_factory: &NodeFactory<S>) -> Result<Self> {
62        let mut g = Graph::<S>::new();
63        for name in &snap.nodes {
64            let n = node_factory(name).ok_or_else(|| {
65                CognisError::Configuration(format!(
66                    "from_snapshot: factory returned no node for `{name}`"
67                ))
68            })?;
69            g.nodes.insert(name.clone(), n);
70        }
71        let mut edges: HashMap<String, String> = HashMap::new();
72        for (f, t) in &snap.edges {
73            edges.insert(f.clone(), t.clone());
74        }
75        g.edges = edges;
76        g.start = snap.start.clone();
77        g.version = snap.version.clone();
78        Ok(g)
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::goto::Goto;
86    use crate::node::{node_fn, NodeOut};
87
88    #[derive(Default, Clone, Debug, PartialEq)]
89    struct S {
90        n: u32,
91    }
92    #[derive(Default)]
93    struct SU {
94        n: u32,
95    }
96    impl GraphState for S {
97        type Update = SU;
98        fn apply(&mut self, u: Self::Update) {
99            self.n += u.n;
100        }
101    }
102
103    #[test]
104    fn snapshot_roundtrip_via_factory() {
105        let g = Graph::<S>::new()
106            .node(
107                "a",
108                node_fn::<S, _, _>("a", |_, _| async {
109                    Ok(NodeOut {
110                        update: SU { n: 1 },
111                        goto: Goto::end(),
112                    })
113                }),
114            )
115            .start_at("a")
116            .compile()
117            .unwrap();
118        let snap = g.snapshot();
119        assert_eq!(snap.nodes, vec!["a"]);
120        assert_eq!(snap.start.as_deref(), Some("a"));
121
122        let factory = |name: &str| -> Option<Arc<dyn Node<S>>> {
123            if name == "a" {
124                Some(Arc::new(node_fn::<S, _, _>("a", |_, _| async {
125                    Ok(NodeOut {
126                        update: SU { n: 1 },
127                        goto: Goto::end(),
128                    })
129                })))
130            } else {
131                None
132            }
133        };
134        let g2 = Graph::<S>::from_snapshot(&snap, &factory).unwrap();
135        let snap2 = g2.compile().unwrap().snapshot();
136        assert_eq!(snap2.nodes, snap.nodes);
137        assert_eq!(snap2.start, snap.start);
138    }
139}