1use std::collections::HashMap;
34use std::sync::{Arc, RwLock};
35
36use chrono::{DateTime, Utc};
37#[cfg(feature = "serde")]
38use serde::{Deserialize, Serialize};
39use uuid::Uuid;
40
41use khive_types::Hash32;
42
43use crate::context::FoldContext;
44use crate::error::FoldError;
45
46#[derive(Debug, Clone)]
51#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
52pub struct Checkpoint<S> {
53 pub id: String,
55
56 pub state: S,
58
59 pub uuid: Uuid,
61
62 pub hash: Hash32,
68
69 pub entries_processed: usize,
71
72 pub context: FoldContext,
74
75 pub fold_version: usize,
77
78 pub created_at: DateTime<Utc>,
80}
81
82impl<S: Serialize> Checkpoint<S> {
83 #[allow(clippy::too_many_arguments)]
87 pub fn new(
88 id: impl Into<String>,
89 state: S,
90 uuid: Uuid,
91 entries_processed: usize,
92 context: FoldContext,
93 fold_version: usize,
94 ) -> Result<Self, FoldError> {
95 let bytes = serde_json::to_vec(&state)?;
96 let hash = Hash32::from_blake3(&bytes);
97 Ok(Self {
98 id: id.into(),
99 state,
100 uuid,
101 hash,
102 entries_processed,
103 context,
104 fold_version,
105 created_at: DateTime::<Utc>::default(),
108 })
109 }
110
111 #[allow(clippy::too_many_arguments)]
116 pub fn with_hash(
117 id: impl Into<String>,
118 state: S,
119 uuid: Uuid,
120 hash: Hash32,
121 entries_processed: usize,
122 context: FoldContext,
123 fold_version: usize,
124 ) -> Self {
125 Self {
126 id: id.into(),
127 state,
128 uuid,
129 hash,
130 entries_processed,
131 context,
132 fold_version,
133 created_at: DateTime::<Utc>::default(),
135 }
136 }
137}
138
139pub trait CheckpointStore<S> {
146 fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
148 where
149 S: Clone + Serialize;
150
151 fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
157 where
158 S: Clone + Serialize;
159
160 fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
165 where
166 S: Clone + Serialize;
167
168 fn delete(&self, id: &str) -> Result<(), FoldError>;
173
174 fn list(&self) -> Result<Vec<String>, FoldError>;
178}
179
180pub struct InMemoryCheckpointStore<S> {
186 inner: Arc<RwLock<HashMap<String, Checkpoint<S>>>>,
187}
188
189impl<S> InMemoryCheckpointStore<S> {
190 pub fn new() -> Self {
192 Self {
193 inner: Arc::new(RwLock::new(HashMap::new())),
194 }
195 }
196}
197
198impl<S> Default for InMemoryCheckpointStore<S> {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl<S: Clone + Send + Sync + Serialize + 'static> CheckpointStore<S>
205 for InMemoryCheckpointStore<S>
206{
207 fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
208 where
209 S: Clone + Serialize,
210 {
211 let bytes = serde_json::to_vec(&checkpoint.state)?;
213 let computed = Hash32::from_blake3(&bytes);
214 let mut stored = checkpoint;
215 stored.hash = computed;
216
217 let mut guard = self
218 .inner
219 .write()
220 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
221 guard.insert(stored.id.clone(), stored);
222 Ok(())
223 }
224
225 fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
226 where
227 S: Clone + Serialize,
228 {
229 let guard = self
230 .inner
231 .read()
232 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
233 let Some(checkpoint) = guard.get(id).cloned() else {
234 return Ok(None);
235 };
236
237 let bytes = serde_json::to_vec(&checkpoint.state)?;
239 let computed = Hash32::from_blake3(&bytes);
240 if !checkpoint.hash.eq_ct(&computed) {
241 return Err(FoldError::IntegrityMismatch {
242 id: id.to_owned(),
243 stored: checkpoint.hash.to_string(),
244 computed: computed.to_string(),
245 });
246 }
247
248 Ok(Some(checkpoint))
249 }
250
251 fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
252 where
253 S: Clone + Serialize,
254 {
255 let guard = self
256 .inner
257 .read()
258 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
259
260 let latest = guard
261 .values()
262 .filter(|c| c.id.starts_with(prefix))
263 .max_by_key(|c| (c.created_at, c.uuid));
265
266 Ok(latest.cloned())
267 }
268
269 fn delete(&self, id: &str) -> Result<(), FoldError> {
270 let mut guard = self
271 .inner
272 .write()
273 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
274 if guard.remove(id).is_none() {
275 return Err(FoldError::CheckpointNotFound(id.to_owned()));
276 }
277 Ok(())
278 }
279
280 fn list(&self) -> Result<Vec<String>, FoldError> {
281 let guard = self
282 .inner
283 .read()
284 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
285 Ok(guard.keys().cloned().collect())
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 fn sample_checkpoint(id: &str, entries: usize) -> Checkpoint<String> {
294 Checkpoint::new(
295 id,
296 format!("state-{entries}"),
297 Uuid::new_v4(),
298 entries,
299 FoldContext::new(),
300 1,
301 )
302 .expect("sample_checkpoint should not fail serialization")
303 }
304
305 #[test]
306 fn save_and_load_roundtrip() {
307 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
308 let ckpt = sample_checkpoint("my-index:ckpt-1", 100);
309 store.save(ckpt).unwrap();
310 let loaded = store.load("my-index:ckpt-1").unwrap().unwrap();
311 assert_eq!(loaded.state, "state-100");
312 assert_eq!(loaded.entries_processed, 100);
313 }
314
315 #[test]
316 fn load_missing_returns_none() {
317 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
318 assert!(store.load("nonexistent").unwrap().is_none());
319 }
320
321 #[test]
322 fn load_latest_returns_most_recent() {
323 use chrono::Duration;
324
325 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
326 let base = DateTime::<Utc>::default();
327
328 let mut ckpt1 = sample_checkpoint("idx:ckpt-1", 10);
331 ckpt1.created_at = base;
332 let mut ckpt2 = sample_checkpoint("idx:ckpt-2", 20);
333 ckpt2.created_at = base + Duration::milliseconds(5);
334 let mut ckpt3 = sample_checkpoint("idx:ckpt-3", 30);
335 ckpt3.created_at = base + Duration::milliseconds(10);
336
337 store.save(ckpt1).unwrap();
338 store.save(ckpt2).unwrap();
339 store.save(ckpt3).unwrap();
340
341 let latest = store.load_latest("idx").unwrap().unwrap();
342 assert_eq!(latest.entries_processed, 30);
343 }
344
345 #[test]
346 fn load_latest_no_match_returns_none() {
347 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
348 store.save(sample_checkpoint("other:ckpt-1", 5)).unwrap();
349 assert!(store.load_latest("my-index").unwrap().is_none());
350 }
351
352 #[test]
353 fn load_latest_prefix_isolation() {
354 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
355 store.save(sample_checkpoint("alpha:ckpt-1", 10)).unwrap();
356 store.save(sample_checkpoint("beta:ckpt-1", 999)).unwrap();
357
358 let latest_alpha = store.load_latest("alpha").unwrap().unwrap();
359 assert_eq!(latest_alpha.entries_processed, 10);
360 }
361
362 #[test]
363 fn checkpoint_fields_accessible() {
364 let ckpt: Checkpoint<u32> =
365 Checkpoint::new("test:ckpt", 42u32, Uuid::new_v4(), 7, FoldContext::new(), 3).unwrap();
366 assert_eq!(ckpt.state, 42);
367 assert_eq!(ckpt.entries_processed, 7);
368 assert_eq!(ckpt.fold_version, 3);
369 }
370
371 #[cfg(feature = "serde")]
374 #[test]
375 fn serde_roundtrip() {
376 let ckpt = sample_checkpoint("serde:test", 42);
377 let json = serde_json::to_string(&ckpt).expect("serialize");
378 let restored: Checkpoint<String> = serde_json::from_str(&json).expect("deserialize");
379 assert_eq!(ckpt.id, restored.id);
380 assert_eq!(ckpt.state, restored.state);
381 assert_eq!(ckpt.entries_processed, restored.entries_processed);
382 assert_eq!(ckpt.fold_version, restored.fold_version);
383 assert_eq!(ckpt.uuid, restored.uuid);
384 assert_eq!(ckpt.hash.as_bytes(), restored.hash.as_bytes());
386 }
387
388 #[test]
389 fn delete_existing_succeeds() {
390 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
391 store.save(sample_checkpoint("del:ckpt-1", 1)).unwrap();
392 store.delete("del:ckpt-1").unwrap();
393 assert!(store.load("del:ckpt-1").unwrap().is_none());
394 }
395
396 #[test]
397 fn delete_nonexistent_returns_not_found() {
398 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
399 let err = store.delete("nope").unwrap_err();
400 assert!(
401 matches!(err, FoldError::CheckpointNotFound(ref id) if id == "nope"),
402 "expected CheckpointNotFound, got {err:?}"
403 );
404 }
405
406 #[test]
407 fn list_returns_all_ids() {
408 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
409 store.save(sample_checkpoint("a:ckpt-1", 1)).unwrap();
410 store.save(sample_checkpoint("b:ckpt-1", 2)).unwrap();
411 store.save(sample_checkpoint("c:ckpt-1", 3)).unwrap();
412 let mut ids = store.list().unwrap();
413 ids.sort();
414 assert_eq!(ids, vec!["a:ckpt-1", "b:ckpt-1", "c:ckpt-1"]);
415 }
416
417 #[test]
418 fn list_empty_store() {
419 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
420 assert!(store.list().unwrap().is_empty());
421 }
422
423 #[test]
424 fn save_overwrite_replaces_previous() {
425 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
426 let ckpt1 = sample_checkpoint("overwrite:ckpt-1", 10);
427 store.save(ckpt1).unwrap();
428
429 let ckpt2 = Checkpoint::new(
431 "overwrite:ckpt-1",
432 "new-state".to_string(),
433 Uuid::new_v4(),
434 99,
435 FoldContext::new(),
436 2,
437 )
438 .unwrap();
439 store.save(ckpt2).unwrap();
440
441 let loaded = store.load("overwrite:ckpt-1").unwrap().unwrap();
442 assert_eq!(loaded.state, "new-state");
443 assert_eq!(loaded.entries_processed, 99);
444 let ids = store.list().unwrap();
446 assert_eq!(ids.iter().filter(|id| *id == "overwrite:ckpt-1").count(), 1);
447 }
448
449 #[test]
450 fn integrity_mismatch_on_corrupted_hash() {
451 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
452 let ckpt = sample_checkpoint("integrity:ckpt-1", 5);
453 store.save(ckpt).unwrap();
454
455 {
457 let mut guard = store.inner.write().unwrap();
458 if let Some(c) = guard.get_mut("integrity:ckpt-1") {
459 c.hash = Hash32::ZERO;
460 }
461 }
462
463 let err = store.load("integrity:ckpt-1").unwrap_err();
464 assert!(
465 matches!(err, FoldError::IntegrityMismatch { .. }),
466 "expected IntegrityMismatch, got {err:?}"
467 );
468 }
469
470 #[test]
471 fn concurrent_saves_all_land() {
472 use std::sync::Arc;
473 use std::thread;
474
475 let store = Arc::new(InMemoryCheckpointStore::<String>::new());
476 let n = 20usize;
477 let handles: Vec<_> = (0..n)
478 .map(|i| {
479 let s = Arc::clone(&store);
480 thread::spawn(move || {
481 s.save(sample_checkpoint(&format!("concurrent:ckpt-{i}"), i))
482 .unwrap();
483 })
484 })
485 .collect();
486 for h in handles {
487 h.join().expect("thread panicked");
488 }
489 let ids = store.list().unwrap();
490 assert_eq!(ids.len(), n, "expected {n} checkpoints, got {}", ids.len());
491 }
492}