1use std::collections::HashMap;
34use std::sync::{Arc, RwLock};
35
36use chrono::{DateTime, Utc};
37use serde::{Deserialize, Serialize};
38use uuid::Uuid;
39
40use khive_types::Hash32;
41
42use crate::context::FoldContext;
43use crate::error::FoldError;
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Checkpoint<S> {
51 pub id: String,
53
54 pub state: S,
56
57 pub uuid: Uuid,
59
60 pub hash: Hash32,
66
67 pub entries_processed: usize,
69
70 pub context: FoldContext,
72
73 pub fold_version: usize,
75
76 pub created_at: DateTime<Utc>,
78}
79
80impl<S: Serialize> Checkpoint<S> {
81 #[allow(clippy::too_many_arguments)]
85 pub fn new(
86 id: impl Into<String>,
87 state: S,
88 uuid: Uuid,
89 entries_processed: usize,
90 context: FoldContext,
91 fold_version: usize,
92 ) -> Result<Self, FoldError> {
93 let bytes = serde_json::to_vec(&state)?;
94 let hash = Hash32::from_blake3(&bytes);
95 Ok(Self {
96 id: id.into(),
97 state,
98 uuid,
99 hash,
100 entries_processed,
101 context,
102 fold_version,
103 created_at: Utc::now(),
104 })
105 }
106
107 #[allow(clippy::too_many_arguments)]
112 pub fn with_hash(
113 id: impl Into<String>,
114 state: S,
115 uuid: Uuid,
116 hash: Hash32,
117 entries_processed: usize,
118 context: FoldContext,
119 fold_version: usize,
120 ) -> Self {
121 Self {
122 id: id.into(),
123 state,
124 uuid,
125 hash,
126 entries_processed,
127 context,
128 fold_version,
129 created_at: Utc::now(),
130 }
131 }
132}
133
134pub trait CheckpointStore<S> {
141 fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
143 where
144 S: Clone + Serialize;
145
146 fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
152 where
153 S: Clone + Serialize;
154
155 fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
160 where
161 S: Clone + Serialize;
162
163 fn delete(&self, id: &str) -> Result<(), FoldError>;
168
169 fn list(&self) -> Result<Vec<String>, FoldError>;
173}
174
175pub struct InMemoryCheckpointStore<S> {
181 inner: Arc<RwLock<HashMap<String, Checkpoint<S>>>>,
182}
183
184impl<S> InMemoryCheckpointStore<S> {
185 pub fn new() -> Self {
187 Self {
188 inner: Arc::new(RwLock::new(HashMap::new())),
189 }
190 }
191}
192
193impl<S> Default for InMemoryCheckpointStore<S> {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199impl<S: Clone + Send + Sync + Serialize + 'static> CheckpointStore<S>
200 for InMemoryCheckpointStore<S>
201{
202 fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
203 where
204 S: Clone + Serialize,
205 {
206 let bytes = serde_json::to_vec(&checkpoint.state)?;
208 let computed = Hash32::from_blake3(&bytes);
209 let mut stored = checkpoint;
210 stored.hash = computed;
211
212 let mut guard = self
213 .inner
214 .write()
215 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
216 guard.insert(stored.id.clone(), stored);
217 Ok(())
218 }
219
220 fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
221 where
222 S: Clone + Serialize,
223 {
224 let guard = self
225 .inner
226 .read()
227 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
228 let Some(checkpoint) = guard.get(id).cloned() else {
229 return Ok(None);
230 };
231
232 let bytes = serde_json::to_vec(&checkpoint.state)?;
234 let computed = Hash32::from_blake3(&bytes);
235 if !checkpoint.hash.eq_ct(&computed) {
236 return Err(FoldError::IntegrityMismatch {
237 id: id.to_owned(),
238 stored: checkpoint.hash.to_string(),
239 computed: computed.to_string(),
240 });
241 }
242
243 Ok(Some(checkpoint))
244 }
245
246 fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
247 where
248 S: Clone + Serialize,
249 {
250 let guard = self
251 .inner
252 .read()
253 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
254
255 let latest = guard
256 .values()
257 .filter(|c| c.id.starts_with(prefix))
258 .max_by_key(|c| (c.created_at, c.uuid));
260
261 Ok(latest.cloned())
262 }
263
264 fn delete(&self, id: &str) -> Result<(), FoldError> {
265 let mut guard = self
266 .inner
267 .write()
268 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
269 if guard.remove(id).is_none() {
270 return Err(FoldError::CheckpointNotFound(id.to_owned()));
271 }
272 Ok(())
273 }
274
275 fn list(&self) -> Result<Vec<String>, FoldError> {
276 let guard = self
277 .inner
278 .read()
279 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
280 Ok(guard.keys().cloned().collect())
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 fn sample_checkpoint(id: &str, entries: usize) -> Checkpoint<String> {
289 Checkpoint::new(
290 id,
291 format!("state-{entries}"),
292 Uuid::new_v4(),
293 entries,
294 FoldContext::new(),
295 1,
296 )
297 .expect("sample_checkpoint should not fail serialization")
298 }
299
300 #[test]
301 fn save_and_load_roundtrip() {
302 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
303 let ckpt = sample_checkpoint("my-index:ckpt-1", 100);
304 store.save(ckpt).unwrap();
305 let loaded = store.load("my-index:ckpt-1").unwrap().unwrap();
306 assert_eq!(loaded.state, "state-100");
307 assert_eq!(loaded.entries_processed, 100);
308 }
309
310 #[test]
311 fn load_missing_returns_none() {
312 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
313 assert!(store.load("nonexistent").unwrap().is_none());
314 }
315
316 #[test]
317 fn load_latest_returns_most_recent() {
318 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
319
320 let ckpt1 = sample_checkpoint("idx:ckpt-1", 10);
321 store.save(ckpt1).unwrap();
322 std::thread::sleep(std::time::Duration::from_millis(5));
324 let ckpt2 = sample_checkpoint("idx:ckpt-2", 20);
325 store.save(ckpt2).unwrap();
326 std::thread::sleep(std::time::Duration::from_millis(5));
327 let ckpt3 = sample_checkpoint("idx:ckpt-3", 30);
328 store.save(ckpt3).unwrap();
329
330 let latest = store.load_latest("idx").unwrap().unwrap();
331 assert_eq!(latest.entries_processed, 30);
332 }
333
334 #[test]
335 fn load_latest_no_match_returns_none() {
336 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
337 store.save(sample_checkpoint("other:ckpt-1", 5)).unwrap();
338 assert!(store.load_latest("my-index").unwrap().is_none());
339 }
340
341 #[test]
342 fn load_latest_prefix_isolation() {
343 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
344 store.save(sample_checkpoint("alpha:ckpt-1", 10)).unwrap();
345 store.save(sample_checkpoint("beta:ckpt-1", 999)).unwrap();
346
347 let latest_alpha = store.load_latest("alpha").unwrap().unwrap();
348 assert_eq!(latest_alpha.entries_processed, 10);
349 }
350
351 #[test]
352 fn checkpoint_fields_accessible() {
353 let ckpt: Checkpoint<u32> =
354 Checkpoint::new("test:ckpt", 42u32, Uuid::new_v4(), 7, FoldContext::new(), 3).unwrap();
355 assert_eq!(ckpt.state, 42);
356 assert_eq!(ckpt.entries_processed, 7);
357 assert_eq!(ckpt.fold_version, 3);
358 }
359
360 #[test]
363 fn serde_roundtrip() {
364 let ckpt = sample_checkpoint("serde:test", 42);
365 let json = serde_json::to_string(&ckpt).expect("serialize");
366 let restored: Checkpoint<String> = serde_json::from_str(&json).expect("deserialize");
367 assert_eq!(ckpt.id, restored.id);
368 assert_eq!(ckpt.state, restored.state);
369 assert_eq!(ckpt.entries_processed, restored.entries_processed);
370 assert_eq!(ckpt.fold_version, restored.fold_version);
371 assert_eq!(ckpt.uuid, restored.uuid);
372 assert_eq!(ckpt.hash.as_bytes(), restored.hash.as_bytes());
374 }
375
376 #[test]
377 fn delete_existing_succeeds() {
378 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
379 store.save(sample_checkpoint("del:ckpt-1", 1)).unwrap();
380 store.delete("del:ckpt-1").unwrap();
381 assert!(store.load("del:ckpt-1").unwrap().is_none());
382 }
383
384 #[test]
385 fn delete_nonexistent_returns_not_found() {
386 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
387 let err = store.delete("nope").unwrap_err();
388 assert!(
389 matches!(err, FoldError::CheckpointNotFound(ref id) if id == "nope"),
390 "expected CheckpointNotFound, got {err:?}"
391 );
392 }
393
394 #[test]
395 fn list_returns_all_ids() {
396 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
397 store.save(sample_checkpoint("a:ckpt-1", 1)).unwrap();
398 store.save(sample_checkpoint("b:ckpt-1", 2)).unwrap();
399 store.save(sample_checkpoint("c:ckpt-1", 3)).unwrap();
400 let mut ids = store.list().unwrap();
401 ids.sort();
402 assert_eq!(ids, vec!["a:ckpt-1", "b:ckpt-1", "c:ckpt-1"]);
403 }
404
405 #[test]
406 fn list_empty_store() {
407 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
408 assert!(store.list().unwrap().is_empty());
409 }
410
411 #[test]
412 fn save_overwrite_replaces_previous() {
413 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
414 let ckpt1 = sample_checkpoint("overwrite:ckpt-1", 10);
415 store.save(ckpt1).unwrap();
416
417 let ckpt2 = Checkpoint::new(
419 "overwrite:ckpt-1",
420 "new-state".to_string(),
421 Uuid::new_v4(),
422 99,
423 FoldContext::new(),
424 2,
425 )
426 .unwrap();
427 store.save(ckpt2).unwrap();
428
429 let loaded = store.load("overwrite:ckpt-1").unwrap().unwrap();
430 assert_eq!(loaded.state, "new-state");
431 assert_eq!(loaded.entries_processed, 99);
432 let ids = store.list().unwrap();
434 assert_eq!(ids.iter().filter(|id| *id == "overwrite:ckpt-1").count(), 1);
435 }
436
437 #[test]
438 fn integrity_mismatch_on_corrupted_hash() {
439 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
440 let ckpt = sample_checkpoint("integrity:ckpt-1", 5);
441 store.save(ckpt).unwrap();
442
443 {
445 let mut guard = store.inner.write().unwrap();
446 if let Some(c) = guard.get_mut("integrity:ckpt-1") {
447 c.hash = Hash32::ZERO;
448 }
449 }
450
451 let err = store.load("integrity:ckpt-1").unwrap_err();
452 assert!(
453 matches!(err, FoldError::IntegrityMismatch { .. }),
454 "expected IntegrityMismatch, got {err:?}"
455 );
456 }
457
458 #[test]
459 fn concurrent_saves_all_land() {
460 use std::sync::Arc;
461 use std::thread;
462
463 let store = Arc::new(InMemoryCheckpointStore::<String>::new());
464 let n = 20usize;
465 let handles: Vec<_> = (0..n)
466 .map(|i| {
467 let s = Arc::clone(&store);
468 thread::spawn(move || {
469 s.save(sample_checkpoint(&format!("concurrent:ckpt-{i}"), i))
470 .unwrap();
471 })
472 })
473 .collect();
474 for h in handles {
475 h.join().expect("thread panicked");
476 }
477 let ids = store.list().unwrap();
478 assert_eq!(ids.len(), n, "expected {n} checkpoints, got {}", ids.len());
479 }
480}