1use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6use chrono::{DateTime, Utc};
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10
11use khive_types::Hash32;
12
13use crate::context::FoldContext;
14use crate::error::FoldError;
15
16#[derive(Debug, Clone)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub struct Checkpoint<S> {
20 pub id: String,
22
23 pub state: S,
25
26 pub uuid: Uuid,
28
29 pub hash: Hash32,
31
32 pub entries_processed: usize,
34
35 pub context: FoldContext,
37
38 pub fold_version: usize,
40
41 pub created_at: DateTime<Utc>,
43}
44
45impl<S: Serialize> Checkpoint<S> {
46 #[allow(clippy::too_many_arguments)]
51 pub fn new(
52 id: impl Into<String>,
53 state: S,
54 uuid: Uuid,
55 entries_processed: usize,
56 context: FoldContext,
57 fold_version: usize,
58 ) -> Result<Self, FoldError> {
59 let bytes = serde_json::to_vec(&state)?;
60 let hash = Hash32::from_blake3(&bytes);
61 Ok(Self {
62 id: id.into(),
63 state,
64 uuid,
65 hash,
66 entries_processed,
67 context,
68 fold_version,
69 created_at: DateTime::<Utc>::default(),
72 })
73 }
74
75 #[allow(clippy::too_many_arguments)]
79 pub fn with_hash(
80 id: impl Into<String>,
81 state: S,
82 uuid: Uuid,
83 hash: Hash32,
84 entries_processed: usize,
85 context: FoldContext,
86 fold_version: usize,
87 ) -> Self {
88 Self {
89 id: id.into(),
90 state,
91 uuid,
92 hash,
93 entries_processed,
94 context,
95 fold_version,
96 created_at: DateTime::<Utc>::default(),
98 }
99 }
100}
101
102pub trait CheckpointStore<S> {
104 fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
106 where
107 S: Clone + Serialize;
108
109 fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
111 where
112 S: Clone + Serialize;
113
114 fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
116 where
117 S: Clone + Serialize;
118
119 fn delete(&self, id: &str) -> Result<(), FoldError>;
121
122 fn list(&self) -> Result<Vec<String>, FoldError>;
124}
125
126pub struct InMemoryCheckpointStore<S> {
128 inner: Arc<RwLock<HashMap<String, Checkpoint<S>>>>,
129}
130
131impl<S> InMemoryCheckpointStore<S> {
132 pub fn new() -> Self {
134 Self {
135 inner: Arc::new(RwLock::new(HashMap::new())),
136 }
137 }
138}
139
140impl<S> Default for InMemoryCheckpointStore<S> {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl<S: Clone + Send + Sync + Serialize + 'static> CheckpointStore<S>
147 for InMemoryCheckpointStore<S>
148{
149 fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
150 where
151 S: Clone + Serialize,
152 {
153 let bytes = serde_json::to_vec(&checkpoint.state)?;
155 let computed = Hash32::from_blake3(&bytes);
156 let mut stored = checkpoint;
157 stored.hash = computed;
158
159 let mut guard = self
160 .inner
161 .write()
162 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
163 guard.insert(stored.id.clone(), stored);
164 Ok(())
165 }
166
167 fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
168 where
169 S: Clone + Serialize,
170 {
171 let guard = self
172 .inner
173 .read()
174 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
175 let Some(checkpoint) = guard.get(id).cloned() else {
176 return Ok(None);
177 };
178
179 let bytes = serde_json::to_vec(&checkpoint.state)?;
181 let computed = Hash32::from_blake3(&bytes);
182 if !checkpoint.hash.eq_ct(&computed) {
183 return Err(FoldError::IntegrityMismatch {
184 id: id.to_owned(),
185 stored: checkpoint.hash.to_string(),
186 computed: computed.to_string(),
187 });
188 }
189
190 Ok(Some(checkpoint))
191 }
192
193 fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
194 where
195 S: Clone + Serialize,
196 {
197 let guard = self
198 .inner
199 .read()
200 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
201
202 let latest = guard
203 .values()
204 .filter(|c| c.id.starts_with(prefix))
205 .max_by_key(|c| (c.created_at, c.uuid));
207
208 Ok(latest.cloned())
209 }
210
211 fn delete(&self, id: &str) -> Result<(), FoldError> {
212 let mut guard = self
213 .inner
214 .write()
215 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
216 if guard.remove(id).is_none() {
217 return Err(FoldError::CheckpointNotFound(id.to_owned()));
218 }
219 Ok(())
220 }
221
222 fn list(&self) -> Result<Vec<String>, FoldError> {
223 let guard = self
224 .inner
225 .read()
226 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
227 let keys: Vec<String> = guard.keys().cloned().collect();
228 Ok(sort_checkpoint_keys(keys))
229 }
230}
231
232pub fn sort_checkpoint_keys(mut keys: Vec<String>) -> Vec<String> {
238 keys.sort();
239 keys
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 fn sample_checkpoint(id: &str, entries: usize) -> Checkpoint<String> {
247 Checkpoint::new(
248 id,
249 format!("state-{entries}"),
250 Uuid::new_v4(),
251 entries,
252 FoldContext::new(),
253 1,
254 )
255 .expect("sample_checkpoint should not fail serialization")
256 }
257
258 #[test]
259 fn save_and_load_roundtrip() {
260 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
261 let ckpt = sample_checkpoint("my-index:ckpt-1", 100);
262 store.save(ckpt).unwrap();
263 let loaded = store.load("my-index:ckpt-1").unwrap().unwrap();
264 assert_eq!(loaded.state, "state-100");
265 assert_eq!(loaded.entries_processed, 100);
266 }
267
268 #[test]
269 fn load_missing_returns_none() {
270 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
271 assert!(store.load("nonexistent").unwrap().is_none());
272 }
273
274 #[test]
275 fn load_latest_returns_most_recent() {
276 use chrono::Duration;
277
278 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
279 let base = DateTime::<Utc>::default();
280
281 let mut ckpt1 = sample_checkpoint("idx:ckpt-1", 10);
284 ckpt1.created_at = base;
285 let mut ckpt2 = sample_checkpoint("idx:ckpt-2", 20);
286 ckpt2.created_at = base + Duration::milliseconds(5);
287 let mut ckpt3 = sample_checkpoint("idx:ckpt-3", 30);
288 ckpt3.created_at = base + Duration::milliseconds(10);
289
290 store.save(ckpt1).unwrap();
291 store.save(ckpt2).unwrap();
292 store.save(ckpt3).unwrap();
293
294 let latest = store.load_latest("idx").unwrap().unwrap();
295 assert_eq!(latest.entries_processed, 30);
296 }
297
298 #[test]
299 fn load_latest_no_match_returns_none() {
300 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
301 store.save(sample_checkpoint("other:ckpt-1", 5)).unwrap();
302 assert!(store.load_latest("my-index").unwrap().is_none());
303 }
304
305 #[test]
306 fn load_latest_prefix_isolation() {
307 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
308 store.save(sample_checkpoint("alpha:ckpt-1", 10)).unwrap();
309 store.save(sample_checkpoint("beta:ckpt-1", 999)).unwrap();
310
311 let latest_alpha = store.load_latest("alpha").unwrap().unwrap();
312 assert_eq!(latest_alpha.entries_processed, 10);
313 }
314
315 #[test]
316 fn checkpoint_fields_accessible() {
317 let ckpt: Checkpoint<u32> =
318 Checkpoint::new("test:ckpt", 42u32, Uuid::new_v4(), 7, FoldContext::new(), 3).unwrap();
319 assert_eq!(ckpt.state, 42);
320 assert_eq!(ckpt.entries_processed, 7);
321 assert_eq!(ckpt.fold_version, 3);
322 }
323
324 #[cfg(feature = "serde")]
327 #[test]
328 fn serde_roundtrip() {
329 let ckpt = sample_checkpoint("serde:test", 42);
330 let json = serde_json::to_string(&ckpt).expect("serialize");
331 let restored: Checkpoint<String> = serde_json::from_str(&json).expect("deserialize");
332 assert_eq!(ckpt.id, restored.id);
333 assert_eq!(ckpt.state, restored.state);
334 assert_eq!(ckpt.entries_processed, restored.entries_processed);
335 assert_eq!(ckpt.fold_version, restored.fold_version);
336 assert_eq!(ckpt.uuid, restored.uuid);
337 assert_eq!(ckpt.hash.as_bytes(), restored.hash.as_bytes());
339 }
340
341 #[test]
342 fn delete_existing_succeeds() {
343 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
344 store.save(sample_checkpoint("del:ckpt-1", 1)).unwrap();
345 store.delete("del:ckpt-1").unwrap();
346 assert!(store.load("del:ckpt-1").unwrap().is_none());
347 }
348
349 #[test]
350 fn delete_nonexistent_returns_not_found() {
351 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
352 let err = store.delete("nope").unwrap_err();
353 assert!(
354 matches!(err, FoldError::CheckpointNotFound(ref id) if id == "nope"),
355 "expected CheckpointNotFound, got {err:?}"
356 );
357 }
358
359 #[test]
360 fn list_returns_all_ids() {
361 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
362 store.save(sample_checkpoint("a:ckpt-1", 1)).unwrap();
363 store.save(sample_checkpoint("b:ckpt-1", 2)).unwrap();
364 store.save(sample_checkpoint("c:ckpt-1", 3)).unwrap();
365 let mut ids = store.list().unwrap();
366 ids.sort();
367 assert_eq!(ids, vec!["a:ckpt-1", "b:ckpt-1", "c:ckpt-1"]);
368 }
369
370 #[test]
371 fn list_empty_store() {
372 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
373 assert!(store.list().unwrap().is_empty());
374 }
375
376 #[test]
377 fn save_overwrite_replaces_previous() {
378 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
379 let ckpt1 = sample_checkpoint("overwrite:ckpt-1", 10);
380 store.save(ckpt1).unwrap();
381
382 let ckpt2 = Checkpoint::new(
384 "overwrite:ckpt-1",
385 "new-state".to_string(),
386 Uuid::new_v4(),
387 99,
388 FoldContext::new(),
389 2,
390 )
391 .unwrap();
392 store.save(ckpt2).unwrap();
393
394 let loaded = store.load("overwrite:ckpt-1").unwrap().unwrap();
395 assert_eq!(loaded.state, "new-state");
396 assert_eq!(loaded.entries_processed, 99);
397 let ids = store.list().unwrap();
399 assert_eq!(ids.iter().filter(|id| *id == "overwrite:ckpt-1").count(), 1);
400 }
401
402 #[test]
403 fn integrity_mismatch_on_corrupted_hash() {
404 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
405 let ckpt = sample_checkpoint("integrity:ckpt-1", 5);
406 store.save(ckpt).unwrap();
407
408 {
410 let mut guard = store.inner.write().unwrap();
411 if let Some(c) = guard.get_mut("integrity:ckpt-1") {
412 c.hash = Hash32::ZERO;
413 }
414 }
415
416 let err = store.load("integrity:ckpt-1").unwrap_err();
417 assert!(
418 matches!(err, FoldError::IntegrityMismatch { .. }),
419 "expected IntegrityMismatch, got {err:?}"
420 );
421 }
422
423 #[test]
424 fn concurrent_saves_all_land() {
425 use std::sync::Arc;
426 use std::thread;
427
428 let store = Arc::new(InMemoryCheckpointStore::<String>::new());
429 let n = 20usize;
430 let handles: Vec<_> = (0..n)
431 .map(|i| {
432 let s = Arc::clone(&store);
433 thread::spawn(move || {
434 s.save(sample_checkpoint(&format!("concurrent:ckpt-{i}"), i))
435 .unwrap();
436 })
437 })
438 .collect();
439 for h in handles {
440 h.join().expect("thread panicked");
441 }
442 let ids = store.list().unwrap();
443 assert_eq!(ids.len(), n, "expected {n} checkpoints, got {}", ids.len());
444 }
445
446 #[test]
454 fn sort_checkpoint_keys_produces_lexicographic_order() {
455 let unsorted = vec![
457 "z:ckpt-3".to_string(),
458 "m:ckpt-2".to_string(),
459 "a:ckpt-1".to_string(),
460 ];
461 let sorted = sort_checkpoint_keys(unsorted);
462 assert_eq!(
463 sorted,
464 vec!["a:ckpt-1", "m:ckpt-2", "z:ckpt-3"],
465 "sort_checkpoint_keys must produce lexicographic order; got {sorted:?}"
466 );
467 }
468
469 #[test]
472 fn list_is_sorted() {
473 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
474 store.save(sample_checkpoint("z:ckpt-1", 1)).unwrap();
476 store.save(sample_checkpoint("a:ckpt-1", 2)).unwrap();
477 store.save(sample_checkpoint("m:ckpt-1", 3)).unwrap();
478 let ids = store.list().unwrap();
479 assert_eq!(
480 ids,
481 vec!["a:ckpt-1", "m:ckpt-1", "z:ckpt-1"],
482 "list() must return sorted keys; got {ids:?}"
483 );
484 }
485}