Skip to main content

khive_fold/
checkpoint.rs

1//! Generic checkpoint envelope and in-memory store for fold-managed indexes.
2
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6use chrono::{DateTime, Utc};
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10
11use khive_types::Hash32;
12
13use crate::context::FoldContext;
14use crate::error::FoldError;
15
16/// Generic checkpoint envelope wrapping an arbitrary fold state snapshot.
17#[derive(Debug, Clone)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub struct Checkpoint<S> {
20    /// Human-readable checkpoint identifier (e.g. `"hnsw_idx:ckpt-1"`).
21    pub id: String,
22
23    /// The snapshot state captured at this checkpoint.
24    pub state: S,
25
26    /// Unique identifier for this checkpoint instance.
27    pub uuid: Uuid,
28
29    /// BLAKE3 content hash of the state; verified on load.
30    pub hash: Hash32,
31
32    /// Number of entries processed when this checkpoint was taken.
33    pub entries_processed: usize,
34
35    /// Fold context at checkpoint time.
36    pub context: FoldContext,
37
38    /// Monotonically increasing fold schema version.
39    pub fold_version: usize,
40
41    /// Wall-clock time when this checkpoint was created.
42    pub created_at: DateTime<Utc>,
43}
44
45impl<S: Serialize> Checkpoint<S> {
46    /// Create a new checkpoint, computing the BLAKE3 hash of the state.
47    // REASON: Checkpoint::new requires id, state, uuid, entries_processed, context, and
48    // fold_version — each is a semantically distinct field with no natural grouping into
49    // a builder or sub-struct without breaking the public API.
50    #[allow(clippy::too_many_arguments)]
51    pub fn new(
52        id: impl Into<String>,
53        state: S,
54        uuid: Uuid,
55        entries_processed: usize,
56        context: FoldContext,
57        fold_version: usize,
58    ) -> Result<Self, FoldError> {
59        let bytes = serde_json::to_vec(&state)?;
60        let hash = Hash32::from_blake3(&bytes);
61        Ok(Self {
62            id: id.into(),
63            state,
64            uuid,
65            hash,
66            entries_processed,
67            context,
68            fold_version,
69            // Foundation layer does not call Utc::now() — epoch is the safe default.
70            // Callers that need the current time should set created_at after construction.
71            created_at: DateTime::<Utc>::default(),
72        })
73    }
74
75    /// Create a checkpoint with a pre-computed hash (for deserialization / testing).
76    // REASON: with_hash mirrors the new() parameter set (minus auto-computed hash) for
77    // deserialization and testing; same structural constraint as new() above.
78    #[allow(clippy::too_many_arguments)]
79    pub fn with_hash(
80        id: impl Into<String>,
81        state: S,
82        uuid: Uuid,
83        hash: Hash32,
84        entries_processed: usize,
85        context: FoldContext,
86        fold_version: usize,
87    ) -> Self {
88        Self {
89            id: id.into(),
90            state,
91            uuid,
92            hash,
93            entries_processed,
94            context,
95            fold_version,
96            // Foundation layer does not call Utc::now() — epoch is the safe default.
97            created_at: DateTime::<Utc>::default(),
98        }
99    }
100}
101
102/// Trait for checkpoint persistence backends.
103pub trait CheckpointStore<S> {
104    /// Persist a checkpoint, computing and storing an integrity hash.
105    fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
106    where
107        S: Clone + Serialize;
108
109    /// Load a checkpoint by exact `id`, verifying the integrity hash.
110    fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
111    where
112        S: Clone + Serialize;
113
114    /// Load the most recently created checkpoint whose `id` starts with `prefix`.
115    fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
116    where
117        S: Clone + Serialize;
118
119    /// Delete the checkpoint with the given `id`.
120    fn delete(&self, id: &str) -> Result<(), FoldError>;
121
122    /// List all checkpoint `id` strings currently stored.
123    fn list(&self) -> Result<Vec<String>, FoldError>;
124}
125
126/// In-memory checkpoint store backed by a `RwLock<HashMap>`.
127pub struct InMemoryCheckpointStore<S> {
128    inner: Arc<RwLock<HashMap<String, Checkpoint<S>>>>,
129}
130
131impl<S> InMemoryCheckpointStore<S> {
132    /// Create a new empty in-memory store.
133    pub fn new() -> Self {
134        Self {
135            inner: Arc::new(RwLock::new(HashMap::new())),
136        }
137    }
138}
139
140impl<S> Default for InMemoryCheckpointStore<S> {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl<S: Clone + Send + Sync + Serialize + 'static> CheckpointStore<S>
147    for InMemoryCheckpointStore<S>
148{
149    fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
150    where
151        S: Clone + Serialize,
152    {
153        // Recompute the hash from the state to ensure the stored hash is canonical.
154        let bytes = serde_json::to_vec(&checkpoint.state)?;
155        let computed = Hash32::from_blake3(&bytes);
156        let mut stored = checkpoint;
157        stored.hash = computed;
158
159        let mut guard = self
160            .inner
161            .write()
162            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
163        guard.insert(stored.id.clone(), stored);
164        Ok(())
165    }
166
167    fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
168    where
169        S: Clone + Serialize,
170    {
171        let guard = self
172            .inner
173            .read()
174            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
175        let Some(checkpoint) = guard.get(id).cloned() else {
176            return Ok(None);
177        };
178
179        // Verify integrity: recompute hash from state and compare.
180        let bytes = serde_json::to_vec(&checkpoint.state)?;
181        let computed = Hash32::from_blake3(&bytes);
182        if !checkpoint.hash.eq_ct(&computed) {
183            return Err(FoldError::IntegrityMismatch {
184                id: id.to_owned(),
185                stored: checkpoint.hash.to_string(),
186                computed: computed.to_string(),
187            });
188        }
189
190        Ok(Some(checkpoint))
191    }
192
193    fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
194    where
195        S: Clone + Serialize,
196    {
197        let guard = self
198            .inner
199            .read()
200            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
201
202        let latest = guard
203            .values()
204            .filter(|c| c.id.starts_with(prefix))
205            // Tiebreak on uuid for determinism when created_at is equal.
206            .max_by_key(|c| (c.created_at, c.uuid));
207
208        Ok(latest.cloned())
209    }
210
211    fn delete(&self, id: &str) -> Result<(), FoldError> {
212        let mut guard = self
213            .inner
214            .write()
215            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
216        if guard.remove(id).is_none() {
217            return Err(FoldError::CheckpointNotFound(id.to_owned()));
218        }
219        Ok(())
220    }
221
222    fn list(&self) -> Result<Vec<String>, FoldError> {
223        let guard = self
224            .inner
225            .read()
226            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
227        let keys: Vec<String> = guard.keys().cloned().collect();
228        Ok(sort_checkpoint_keys(keys))
229    }
230}
231
232/// Sort a `Vec<String>` of checkpoint IDs into lexicographic order.
233///
234/// Extracted as a standalone helper so it can be unit-tested with intentionally
235/// unsorted input, giving fail-before/pass-after coverage independent of
236/// `HashMap` randomisation.
237pub fn sort_checkpoint_keys(mut keys: Vec<String>) -> Vec<String> {
238    keys.sort();
239    keys
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    fn sample_checkpoint(id: &str, entries: usize) -> Checkpoint<String> {
247        Checkpoint::new(
248            id,
249            format!("state-{entries}"),
250            Uuid::new_v4(),
251            entries,
252            FoldContext::new(),
253            1,
254        )
255        .expect("sample_checkpoint should not fail serialization")
256    }
257
258    #[test]
259    fn save_and_load_roundtrip() {
260        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
261        let ckpt = sample_checkpoint("my-index:ckpt-1", 100);
262        store.save(ckpt).unwrap();
263        let loaded = store.load("my-index:ckpt-1").unwrap().unwrap();
264        assert_eq!(loaded.state, "state-100");
265        assert_eq!(loaded.entries_processed, 100);
266    }
267
268    #[test]
269    fn load_missing_returns_none() {
270        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
271        assert!(store.load("nonexistent").unwrap().is_none());
272    }
273
274    #[test]
275    fn load_latest_returns_most_recent() {
276        use chrono::Duration;
277
278        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
279        let base = DateTime::<Utc>::default();
280
281        // Build checkpoints with explicit, strictly ordered created_at values
282        // so load_latest is deterministic without relying on wall-clock time.
283        let mut ckpt1 = sample_checkpoint("idx:ckpt-1", 10);
284        ckpt1.created_at = base;
285        let mut ckpt2 = sample_checkpoint("idx:ckpt-2", 20);
286        ckpt2.created_at = base + Duration::milliseconds(5);
287        let mut ckpt3 = sample_checkpoint("idx:ckpt-3", 30);
288        ckpt3.created_at = base + Duration::milliseconds(10);
289
290        store.save(ckpt1).unwrap();
291        store.save(ckpt2).unwrap();
292        store.save(ckpt3).unwrap();
293
294        let latest = store.load_latest("idx").unwrap().unwrap();
295        assert_eq!(latest.entries_processed, 30);
296    }
297
298    #[test]
299    fn load_latest_no_match_returns_none() {
300        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
301        store.save(sample_checkpoint("other:ckpt-1", 5)).unwrap();
302        assert!(store.load_latest("my-index").unwrap().is_none());
303    }
304
305    #[test]
306    fn load_latest_prefix_isolation() {
307        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
308        store.save(sample_checkpoint("alpha:ckpt-1", 10)).unwrap();
309        store.save(sample_checkpoint("beta:ckpt-1", 999)).unwrap();
310
311        let latest_alpha = store.load_latest("alpha").unwrap().unwrap();
312        assert_eq!(latest_alpha.entries_processed, 10);
313    }
314
315    #[test]
316    fn checkpoint_fields_accessible() {
317        let ckpt: Checkpoint<u32> =
318            Checkpoint::new("test:ckpt", 42u32, Uuid::new_v4(), 7, FoldContext::new(), 3).unwrap();
319        assert_eq!(ckpt.state, 42);
320        assert_eq!(ckpt.entries_processed, 7);
321        assert_eq!(ckpt.fold_version, 3);
322    }
323
324    // --- Additional tests (F-NEW-8) ---
325
326    #[cfg(feature = "serde")]
327    #[test]
328    fn serde_roundtrip() {
329        let ckpt = sample_checkpoint("serde:test", 42);
330        let json = serde_json::to_string(&ckpt).expect("serialize");
331        let restored: Checkpoint<String> = serde_json::from_str(&json).expect("deserialize");
332        assert_eq!(ckpt.id, restored.id);
333        assert_eq!(ckpt.state, restored.state);
334        assert_eq!(ckpt.entries_processed, restored.entries_processed);
335        assert_eq!(ckpt.fold_version, restored.fold_version);
336        assert_eq!(ckpt.uuid, restored.uuid);
337        // Hash bytes should survive the roundtrip unchanged.
338        assert_eq!(ckpt.hash.as_bytes(), restored.hash.as_bytes());
339    }
340
341    #[test]
342    fn delete_existing_succeeds() {
343        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
344        store.save(sample_checkpoint("del:ckpt-1", 1)).unwrap();
345        store.delete("del:ckpt-1").unwrap();
346        assert!(store.load("del:ckpt-1").unwrap().is_none());
347    }
348
349    #[test]
350    fn delete_nonexistent_returns_not_found() {
351        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
352        let err = store.delete("nope").unwrap_err();
353        assert!(
354            matches!(err, FoldError::CheckpointNotFound(ref id) if id == "nope"),
355            "expected CheckpointNotFound, got {err:?}"
356        );
357    }
358
359    #[test]
360    fn list_returns_all_ids() {
361        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
362        store.save(sample_checkpoint("a:ckpt-1", 1)).unwrap();
363        store.save(sample_checkpoint("b:ckpt-1", 2)).unwrap();
364        store.save(sample_checkpoint("c:ckpt-1", 3)).unwrap();
365        let mut ids = store.list().unwrap();
366        ids.sort();
367        assert_eq!(ids, vec!["a:ckpt-1", "b:ckpt-1", "c:ckpt-1"]);
368    }
369
370    #[test]
371    fn list_empty_store() {
372        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
373        assert!(store.list().unwrap().is_empty());
374    }
375
376    #[test]
377    fn save_overwrite_replaces_previous() {
378        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
379        let ckpt1 = sample_checkpoint("overwrite:ckpt-1", 10);
380        store.save(ckpt1).unwrap();
381
382        // Save again with the same id but different state.
383        let ckpt2 = Checkpoint::new(
384            "overwrite:ckpt-1",
385            "new-state".to_string(),
386            Uuid::new_v4(),
387            99,
388            FoldContext::new(),
389            2,
390        )
391        .unwrap();
392        store.save(ckpt2).unwrap();
393
394        let loaded = store.load("overwrite:ckpt-1").unwrap().unwrap();
395        assert_eq!(loaded.state, "new-state");
396        assert_eq!(loaded.entries_processed, 99);
397        // Only one entry with that id.
398        let ids = store.list().unwrap();
399        assert_eq!(ids.iter().filter(|id| *id == "overwrite:ckpt-1").count(), 1);
400    }
401
402    #[test]
403    fn integrity_mismatch_on_corrupted_hash() {
404        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
405        let ckpt = sample_checkpoint("integrity:ckpt-1", 5);
406        store.save(ckpt).unwrap();
407
408        // Directly corrupt the stored hash by replacing it with ZERO.
409        {
410            let mut guard = store.inner.write().unwrap();
411            if let Some(c) = guard.get_mut("integrity:ckpt-1") {
412                c.hash = Hash32::ZERO;
413            }
414        }
415
416        let err = store.load("integrity:ckpt-1").unwrap_err();
417        assert!(
418            matches!(err, FoldError::IntegrityMismatch { .. }),
419            "expected IntegrityMismatch, got {err:?}"
420        );
421    }
422
423    #[test]
424    fn concurrent_saves_all_land() {
425        use std::sync::Arc;
426        use std::thread;
427
428        let store = Arc::new(InMemoryCheckpointStore::<String>::new());
429        let n = 20usize;
430        let handles: Vec<_> = (0..n)
431            .map(|i| {
432                let s = Arc::clone(&store);
433                thread::spawn(move || {
434                    s.save(sample_checkpoint(&format!("concurrent:ckpt-{i}"), i))
435                        .unwrap();
436                })
437            })
438            .collect();
439        for h in handles {
440            h.join().expect("thread panicked");
441        }
442        let ids = store.list().unwrap();
443        assert_eq!(ids.len(), n, "expected {n} checkpoints, got {}", ids.len());
444    }
445
446    /// `sort_checkpoint_keys` must return lexicographic order on an intentionally
447    /// reverse-sorted input.
448    ///
449    /// This is a fail-before/pass-after unit test for the ordering helper itself:
450    /// the old `HashMap.keys().cloned().collect()` path returned keys in HashMap
451    /// iteration order (non-deterministic).  Passing a reversed vector guarantees
452    /// the test fails against any implementation that skips the sort step.
453    #[test]
454    fn sort_checkpoint_keys_produces_lexicographic_order() {
455        // Intentionally REVERSE alphabetical — worst case for unsorted implementations.
456        let unsorted = vec![
457            "z:ckpt-3".to_string(),
458            "m:ckpt-2".to_string(),
459            "a:ckpt-1".to_string(),
460        ];
461        let sorted = sort_checkpoint_keys(unsorted);
462        assert_eq!(
463            sorted,
464            vec!["a:ckpt-1", "m:ckpt-2", "z:ckpt-3"],
465            "sort_checkpoint_keys must produce lexicographic order; got {sorted:?}"
466        );
467    }
468
469    /// Integration: `InMemoryCheckpointStore::list` must return keys in
470    /// lexicographic order regardless of insertion order.
471    #[test]
472    fn list_is_sorted() {
473        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
474        // Insert in non-alphabetical order.
475        store.save(sample_checkpoint("z:ckpt-1", 1)).unwrap();
476        store.save(sample_checkpoint("a:ckpt-1", 2)).unwrap();
477        store.save(sample_checkpoint("m:ckpt-1", 3)).unwrap();
478        let ids = store.list().unwrap();
479        assert_eq!(
480            ids,
481            vec!["a:ckpt-1", "m:ckpt-1", "z:ckpt-1"],
482            "list() must return sorted keys; got {ids:?}"
483        );
484    }
485}