Skip to main content

khive_fold/
checkpoint.rs

1//! Checkpoint protocol for fold-based index persistence.
2//!
3//! Provides generic snapshot envelopes and in-memory storage for use
4//! by HNSW and other fold-managed indexes.
5//!
6//! # Formal proof reference
7//!
8//! `proofs/Retrieval/HNSW.lean` — checkpoint correctness guarantees
9//! used in HNSW snapshot/restore cycles
10//! (khive.Retrieval.HNSW.checkpoint_correctness).
11//!
12//! # Architecture
13//!
14//! ```text
15//! HnswIndex ──snapshot──> HnswSnapshot ──wrap──> Checkpoint<HnswSnapshot>
16//!                                                       │
17//!                                         CheckpointStore::save(...)
18//! ```
19//!
20//! The snapshot types and this checkpoint envelope are always available;
21//! the fold feature flag in consuming crates gates whether they are exposed
22//! to callers.
23//!
24//! # Integrity model
25//!
26//! `save` serializes `state` to canonical JSON, computes a BLAKE3 hash, and
27//! stores it in `Checkpoint.hash`.  `load` recomputes the hash from the stored
28//! bytes and returns `FoldError::IntegrityMismatch` if they disagree.  The hash
29//! field is therefore always meaningful — `Hash32::ZERO` is only valid if the
30//! canonical serialization of `state` actually hashes to zero (practically
31//! impossible).
32
33use std::collections::HashMap;
34use std::sync::{Arc, RwLock};
35
36use chrono::{DateTime, Utc};
37use serde::{Deserialize, Serialize};
38use uuid::Uuid;
39
40use khive_types::Hash32;
41
42use crate::context::FoldContext;
43use crate::error::FoldError;
44
45/// Generic checkpoint envelope wrapping an arbitrary fold state snapshot.
46///
47/// Carries metadata (ID, timestamp, hash, fold version) alongside the
48/// serializable state so consumers can verify and load the correct snapshot.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Checkpoint<S> {
51    /// Human-readable checkpoint identifier (e.g. `"hnsw_idx:ckpt-1"`).
52    pub id: String,
53
54    /// The snapshot state captured at this checkpoint.
55    pub state: S,
56
57    /// Unique identifier for this checkpoint instance.
58    pub uuid: Uuid,
59
60    /// BLAKE3 content hash of the canonical JSON serialization of `state`.
61    ///
62    /// Computed by [`CheckpointStore::save`] and verified by
63    /// [`CheckpointStore::load`].  A mismatch returns
64    /// [`FoldError::IntegrityMismatch`].
65    pub hash: Hash32,
66
67    /// Number of entries processed when this checkpoint was taken.
68    pub entries_processed: usize,
69
70    /// Fold context at checkpoint time.
71    pub context: FoldContext,
72
73    /// Monotonically increasing fold schema version.
74    pub fold_version: usize,
75
76    /// Wall-clock time when this checkpoint was created.
77    pub created_at: DateTime<Utc>,
78}
79
80impl<S: Serialize> Checkpoint<S> {
81    /// Create a new checkpoint, computing the BLAKE3 hash of the state.
82    ///
83    /// Returns `FoldError::Serialization` if `state` cannot be serialized to JSON.
84    #[allow(clippy::too_many_arguments)]
85    pub fn new(
86        id: impl Into<String>,
87        state: S,
88        uuid: Uuid,
89        entries_processed: usize,
90        context: FoldContext,
91        fold_version: usize,
92    ) -> Result<Self, FoldError> {
93        let bytes = serde_json::to_vec(&state)?;
94        let hash = Hash32::from_blake3(&bytes);
95        Ok(Self {
96            id: id.into(),
97            state,
98            uuid,
99            hash,
100            entries_processed,
101            context,
102            fold_version,
103            created_at: Utc::now(),
104        })
105    }
106
107    /// Create a checkpoint with a pre-computed hash (for deserialization / testing).
108    ///
109    /// Callers are responsible for ensuring `hash` is consistent with `state`.
110    /// Prefer [`Checkpoint::new`] for production use.
111    #[allow(clippy::too_many_arguments)]
112    pub fn with_hash(
113        id: impl Into<String>,
114        state: S,
115        uuid: Uuid,
116        hash: Hash32,
117        entries_processed: usize,
118        context: FoldContext,
119        fold_version: usize,
120    ) -> Self {
121        Self {
122            id: id.into(),
123            state,
124            uuid,
125            hash,
126            entries_processed,
127            context,
128            fold_version,
129            created_at: Utc::now(),
130        }
131    }
132}
133
134/// Trait for checkpoint persistence backends.
135///
136/// The key is the checkpoint `id` string. `load_latest` returns the
137/// checkpoint whose prefix matches — defined as all checkpoints whose
138/// `id` starts with the given prefix, selecting the most recently created.
139/// Ties on `created_at` are broken by `uuid` (lexicographic) for determinism.
140pub trait CheckpointStore<S> {
141    /// Persist a checkpoint, computing and storing an integrity hash.
142    fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
143    where
144        S: Clone + Serialize;
145
146    /// Load a checkpoint by its exact `id`, verifying the integrity hash.
147    ///
148    /// Returns `Ok(None)` when no checkpoint with that `id` exists.
149    /// Returns `Err(FoldError::IntegrityMismatch)` if the stored hash does not
150    /// match the recomputed hash of the loaded state.
151    fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
152    where
153        S: Clone + Serialize;
154
155    /// Load the most recently created checkpoint whose `id` starts with `prefix`.
156    ///
157    /// Ties on `created_at` are broken by `uuid` for determinism.
158    /// Returns `None` when no checkpoints match the prefix.
159    fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
160    where
161        S: Clone + Serialize;
162
163    /// Delete the checkpoint with the given `id`.
164    ///
165    /// Returns `Err(FoldError::CheckpointNotFound)` if no checkpoint with that
166    /// `id` exists.
167    fn delete(&self, id: &str) -> Result<(), FoldError>;
168
169    /// List all checkpoint `id` strings currently stored.
170    ///
171    /// The order is unspecified; callers should sort if a stable order is needed.
172    fn list(&self) -> Result<Vec<String>, FoldError>;
173}
174
175/// In-memory checkpoint store backed by a `RwLock<HashMap>`.
176///
177/// Suitable for tests and single-process usage where durability is not
178/// required. Production deployments should implement [`CheckpointStore`]
179/// with durable storage (e.g. SQLite via `khive-db`).
180pub struct InMemoryCheckpointStore<S> {
181    inner: Arc<RwLock<HashMap<String, Checkpoint<S>>>>,
182}
183
184impl<S> InMemoryCheckpointStore<S> {
185    /// Create a new empty in-memory store.
186    pub fn new() -> Self {
187        Self {
188            inner: Arc::new(RwLock::new(HashMap::new())),
189        }
190    }
191}
192
193impl<S> Default for InMemoryCheckpointStore<S> {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199impl<S: Clone + Send + Sync + Serialize + 'static> CheckpointStore<S>
200    for InMemoryCheckpointStore<S>
201{
202    fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
203    where
204        S: Clone + Serialize,
205    {
206        // Recompute the hash from the state to ensure the stored hash is canonical.
207        let bytes = serde_json::to_vec(&checkpoint.state)?;
208        let computed = Hash32::from_blake3(&bytes);
209        let mut stored = checkpoint;
210        stored.hash = computed;
211
212        let mut guard = self
213            .inner
214            .write()
215            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
216        guard.insert(stored.id.clone(), stored);
217        Ok(())
218    }
219
220    fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
221    where
222        S: Clone + Serialize,
223    {
224        let guard = self
225            .inner
226            .read()
227            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
228        let Some(checkpoint) = guard.get(id).cloned() else {
229            return Ok(None);
230        };
231
232        // Verify integrity: recompute hash from state and compare.
233        let bytes = serde_json::to_vec(&checkpoint.state)?;
234        let computed = Hash32::from_blake3(&bytes);
235        if !checkpoint.hash.eq_ct(&computed) {
236            return Err(FoldError::IntegrityMismatch {
237                id: id.to_owned(),
238                stored: checkpoint.hash.to_string(),
239                computed: computed.to_string(),
240            });
241        }
242
243        Ok(Some(checkpoint))
244    }
245
246    fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
247    where
248        S: Clone + Serialize,
249    {
250        let guard = self
251            .inner
252            .read()
253            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
254
255        let latest = guard
256            .values()
257            .filter(|c| c.id.starts_with(prefix))
258            // Tiebreak on uuid for determinism when created_at is equal.
259            .max_by_key(|c| (c.created_at, c.uuid));
260
261        Ok(latest.cloned())
262    }
263
264    fn delete(&self, id: &str) -> Result<(), FoldError> {
265        let mut guard = self
266            .inner
267            .write()
268            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
269        if guard.remove(id).is_none() {
270            return Err(FoldError::CheckpointNotFound(id.to_owned()));
271        }
272        Ok(())
273    }
274
275    fn list(&self) -> Result<Vec<String>, FoldError> {
276        let guard = self
277            .inner
278            .read()
279            .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
280        Ok(guard.keys().cloned().collect())
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    fn sample_checkpoint(id: &str, entries: usize) -> Checkpoint<String> {
289        Checkpoint::new(
290            id,
291            format!("state-{entries}"),
292            Uuid::new_v4(),
293            entries,
294            FoldContext::new(),
295            1,
296        )
297        .expect("sample_checkpoint should not fail serialization")
298    }
299
300    #[test]
301    fn save_and_load_roundtrip() {
302        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
303        let ckpt = sample_checkpoint("my-index:ckpt-1", 100);
304        store.save(ckpt).unwrap();
305        let loaded = store.load("my-index:ckpt-1").unwrap().unwrap();
306        assert_eq!(loaded.state, "state-100");
307        assert_eq!(loaded.entries_processed, 100);
308    }
309
310    #[test]
311    fn load_missing_returns_none() {
312        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
313        assert!(store.load("nonexistent").unwrap().is_none());
314    }
315
316    #[test]
317    fn load_latest_returns_most_recent() {
318        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
319
320        let ckpt1 = sample_checkpoint("idx:ckpt-1", 10);
321        store.save(ckpt1).unwrap();
322        // small sleep so created_at differs
323        std::thread::sleep(std::time::Duration::from_millis(5));
324        let ckpt2 = sample_checkpoint("idx:ckpt-2", 20);
325        store.save(ckpt2).unwrap();
326        std::thread::sleep(std::time::Duration::from_millis(5));
327        let ckpt3 = sample_checkpoint("idx:ckpt-3", 30);
328        store.save(ckpt3).unwrap();
329
330        let latest = store.load_latest("idx").unwrap().unwrap();
331        assert_eq!(latest.entries_processed, 30);
332    }
333
334    #[test]
335    fn load_latest_no_match_returns_none() {
336        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
337        store.save(sample_checkpoint("other:ckpt-1", 5)).unwrap();
338        assert!(store.load_latest("my-index").unwrap().is_none());
339    }
340
341    #[test]
342    fn load_latest_prefix_isolation() {
343        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
344        store.save(sample_checkpoint("alpha:ckpt-1", 10)).unwrap();
345        store.save(sample_checkpoint("beta:ckpt-1", 999)).unwrap();
346
347        let latest_alpha = store.load_latest("alpha").unwrap().unwrap();
348        assert_eq!(latest_alpha.entries_processed, 10);
349    }
350
351    #[test]
352    fn checkpoint_fields_accessible() {
353        let ckpt: Checkpoint<u32> =
354            Checkpoint::new("test:ckpt", 42u32, Uuid::new_v4(), 7, FoldContext::new(), 3).unwrap();
355        assert_eq!(ckpt.state, 42);
356        assert_eq!(ckpt.entries_processed, 7);
357        assert_eq!(ckpt.fold_version, 3);
358    }
359
360    // --- Additional tests (F-NEW-8) ---
361
362    #[test]
363    fn serde_roundtrip() {
364        let ckpt = sample_checkpoint("serde:test", 42);
365        let json = serde_json::to_string(&ckpt).expect("serialize");
366        let restored: Checkpoint<String> = serde_json::from_str(&json).expect("deserialize");
367        assert_eq!(ckpt.id, restored.id);
368        assert_eq!(ckpt.state, restored.state);
369        assert_eq!(ckpt.entries_processed, restored.entries_processed);
370        assert_eq!(ckpt.fold_version, restored.fold_version);
371        assert_eq!(ckpt.uuid, restored.uuid);
372        // Hash bytes should survive the roundtrip unchanged.
373        assert_eq!(ckpt.hash.as_bytes(), restored.hash.as_bytes());
374    }
375
376    #[test]
377    fn delete_existing_succeeds() {
378        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
379        store.save(sample_checkpoint("del:ckpt-1", 1)).unwrap();
380        store.delete("del:ckpt-1").unwrap();
381        assert!(store.load("del:ckpt-1").unwrap().is_none());
382    }
383
384    #[test]
385    fn delete_nonexistent_returns_not_found() {
386        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
387        let err = store.delete("nope").unwrap_err();
388        assert!(
389            matches!(err, FoldError::CheckpointNotFound(ref id) if id == "nope"),
390            "expected CheckpointNotFound, got {err:?}"
391        );
392    }
393
394    #[test]
395    fn list_returns_all_ids() {
396        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
397        store.save(sample_checkpoint("a:ckpt-1", 1)).unwrap();
398        store.save(sample_checkpoint("b:ckpt-1", 2)).unwrap();
399        store.save(sample_checkpoint("c:ckpt-1", 3)).unwrap();
400        let mut ids = store.list().unwrap();
401        ids.sort();
402        assert_eq!(ids, vec!["a:ckpt-1", "b:ckpt-1", "c:ckpt-1"]);
403    }
404
405    #[test]
406    fn list_empty_store() {
407        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
408        assert!(store.list().unwrap().is_empty());
409    }
410
411    #[test]
412    fn save_overwrite_replaces_previous() {
413        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
414        let ckpt1 = sample_checkpoint("overwrite:ckpt-1", 10);
415        store.save(ckpt1).unwrap();
416
417        // Save again with the same id but different state.
418        let ckpt2 = Checkpoint::new(
419            "overwrite:ckpt-1",
420            "new-state".to_string(),
421            Uuid::new_v4(),
422            99,
423            FoldContext::new(),
424            2,
425        )
426        .unwrap();
427        store.save(ckpt2).unwrap();
428
429        let loaded = store.load("overwrite:ckpt-1").unwrap().unwrap();
430        assert_eq!(loaded.state, "new-state");
431        assert_eq!(loaded.entries_processed, 99);
432        // Only one entry with that id.
433        let ids = store.list().unwrap();
434        assert_eq!(ids.iter().filter(|id| *id == "overwrite:ckpt-1").count(), 1);
435    }
436
437    #[test]
438    fn integrity_mismatch_on_corrupted_hash() {
439        let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
440        let ckpt = sample_checkpoint("integrity:ckpt-1", 5);
441        store.save(ckpt).unwrap();
442
443        // Directly corrupt the stored hash by replacing it with ZERO.
444        {
445            let mut guard = store.inner.write().unwrap();
446            if let Some(c) = guard.get_mut("integrity:ckpt-1") {
447                c.hash = Hash32::ZERO;
448            }
449        }
450
451        let err = store.load("integrity:ckpt-1").unwrap_err();
452        assert!(
453            matches!(err, FoldError::IntegrityMismatch { .. }),
454            "expected IntegrityMismatch, got {err:?}"
455        );
456    }
457
458    #[test]
459    fn concurrent_saves_all_land() {
460        use std::sync::Arc;
461        use std::thread;
462
463        let store = Arc::new(InMemoryCheckpointStore::<String>::new());
464        let n = 20usize;
465        let handles: Vec<_> = (0..n)
466            .map(|i| {
467                let s = Arc::clone(&store);
468                thread::spawn(move || {
469                    s.save(sample_checkpoint(&format!("concurrent:ckpt-{i}"), i))
470                        .unwrap();
471                })
472            })
473            .collect();
474        for h in handles {
475            h.join().expect("thread panicked");
476        }
477        let ids = store.list().unwrap();
478        assert_eq!(ids.len(), n, "expected {n} checkpoints, got {}", ids.len());
479    }
480}