adk_graph/
checkpoint.rs

1//! Checkpointing for persistent graph state
2
3#[cfg(feature = "sqlite")]
4use crate::error::GraphError;
5use crate::error::Result;
6use crate::state::Checkpoint;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12/// Checkpointer trait for persistence
13#[async_trait]
14pub trait Checkpointer: Send + Sync {
15    /// Save a checkpoint
16    async fn save(&self, checkpoint: &Checkpoint) -> Result<String>;
17
18    /// Load the latest checkpoint for a thread
19    async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint>>;
20
21    /// Load a specific checkpoint by ID
22    async fn load_by_id(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>>;
23
24    /// List all checkpoints for a thread (for time travel)
25    async fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>>;
26
27    /// Delete checkpoints for a thread
28    async fn delete(&self, thread_id: &str) -> Result<()>;
29}
30
31/// In-memory checkpointer for development and testing
32#[derive(Default)]
33pub struct MemoryCheckpointer {
34    checkpoints: Arc<RwLock<HashMap<String, Vec<Checkpoint>>>>,
35}
36
37impl MemoryCheckpointer {
38    /// Create a new in-memory checkpointer
39    pub fn new() -> Self {
40        Self::default()
41    }
42}
43
44#[async_trait]
45impl Checkpointer for MemoryCheckpointer {
46    async fn save(&self, checkpoint: &Checkpoint) -> Result<String> {
47        let mut store = self.checkpoints.write().await;
48        let thread_checkpoints = store.entry(checkpoint.thread_id.clone()).or_insert_with(Vec::new);
49
50        let checkpoint_id = checkpoint.checkpoint_id.clone();
51        thread_checkpoints.push(checkpoint.clone());
52
53        Ok(checkpoint_id)
54    }
55
56    async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint>> {
57        let store = self.checkpoints.read().await;
58        Ok(store.get(thread_id).and_then(|checkpoints| checkpoints.last()).cloned())
59    }
60
61    async fn load_by_id(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>> {
62        let store = self.checkpoints.read().await;
63        for checkpoints in store.values() {
64            for checkpoint in checkpoints {
65                if checkpoint.checkpoint_id == checkpoint_id {
66                    return Ok(Some(checkpoint.clone()));
67                }
68            }
69        }
70        Ok(None)
71    }
72
73    async fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>> {
74        let store = self.checkpoints.read().await;
75        Ok(store.get(thread_id).cloned().unwrap_or_default())
76    }
77
78    async fn delete(&self, thread_id: &str) -> Result<()> {
79        let mut store = self.checkpoints.write().await;
80        store.remove(thread_id);
81        Ok(())
82    }
83}
84
85/// SQLite checkpointer for production use
86#[cfg(feature = "sqlite")]
87pub struct SqliteCheckpointer {
88    pool: sqlx::SqlitePool,
89}
90
91#[cfg(feature = "sqlite")]
92impl SqliteCheckpointer {
93    /// Create a new SQLite checkpointer
94    pub async fn new(database_url: &str) -> Result<Self> {
95        let pool = sqlx::SqlitePool::connect(database_url)
96            .await
97            .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
98
99        // Create table
100        sqlx::query(
101            r#"
102            CREATE TABLE IF NOT EXISTS graph_checkpoints (
103                id TEXT PRIMARY KEY,
104                thread_id TEXT NOT NULL,
105                state TEXT NOT NULL,
106                step INTEGER NOT NULL,
107                pending_nodes TEXT NOT NULL,
108                metadata TEXT,
109                created_at TEXT NOT NULL
110            )
111            "#,
112        )
113        .execute(&pool)
114        .await
115        .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
116
117        sqlx::query(
118            r#"
119            CREATE INDEX IF NOT EXISTS idx_graph_checkpoints_thread
120            ON graph_checkpoints(thread_id, created_at DESC)
121            "#,
122        )
123        .execute(&pool)
124        .await
125        .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
126
127        Ok(Self { pool })
128    }
129
130    /// Create an in-memory SQLite checkpointer (for testing)
131    pub async fn in_memory() -> Result<Self> {
132        Self::new(":memory:").await
133    }
134}
135
136#[cfg(feature = "sqlite")]
137#[async_trait]
138impl Checkpointer for SqliteCheckpointer {
139    async fn save(&self, checkpoint: &Checkpoint) -> Result<String> {
140        let state_json = serde_json::to_string(&checkpoint.state)?;
141        let pending_json = serde_json::to_string(&checkpoint.pending_nodes)?;
142        let metadata_json = serde_json::to_string(&checkpoint.metadata)?;
143        let created_at = checkpoint.created_at.to_rfc3339();
144
145        sqlx::query(
146            r#"
147            INSERT INTO graph_checkpoints (id, thread_id, state, step, pending_nodes, metadata, created_at)
148            VALUES (?, ?, ?, ?, ?, ?, ?)
149            "#,
150        )
151        .bind(&checkpoint.checkpoint_id)
152        .bind(&checkpoint.thread_id)
153        .bind(&state_json)
154        .bind(checkpoint.step as i64)
155        .bind(&pending_json)
156        .bind(&metadata_json)
157        .bind(&created_at)
158        .execute(&self.pool)
159        .await
160        .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
161
162        Ok(checkpoint.checkpoint_id.clone())
163    }
164
165    async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint>> {
166        let row: Option<(String, String, String, i64, String, String, String)> = sqlx::query_as(
167            r#"
168            SELECT id, thread_id, state, step, pending_nodes, metadata, created_at
169            FROM graph_checkpoints
170            WHERE thread_id = ?
171            ORDER BY created_at DESC
172            LIMIT 1
173            "#,
174        )
175        .bind(thread_id)
176        .fetch_optional(&self.pool)
177        .await
178        .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
179
180        match row {
181            Some((id, thread_id, state, step, pending_nodes, metadata, created_at)) => {
182                let checkpoint = Checkpoint {
183                    checkpoint_id: id,
184                    thread_id,
185                    state: serde_json::from_str(&state)?,
186                    step: step as usize,
187                    pending_nodes: serde_json::from_str(&pending_nodes)?,
188                    metadata: serde_json::from_str(&metadata)?,
189                    created_at: chrono::DateTime::parse_from_rfc3339(&created_at)
190                        .map_err(|e| GraphError::CheckpointError(e.to_string()))?
191                        .with_timezone(&chrono::Utc),
192                };
193                Ok(Some(checkpoint))
194            }
195            None => Ok(None),
196        }
197    }
198
199    async fn load_by_id(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>> {
200        let row: Option<(String, String, String, i64, String, String, String)> = sqlx::query_as(
201            r#"
202            SELECT id, thread_id, state, step, pending_nodes, metadata, created_at
203            FROM graph_checkpoints
204            WHERE id = ?
205            "#,
206        )
207        .bind(checkpoint_id)
208        .fetch_optional(&self.pool)
209        .await
210        .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
211
212        match row {
213            Some((id, thread_id, state, step, pending_nodes, metadata, created_at)) => {
214                let checkpoint = Checkpoint {
215                    checkpoint_id: id,
216                    thread_id,
217                    state: serde_json::from_str(&state)?,
218                    step: step as usize,
219                    pending_nodes: serde_json::from_str(&pending_nodes)?,
220                    metadata: serde_json::from_str(&metadata)?,
221                    created_at: chrono::DateTime::parse_from_rfc3339(&created_at)
222                        .map_err(|e| GraphError::CheckpointError(e.to_string()))?
223                        .with_timezone(&chrono::Utc),
224                };
225                Ok(Some(checkpoint))
226            }
227            None => Ok(None),
228        }
229    }
230
231    async fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>> {
232        let rows: Vec<(String, String, String, i64, String, String, String)> = sqlx::query_as(
233            r#"
234            SELECT id, thread_id, state, step, pending_nodes, metadata, created_at
235            FROM graph_checkpoints
236            WHERE thread_id = ?
237            ORDER BY created_at ASC
238            "#,
239        )
240        .bind(thread_id)
241        .fetch_all(&self.pool)
242        .await
243        .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
244
245        let mut checkpoints = Vec::with_capacity(rows.len());
246        for (id, thread_id, state, step, pending_nodes, metadata, created_at) in rows {
247            checkpoints.push(Checkpoint {
248                checkpoint_id: id,
249                thread_id,
250                state: serde_json::from_str(&state)?,
251                step: step as usize,
252                pending_nodes: serde_json::from_str(&pending_nodes)?,
253                metadata: serde_json::from_str(&metadata)?,
254                created_at: chrono::DateTime::parse_from_rfc3339(&created_at)
255                    .map_err(|e| GraphError::CheckpointError(e.to_string()))?
256                    .with_timezone(&chrono::Utc),
257            });
258        }
259        Ok(checkpoints)
260    }
261
262    async fn delete(&self, thread_id: &str) -> Result<()> {
263        sqlx::query("DELETE FROM graph_checkpoints WHERE thread_id = ?")
264            .bind(thread_id)
265            .execute(&self.pool)
266            .await
267            .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
268        Ok(())
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::state::State;
276
277    #[tokio::test]
278    async fn test_memory_checkpointer() {
279        let cp = MemoryCheckpointer::new();
280
281        // Create and save checkpoint
282        let checkpoint = Checkpoint::new("thread_1", State::new(), 0, vec!["node_a".to_string()]);
283        let id = cp.save(&checkpoint).await.unwrap();
284        assert!(!id.is_empty());
285
286        // Load latest
287        let loaded = cp.load("thread_1").await.unwrap();
288        assert!(loaded.is_some());
289        assert_eq!(loaded.unwrap().step, 0);
290
291        // Save another checkpoint
292        let checkpoint2 = Checkpoint::new("thread_1", State::new(), 1, vec!["node_b".to_string()]);
293        cp.save(&checkpoint2).await.unwrap();
294
295        // Load latest should return step 1
296        let loaded = cp.load("thread_1").await.unwrap();
297        assert_eq!(loaded.unwrap().step, 1);
298
299        // List should return both
300        let all = cp.list("thread_1").await.unwrap();
301        assert_eq!(all.len(), 2);
302
303        // Delete
304        cp.delete("thread_1").await.unwrap();
305        let loaded = cp.load("thread_1").await.unwrap();
306        assert!(loaded.is_none());
307    }
308}