Skip to main content

cortexai_agents/
checkpoint.rs

1//! # Checkpointing
2//!
3//! State persistence system for agent execution with recovery and time-travel capabilities.
4//!
5//! Inspired by LangGraph's checkpointing pattern.
6//!
7//! ## Features
8//!
9//! - **State Snapshots**: Capture agent state at any point
10//! - **Recovery**: Resume execution from any checkpoint
11//! - **Time-Travel**: Navigate between checkpoints
12//! - **Branching**: Create alternate execution paths
13//! - **Multiple Backends**: Memory, SQLite, or custom storage
14//!
15//! ## Example
16//!
17//! ```rust,ignore
18//! use cortex::checkpoint::{CheckpointManager, MemoryCheckpointStore};
19//!
20//! let store = MemoryCheckpointStore::new();
21//! let manager = CheckpointManager::new(store);
22//!
23//! // Save a checkpoint
24//! let checkpoint_id = manager.save("thread_1", &state).await?;
25//!
26//! // Later, restore from checkpoint
27//! let state = manager.load("thread_1", &checkpoint_id).await?;
28//!
29//! // Time-travel: list all checkpoints
30//! let history = manager.list("thread_1").await?;
31//! ```
32
33use std::collections::HashMap;
34use std::sync::Arc;
35use std::time::{Duration, SystemTime, UNIX_EPOCH};
36
37use async_trait::async_trait;
38use parking_lot::RwLock;
39use serde::{de::DeserializeOwned, Deserialize, Serialize};
40use tracing::{debug, info};
41
42/// Unique identifier for a checkpoint
43pub type CheckpointId = String;
44
45/// Unique identifier for a thread (conversation/session)
46pub type ThreadId = String;
47
48/// Metadata associated with a checkpoint
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct CheckpointMetadata {
51    /// Unique checkpoint ID
52    pub id: CheckpointId,
53    /// Thread this checkpoint belongs to
54    pub thread_id: ThreadId,
55    /// When the checkpoint was created
56    pub created_at: u64,
57    /// Optional parent checkpoint (for branching)
58    pub parent_id: Option<CheckpointId>,
59    /// Step number in the execution
60    pub step: u64,
61    /// Optional human-readable label
62    pub label: Option<String>,
63    /// Custom tags for filtering
64    pub tags: Vec<String>,
65    /// Size of the state in bytes
66    pub state_size: usize,
67    /// Additional custom metadata
68    pub custom: HashMap<String, String>,
69}
70
71impl CheckpointMetadata {
72    pub fn new(thread_id: impl Into<String>, step: u64) -> Self {
73        let now = SystemTime::now()
74            .duration_since(UNIX_EPOCH)
75            .unwrap_or_default()
76            .as_secs();
77
78        Self {
79            id: uuid::Uuid::new_v4().to_string(),
80            thread_id: thread_id.into(),
81            created_at: now,
82            parent_id: None,
83            step,
84            label: None,
85            tags: Vec::new(),
86            state_size: 0,
87            custom: HashMap::new(),
88        }
89    }
90
91    pub fn with_parent(mut self, parent_id: impl Into<String>) -> Self {
92        self.parent_id = Some(parent_id.into());
93        self
94    }
95
96    pub fn with_label(mut self, label: impl Into<String>) -> Self {
97        self.label = Some(label.into());
98        self
99    }
100
101    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
102        self.tags.push(tag.into());
103        self
104    }
105
106    pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
107        self.custom.insert(key.into(), value.into());
108        self
109    }
110}
111
112/// A complete checkpoint including state
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct Checkpoint<S> {
115    /// Checkpoint metadata
116    pub metadata: CheckpointMetadata,
117    /// The actual state
118    pub state: S,
119}
120
121impl<S> Checkpoint<S> {
122    pub fn new(thread_id: impl Into<String>, step: u64, state: S) -> Self
123    where
124        S: Serialize,
125    {
126        let state_size = serde_json::to_vec(&state).map(|v| v.len()).unwrap_or(0);
127        Self {
128            metadata: CheckpointMetadata {
129                state_size,
130                ..CheckpointMetadata::new(thread_id, step)
131            },
132            state,
133        }
134    }
135
136    pub fn id(&self) -> &str {
137        &self.metadata.id
138    }
139
140    pub fn thread_id(&self) -> &str {
141        &self.metadata.thread_id
142    }
143
144    pub fn step(&self) -> u64 {
145        self.metadata.step
146    }
147}
148
149/// Error type for checkpoint operations
150#[derive(Debug, thiserror::Error)]
151pub enum CheckpointError {
152    #[error("Checkpoint not found: {0}")]
153    NotFound(CheckpointId),
154
155    #[error("Thread not found: {0}")]
156    ThreadNotFound(ThreadId),
157
158    #[error("Serialization error: {0}")]
159    Serialization(String),
160
161    #[error("Storage error: {0}")]
162    Storage(String),
163
164    #[error("Invalid state: {0}")]
165    InvalidState(String),
166}
167
168/// Trait for checkpoint storage backends
169#[async_trait]
170pub trait CheckpointStore: Send + Sync {
171    /// Save a checkpoint, returns the checkpoint ID
172    async fn save(
173        &self,
174        thread_id: &str,
175        metadata: CheckpointMetadata,
176        state: Vec<u8>,
177    ) -> Result<CheckpointId, CheckpointError>;
178
179    /// Load a specific checkpoint
180    async fn load(
181        &self,
182        thread_id: &str,
183        checkpoint_id: &str,
184    ) -> Result<(CheckpointMetadata, Vec<u8>), CheckpointError>;
185
186    /// Load the latest checkpoint for a thread
187    async fn load_latest(
188        &self,
189        thread_id: &str,
190    ) -> Result<(CheckpointMetadata, Vec<u8>), CheckpointError>;
191
192    /// List all checkpoints for a thread
193    async fn list(&self, thread_id: &str) -> Result<Vec<CheckpointMetadata>, CheckpointError>;
194
195    /// Delete a specific checkpoint
196    async fn delete(&self, thread_id: &str, checkpoint_id: &str) -> Result<(), CheckpointError>;
197
198    /// Delete all checkpoints for a thread
199    async fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError>;
200
201    /// Get checkpoint count for a thread
202    async fn count(&self, thread_id: &str) -> Result<usize, CheckpointError>;
203
204    /// List all threads with checkpoints
205    async fn list_threads(&self) -> Result<Vec<ThreadId>, CheckpointError>;
206}
207
208/// In-memory checkpoint store (for testing and development)
209#[derive(Default)]
210pub struct MemoryCheckpointStore {
211    checkpoints: RwLock<HashMap<ThreadId, Vec<(CheckpointMetadata, Vec<u8>)>>>,
212}
213
214impl MemoryCheckpointStore {
215    pub fn new() -> Self {
216        Self::default()
217    }
218}
219
220#[async_trait]
221impl CheckpointStore for MemoryCheckpointStore {
222    async fn save(
223        &self,
224        thread_id: &str,
225        metadata: CheckpointMetadata,
226        state: Vec<u8>,
227    ) -> Result<CheckpointId, CheckpointError> {
228        let id = metadata.id.clone();
229        let mut checkpoints = self.checkpoints.write();
230        checkpoints
231            .entry(thread_id.to_string())
232            .or_default()
233            .push((metadata, state));
234        Ok(id)
235    }
236
237    async fn load(
238        &self,
239        thread_id: &str,
240        checkpoint_id: &str,
241    ) -> Result<(CheckpointMetadata, Vec<u8>), CheckpointError> {
242        let checkpoints = self.checkpoints.read();
243        let thread_checkpoints = checkpoints
244            .get(thread_id)
245            .ok_or_else(|| CheckpointError::ThreadNotFound(thread_id.to_string()))?;
246
247        thread_checkpoints
248            .iter()
249            .find(|(m, _)| m.id == checkpoint_id)
250            .cloned()
251            .ok_or_else(|| CheckpointError::NotFound(checkpoint_id.to_string()))
252    }
253
254    async fn load_latest(
255        &self,
256        thread_id: &str,
257    ) -> Result<(CheckpointMetadata, Vec<u8>), CheckpointError> {
258        let checkpoints = self.checkpoints.read();
259        let thread_checkpoints = checkpoints
260            .get(thread_id)
261            .ok_or_else(|| CheckpointError::ThreadNotFound(thread_id.to_string()))?;
262
263        thread_checkpoints
264            .last()
265            .cloned()
266            .ok_or_else(|| CheckpointError::ThreadNotFound(thread_id.to_string()))
267    }
268
269    async fn list(&self, thread_id: &str) -> Result<Vec<CheckpointMetadata>, CheckpointError> {
270        let checkpoints = self.checkpoints.read();
271        Ok(checkpoints
272            .get(thread_id)
273            .map(|v| v.iter().map(|(m, _)| m.clone()).collect())
274            .unwrap_or_default())
275    }
276
277    async fn delete(&self, thread_id: &str, checkpoint_id: &str) -> Result<(), CheckpointError> {
278        let mut checkpoints = self.checkpoints.write();
279        if let Some(thread_checkpoints) = checkpoints.get_mut(thread_id) {
280            thread_checkpoints.retain(|(m, _)| m.id != checkpoint_id);
281        }
282        Ok(())
283    }
284
285    async fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError> {
286        let mut checkpoints = self.checkpoints.write();
287        checkpoints.remove(thread_id);
288        Ok(())
289    }
290
291    async fn count(&self, thread_id: &str) -> Result<usize, CheckpointError> {
292        let checkpoints = self.checkpoints.read();
293        Ok(checkpoints.get(thread_id).map(|v| v.len()).unwrap_or(0))
294    }
295
296    async fn list_threads(&self) -> Result<Vec<ThreadId>, CheckpointError> {
297        let checkpoints = self.checkpoints.read();
298        Ok(checkpoints.keys().cloned().collect())
299    }
300}
301
302/// Configuration for checkpoint manager
303#[derive(Debug, Clone)]
304pub struct CheckpointConfig {
305    /// Maximum number of checkpoints to keep per thread
306    pub max_checkpoints_per_thread: Option<usize>,
307    /// Auto-checkpoint every N steps
308    pub auto_checkpoint_interval: Option<u64>,
309    /// Whether to compress state before storing
310    pub compress: bool,
311    /// TTL for checkpoints (auto-delete after this duration)
312    pub ttl: Option<Duration>,
313}
314
315impl Default for CheckpointConfig {
316    fn default() -> Self {
317        Self {
318            max_checkpoints_per_thread: Some(100),
319            auto_checkpoint_interval: None,
320            compress: false,
321            ttl: None,
322        }
323    }
324}
325
326/// Main checkpoint manager
327pub struct CheckpointManager<Store: CheckpointStore> {
328    store: Arc<Store>,
329    config: CheckpointConfig,
330    /// Current step per thread
331    steps: RwLock<HashMap<ThreadId, u64>>,
332}
333
334impl<Store: CheckpointStore> CheckpointManager<Store> {
335    pub fn new(store: Store) -> Self {
336        Self {
337            store: Arc::new(store),
338            config: CheckpointConfig::default(),
339            steps: RwLock::new(HashMap::new()),
340        }
341    }
342
343    pub fn with_config(mut self, config: CheckpointConfig) -> Self {
344        self.config = config;
345        self
346    }
347
348    /// Save current state as a checkpoint
349    pub async fn save<S: Serialize + Send>(
350        &self,
351        thread_id: &str,
352        state: &S,
353    ) -> Result<CheckpointId, CheckpointError> {
354        self.save_with_label(thread_id, state, None).await
355    }
356
357    /// Save current state with an optional label
358    pub async fn save_with_label<S: Serialize + Send>(
359        &self,
360        thread_id: &str,
361        state: &S,
362        label: Option<String>,
363    ) -> Result<CheckpointId, CheckpointError> {
364        // Serialize state
365        let state_bytes =
366            serde_json::to_vec(state).map_err(|e| CheckpointError::Serialization(e.to_string()))?;
367
368        // Get and increment step
369        let step = {
370            let mut steps = self.steps.write();
371            let step = steps.entry(thread_id.to_string()).or_insert(0);
372            *step += 1;
373            *step
374        };
375
376        // Get parent ID (previous checkpoint)
377        let parent_id = self
378            .store
379            .load_latest(thread_id)
380            .await
381            .ok()
382            .map(|(m, _)| m.id);
383
384        // Create metadata
385        let mut metadata = CheckpointMetadata::new(thread_id, step);
386        metadata.state_size = state_bytes.len();
387        if let Some(parent) = parent_id {
388            metadata = metadata.with_parent(parent);
389        }
390        if let Some(lbl) = label {
391            metadata = metadata.with_label(lbl);
392        }
393
394        // Save checkpoint
395        let id = self.store.save(thread_id, metadata, state_bytes).await?;
396
397        // Cleanup old checkpoints if needed
398        if let Some(max) = self.config.max_checkpoints_per_thread {
399            self.cleanup_old_checkpoints(thread_id, max).await?;
400        }
401
402        debug!(thread_id, checkpoint_id = %id, step, "Checkpoint saved");
403        Ok(id)
404    }
405
406    /// Load a specific checkpoint
407    pub async fn load<S: DeserializeOwned>(
408        &self,
409        thread_id: &str,
410        checkpoint_id: &str,
411    ) -> Result<Checkpoint<S>, CheckpointError> {
412        let (metadata, state_bytes) = self.store.load(thread_id, checkpoint_id).await?;
413        let state = serde_json::from_slice(&state_bytes)
414            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
415
416        debug!(
417            thread_id,
418            checkpoint_id,
419            step = metadata.step,
420            "Checkpoint loaded"
421        );
422        Ok(Checkpoint { metadata, state })
423    }
424
425    /// Load the latest checkpoint for a thread
426    pub async fn load_latest<S: DeserializeOwned>(
427        &self,
428        thread_id: &str,
429    ) -> Result<Checkpoint<S>, CheckpointError> {
430        let (metadata, state_bytes) = self.store.load_latest(thread_id).await?;
431        let state = serde_json::from_slice(&state_bytes)
432            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
433
434        debug!(thread_id, checkpoint_id = %metadata.id, step = metadata.step, "Latest checkpoint loaded");
435        Ok(Checkpoint { metadata, state })
436    }
437
438    /// Get checkpoint history for a thread
439    pub async fn history(
440        &self,
441        thread_id: &str,
442    ) -> Result<Vec<CheckpointMetadata>, CheckpointError> {
443        self.store.list(thread_id).await
444    }
445
446    /// Fork from a checkpoint (create a new thread starting from this point)
447    pub async fn fork<S: Serialize + DeserializeOwned + Send>(
448        &self,
449        source_thread_id: &str,
450        checkpoint_id: &str,
451        new_thread_id: &str,
452    ) -> Result<CheckpointId, CheckpointError> {
453        // Load the source checkpoint
454        let checkpoint: Checkpoint<S> = self.load(source_thread_id, checkpoint_id).await?;
455
456        // Save as first checkpoint in new thread
457        let id = self.save(new_thread_id, &checkpoint.state).await?;
458
459        info!(
460            source_thread = source_thread_id,
461            source_checkpoint = checkpoint_id,
462            new_thread = new_thread_id,
463            new_checkpoint = %id,
464            "Thread forked"
465        );
466
467        Ok(id)
468    }
469
470    /// Rewind to a previous checkpoint (delete all checkpoints after it)
471    pub async fn rewind(
472        &self,
473        thread_id: &str,
474        checkpoint_id: &str,
475    ) -> Result<(), CheckpointError> {
476        let history = self.store.list(thread_id).await?;
477
478        // Find the target checkpoint
479        let target_idx = history
480            .iter()
481            .position(|m| m.id == checkpoint_id)
482            .ok_or_else(|| CheckpointError::NotFound(checkpoint_id.to_string()))?;
483
484        // Delete all checkpoints after the target
485        for checkpoint in history.iter().skip(target_idx + 1) {
486            self.store.delete(thread_id, &checkpoint.id).await?;
487        }
488
489        // Update step counter
490        {
491            let mut steps = self.steps.write();
492            if let Some(target) = history.get(target_idx) {
493                steps.insert(thread_id.to_string(), target.step);
494            }
495        }
496
497        info!(thread_id, checkpoint_id, "Rewound to checkpoint");
498        Ok(())
499    }
500
501    /// Get checkpoints by tag
502    pub async fn find_by_tag(
503        &self,
504        thread_id: &str,
505        tag: &str,
506    ) -> Result<Vec<CheckpointMetadata>, CheckpointError> {
507        let history = self.store.list(thread_id).await?;
508        Ok(history
509            .into_iter()
510            .filter(|m| m.tags.contains(&tag.to_string()))
511            .collect())
512    }
513
514    /// Get checkpoints by label
515    pub async fn find_by_label(
516        &self,
517        thread_id: &str,
518        label: &str,
519    ) -> Result<Option<CheckpointMetadata>, CheckpointError> {
520        let history = self.store.list(thread_id).await?;
521        Ok(history
522            .into_iter()
523            .find(|m| m.label.as_deref() == Some(label)))
524    }
525
526    /// Delete a thread and all its checkpoints
527    pub async fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError> {
528        self.store.delete_thread(thread_id).await?;
529        self.steps.write().remove(thread_id);
530        info!(thread_id, "Thread deleted");
531        Ok(())
532    }
533
534    /// Get current step for a thread
535    pub fn current_step(&self, thread_id: &str) -> u64 {
536        self.steps.read().get(thread_id).copied().unwrap_or(0)
537    }
538
539    /// List all threads
540    pub async fn list_threads(&self) -> Result<Vec<ThreadId>, CheckpointError> {
541        self.store.list_threads().await
542    }
543
544    async fn cleanup_old_checkpoints(
545        &self,
546        thread_id: &str,
547        max: usize,
548    ) -> Result<(), CheckpointError> {
549        let count = self.store.count(thread_id).await?;
550        if count > max {
551            let history = self.store.list(thread_id).await?;
552            let to_delete = count - max;
553
554            for checkpoint in history.iter().take(to_delete) {
555                self.store.delete(thread_id, &checkpoint.id).await?;
556                debug!(thread_id, checkpoint_id = %checkpoint.id, "Old checkpoint deleted");
557            }
558        }
559        Ok(())
560    }
561}
562
563/// Convenience type for the default memory-based checkpoint manager
564pub type MemoryCheckpointManager = CheckpointManager<MemoryCheckpointStore>;
565
566impl MemoryCheckpointManager {
567    pub fn in_memory() -> Self {
568        Self::new(MemoryCheckpointStore::new())
569    }
570}
571
572/// Builder for creating checkpoints with fluent API
573pub struct CheckpointBuilder<'a, S: Serialize + Send, Store: CheckpointStore> {
574    manager: &'a CheckpointManager<Store>,
575    thread_id: String,
576    state: &'a S,
577    label: Option<String>,
578    tags: Vec<String>,
579    custom: HashMap<String, String>,
580}
581
582impl<'a, S: Serialize + Send, Store: CheckpointStore> CheckpointBuilder<'a, S, Store> {
583    pub fn new(
584        manager: &'a CheckpointManager<Store>,
585        thread_id: impl Into<String>,
586        state: &'a S,
587    ) -> Self {
588        Self {
589            manager,
590            thread_id: thread_id.into(),
591            state,
592            label: None,
593            tags: Vec::new(),
594            custom: HashMap::new(),
595        }
596    }
597
598    pub fn label(mut self, label: impl Into<String>) -> Self {
599        self.label = Some(label.into());
600        self
601    }
602
603    pub fn tag(mut self, tag: impl Into<String>) -> Self {
604        self.tags.push(tag.into());
605        self
606    }
607
608    pub fn custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
609        self.custom.insert(key.into(), value.into());
610        self
611    }
612
613    pub async fn save(self) -> Result<CheckpointId, CheckpointError> {
614        self.manager
615            .save_with_label(&self.thread_id, self.state, self.label)
616            .await
617    }
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623
624    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
625    struct TestState {
626        messages: Vec<String>,
627        counter: u32,
628    }
629
630    #[tokio::test]
631    async fn test_save_and_load_checkpoint() {
632        let manager = MemoryCheckpointManager::in_memory();
633
634        let state = TestState {
635            messages: vec!["hello".to_string()],
636            counter: 1,
637        };
638
639        let id = manager.save("thread1", &state).await.unwrap();
640
641        let loaded: Checkpoint<TestState> = manager.load("thread1", &id).await.unwrap();
642        assert_eq!(loaded.state, state);
643        assert_eq!(loaded.metadata.step, 1);
644    }
645
646    #[tokio::test]
647    async fn test_load_latest() {
648        let manager = MemoryCheckpointManager::in_memory();
649
650        let state1 = TestState {
651            messages: vec!["first".to_string()],
652            counter: 1,
653        };
654        let state2 = TestState {
655            messages: vec!["second".to_string()],
656            counter: 2,
657        };
658
659        manager.save("thread1", &state1).await.unwrap();
660        manager.save("thread1", &state2).await.unwrap();
661
662        let loaded: Checkpoint<TestState> = manager.load_latest("thread1").await.unwrap();
663        assert_eq!(loaded.state, state2);
664        assert_eq!(loaded.metadata.step, 2);
665    }
666
667    #[tokio::test]
668    async fn test_checkpoint_history() {
669        let manager = MemoryCheckpointManager::in_memory();
670
671        let state = TestState {
672            messages: vec![],
673            counter: 0,
674        };
675
676        manager.save("thread1", &state).await.unwrap();
677        manager.save("thread1", &state).await.unwrap();
678        manager.save("thread1", &state).await.unwrap();
679
680        let history = manager.history("thread1").await.unwrap();
681        assert_eq!(history.len(), 3);
682        assert_eq!(history[0].step, 1);
683        assert_eq!(history[1].step, 2);
684        assert_eq!(history[2].step, 3);
685    }
686
687    #[tokio::test]
688    async fn test_fork_thread() {
689        let manager = MemoryCheckpointManager::in_memory();
690
691        let state = TestState {
692            messages: vec!["original".to_string()],
693            counter: 5,
694        };
695
696        let checkpoint_id = manager.save("thread1", &state).await.unwrap();
697
698        manager
699            .fork::<TestState>("thread1", &checkpoint_id, "thread2")
700            .await
701            .unwrap();
702
703        let forked: Checkpoint<TestState> = manager.load_latest("thread2").await.unwrap();
704        assert_eq!(forked.state, state);
705    }
706
707    #[tokio::test]
708    async fn test_rewind() {
709        let manager = MemoryCheckpointManager::in_memory();
710
711        let states: Vec<TestState> = (0..5)
712            .map(|i| TestState {
713                messages: vec![format!("msg{}", i)],
714                counter: i,
715            })
716            .collect();
717
718        let mut checkpoint_ids = Vec::new();
719        for state in &states {
720            let id = manager.save("thread1", state).await.unwrap();
721            checkpoint_ids.push(id);
722        }
723
724        // Rewind to checkpoint 2 (3rd checkpoint, index 2)
725        manager.rewind("thread1", &checkpoint_ids[2]).await.unwrap();
726
727        let history = manager.history("thread1").await.unwrap();
728        assert_eq!(history.len(), 3);
729
730        let latest: Checkpoint<TestState> = manager.load_latest("thread1").await.unwrap();
731        assert_eq!(latest.state.counter, 2);
732    }
733
734    #[tokio::test]
735    async fn test_max_checkpoints_cleanup() {
736        let config = CheckpointConfig {
737            max_checkpoints_per_thread: Some(3),
738            ..Default::default()
739        };
740        let manager = MemoryCheckpointManager::in_memory().with_config(config);
741
742        let state = TestState {
743            messages: vec![],
744            counter: 0,
745        };
746
747        // Save 5 checkpoints
748        for _ in 0..5 {
749            manager.save("thread1", &state).await.unwrap();
750        }
751
752        // Should only have 3 checkpoints
753        let history = manager.history("thread1").await.unwrap();
754        assert_eq!(history.len(), 3);
755
756        // Should have kept the latest 3 (steps 3, 4, 5)
757        assert_eq!(history[0].step, 3);
758        assert_eq!(history[1].step, 4);
759        assert_eq!(history[2].step, 5);
760    }
761
762    #[tokio::test]
763    async fn test_delete_thread() {
764        let manager = MemoryCheckpointManager::in_memory();
765
766        let state = TestState {
767            messages: vec![],
768            counter: 0,
769        };
770
771        manager.save("thread1", &state).await.unwrap();
772        manager.save("thread1", &state).await.unwrap();
773
774        manager.delete_thread("thread1").await.unwrap();
775
776        let history = manager.history("thread1").await.unwrap();
777        assert!(history.is_empty());
778    }
779
780    #[tokio::test]
781    async fn test_list_threads() {
782        let manager = MemoryCheckpointManager::in_memory();
783
784        let state = TestState {
785            messages: vec![],
786            counter: 0,
787        };
788
789        manager.save("thread1", &state).await.unwrap();
790        manager.save("thread2", &state).await.unwrap();
791        manager.save("thread3", &state).await.unwrap();
792
793        let threads = manager.list_threads().await.unwrap();
794        assert_eq!(threads.len(), 3);
795    }
796
797    #[tokio::test]
798    async fn test_current_step() {
799        let manager = MemoryCheckpointManager::in_memory();
800
801        let state = TestState {
802            messages: vec![],
803            counter: 0,
804        };
805
806        assert_eq!(manager.current_step("thread1"), 0);
807
808        manager.save("thread1", &state).await.unwrap();
809        assert_eq!(manager.current_step("thread1"), 1);
810
811        manager.save("thread1", &state).await.unwrap();
812        assert_eq!(manager.current_step("thread1"), 2);
813    }
814
815    #[tokio::test]
816    async fn test_checkpoint_with_label() {
817        let manager = MemoryCheckpointManager::in_memory();
818
819        let state = TestState {
820            messages: vec![],
821            counter: 0,
822        };
823
824        manager
825            .save_with_label("thread1", &state, Some("important".to_string()))
826            .await
827            .unwrap();
828
829        let found = manager.find_by_label("thread1", "important").await.unwrap();
830        assert!(found.is_some());
831        assert_eq!(found.unwrap().label.as_deref(), Some("important"));
832    }
833
834    #[tokio::test]
835    async fn test_parent_chain() {
836        let manager = MemoryCheckpointManager::in_memory();
837
838        let state = TestState {
839            messages: vec![],
840            counter: 0,
841        };
842
843        let id1 = manager.save("thread1", &state).await.unwrap();
844        let id2 = manager.save("thread1", &state).await.unwrap();
845        let _id3 = manager.save("thread1", &state).await.unwrap();
846
847        let history = manager.history("thread1").await.unwrap();
848
849        // First checkpoint has no parent
850        assert!(history[0].parent_id.is_none());
851
852        // Second checkpoint's parent is first
853        assert_eq!(history[1].parent_id.as_deref(), Some(id1.as_str()));
854
855        // Third checkpoint's parent is second
856        assert_eq!(history[2].parent_id.as_deref(), Some(id2.as_str()));
857    }
858}