Skip to main content

khive_fold/
checkpoint.rs

1//! Generic checkpoint envelope and in-memory store for fold-managed indexes.
2
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6use chrono::{DateTime, Utc};
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10
11use khive_types::Hash32;
12
13use crate::context::FoldContext;
14use crate::error::FoldError;
15
16/// Generic checkpoint envelope wrapping an arbitrary fold state snapshot.
17#[derive(Debug, Clone)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub struct Checkpoint<S> {
20    /// Human-readable checkpoint identifier (e.g. `"hnsw_idx:ckpt-1"`).
21    pub id: String,
22
23    /// The snapshot state captured at this checkpoint.
24    pub state: S,
25
26    /// Unique identifier for this checkpoint instance.
27    pub uuid: Uuid,
28
29    /// BLAKE3 content hash of the state; verified on load.
30    pub hash: Hash32,
31
32    /// Number of entries processed when this checkpoint was taken.
33    pub entries_processed: usize,
34
35    /// Fold context at checkpoint time.
36    pub context: FoldContext,
37
38    /// Monotonically increasing fold schema version.
39    pub fold_version: usize,
40
41    /// Wall-clock time when this checkpoint was created.
42    pub created_at: DateTime<Utc>,
43}
44
45impl<S: Serialize> Checkpoint<S> {
46    /// Create a new checkpoint, computing the BLAKE3 hash of the state.
47    // REASON: Checkpoint::new requires id, state, uuid, entries_processed, context, and
48    // fold_version — each is a semantically distinct field with no natural grouping into
49    // a builder or sub-struct without breaking the public API.
50    #[allow(clippy::too_many_arguments)]
51    pub fn new(
52        id: impl Into<String>,
53        state: S,
54        uuid: Uuid,
55        entries_processed: usize,
56        context: FoldContext,
57        fold_version: usize,
58    ) -> Result<Self, FoldError> {
59        let bytes = serde_json::to_vec(&state)?;
60        let hash = Hash32::from_blake3(&bytes);
61        Ok(Self {
62            id: id.into(),
63            state,
64            uuid,
65            hash,
66            entries_processed,
67            context,
68            fold_version,
69            // Foundation layer does not call Utc::now() — epoch is the safe default.
70            // Callers that need the current time should set created_at after construction.
71            created_at: DateTime::<Utc>::default(),
72        })
73    }
74
75    /// Create a checkpoint with a pre-computed hash (for deserialization / testing).
76    // REASON: with_hash mirrors the new() parameter set (minus auto-computed hash) for
77    // deserialization and testing; same structural constraint as new() above.
78    #[allow(clippy::too_many_arguments)]
79    pub fn with_hash(
80        id: impl Into<String>,
81        state: S,
82        uuid: Uuid,
83        hash: Hash32,
84        entries_processed: usize,
85        context: FoldContext,
86        fold_version: usize,
87    ) -> Self {
88        Self {
89            id: id.into(),
90            state,
91            uuid,
92            hash,
93            entries_processed,
94            context,
95            fold_version,
96            // Foundation layer does not call Utc::now() — epoch is the safe default.
97            created_at: DateTime::<Utc>::default(),
98        }
99    }
100}
101
102/// Trait for checkpoint persistence backends.
103pub trait CheckpointStore<S> {
104    /// Persist a checkpoint, computing and storing an integrity hash.
105    fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
106    where
107        S: Clone + Serialize;
108
109    /// Load a checkpoint by exact `id`, verifying the integrity hash.
110    fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
111    where
112        S: Clone + Serialize;
113
114    /// Load the most recently created checkpoint whose `id` starts with `prefix`.
115    fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
116    where
117        S: Clone + Serialize;
118
119    /// Delete the checkpoint with the given `id`.
120    fn delete(&self, id: &str) -> Result<(), FoldError>;
121
122    /// List all checkpoint `id` strings currently stored.
123    fn list(&self) -> Result<Vec<String>, FoldError>;
124}
125
126/// In-memory checkpoint store backed by a `RwLock<HashMap>`.
127pub struct InMemoryCheckpointStore<S> {
128    inner: Arc<RwLock<HashMap<String, Checkpoint<S>>>>,
129}
130
131impl<S> InMemoryCheckpointStore<S> {
132    /// Create a new empty in-memory store.
133    pub fn new() -> Self {
134        Self {
135            inner: Arc::new(RwLock::new(HashMap::new())),
136        }
137    }
138}
139
140impl<S> Default for InMemoryCheckpointStore<S> {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl<S: Clone + Send + Sync + Serialize + 'static> CheckpointStore<S>
147    for InMemoryCheckpointStore<S>
148{
149    fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
150    where
151        S: Clone + Serialize,
152    {
153        // Recompute the hash from the state to ensure the stored hash is canonical.
154        let bytes = serde_json::to_vec(&checkpoint.state)?;
155        let computed = Hash32::from_blake3(&bytes);
156        let mut stored = checkpoint;
157        stored.hash = computed;
158
159        let mut guard = self
160            .inner
161            .write()
162            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
163        guard.insert(stored.id.clone(), stored);
164        Ok(())
165    }
166
167    fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
168    where
169        S: Clone + Serialize,
170    {
171        let guard = self
172            .inner
173            .read()
174            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
175        let Some(checkpoint) = guard.get(id).cloned() else {
176            return Ok(None);
177        };
178
179        // Verify integrity: recompute hash from state and compare.
180        let bytes = serde_json::to_vec(&checkpoint.state)?;
181        let computed = Hash32::from_blake3(&bytes);
182        if !checkpoint.hash.eq_ct(&computed) {
183            return Err(FoldError::IntegrityMismatch {
184                id: id.to_owned(),
185                stored: checkpoint.hash.to_string(),
186                computed: computed.to_string(),
187            });
188        }
189
190        Ok(Some(checkpoint))
191    }
192
193    fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
194    where
195        S: Clone + Serialize,
196    {
197        let guard = self
198            .inner
199            .read()
200            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
201
202        let latest = guard
203            .values()
204            .filter(|c| c.id.starts_with(prefix))
205            // Tiebreak on uuid for determinism when created_at is equal.
206            .max_by_key(|c| (c.created_at, c.uuid));
207
208        Ok(latest.cloned())
209    }
210
211    fn delete(&self, id: &str) -> Result<(), FoldError> {
212        let mut guard = self
213            .inner
214            .write()
215            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
216        if guard.remove(id).is_none() {
217            return Err(FoldError::CheckpointNotFound(id.to_owned()));
218        }
219        Ok(())
220    }
221
222    fn list(&self) -> Result<Vec<String>, FoldError> {
223        let guard = self
224            .inner
225            .read()
226            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
227        Ok(guard.keys().cloned().collect())
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    fn sample_checkpoint(id: &str, entries: usize) -> Checkpoint<String> {
236        Checkpoint::new(
237            id,
238            format!("state-{entries}"),
239            Uuid::new_v4(),
240            entries,
241            FoldContext::new(),
242            1,
243        )
244        .expect("sample_checkpoint should not fail serialization")
245    }
246
247    #[test]
248    fn save_and_load_roundtrip() {
249        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
250        let ckpt = sample_checkpoint("my-index:ckpt-1", 100);
251        store.save(ckpt).unwrap();
252        let loaded = store.load("my-index:ckpt-1").unwrap().unwrap();
253        assert_eq!(loaded.state, "state-100");
254        assert_eq!(loaded.entries_processed, 100);
255    }
256
257    #[test]
258    fn load_missing_returns_none() {
259        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
260        assert!(store.load("nonexistent").unwrap().is_none());
261    }
262
263    #[test]
264    fn load_latest_returns_most_recent() {
265        use chrono::Duration;
266
267        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
268        let base = DateTime::<Utc>::default();
269
270        // Build checkpoints with explicit, strictly ordered created_at values
271        // so load_latest is deterministic without relying on wall-clock time.
272        let mut ckpt1 = sample_checkpoint("idx:ckpt-1", 10);
273        ckpt1.created_at = base;
274        let mut ckpt2 = sample_checkpoint("idx:ckpt-2", 20);
275        ckpt2.created_at = base + Duration::milliseconds(5);
276        let mut ckpt3 = sample_checkpoint("idx:ckpt-3", 30);
277        ckpt3.created_at = base + Duration::milliseconds(10);
278
279        store.save(ckpt1).unwrap();
280        store.save(ckpt2).unwrap();
281        store.save(ckpt3).unwrap();
282
283        let latest = store.load_latest("idx").unwrap().unwrap();
284        assert_eq!(latest.entries_processed, 30);
285    }
286
287    #[test]
288    fn load_latest_no_match_returns_none() {
289        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
290        store.save(sample_checkpoint("other:ckpt-1", 5)).unwrap();
291        assert!(store.load_latest("my-index").unwrap().is_none());
292    }
293
294    #[test]
295    fn load_latest_prefix_isolation() {
296        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
297        store.save(sample_checkpoint("alpha:ckpt-1", 10)).unwrap();
298        store.save(sample_checkpoint("beta:ckpt-1", 999)).unwrap();
299
300        let latest_alpha = store.load_latest("alpha").unwrap().unwrap();
301        assert_eq!(latest_alpha.entries_processed, 10);
302    }
303
304    #[test]
305    fn checkpoint_fields_accessible() {
306        let ckpt: Checkpoint<u32> =
307            Checkpoint::new("test:ckpt", 42u32, Uuid::new_v4(), 7, FoldContext::new(), 3).unwrap();
308        assert_eq!(ckpt.state, 42);
309        assert_eq!(ckpt.entries_processed, 7);
310        assert_eq!(ckpt.fold_version, 3);
311    }
312
313    // --- Additional tests (F-NEW-8) ---
314
315    #[cfg(feature = "serde")]
316    #[test]
317    fn serde_roundtrip() {
318        let ckpt = sample_checkpoint("serde:test", 42);
319        let json = serde_json::to_string(&ckpt).expect("serialize");
320        let restored: Checkpoint<String> = serde_json::from_str(&json).expect("deserialize");
321        assert_eq!(ckpt.id, restored.id);
322        assert_eq!(ckpt.state, restored.state);
323        assert_eq!(ckpt.entries_processed, restored.entries_processed);
324        assert_eq!(ckpt.fold_version, restored.fold_version);
325        assert_eq!(ckpt.uuid, restored.uuid);
326        // Hash bytes should survive the roundtrip unchanged.
327        assert_eq!(ckpt.hash.as_bytes(), restored.hash.as_bytes());
328    }
329
330    #[test]
331    fn delete_existing_succeeds() {
332        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
333        store.save(sample_checkpoint("del:ckpt-1", 1)).unwrap();
334        store.delete("del:ckpt-1").unwrap();
335        assert!(store.load("del:ckpt-1").unwrap().is_none());
336    }
337
338    #[test]
339    fn delete_nonexistent_returns_not_found() {
340        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
341        let err = store.delete("nope").unwrap_err();
342        assert!(
343            matches!(err, FoldError::CheckpointNotFound(ref id) if id == "nope"),
344            "expected CheckpointNotFound, got {err:?}"
345        );
346    }
347
348    #[test]
349    fn list_returns_all_ids() {
350        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
351        store.save(sample_checkpoint("a:ckpt-1", 1)).unwrap();
352        store.save(sample_checkpoint("b:ckpt-1", 2)).unwrap();
353        store.save(sample_checkpoint("c:ckpt-1", 3)).unwrap();
354        let mut ids = store.list().unwrap();
355        ids.sort();
356        assert_eq!(ids, vec!["a:ckpt-1", "b:ckpt-1", "c:ckpt-1"]);
357    }
358
359    #[test]
360    fn list_empty_store() {
361        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
362        assert!(store.list().unwrap().is_empty());
363    }
364
365    #[test]
366    fn save_overwrite_replaces_previous() {
367        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
368        let ckpt1 = sample_checkpoint("overwrite:ckpt-1", 10);
369        store.save(ckpt1).unwrap();
370
371        // Save again with the same id but different state.
372        let ckpt2 = Checkpoint::new(
373            "overwrite:ckpt-1",
374            "new-state".to_string(),
375            Uuid::new_v4(),
376            99,
377            FoldContext::new(),
378            2,
379        )
380        .unwrap();
381        store.save(ckpt2).unwrap();
382
383        let loaded = store.load("overwrite:ckpt-1").unwrap().unwrap();
384        assert_eq!(loaded.state, "new-state");
385        assert_eq!(loaded.entries_processed, 99);
386        // Only one entry with that id.
387        let ids = store.list().unwrap();
388        assert_eq!(ids.iter().filter(|id| *id == "overwrite:ckpt-1").count(), 1);
389    }
390
391    #[test]
392    fn integrity_mismatch_on_corrupted_hash() {
393        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
394        let ckpt = sample_checkpoint("integrity:ckpt-1", 5);
395        store.save(ckpt).unwrap();
396
397        // Directly corrupt the stored hash by replacing it with ZERO.
398        {
399            let mut guard = store.inner.write().unwrap();
400            if let Some(c) = guard.get_mut("integrity:ckpt-1") {
401                c.hash = Hash32::ZERO;
402            }
403        }
404
405        let err = store.load("integrity:ckpt-1").unwrap_err();
406        assert!(
407            matches!(err, FoldError::IntegrityMismatch { .. }),
408            "expected IntegrityMismatch, got {err:?}"
409        );
410    }
411
412    #[test]
413    fn concurrent_saves_all_land() {
414        use std::sync::Arc;
415        use std::thread;
416
417        let store = Arc::new(InMemoryCheckpointStore::<String>::new());
418        let n = 20usize;
419        let handles: Vec<_> = (0..n)
420            .map(|i| {
421                let s = Arc::clone(&store);
422                thread::spawn(move || {
423                    s.save(sample_checkpoint(&format!("concurrent:ckpt-{i}"), i))
424                        .unwrap();
425                })
426            })
427            .collect();
428        for h in handles {
429            h.join().expect("thread panicked");
430        }
431        let ids = store.list().unwrap();
432        assert_eq!(ids.len(), n, "expected {n} checkpoints, got {}", ids.len());
433    }
434}