1use std::collections::HashMap;
34use std::sync::Arc;
35use std::time::{Duration, SystemTime, UNIX_EPOCH};
36
37use async_trait::async_trait;
38use parking_lot::RwLock;
39use serde::{de::DeserializeOwned, Deserialize, Serialize};
40use tracing::{debug, info};
41
42pub type CheckpointId = String;
44
45pub type ThreadId = String;
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct CheckpointMetadata {
51 pub id: CheckpointId,
53 pub thread_id: ThreadId,
55 pub created_at: u64,
57 pub parent_id: Option<CheckpointId>,
59 pub step: u64,
61 pub label: Option<String>,
63 pub tags: Vec<String>,
65 pub state_size: usize,
67 pub custom: HashMap<String, String>,
69}
70
71impl CheckpointMetadata {
72 pub fn new(thread_id: impl Into<String>, step: u64) -> Self {
73 let now = SystemTime::now()
74 .duration_since(UNIX_EPOCH)
75 .unwrap_or_default()
76 .as_secs();
77
78 Self {
79 id: uuid::Uuid::new_v4().to_string(),
80 thread_id: thread_id.into(),
81 created_at: now,
82 parent_id: None,
83 step,
84 label: None,
85 tags: Vec::new(),
86 state_size: 0,
87 custom: HashMap::new(),
88 }
89 }
90
91 pub fn with_parent(mut self, parent_id: impl Into<String>) -> Self {
92 self.parent_id = Some(parent_id.into());
93 self
94 }
95
96 pub fn with_label(mut self, label: impl Into<String>) -> Self {
97 self.label = Some(label.into());
98 self
99 }
100
101 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
102 self.tags.push(tag.into());
103 self
104 }
105
106 pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
107 self.custom.insert(key.into(), value.into());
108 self
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct Checkpoint<S> {
115 pub metadata: CheckpointMetadata,
117 pub state: S,
119}
120
121impl<S> Checkpoint<S> {
122 pub fn new(thread_id: impl Into<String>, step: u64, state: S) -> Self
123 where
124 S: Serialize,
125 {
126 let state_size = serde_json::to_vec(&state).map(|v| v.len()).unwrap_or(0);
127 Self {
128 metadata: CheckpointMetadata {
129 state_size,
130 ..CheckpointMetadata::new(thread_id, step)
131 },
132 state,
133 }
134 }
135
136 pub fn id(&self) -> &str {
137 &self.metadata.id
138 }
139
140 pub fn thread_id(&self) -> &str {
141 &self.metadata.thread_id
142 }
143
144 pub fn step(&self) -> u64 {
145 self.metadata.step
146 }
147}
148
149#[derive(Debug, thiserror::Error)]
151pub enum CheckpointError {
152 #[error("Checkpoint not found: {0}")]
153 NotFound(CheckpointId),
154
155 #[error("Thread not found: {0}")]
156 ThreadNotFound(ThreadId),
157
158 #[error("Serialization error: {0}")]
159 Serialization(String),
160
161 #[error("Storage error: {0}")]
162 Storage(String),
163
164 #[error("Invalid state: {0}")]
165 InvalidState(String),
166}
167
168#[async_trait]
170pub trait CheckpointStore: Send + Sync {
171 async fn save(
173 &self,
174 thread_id: &str,
175 metadata: CheckpointMetadata,
176 state: Vec<u8>,
177 ) -> Result<CheckpointId, CheckpointError>;
178
179 async fn load(
181 &self,
182 thread_id: &str,
183 checkpoint_id: &str,
184 ) -> Result<(CheckpointMetadata, Vec<u8>), CheckpointError>;
185
186 async fn load_latest(
188 &self,
189 thread_id: &str,
190 ) -> Result<(CheckpointMetadata, Vec<u8>), CheckpointError>;
191
192 async fn list(&self, thread_id: &str) -> Result<Vec<CheckpointMetadata>, CheckpointError>;
194
195 async fn delete(&self, thread_id: &str, checkpoint_id: &str) -> Result<(), CheckpointError>;
197
198 async fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError>;
200
201 async fn count(&self, thread_id: &str) -> Result<usize, CheckpointError>;
203
204 async fn list_threads(&self) -> Result<Vec<ThreadId>, CheckpointError>;
206}
207
208#[derive(Default)]
210pub struct MemoryCheckpointStore {
211 checkpoints: RwLock<HashMap<ThreadId, Vec<(CheckpointMetadata, Vec<u8>)>>>,
212}
213
214impl MemoryCheckpointStore {
215 pub fn new() -> Self {
216 Self::default()
217 }
218}
219
220#[async_trait]
221impl CheckpointStore for MemoryCheckpointStore {
222 async fn save(
223 &self,
224 thread_id: &str,
225 metadata: CheckpointMetadata,
226 state: Vec<u8>,
227 ) -> Result<CheckpointId, CheckpointError> {
228 let id = metadata.id.clone();
229 let mut checkpoints = self.checkpoints.write();
230 checkpoints
231 .entry(thread_id.to_string())
232 .or_default()
233 .push((metadata, state));
234 Ok(id)
235 }
236
237 async fn load(
238 &self,
239 thread_id: &str,
240 checkpoint_id: &str,
241 ) -> Result<(CheckpointMetadata, Vec<u8>), CheckpointError> {
242 let checkpoints = self.checkpoints.read();
243 let thread_checkpoints = checkpoints
244 .get(thread_id)
245 .ok_or_else(|| CheckpointError::ThreadNotFound(thread_id.to_string()))?;
246
247 thread_checkpoints
248 .iter()
249 .find(|(m, _)| m.id == checkpoint_id)
250 .cloned()
251 .ok_or_else(|| CheckpointError::NotFound(checkpoint_id.to_string()))
252 }
253
254 async fn load_latest(
255 &self,
256 thread_id: &str,
257 ) -> Result<(CheckpointMetadata, Vec<u8>), CheckpointError> {
258 let checkpoints = self.checkpoints.read();
259 let thread_checkpoints = checkpoints
260 .get(thread_id)
261 .ok_or_else(|| CheckpointError::ThreadNotFound(thread_id.to_string()))?;
262
263 thread_checkpoints
264 .last()
265 .cloned()
266 .ok_or_else(|| CheckpointError::ThreadNotFound(thread_id.to_string()))
267 }
268
269 async fn list(&self, thread_id: &str) -> Result<Vec<CheckpointMetadata>, CheckpointError> {
270 let checkpoints = self.checkpoints.read();
271 Ok(checkpoints
272 .get(thread_id)
273 .map(|v| v.iter().map(|(m, _)| m.clone()).collect())
274 .unwrap_or_default())
275 }
276
277 async fn delete(&self, thread_id: &str, checkpoint_id: &str) -> Result<(), CheckpointError> {
278 let mut checkpoints = self.checkpoints.write();
279 if let Some(thread_checkpoints) = checkpoints.get_mut(thread_id) {
280 thread_checkpoints.retain(|(m, _)| m.id != checkpoint_id);
281 }
282 Ok(())
283 }
284
285 async fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError> {
286 let mut checkpoints = self.checkpoints.write();
287 checkpoints.remove(thread_id);
288 Ok(())
289 }
290
291 async fn count(&self, thread_id: &str) -> Result<usize, CheckpointError> {
292 let checkpoints = self.checkpoints.read();
293 Ok(checkpoints.get(thread_id).map(|v| v.len()).unwrap_or(0))
294 }
295
296 async fn list_threads(&self) -> Result<Vec<ThreadId>, CheckpointError> {
297 let checkpoints = self.checkpoints.read();
298 Ok(checkpoints.keys().cloned().collect())
299 }
300}
301
302#[derive(Debug, Clone)]
304pub struct CheckpointConfig {
305 pub max_checkpoints_per_thread: Option<usize>,
307 pub auto_checkpoint_interval: Option<u64>,
309 pub compress: bool,
311 pub ttl: Option<Duration>,
313}
314
315impl Default for CheckpointConfig {
316 fn default() -> Self {
317 Self {
318 max_checkpoints_per_thread: Some(100),
319 auto_checkpoint_interval: None,
320 compress: false,
321 ttl: None,
322 }
323 }
324}
325
326pub struct CheckpointManager<Store: CheckpointStore> {
328 store: Arc<Store>,
329 config: CheckpointConfig,
330 steps: RwLock<HashMap<ThreadId, u64>>,
332}
333
334impl<Store: CheckpointStore> CheckpointManager<Store> {
335 pub fn new(store: Store) -> Self {
336 Self {
337 store: Arc::new(store),
338 config: CheckpointConfig::default(),
339 steps: RwLock::new(HashMap::new()),
340 }
341 }
342
343 pub fn with_config(mut self, config: CheckpointConfig) -> Self {
344 self.config = config;
345 self
346 }
347
348 pub async fn save<S: Serialize + Send>(
350 &self,
351 thread_id: &str,
352 state: &S,
353 ) -> Result<CheckpointId, CheckpointError> {
354 self.save_with_label(thread_id, state, None).await
355 }
356
357 pub async fn save_with_label<S: Serialize + Send>(
359 &self,
360 thread_id: &str,
361 state: &S,
362 label: Option<String>,
363 ) -> Result<CheckpointId, CheckpointError> {
364 let state_bytes =
366 serde_json::to_vec(state).map_err(|e| CheckpointError::Serialization(e.to_string()))?;
367
368 let step = {
370 let mut steps = self.steps.write();
371 let step = steps.entry(thread_id.to_string()).or_insert(0);
372 *step += 1;
373 *step
374 };
375
376 let parent_id = self
378 .store
379 .load_latest(thread_id)
380 .await
381 .ok()
382 .map(|(m, _)| m.id);
383
384 let mut metadata = CheckpointMetadata::new(thread_id, step);
386 metadata.state_size = state_bytes.len();
387 if let Some(parent) = parent_id {
388 metadata = metadata.with_parent(parent);
389 }
390 if let Some(lbl) = label {
391 metadata = metadata.with_label(lbl);
392 }
393
394 let id = self.store.save(thread_id, metadata, state_bytes).await?;
396
397 if let Some(max) = self.config.max_checkpoints_per_thread {
399 self.cleanup_old_checkpoints(thread_id, max).await?;
400 }
401
402 debug!(thread_id, checkpoint_id = %id, step, "Checkpoint saved");
403 Ok(id)
404 }
405
406 pub async fn load<S: DeserializeOwned>(
408 &self,
409 thread_id: &str,
410 checkpoint_id: &str,
411 ) -> Result<Checkpoint<S>, CheckpointError> {
412 let (metadata, state_bytes) = self.store.load(thread_id, checkpoint_id).await?;
413 let state = serde_json::from_slice(&state_bytes)
414 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
415
416 debug!(
417 thread_id,
418 checkpoint_id,
419 step = metadata.step,
420 "Checkpoint loaded"
421 );
422 Ok(Checkpoint { metadata, state })
423 }
424
425 pub async fn load_latest<S: DeserializeOwned>(
427 &self,
428 thread_id: &str,
429 ) -> Result<Checkpoint<S>, CheckpointError> {
430 let (metadata, state_bytes) = self.store.load_latest(thread_id).await?;
431 let state = serde_json::from_slice(&state_bytes)
432 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
433
434 debug!(thread_id, checkpoint_id = %metadata.id, step = metadata.step, "Latest checkpoint loaded");
435 Ok(Checkpoint { metadata, state })
436 }
437
438 pub async fn history(
440 &self,
441 thread_id: &str,
442 ) -> Result<Vec<CheckpointMetadata>, CheckpointError> {
443 self.store.list(thread_id).await
444 }
445
446 pub async fn fork<S: Serialize + DeserializeOwned + Send>(
448 &self,
449 source_thread_id: &str,
450 checkpoint_id: &str,
451 new_thread_id: &str,
452 ) -> Result<CheckpointId, CheckpointError> {
453 let checkpoint: Checkpoint<S> = self.load(source_thread_id, checkpoint_id).await?;
455
456 let id = self.save(new_thread_id, &checkpoint.state).await?;
458
459 info!(
460 source_thread = source_thread_id,
461 source_checkpoint = checkpoint_id,
462 new_thread = new_thread_id,
463 new_checkpoint = %id,
464 "Thread forked"
465 );
466
467 Ok(id)
468 }
469
470 pub async fn rewind(
472 &self,
473 thread_id: &str,
474 checkpoint_id: &str,
475 ) -> Result<(), CheckpointError> {
476 let history = self.store.list(thread_id).await?;
477
478 let target_idx = history
480 .iter()
481 .position(|m| m.id == checkpoint_id)
482 .ok_or_else(|| CheckpointError::NotFound(checkpoint_id.to_string()))?;
483
484 for checkpoint in history.iter().skip(target_idx + 1) {
486 self.store.delete(thread_id, &checkpoint.id).await?;
487 }
488
489 {
491 let mut steps = self.steps.write();
492 if let Some(target) = history.get(target_idx) {
493 steps.insert(thread_id.to_string(), target.step);
494 }
495 }
496
497 info!(thread_id, checkpoint_id, "Rewound to checkpoint");
498 Ok(())
499 }
500
501 pub async fn find_by_tag(
503 &self,
504 thread_id: &str,
505 tag: &str,
506 ) -> Result<Vec<CheckpointMetadata>, CheckpointError> {
507 let history = self.store.list(thread_id).await?;
508 Ok(history
509 .into_iter()
510 .filter(|m| m.tags.contains(&tag.to_string()))
511 .collect())
512 }
513
514 pub async fn find_by_label(
516 &self,
517 thread_id: &str,
518 label: &str,
519 ) -> Result<Option<CheckpointMetadata>, CheckpointError> {
520 let history = self.store.list(thread_id).await?;
521 Ok(history
522 .into_iter()
523 .find(|m| m.label.as_deref() == Some(label)))
524 }
525
526 pub async fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError> {
528 self.store.delete_thread(thread_id).await?;
529 self.steps.write().remove(thread_id);
530 info!(thread_id, "Thread deleted");
531 Ok(())
532 }
533
534 pub fn current_step(&self, thread_id: &str) -> u64 {
536 self.steps.read().get(thread_id).copied().unwrap_or(0)
537 }
538
539 pub async fn list_threads(&self) -> Result<Vec<ThreadId>, CheckpointError> {
541 self.store.list_threads().await
542 }
543
544 async fn cleanup_old_checkpoints(
545 &self,
546 thread_id: &str,
547 max: usize,
548 ) -> Result<(), CheckpointError> {
549 let count = self.store.count(thread_id).await?;
550 if count > max {
551 let history = self.store.list(thread_id).await?;
552 let to_delete = count - max;
553
554 for checkpoint in history.iter().take(to_delete) {
555 self.store.delete(thread_id, &checkpoint.id).await?;
556 debug!(thread_id, checkpoint_id = %checkpoint.id, "Old checkpoint deleted");
557 }
558 }
559 Ok(())
560 }
561}
562
563pub type MemoryCheckpointManager = CheckpointManager<MemoryCheckpointStore>;
565
566impl MemoryCheckpointManager {
567 pub fn in_memory() -> Self {
568 Self::new(MemoryCheckpointStore::new())
569 }
570}
571
572pub struct CheckpointBuilder<'a, S: Serialize + Send, Store: CheckpointStore> {
574 manager: &'a CheckpointManager<Store>,
575 thread_id: String,
576 state: &'a S,
577 label: Option<String>,
578 tags: Vec<String>,
579 custom: HashMap<String, String>,
580}
581
582impl<'a, S: Serialize + Send, Store: CheckpointStore> CheckpointBuilder<'a, S, Store> {
583 pub fn new(
584 manager: &'a CheckpointManager<Store>,
585 thread_id: impl Into<String>,
586 state: &'a S,
587 ) -> Self {
588 Self {
589 manager,
590 thread_id: thread_id.into(),
591 state,
592 label: None,
593 tags: Vec::new(),
594 custom: HashMap::new(),
595 }
596 }
597
598 pub fn label(mut self, label: impl Into<String>) -> Self {
599 self.label = Some(label.into());
600 self
601 }
602
603 pub fn tag(mut self, tag: impl Into<String>) -> Self {
604 self.tags.push(tag.into());
605 self
606 }
607
608 pub fn custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
609 self.custom.insert(key.into(), value.into());
610 self
611 }
612
613 pub async fn save(self) -> Result<CheckpointId, CheckpointError> {
614 self.manager
615 .save_with_label(&self.thread_id, self.state, self.label)
616 .await
617 }
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623
624 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
625 struct TestState {
626 messages: Vec<String>,
627 counter: u32,
628 }
629
630 #[tokio::test]
631 async fn test_save_and_load_checkpoint() {
632 let manager = MemoryCheckpointManager::in_memory();
633
634 let state = TestState {
635 messages: vec!["hello".to_string()],
636 counter: 1,
637 };
638
639 let id = manager.save("thread1", &state).await.unwrap();
640
641 let loaded: Checkpoint<TestState> = manager.load("thread1", &id).await.unwrap();
642 assert_eq!(loaded.state, state);
643 assert_eq!(loaded.metadata.step, 1);
644 }
645
646 #[tokio::test]
647 async fn test_load_latest() {
648 let manager = MemoryCheckpointManager::in_memory();
649
650 let state1 = TestState {
651 messages: vec!["first".to_string()],
652 counter: 1,
653 };
654 let state2 = TestState {
655 messages: vec!["second".to_string()],
656 counter: 2,
657 };
658
659 manager.save("thread1", &state1).await.unwrap();
660 manager.save("thread1", &state2).await.unwrap();
661
662 let loaded: Checkpoint<TestState> = manager.load_latest("thread1").await.unwrap();
663 assert_eq!(loaded.state, state2);
664 assert_eq!(loaded.metadata.step, 2);
665 }
666
667 #[tokio::test]
668 async fn test_checkpoint_history() {
669 let manager = MemoryCheckpointManager::in_memory();
670
671 let state = TestState {
672 messages: vec![],
673 counter: 0,
674 };
675
676 manager.save("thread1", &state).await.unwrap();
677 manager.save("thread1", &state).await.unwrap();
678 manager.save("thread1", &state).await.unwrap();
679
680 let history = manager.history("thread1").await.unwrap();
681 assert_eq!(history.len(), 3);
682 assert_eq!(history[0].step, 1);
683 assert_eq!(history[1].step, 2);
684 assert_eq!(history[2].step, 3);
685 }
686
687 #[tokio::test]
688 async fn test_fork_thread() {
689 let manager = MemoryCheckpointManager::in_memory();
690
691 let state = TestState {
692 messages: vec!["original".to_string()],
693 counter: 5,
694 };
695
696 let checkpoint_id = manager.save("thread1", &state).await.unwrap();
697
698 manager
699 .fork::<TestState>("thread1", &checkpoint_id, "thread2")
700 .await
701 .unwrap();
702
703 let forked: Checkpoint<TestState> = manager.load_latest("thread2").await.unwrap();
704 assert_eq!(forked.state, state);
705 }
706
707 #[tokio::test]
708 async fn test_rewind() {
709 let manager = MemoryCheckpointManager::in_memory();
710
711 let states: Vec<TestState> = (0..5)
712 .map(|i| TestState {
713 messages: vec![format!("msg{}", i)],
714 counter: i,
715 })
716 .collect();
717
718 let mut checkpoint_ids = Vec::new();
719 for state in &states {
720 let id = manager.save("thread1", state).await.unwrap();
721 checkpoint_ids.push(id);
722 }
723
724 manager.rewind("thread1", &checkpoint_ids[2]).await.unwrap();
726
727 let history = manager.history("thread1").await.unwrap();
728 assert_eq!(history.len(), 3);
729
730 let latest: Checkpoint<TestState> = manager.load_latest("thread1").await.unwrap();
731 assert_eq!(latest.state.counter, 2);
732 }
733
734 #[tokio::test]
735 async fn test_max_checkpoints_cleanup() {
736 let config = CheckpointConfig {
737 max_checkpoints_per_thread: Some(3),
738 ..Default::default()
739 };
740 let manager = MemoryCheckpointManager::in_memory().with_config(config);
741
742 let state = TestState {
743 messages: vec![],
744 counter: 0,
745 };
746
747 for _ in 0..5 {
749 manager.save("thread1", &state).await.unwrap();
750 }
751
752 let history = manager.history("thread1").await.unwrap();
754 assert_eq!(history.len(), 3);
755
756 assert_eq!(history[0].step, 3);
758 assert_eq!(history[1].step, 4);
759 assert_eq!(history[2].step, 5);
760 }
761
762 #[tokio::test]
763 async fn test_delete_thread() {
764 let manager = MemoryCheckpointManager::in_memory();
765
766 let state = TestState {
767 messages: vec![],
768 counter: 0,
769 };
770
771 manager.save("thread1", &state).await.unwrap();
772 manager.save("thread1", &state).await.unwrap();
773
774 manager.delete_thread("thread1").await.unwrap();
775
776 let history = manager.history("thread1").await.unwrap();
777 assert!(history.is_empty());
778 }
779
780 #[tokio::test]
781 async fn test_list_threads() {
782 let manager = MemoryCheckpointManager::in_memory();
783
784 let state = TestState {
785 messages: vec![],
786 counter: 0,
787 };
788
789 manager.save("thread1", &state).await.unwrap();
790 manager.save("thread2", &state).await.unwrap();
791 manager.save("thread3", &state).await.unwrap();
792
793 let threads = manager.list_threads().await.unwrap();
794 assert_eq!(threads.len(), 3);
795 }
796
797 #[tokio::test]
798 async fn test_current_step() {
799 let manager = MemoryCheckpointManager::in_memory();
800
801 let state = TestState {
802 messages: vec![],
803 counter: 0,
804 };
805
806 assert_eq!(manager.current_step("thread1"), 0);
807
808 manager.save("thread1", &state).await.unwrap();
809 assert_eq!(manager.current_step("thread1"), 1);
810
811 manager.save("thread1", &state).await.unwrap();
812 assert_eq!(manager.current_step("thread1"), 2);
813 }
814
815 #[tokio::test]
816 async fn test_checkpoint_with_label() {
817 let manager = MemoryCheckpointManager::in_memory();
818
819 let state = TestState {
820 messages: vec![],
821 counter: 0,
822 };
823
824 manager
825 .save_with_label("thread1", &state, Some("important".to_string()))
826 .await
827 .unwrap();
828
829 let found = manager.find_by_label("thread1", "important").await.unwrap();
830 assert!(found.is_some());
831 assert_eq!(found.unwrap().label.as_deref(), Some("important"));
832 }
833
834 #[tokio::test]
835 async fn test_parent_chain() {
836 let manager = MemoryCheckpointManager::in_memory();
837
838 let state = TestState {
839 messages: vec![],
840 counter: 0,
841 };
842
843 let id1 = manager.save("thread1", &state).await.unwrap();
844 let id2 = manager.save("thread1", &state).await.unwrap();
845 let _id3 = manager.save("thread1", &state).await.unwrap();
846
847 let history = manager.history("thread1").await.unwrap();
848
849 assert!(history[0].parent_id.is_none());
851
852 assert_eq!(history[1].parent_id.as_deref(), Some(id1.as_str()));
854
855 assert_eq!(history[2].parent_id.as_deref(), Some(id2.as_str()));
857 }
858}