Skip to main content

khive_fold/
checkpoint.rs

1//! Checkpoint protocol for fold-based index persistence.
2//!
3//! Provides generic snapshot envelopes and in-memory storage for use
4//! by HNSW and other fold-managed indexes.
5//!
6//! # Formal proof reference
7//!
8//! `proofs/Retrieval/HNSW.lean` — checkpoint correctness guarantees
9//! used in HNSW snapshot/restore cycles
10//! (khive.Retrieval.HNSW.checkpoint_correctness).
11//!
12//! # Architecture
13//!
14//! ```text
15//! HnswIndex ──snapshot──> HnswSnapshot ──wrap──> Checkpoint<HnswSnapshot>
16//!                                                       │
17//!                                         CheckpointStore::save(...)
18//! ```
19//!
20//! The snapshot types and this checkpoint envelope are always available;
21//! the fold feature flag in consuming crates gates whether they are exposed
22//! to callers.
23//!
24//! # Integrity model
25//!
26//! `save` serializes `state` to canonical JSON, computes a BLAKE3 hash, and
27//! stores it in `Checkpoint.hash`.  `load` recomputes the hash from the stored
28//! bytes and returns `FoldError::IntegrityMismatch` if they disagree.  The hash
29//! field is therefore always meaningful — `Hash32::ZERO` is only valid if the
30//! canonical serialization of `state` actually hashes to zero (practically
31//! impossible).
32
33use std::collections::HashMap;
34use std::sync::{Arc, RwLock};
35
36use chrono::{DateTime, Utc};
37#[cfg(feature = "serde")]
38use serde::{Deserialize, Serialize};
39use uuid::Uuid;
40
41use khive_types::Hash32;
42
43use crate::context::FoldContext;
44use crate::error::FoldError;
45
46/// Generic checkpoint envelope wrapping an arbitrary fold state snapshot.
47///
48/// Carries metadata (ID, timestamp, hash, fold version) alongside the
49/// serializable state so consumers can verify and load the correct snapshot.
50#[derive(Debug, Clone)]
51#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
52pub struct Checkpoint<S> {
53    /// Human-readable checkpoint identifier (e.g. `"hnsw_idx:ckpt-1"`).
54    pub id: String,
55
56    /// The snapshot state captured at this checkpoint.
57    pub state: S,
58
59    /// Unique identifier for this checkpoint instance.
60    pub uuid: Uuid,
61
62    /// BLAKE3 content hash of the canonical JSON serialization of `state`.
63    ///
64    /// Computed by [`CheckpointStore::save`] and verified by
65    /// [`CheckpointStore::load`].  A mismatch returns
66    /// [`FoldError::IntegrityMismatch`].
67    pub hash: Hash32,
68
69    /// Number of entries processed when this checkpoint was taken.
70    pub entries_processed: usize,
71
72    /// Fold context at checkpoint time.
73    pub context: FoldContext,
74
75    /// Monotonically increasing fold schema version.
76    pub fold_version: usize,
77
78    /// Wall-clock time when this checkpoint was created.
79    pub created_at: DateTime<Utc>,
80}
81
82impl<S: Serialize> Checkpoint<S> {
83    /// Create a new checkpoint, computing the BLAKE3 hash of the state.
84    ///
85    /// Returns `FoldError::Serialization` if `state` cannot be serialized to JSON.
86    #[allow(clippy::too_many_arguments)]
87    pub fn new(
88        id: impl Into<String>,
89        state: S,
90        uuid: Uuid,
91        entries_processed: usize,
92        context: FoldContext,
93        fold_version: usize,
94    ) -> Result<Self, FoldError> {
95        let bytes = serde_json::to_vec(&state)?;
96        let hash = Hash32::from_blake3(&bytes);
97        Ok(Self {
98            id: id.into(),
99            state,
100            uuid,
101            hash,
102            entries_processed,
103            context,
104            fold_version,
105            // ADR-024: no clock calls in the foundation layer.
106            // Callers that need the current time should set created_at after construction.
107            created_at: DateTime::<Utc>::default(),
108        })
109    }
110
111    /// Create a checkpoint with a pre-computed hash (for deserialization / testing).
112    ///
113    /// Callers are responsible for ensuring `hash` is consistent with `state`.
114    /// Prefer [`Checkpoint::new`] for production use.
115    #[allow(clippy::too_many_arguments)]
116    pub fn with_hash(
117        id: impl Into<String>,
118        state: S,
119        uuid: Uuid,
120        hash: Hash32,
121        entries_processed: usize,
122        context: FoldContext,
123        fold_version: usize,
124    ) -> Self {
125        Self {
126            id: id.into(),
127            state,
128            uuid,
129            hash,
130            entries_processed,
131            context,
132            fold_version,
133            // ADR-024: no clock calls in the foundation layer.
134            created_at: DateTime::<Utc>::default(),
135        }
136    }
137}
138
139/// Trait for checkpoint persistence backends.
140///
141/// The key is the checkpoint `id` string. `load_latest` returns the
142/// checkpoint whose prefix matches — defined as all checkpoints whose
143/// `id` starts with the given prefix, selecting the most recently created.
144/// Ties on `created_at` are broken by `uuid` (lexicographic) for determinism.
145pub trait CheckpointStore<S> {
146    /// Persist a checkpoint, computing and storing an integrity hash.
147    fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
148    where
149        S: Clone + Serialize;
150
151    /// Load a checkpoint by its exact `id`, verifying the integrity hash.
152    ///
153    /// Returns `Ok(None)` when no checkpoint with that `id` exists.
154    /// Returns `Err(FoldError::IntegrityMismatch)` if the stored hash does not
155    /// match the recomputed hash of the loaded state.
156    fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
157    where
158        S: Clone + Serialize;
159
160    /// Load the most recently created checkpoint whose `id` starts with `prefix`.
161    ///
162    /// Ties on `created_at` are broken by `uuid` for determinism.
163    /// Returns `None` when no checkpoints match the prefix.
164    fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
165    where
166        S: Clone + Serialize;
167
168    /// Delete the checkpoint with the given `id`.
169    ///
170    /// Returns `Err(FoldError::CheckpointNotFound)` if no checkpoint with that
171    /// `id` exists.
172    fn delete(&self, id: &str) -> Result<(), FoldError>;
173
174    /// List all checkpoint `id` strings currently stored.
175    ///
176    /// The order is unspecified; callers should sort if a stable order is needed.
177    fn list(&self) -> Result<Vec<String>, FoldError>;
178}
179
180/// In-memory checkpoint store backed by a `RwLock<HashMap>`.
181///
182/// Suitable for tests and single-process usage where durability is not
183/// required. Production deployments should implement [`CheckpointStore`]
184/// with durable storage (e.g. SQLite via `khive-db`).
185pub struct InMemoryCheckpointStore<S> {
186    inner: Arc<RwLock<HashMap<String, Checkpoint<S>>>>,
187}
188
189impl<S> InMemoryCheckpointStore<S> {
190    /// Create a new empty in-memory store.
191    pub fn new() -> Self {
192        Self {
193            inner: Arc::new(RwLock::new(HashMap::new())),
194        }
195    }
196}
197
198impl<S> Default for InMemoryCheckpointStore<S> {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl<S: Clone + Send + Sync + Serialize + 'static> CheckpointStore<S>
205    for InMemoryCheckpointStore<S>
206{
207    fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
208    where
209        S: Clone + Serialize,
210    {
211        // Recompute the hash from the state to ensure the stored hash is canonical.
212        let bytes = serde_json::to_vec(&checkpoint.state)?;
213        let computed = Hash32::from_blake3(&bytes);
214        let mut stored = checkpoint;
215        stored.hash = computed;
216
217        let mut guard = self
218            .inner
219            .write()
220            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
221        guard.insert(stored.id.clone(), stored);
222        Ok(())
223    }
224
225    fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
226    where
227        S: Clone + Serialize,
228    {
229        let guard = self
230            .inner
231            .read()
232            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
233        let Some(checkpoint) = guard.get(id).cloned() else {
234            return Ok(None);
235        };
236
237        // Verify integrity: recompute hash from state and compare.
238        let bytes = serde_json::to_vec(&checkpoint.state)?;
239        let computed = Hash32::from_blake3(&bytes);
240        if !checkpoint.hash.eq_ct(&computed) {
241            return Err(FoldError::IntegrityMismatch {
242                id: id.to_owned(),
243                stored: checkpoint.hash.to_string(),
244                computed: computed.to_string(),
245            });
246        }
247
248        Ok(Some(checkpoint))
249    }
250
251    fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
252    where
253        S: Clone + Serialize,
254    {
255        let guard = self
256            .inner
257            .read()
258            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
259
260        let latest = guard
261            .values()
262            .filter(|c| c.id.starts_with(prefix))
263            // Tiebreak on uuid for determinism when created_at is equal.
264            .max_by_key(|c| (c.created_at, c.uuid));
265
266        Ok(latest.cloned())
267    }
268
269    fn delete(&self, id: &str) -> Result<(), FoldError> {
270        let mut guard = self
271            .inner
272            .write()
273            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
274        if guard.remove(id).is_none() {
275            return Err(FoldError::CheckpointNotFound(id.to_owned()));
276        }
277        Ok(())
278    }
279
280    fn list(&self) -> Result<Vec<String>, FoldError> {
281        let guard = self
282            .inner
283            .read()
284            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
285        Ok(guard.keys().cloned().collect())
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    fn sample_checkpoint(id: &str, entries: usize) -> Checkpoint<String> {
294        Checkpoint::new(
295            id,
296            format!("state-{entries}"),
297            Uuid::new_v4(),
298            entries,
299            FoldContext::new(),
300            1,
301        )
302        .expect("sample_checkpoint should not fail serialization")
303    }
304
305    #[test]
306    fn save_and_load_roundtrip() {
307        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
308        let ckpt = sample_checkpoint("my-index:ckpt-1", 100);
309        store.save(ckpt).unwrap();
310        let loaded = store.load("my-index:ckpt-1").unwrap().unwrap();
311        assert_eq!(loaded.state, "state-100");
312        assert_eq!(loaded.entries_processed, 100);
313    }
314
315    #[test]
316    fn load_missing_returns_none() {
317        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
318        assert!(store.load("nonexistent").unwrap().is_none());
319    }
320
321    #[test]
322    fn load_latest_returns_most_recent() {
323        use chrono::Duration;
324
325        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
326        let base = DateTime::<Utc>::default();
327
328        // Build checkpoints with explicit, strictly ordered created_at values
329        // so load_latest is deterministic without relying on wall-clock time.
330        let mut ckpt1 = sample_checkpoint("idx:ckpt-1", 10);
331        ckpt1.created_at = base;
332        let mut ckpt2 = sample_checkpoint("idx:ckpt-2", 20);
333        ckpt2.created_at = base + Duration::milliseconds(5);
334        let mut ckpt3 = sample_checkpoint("idx:ckpt-3", 30);
335        ckpt3.created_at = base + Duration::milliseconds(10);
336
337        store.save(ckpt1).unwrap();
338        store.save(ckpt2).unwrap();
339        store.save(ckpt3).unwrap();
340
341        let latest = store.load_latest("idx").unwrap().unwrap();
342        assert_eq!(latest.entries_processed, 30);
343    }
344
345    #[test]
346    fn load_latest_no_match_returns_none() {
347        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
348        store.save(sample_checkpoint("other:ckpt-1", 5)).unwrap();
349        assert!(store.load_latest("my-index").unwrap().is_none());
350    }
351
352    #[test]
353    fn load_latest_prefix_isolation() {
354        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
355        store.save(sample_checkpoint("alpha:ckpt-1", 10)).unwrap();
356        store.save(sample_checkpoint("beta:ckpt-1", 999)).unwrap();
357
358        let latest_alpha = store.load_latest("alpha").unwrap().unwrap();
359        assert_eq!(latest_alpha.entries_processed, 10);
360    }
361
362    #[test]
363    fn checkpoint_fields_accessible() {
364        let ckpt: Checkpoint<u32> =
365            Checkpoint::new("test:ckpt", 42u32, Uuid::new_v4(), 7, FoldContext::new(), 3).unwrap();
366        assert_eq!(ckpt.state, 42);
367        assert_eq!(ckpt.entries_processed, 7);
368        assert_eq!(ckpt.fold_version, 3);
369    }
370
371    // --- Additional tests (F-NEW-8) ---
372
373    #[cfg(feature = "serde")]
374    #[test]
375    fn serde_roundtrip() {
376        let ckpt = sample_checkpoint("serde:test", 42);
377        let json = serde_json::to_string(&ckpt).expect("serialize");
378        let restored: Checkpoint<String> = serde_json::from_str(&json).expect("deserialize");
379        assert_eq!(ckpt.id, restored.id);
380        assert_eq!(ckpt.state, restored.state);
381        assert_eq!(ckpt.entries_processed, restored.entries_processed);
382        assert_eq!(ckpt.fold_version, restored.fold_version);
383        assert_eq!(ckpt.uuid, restored.uuid);
384        // Hash bytes should survive the roundtrip unchanged.
385        assert_eq!(ckpt.hash.as_bytes(), restored.hash.as_bytes());
386    }
387
388    #[test]
389    fn delete_existing_succeeds() {
390        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
391        store.save(sample_checkpoint("del:ckpt-1", 1)).unwrap();
392        store.delete("del:ckpt-1").unwrap();
393        assert!(store.load("del:ckpt-1").unwrap().is_none());
394    }
395
396    #[test]
397    fn delete_nonexistent_returns_not_found() {
398        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
399        let err = store.delete("nope").unwrap_err();
400        assert!(
401            matches!(err, FoldError::CheckpointNotFound(ref id) if id == "nope"),
402            "expected CheckpointNotFound, got {err:?}"
403        );
404    }
405
406    #[test]
407    fn list_returns_all_ids() {
408        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
409        store.save(sample_checkpoint("a:ckpt-1", 1)).unwrap();
410        store.save(sample_checkpoint("b:ckpt-1", 2)).unwrap();
411        store.save(sample_checkpoint("c:ckpt-1", 3)).unwrap();
412        let mut ids = store.list().unwrap();
413        ids.sort();
414        assert_eq!(ids, vec!["a:ckpt-1", "b:ckpt-1", "c:ckpt-1"]);
415    }
416
417    #[test]
418    fn list_empty_store() {
419        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
420        assert!(store.list().unwrap().is_empty());
421    }
422
423    #[test]
424    fn save_overwrite_replaces_previous() {
425        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
426        let ckpt1 = sample_checkpoint("overwrite:ckpt-1", 10);
427        store.save(ckpt1).unwrap();
428
429        // Save again with the same id but different state.
430        let ckpt2 = Checkpoint::new(
431            "overwrite:ckpt-1",
432            "new-state".to_string(),
433            Uuid::new_v4(),
434            99,
435            FoldContext::new(),
436            2,
437        )
438        .unwrap();
439        store.save(ckpt2).unwrap();
440
441        let loaded = store.load("overwrite:ckpt-1").unwrap().unwrap();
442        assert_eq!(loaded.state, "new-state");
443        assert_eq!(loaded.entries_processed, 99);
444        // Only one entry with that id.
445        let ids = store.list().unwrap();
446        assert_eq!(ids.iter().filter(|id| *id == "overwrite:ckpt-1").count(), 1);
447    }
448
449    #[test]
450    fn integrity_mismatch_on_corrupted_hash() {
451        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
452        let ckpt = sample_checkpoint("integrity:ckpt-1", 5);
453        store.save(ckpt).unwrap();
454
455        // Directly corrupt the stored hash by replacing it with ZERO.
456        {
457            let mut guard = store.inner.write().unwrap();
458            if let Some(c) = guard.get_mut("integrity:ckpt-1") {
459                c.hash = Hash32::ZERO;
460            }
461        }
462
463        let err = store.load("integrity:ckpt-1").unwrap_err();
464        assert!(
465            matches!(err, FoldError::IntegrityMismatch { .. }),
466            "expected IntegrityMismatch, got {err:?}"
467        );
468    }
469
470    #[test]
471    fn concurrent_saves_all_land() {
472        use std::sync::Arc;
473        use std::thread;
474
475        let store = Arc::new(InMemoryCheckpointStore::<String>::new());
476        let n = 20usize;
477        let handles: Vec<_> = (0..n)
478            .map(|i| {
479                let s = Arc::clone(&store);
480                thread::spawn(move || {
481                    s.save(sample_checkpoint(&format!("concurrent:ckpt-{i}"), i))
482                        .unwrap();
483                })
484            })
485            .collect();
486        for h in handles {
487            h.join().expect("thread panicked");
488        }
489        let ids = store.list().unwrap();
490        assert_eq!(ids.len(), n, "expected {n} checkpoints, got {}", ids.len());
491    }
492}