1use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6use chrono::{DateTime, Utc};
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10
11use khive_types::Hash32;
12
13use crate::context::FoldContext;
14use crate::error::FoldError;
15
16#[derive(Debug, Clone)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub struct Checkpoint<S> {
20 pub id: String,
22
23 pub state: S,
25
26 pub uuid: Uuid,
28
29 pub hash: Hash32,
31
32 pub entries_processed: usize,
34
35 pub context: FoldContext,
37
38 pub fold_version: usize,
40
41 pub created_at: DateTime<Utc>,
43}
44
45impl<S: Serialize> Checkpoint<S> {
46 #[allow(clippy::too_many_arguments)]
51 pub fn new(
52 id: impl Into<String>,
53 state: S,
54 uuid: Uuid,
55 entries_processed: usize,
56 context: FoldContext,
57 fold_version: usize,
58 ) -> Result<Self, FoldError> {
59 let bytes = serde_json::to_vec(&state)?;
60 let hash = Hash32::from_blake3(&bytes);
61 Ok(Self {
62 id: id.into(),
63 state,
64 uuid,
65 hash,
66 entries_processed,
67 context,
68 fold_version,
69 created_at: DateTime::<Utc>::default(),
72 })
73 }
74
75 #[allow(clippy::too_many_arguments)]
79 pub fn with_hash(
80 id: impl Into<String>,
81 state: S,
82 uuid: Uuid,
83 hash: Hash32,
84 entries_processed: usize,
85 context: FoldContext,
86 fold_version: usize,
87 ) -> Self {
88 Self {
89 id: id.into(),
90 state,
91 uuid,
92 hash,
93 entries_processed,
94 context,
95 fold_version,
96 created_at: DateTime::<Utc>::default(),
98 }
99 }
100}
101
102pub trait CheckpointStore<S> {
104 fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
106 where
107 S: Clone + Serialize;
108
109 fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
111 where
112 S: Clone + Serialize;
113
114 fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
116 where
117 S: Clone + Serialize;
118
119 fn delete(&self, id: &str) -> Result<(), FoldError>;
121
122 fn list(&self) -> Result<Vec<String>, FoldError>;
124}
125
126pub struct InMemoryCheckpointStore<S> {
128 inner: Arc<RwLock<HashMap<String, Checkpoint<S>>>>,
129}
130
131impl<S> InMemoryCheckpointStore<S> {
132 pub fn new() -> Self {
134 Self {
135 inner: Arc::new(RwLock::new(HashMap::new())),
136 }
137 }
138}
139
140impl<S> Default for InMemoryCheckpointStore<S> {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl<S: Clone + Send + Sync + Serialize + 'static> CheckpointStore<S>
147 for InMemoryCheckpointStore<S>
148{
149 fn save(&self, checkpoint: Checkpoint<S>) -> Result<(), FoldError>
150 where
151 S: Clone + Serialize,
152 {
153 let bytes = serde_json::to_vec(&checkpoint.state)?;
155 let computed = Hash32::from_blake3(&bytes);
156 let mut stored = checkpoint;
157 stored.hash = computed;
158
159 let mut guard = self
160 .inner
161 .write()
162 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
163 guard.insert(stored.id.clone(), stored);
164 Ok(())
165 }
166
167 fn load(&self, id: &str) -> Result<Option<Checkpoint<S>>, FoldError>
168 where
169 S: Clone + Serialize,
170 {
171 let guard = self
172 .inner
173 .read()
174 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
175 let Some(checkpoint) = guard.get(id).cloned() else {
176 return Ok(None);
177 };
178
179 let bytes = serde_json::to_vec(&checkpoint.state)?;
181 let computed = Hash32::from_blake3(&bytes);
182 if !checkpoint.hash.eq_ct(&computed) {
183 return Err(FoldError::IntegrityMismatch {
184 id: id.to_owned(),
185 stored: checkpoint.hash.to_string(),
186 computed: computed.to_string(),
187 });
188 }
189
190 Ok(Some(checkpoint))
191 }
192
193 fn load_latest(&self, prefix: &str) -> Result<Option<Checkpoint<S>>, FoldError>
194 where
195 S: Clone + Serialize,
196 {
197 let guard = self
198 .inner
199 .read()
200 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
201
202 let latest = guard
203 .values()
204 .filter(|c| c.id.starts_with(prefix))
205 .max_by_key(|c| (c.created_at, c.uuid));
207
208 Ok(latest.cloned())
209 }
210
211 fn delete(&self, id: &str) -> Result<(), FoldError> {
212 let mut guard = self
213 .inner
214 .write()
215 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
216 if guard.remove(id).is_none() {
217 return Err(FoldError::CheckpointNotFound(id.to_owned()));
218 }
219 Ok(())
220 }
221
222 fn list(&self) -> Result<Vec<String>, FoldError> {
223 let guard = self
224 .inner
225 .read()
226 .map_err(|e| FoldError::LockPoisoned(e.to_string()))?;
227 Ok(guard.keys().cloned().collect())
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 fn sample_checkpoint(id: &str, entries: usize) -> Checkpoint<String> {
236 Checkpoint::new(
237 id,
238 format!("state-{entries}"),
239 Uuid::new_v4(),
240 entries,
241 FoldContext::new(),
242 1,
243 )
244 .expect("sample_checkpoint should not fail serialization")
245 }
246
247 #[test]
248 fn save_and_load_roundtrip() {
249 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
250 let ckpt = sample_checkpoint("my-index:ckpt-1", 100);
251 store.save(ckpt).unwrap();
252 let loaded = store.load("my-index:ckpt-1").unwrap().unwrap();
253 assert_eq!(loaded.state, "state-100");
254 assert_eq!(loaded.entries_processed, 100);
255 }
256
257 #[test]
258 fn load_missing_returns_none() {
259 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
260 assert!(store.load("nonexistent").unwrap().is_none());
261 }
262
263 #[test]
264 fn load_latest_returns_most_recent() {
265 use chrono::Duration;
266
267 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
268 let base = DateTime::<Utc>::default();
269
270 let mut ckpt1 = sample_checkpoint("idx:ckpt-1", 10);
273 ckpt1.created_at = base;
274 let mut ckpt2 = sample_checkpoint("idx:ckpt-2", 20);
275 ckpt2.created_at = base + Duration::milliseconds(5);
276 let mut ckpt3 = sample_checkpoint("idx:ckpt-3", 30);
277 ckpt3.created_at = base + Duration::milliseconds(10);
278
279 store.save(ckpt1).unwrap();
280 store.save(ckpt2).unwrap();
281 store.save(ckpt3).unwrap();
282
283 let latest = store.load_latest("idx").unwrap().unwrap();
284 assert_eq!(latest.entries_processed, 30);
285 }
286
287 #[test]
288 fn load_latest_no_match_returns_none() {
289 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
290 store.save(sample_checkpoint("other:ckpt-1", 5)).unwrap();
291 assert!(store.load_latest("my-index").unwrap().is_none());
292 }
293
294 #[test]
295 fn load_latest_prefix_isolation() {
296 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
297 store.save(sample_checkpoint("alpha:ckpt-1", 10)).unwrap();
298 store.save(sample_checkpoint("beta:ckpt-1", 999)).unwrap();
299
300 let latest_alpha = store.load_latest("alpha").unwrap().unwrap();
301 assert_eq!(latest_alpha.entries_processed, 10);
302 }
303
304 #[test]
305 fn checkpoint_fields_accessible() {
306 let ckpt: Checkpoint<u32> =
307 Checkpoint::new("test:ckpt", 42u32, Uuid::new_v4(), 7, FoldContext::new(), 3).unwrap();
308 assert_eq!(ckpt.state, 42);
309 assert_eq!(ckpt.entries_processed, 7);
310 assert_eq!(ckpt.fold_version, 3);
311 }
312
313 #[cfg(feature = "serde")]
316 #[test]
317 fn serde_roundtrip() {
318 let ckpt = sample_checkpoint("serde:test", 42);
319 let json = serde_json::to_string(&ckpt).expect("serialize");
320 let restored: Checkpoint<String> = serde_json::from_str(&json).expect("deserialize");
321 assert_eq!(ckpt.id, restored.id);
322 assert_eq!(ckpt.state, restored.state);
323 assert_eq!(ckpt.entries_processed, restored.entries_processed);
324 assert_eq!(ckpt.fold_version, restored.fold_version);
325 assert_eq!(ckpt.uuid, restored.uuid);
326 assert_eq!(ckpt.hash.as_bytes(), restored.hash.as_bytes());
328 }
329
330 #[test]
331 fn delete_existing_succeeds() {
332 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
333 store.save(sample_checkpoint("del:ckpt-1", 1)).unwrap();
334 store.delete("del:ckpt-1").unwrap();
335 assert!(store.load("del:ckpt-1").unwrap().is_none());
336 }
337
338 #[test]
339 fn delete_nonexistent_returns_not_found() {
340 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
341 let err = store.delete("nope").unwrap_err();
342 assert!(
343 matches!(err, FoldError::CheckpointNotFound(ref id) if id == "nope"),
344 "expected CheckpointNotFound, got {err:?}"
345 );
346 }
347
348 #[test]
349 fn list_returns_all_ids() {
350 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
351 store.save(sample_checkpoint("a:ckpt-1", 1)).unwrap();
352 store.save(sample_checkpoint("b:ckpt-1", 2)).unwrap();
353 store.save(sample_checkpoint("c:ckpt-1", 3)).unwrap();
354 let mut ids = store.list().unwrap();
355 ids.sort();
356 assert_eq!(ids, vec!["a:ckpt-1", "b:ckpt-1", "c:ckpt-1"]);
357 }
358
359 #[test]
360 fn list_empty_store() {
361 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
362 assert!(store.list().unwrap().is_empty());
363 }
364
365 #[test]
366 fn save_overwrite_replaces_previous() {
367 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
368 let ckpt1 = sample_checkpoint("overwrite:ckpt-1", 10);
369 store.save(ckpt1).unwrap();
370
371 let ckpt2 = Checkpoint::new(
373 "overwrite:ckpt-1",
374 "new-state".to_string(),
375 Uuid::new_v4(),
376 99,
377 FoldContext::new(),
378 2,
379 )
380 .unwrap();
381 store.save(ckpt2).unwrap();
382
383 let loaded = store.load("overwrite:ckpt-1").unwrap().unwrap();
384 assert_eq!(loaded.state, "new-state");
385 assert_eq!(loaded.entries_processed, 99);
386 let ids = store.list().unwrap();
388 assert_eq!(ids.iter().filter(|id| *id == "overwrite:ckpt-1").count(), 1);
389 }
390
391 #[test]
392 fn integrity_mismatch_on_corrupted_hash() {
393 let store: InMemoryCheckpointStore<String> = InMemoryCheckpointStore::new();
394 let ckpt = sample_checkpoint("integrity:ckpt-1", 5);
395 store.save(ckpt).unwrap();
396
397 {
399 let mut guard = store.inner.write().unwrap();
400 if let Some(c) = guard.get_mut("integrity:ckpt-1") {
401 c.hash = Hash32::ZERO;
402 }
403 }
404
405 let err = store.load("integrity:ckpt-1").unwrap_err();
406 assert!(
407 matches!(err, FoldError::IntegrityMismatch { .. }),
408 "expected IntegrityMismatch, got {err:?}"
409 );
410 }
411
412 #[test]
413 fn concurrent_saves_all_land() {
414 use std::sync::Arc;
415 use std::thread;
416
417 let store = Arc::new(InMemoryCheckpointStore::<String>::new());
418 let n = 20usize;
419 let handles: Vec<_> = (0..n)
420 .map(|i| {
421 let s = Arc::clone(&store);
422 thread::spawn(move || {
423 s.save(sample_checkpoint(&format!("concurrent:ckpt-{i}"), i))
424 .unwrap();
425 })
426 })
427 .collect();
428 for h in handles {
429 h.join().expect("thread panicked");
430 }
431 let ids = store.list().unwrap();
432 assert_eq!(ids.len(), n, "expected {n} checkpoints, got {}", ids.len());
433 }
434}