Skip to main content

cognis_graph/checkpoint/
in_memory.rs

1//! In-process checkpointer. Useful for tests and short-lived processes.
2//! NOT durable across restarts.
3
4use std::collections::HashMap;
5use std::sync::Mutex;
6
7use async_trait::async_trait;
8use uuid::Uuid;
9
10use cognis_core::{CognisError, Result};
11
12use crate::state::GraphState;
13
14use super::Checkpointer;
15
16/// Stores `(run_id, namespace, step) -> S` in a Mutex-protected HashMap.
17/// State must be `Clone` because saving stores a clone and loading returns one.
18pub struct InMemoryCheckpointer<S: GraphState + Clone> {
19    runs: Mutex<HashMap<(Uuid, String), HashMap<u64, S>>>,
20    active: Mutex<HashMap<(Uuid, String, u64), Vec<super::ActiveSnapshot>>>,
21    namespace: String,
22}
23
24impl<S: GraphState + Clone> Default for InMemoryCheckpointer<S> {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl<S: GraphState + Clone> InMemoryCheckpointer<S> {
31    /// Empty checkpointer.
32    pub fn new() -> Self {
33        Self {
34            runs: Mutex::new(HashMap::new()),
35            active: Mutex::new(HashMap::new()),
36            namespace: String::new(),
37        }
38    }
39
40    /// Set the namespace for subgraph isolation. Operations on this instance
41    /// scope to `(run_id, namespace, step)`.
42    pub fn with_namespace(mut self, ns: impl Into<String>) -> Self {
43        self.namespace = ns.into();
44        self
45    }
46}
47
48#[async_trait]
49impl<S: GraphState + Clone> Checkpointer<S> for InMemoryCheckpointer<S> {
50    async fn save(&self, run_id: Uuid, step: u64, state: &S) -> Result<()> {
51        let mut runs = self
52            .runs
53            .lock()
54            .map_err(|e| CognisError::Internal(format!("checkpointer mutex poisoned: {e}")))?;
55        runs.entry((run_id, self.namespace.clone()))
56            .or_default()
57            .insert(step, state.clone());
58        Ok(())
59    }
60
61    async fn load(&self, run_id: Uuid, step: Option<u64>) -> Result<Option<S>> {
62        let runs = self
63            .runs
64            .lock()
65            .map_err(|e| CognisError::Internal(format!("checkpointer mutex poisoned: {e}")))?;
66        let Some(steps) = runs.get(&(run_id, self.namespace.clone())) else {
67            return Ok(None);
68        };
69        match step {
70            Some(s) => Ok(steps.get(&s).cloned()),
71            None => {
72                let max = steps.keys().copied().max();
73                Ok(max.and_then(|s| steps.get(&s).cloned()))
74            }
75        }
76    }
77
78    async fn list(&self, run_id: Uuid) -> Result<Vec<u64>> {
79        let runs = self
80            .runs
81            .lock()
82            .map_err(|e| CognisError::Internal(format!("checkpointer mutex poisoned: {e}")))?;
83        let mut steps: Vec<u64> = runs
84            .get(&(run_id, self.namespace.clone()))
85            .map(|s| s.keys().copied().collect())
86            .unwrap_or_default();
87        steps.sort();
88        Ok(steps)
89    }
90
91    async fn save_active(
92        &self,
93        run_id: Uuid,
94        step: u64,
95        active: &[super::ActiveSnapshot],
96    ) -> Result<()> {
97        let mut a = self
98            .active
99            .lock()
100            .map_err(|e| CognisError::Internal(format!("active mutex poisoned: {e}")))?;
101        a.insert((run_id, self.namespace.clone(), step), active.to_vec());
102        Ok(())
103    }
104
105    async fn load_active(&self, run_id: Uuid, step: u64) -> Result<Vec<super::ActiveSnapshot>> {
106        let a = self
107            .active
108            .lock()
109            .map_err(|e| CognisError::Internal(format!("active mutex poisoned: {e}")))?;
110        Ok(a.get(&(run_id, self.namespace.clone(), step))
111            .cloned()
112            .unwrap_or_default())
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[derive(Default, Clone, Debug, PartialEq)]
121    struct S {
122        n: u32,
123    }
124    #[derive(Default)]
125    struct SU {
126        n: u32,
127    }
128    impl GraphState for S {
129        type Update = SU;
130        fn apply(&mut self, u: Self::Update) {
131            self.n += u.n;
132        }
133    }
134
135    #[tokio::test]
136    async fn save_then_load_explicit_step() {
137        let cp = InMemoryCheckpointer::<S>::new();
138        let id = Uuid::new_v4();
139        cp.save(id, 0, &S { n: 1 }).await.unwrap();
140        cp.save(id, 1, &S { n: 2 }).await.unwrap();
141        cp.save(id, 2, &S { n: 3 }).await.unwrap();
142
143        assert_eq!(cp.load(id, Some(0)).await.unwrap(), Some(S { n: 1 }));
144        assert_eq!(cp.load(id, Some(1)).await.unwrap(), Some(S { n: 2 }));
145        assert_eq!(cp.load(id, Some(99)).await.unwrap(), None);
146    }
147
148    #[tokio::test]
149    async fn load_latest_when_step_is_none() {
150        let cp = InMemoryCheckpointer::<S>::new();
151        let id = Uuid::new_v4();
152        cp.save(id, 0, &S { n: 1 }).await.unwrap();
153        cp.save(id, 5, &S { n: 9 }).await.unwrap();
154        cp.save(id, 2, &S { n: 4 }).await.unwrap();
155        assert_eq!(cp.load(id, None).await.unwrap(), Some(S { n: 9 }));
156    }
157
158    #[tokio::test]
159    async fn list_returns_sorted_steps() {
160        let cp = InMemoryCheckpointer::<S>::new();
161        let id = Uuid::new_v4();
162        for s in [3u64, 1, 4, 1, 5, 9, 2, 6] {
163            cp.save(id, s, &S { n: s as u32 }).await.unwrap();
164        }
165        assert_eq!(cp.list(id).await.unwrap(), vec![1, 2, 3, 4, 5, 6, 9]);
166    }
167
168    #[tokio::test]
169    async fn unknown_run_returns_empty() {
170        let cp = InMemoryCheckpointer::<S>::new();
171        let unknown = Uuid::new_v4();
172        assert_eq!(cp.load(unknown, None).await.unwrap(), None);
173        assert!(cp.list(unknown).await.unwrap().is_empty());
174    }
175}