cognis_graph/checkpoint/
in_memory.rs1use std::collections::HashMap;
5use std::sync::Mutex;
6
7use async_trait::async_trait;
8use uuid::Uuid;
9
10use cognis_core::{CognisError, Result};
11
12use crate::state::GraphState;
13
14use super::Checkpointer;
15
16pub struct InMemoryCheckpointer<S: GraphState + Clone> {
19 runs: Mutex<HashMap<(Uuid, String), HashMap<u64, S>>>,
20 active: Mutex<HashMap<(Uuid, String, u64), Vec<super::ActiveSnapshot>>>,
21 namespace: String,
22}
23
24impl<S: GraphState + Clone> Default for InMemoryCheckpointer<S> {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl<S: GraphState + Clone> InMemoryCheckpointer<S> {
31 pub fn new() -> Self {
33 Self {
34 runs: Mutex::new(HashMap::new()),
35 active: Mutex::new(HashMap::new()),
36 namespace: String::new(),
37 }
38 }
39
40 pub fn with_namespace(mut self, ns: impl Into<String>) -> Self {
43 self.namespace = ns.into();
44 self
45 }
46}
47
48#[async_trait]
49impl<S: GraphState + Clone> Checkpointer<S> for InMemoryCheckpointer<S> {
50 async fn save(&self, run_id: Uuid, step: u64, state: &S) -> Result<()> {
51 let mut runs = self
52 .runs
53 .lock()
54 .map_err(|e| CognisError::Internal(format!("checkpointer mutex poisoned: {e}")))?;
55 runs.entry((run_id, self.namespace.clone()))
56 .or_default()
57 .insert(step, state.clone());
58 Ok(())
59 }
60
61 async fn load(&self, run_id: Uuid, step: Option<u64>) -> Result<Option<S>> {
62 let runs = self
63 .runs
64 .lock()
65 .map_err(|e| CognisError::Internal(format!("checkpointer mutex poisoned: {e}")))?;
66 let Some(steps) = runs.get(&(run_id, self.namespace.clone())) else {
67 return Ok(None);
68 };
69 match step {
70 Some(s) => Ok(steps.get(&s).cloned()),
71 None => {
72 let max = steps.keys().copied().max();
73 Ok(max.and_then(|s| steps.get(&s).cloned()))
74 }
75 }
76 }
77
78 async fn list(&self, run_id: Uuid) -> Result<Vec<u64>> {
79 let runs = self
80 .runs
81 .lock()
82 .map_err(|e| CognisError::Internal(format!("checkpointer mutex poisoned: {e}")))?;
83 let mut steps: Vec<u64> = runs
84 .get(&(run_id, self.namespace.clone()))
85 .map(|s| s.keys().copied().collect())
86 .unwrap_or_default();
87 steps.sort();
88 Ok(steps)
89 }
90
91 async fn save_active(
92 &self,
93 run_id: Uuid,
94 step: u64,
95 active: &[super::ActiveSnapshot],
96 ) -> Result<()> {
97 let mut a = self
98 .active
99 .lock()
100 .map_err(|e| CognisError::Internal(format!("active mutex poisoned: {e}")))?;
101 a.insert((run_id, self.namespace.clone(), step), active.to_vec());
102 Ok(())
103 }
104
105 async fn load_active(&self, run_id: Uuid, step: u64) -> Result<Vec<super::ActiveSnapshot>> {
106 let a = self
107 .active
108 .lock()
109 .map_err(|e| CognisError::Internal(format!("active mutex poisoned: {e}")))?;
110 Ok(a.get(&(run_id, self.namespace.clone(), step))
111 .cloned()
112 .unwrap_or_default())
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[derive(Default, Clone, Debug, PartialEq)]
121 struct S {
122 n: u32,
123 }
124 #[derive(Default)]
125 struct SU {
126 n: u32,
127 }
128 impl GraphState for S {
129 type Update = SU;
130 fn apply(&mut self, u: Self::Update) {
131 self.n += u.n;
132 }
133 }
134
135 #[tokio::test]
136 async fn save_then_load_explicit_step() {
137 let cp = InMemoryCheckpointer::<S>::new();
138 let id = Uuid::new_v4();
139 cp.save(id, 0, &S { n: 1 }).await.unwrap();
140 cp.save(id, 1, &S { n: 2 }).await.unwrap();
141 cp.save(id, 2, &S { n: 3 }).await.unwrap();
142
143 assert_eq!(cp.load(id, Some(0)).await.unwrap(), Some(S { n: 1 }));
144 assert_eq!(cp.load(id, Some(1)).await.unwrap(), Some(S { n: 2 }));
145 assert_eq!(cp.load(id, Some(99)).await.unwrap(), None);
146 }
147
148 #[tokio::test]
149 async fn load_latest_when_step_is_none() {
150 let cp = InMemoryCheckpointer::<S>::new();
151 let id = Uuid::new_v4();
152 cp.save(id, 0, &S { n: 1 }).await.unwrap();
153 cp.save(id, 5, &S { n: 9 }).await.unwrap();
154 cp.save(id, 2, &S { n: 4 }).await.unwrap();
155 assert_eq!(cp.load(id, None).await.unwrap(), Some(S { n: 9 }));
156 }
157
158 #[tokio::test]
159 async fn list_returns_sorted_steps() {
160 let cp = InMemoryCheckpointer::<S>::new();
161 let id = Uuid::new_v4();
162 for s in [3u64, 1, 4, 1, 5, 9, 2, 6] {
163 cp.save(id, s, &S { n: s as u32 }).await.unwrap();
164 }
165 assert_eq!(cp.list(id).await.unwrap(), vec![1, 2, 3, 4, 5, 6, 9]);
166 }
167
168 #[tokio::test]
169 async fn unknown_run_returns_empty() {
170 let cp = InMemoryCheckpointer::<S>::new();
171 let unknown = Uuid::new_v4();
172 assert_eq!(cp.load(unknown, None).await.unwrap(), None);
173 assert!(cp.list(unknown).await.unwrap().is_empty());
174 }
175}