1use std::collections::HashMap;
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11
12use cognis_core::{CognisError, Result};
13
14use crate::builder::Graph;
15use crate::compiled::CompiledGraph;
16use crate::node::Node;
17use crate::state::GraphState;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct GraphSnapshot {
22 pub nodes: Vec<String>,
24 pub edges: Vec<(String, String)>,
26 pub start: Option<String>,
28 pub version: Option<String>,
30}
31
32impl<S: GraphState> CompiledGraph<S> {
33 pub fn snapshot(&self) -> GraphSnapshot {
35 let mut nodes: Vec<String> = self.graph.nodes.keys().cloned().collect();
36 nodes.sort();
37 let mut edges: Vec<(String, String)> = self
38 .graph
39 .edges
40 .iter()
41 .map(|(f, t)| (f.clone(), t.clone()))
42 .collect();
43 edges.sort();
44 GraphSnapshot {
45 nodes,
46 edges,
47 start: self.graph.start.clone(),
48 version: self.graph.version.clone(),
49 }
50 }
51}
52
53pub type NodeFactory<S> = dyn Fn(&str) -> Option<Arc<dyn Node<S>>>;
56
57impl<S: GraphState> Graph<S> {
58 pub fn from_snapshot(snap: &GraphSnapshot, node_factory: &NodeFactory<S>) -> Result<Self> {
62 let mut g = Graph::<S>::new();
63 for name in &snap.nodes {
64 let n = node_factory(name).ok_or_else(|| {
65 CognisError::Configuration(format!(
66 "from_snapshot: factory returned no node for `{name}`"
67 ))
68 })?;
69 g.nodes.insert(name.clone(), n);
70 }
71 let mut edges: HashMap<String, String> = HashMap::new();
72 for (f, t) in &snap.edges {
73 edges.insert(f.clone(), t.clone());
74 }
75 g.edges = edges;
76 g.start = snap.start.clone();
77 g.version = snap.version.clone();
78 Ok(g)
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use crate::goto::Goto;
86 use crate::node::{node_fn, NodeOut};
87
88 #[derive(Default, Clone, Debug, PartialEq)]
89 struct S {
90 n: u32,
91 }
92 #[derive(Default)]
93 struct SU {
94 n: u32,
95 }
96 impl GraphState for S {
97 type Update = SU;
98 fn apply(&mut self, u: Self::Update) {
99 self.n += u.n;
100 }
101 }
102
103 #[test]
104 fn snapshot_roundtrip_via_factory() {
105 let g = Graph::<S>::new()
106 .node(
107 "a",
108 node_fn::<S, _, _>("a", |_, _| async {
109 Ok(NodeOut {
110 update: SU { n: 1 },
111 goto: Goto::end(),
112 })
113 }),
114 )
115 .start_at("a")
116 .compile()
117 .unwrap();
118 let snap = g.snapshot();
119 assert_eq!(snap.nodes, vec!["a"]);
120 assert_eq!(snap.start.as_deref(), Some("a"));
121
122 let factory = |name: &str| -> Option<Arc<dyn Node<S>>> {
123 if name == "a" {
124 Some(Arc::new(node_fn::<S, _, _>("a", |_, _| async {
125 Ok(NodeOut {
126 update: SU { n: 1 },
127 goto: Goto::end(),
128 })
129 })))
130 } else {
131 None
132 }
133 };
134 let g2 = Graph::<S>::from_snapshot(&snap, &factory).unwrap();
135 let snap2 = g2.compile().unwrap().snapshot();
136 assert_eq!(snap2.nodes, snap.nodes);
137 assert_eq!(snap2.start, snap.start);
138 }
139}