1#[cfg(feature = "sqlite")]
4use crate::error::GraphError;
5use crate::error::Result;
6use crate::state::Checkpoint;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[async_trait]
14pub trait Checkpointer: Send + Sync {
15 async fn save(&self, checkpoint: &Checkpoint) -> Result<String>;
17
18 async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint>>;
20
21 async fn load_by_id(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>>;
23
24 async fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>>;
26
27 async fn delete(&self, thread_id: &str) -> Result<()>;
29}
30
31#[derive(Default)]
33pub struct MemoryCheckpointer {
34 checkpoints: Arc<RwLock<HashMap<String, Vec<Checkpoint>>>>,
35}
36
37impl MemoryCheckpointer {
38 pub fn new() -> Self {
40 Self::default()
41 }
42}
43
44#[async_trait]
45impl Checkpointer for MemoryCheckpointer {
46 async fn save(&self, checkpoint: &Checkpoint) -> Result<String> {
47 let mut store = self.checkpoints.write().await;
48 let thread_checkpoints = store.entry(checkpoint.thread_id.clone()).or_insert_with(Vec::new);
49
50 let checkpoint_id = checkpoint.checkpoint_id.clone();
51 thread_checkpoints.push(checkpoint.clone());
52
53 Ok(checkpoint_id)
54 }
55
56 async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint>> {
57 let store = self.checkpoints.read().await;
58 Ok(store.get(thread_id).and_then(|checkpoints| checkpoints.last()).cloned())
59 }
60
61 async fn load_by_id(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>> {
62 let store = self.checkpoints.read().await;
63 for checkpoints in store.values() {
64 for checkpoint in checkpoints {
65 if checkpoint.checkpoint_id == checkpoint_id {
66 return Ok(Some(checkpoint.clone()));
67 }
68 }
69 }
70 Ok(None)
71 }
72
73 async fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>> {
74 let store = self.checkpoints.read().await;
75 Ok(store.get(thread_id).cloned().unwrap_or_default())
76 }
77
78 async fn delete(&self, thread_id: &str) -> Result<()> {
79 let mut store = self.checkpoints.write().await;
80 store.remove(thread_id);
81 Ok(())
82 }
83}
84
85#[cfg(feature = "sqlite")]
87pub struct SqliteCheckpointer {
88 pool: sqlx::SqlitePool,
89}
90
91#[cfg(feature = "sqlite")]
92impl SqliteCheckpointer {
93 pub async fn new(database_url: &str) -> Result<Self> {
95 let pool = sqlx::SqlitePool::connect(database_url)
96 .await
97 .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
98
99 sqlx::query(
101 r#"
102 CREATE TABLE IF NOT EXISTS graph_checkpoints (
103 id TEXT PRIMARY KEY,
104 thread_id TEXT NOT NULL,
105 state TEXT NOT NULL,
106 step INTEGER NOT NULL,
107 pending_nodes TEXT NOT NULL,
108 metadata TEXT,
109 created_at TEXT NOT NULL
110 )
111 "#,
112 )
113 .execute(&pool)
114 .await
115 .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
116
117 sqlx::query(
118 r#"
119 CREATE INDEX IF NOT EXISTS idx_graph_checkpoints_thread
120 ON graph_checkpoints(thread_id, created_at DESC)
121 "#,
122 )
123 .execute(&pool)
124 .await
125 .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
126
127 Ok(Self { pool })
128 }
129
130 pub async fn in_memory() -> Result<Self> {
132 Self::new(":memory:").await
133 }
134}
135
136#[cfg(feature = "sqlite")]
137#[async_trait]
138impl Checkpointer for SqliteCheckpointer {
139 async fn save(&self, checkpoint: &Checkpoint) -> Result<String> {
140 let state_json = serde_json::to_string(&checkpoint.state)?;
141 let pending_json = serde_json::to_string(&checkpoint.pending_nodes)?;
142 let metadata_json = serde_json::to_string(&checkpoint.metadata)?;
143 let created_at = checkpoint.created_at.to_rfc3339();
144
145 sqlx::query(
146 r#"
147 INSERT INTO graph_checkpoints (id, thread_id, state, step, pending_nodes, metadata, created_at)
148 VALUES (?, ?, ?, ?, ?, ?, ?)
149 "#,
150 )
151 .bind(&checkpoint.checkpoint_id)
152 .bind(&checkpoint.thread_id)
153 .bind(&state_json)
154 .bind(checkpoint.step as i64)
155 .bind(&pending_json)
156 .bind(&metadata_json)
157 .bind(&created_at)
158 .execute(&self.pool)
159 .await
160 .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
161
162 Ok(checkpoint.checkpoint_id.clone())
163 }
164
165 async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint>> {
166 let row: Option<(String, String, String, i64, String, String, String)> = sqlx::query_as(
167 r#"
168 SELECT id, thread_id, state, step, pending_nodes, metadata, created_at
169 FROM graph_checkpoints
170 WHERE thread_id = ?
171 ORDER BY created_at DESC
172 LIMIT 1
173 "#,
174 )
175 .bind(thread_id)
176 .fetch_optional(&self.pool)
177 .await
178 .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
179
180 match row {
181 Some((id, thread_id, state, step, pending_nodes, metadata, created_at)) => {
182 let checkpoint = Checkpoint {
183 checkpoint_id: id,
184 thread_id,
185 state: serde_json::from_str(&state)?,
186 step: step as usize,
187 pending_nodes: serde_json::from_str(&pending_nodes)?,
188 metadata: serde_json::from_str(&metadata)?,
189 created_at: chrono::DateTime::parse_from_rfc3339(&created_at)
190 .map_err(|e| GraphError::CheckpointError(e.to_string()))?
191 .with_timezone(&chrono::Utc),
192 };
193 Ok(Some(checkpoint))
194 }
195 None => Ok(None),
196 }
197 }
198
199 async fn load_by_id(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>> {
200 let row: Option<(String, String, String, i64, String, String, String)> = sqlx::query_as(
201 r#"
202 SELECT id, thread_id, state, step, pending_nodes, metadata, created_at
203 FROM graph_checkpoints
204 WHERE id = ?
205 "#,
206 )
207 .bind(checkpoint_id)
208 .fetch_optional(&self.pool)
209 .await
210 .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
211
212 match row {
213 Some((id, thread_id, state, step, pending_nodes, metadata, created_at)) => {
214 let checkpoint = Checkpoint {
215 checkpoint_id: id,
216 thread_id,
217 state: serde_json::from_str(&state)?,
218 step: step as usize,
219 pending_nodes: serde_json::from_str(&pending_nodes)?,
220 metadata: serde_json::from_str(&metadata)?,
221 created_at: chrono::DateTime::parse_from_rfc3339(&created_at)
222 .map_err(|e| GraphError::CheckpointError(e.to_string()))?
223 .with_timezone(&chrono::Utc),
224 };
225 Ok(Some(checkpoint))
226 }
227 None => Ok(None),
228 }
229 }
230
231 async fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>> {
232 let rows: Vec<(String, String, String, i64, String, String, String)> = sqlx::query_as(
233 r#"
234 SELECT id, thread_id, state, step, pending_nodes, metadata, created_at
235 FROM graph_checkpoints
236 WHERE thread_id = ?
237 ORDER BY created_at ASC
238 "#,
239 )
240 .bind(thread_id)
241 .fetch_all(&self.pool)
242 .await
243 .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
244
245 let mut checkpoints = Vec::with_capacity(rows.len());
246 for (id, thread_id, state, step, pending_nodes, metadata, created_at) in rows {
247 checkpoints.push(Checkpoint {
248 checkpoint_id: id,
249 thread_id,
250 state: serde_json::from_str(&state)?,
251 step: step as usize,
252 pending_nodes: serde_json::from_str(&pending_nodes)?,
253 metadata: serde_json::from_str(&metadata)?,
254 created_at: chrono::DateTime::parse_from_rfc3339(&created_at)
255 .map_err(|e| GraphError::CheckpointError(e.to_string()))?
256 .with_timezone(&chrono::Utc),
257 });
258 }
259 Ok(checkpoints)
260 }
261
262 async fn delete(&self, thread_id: &str) -> Result<()> {
263 sqlx::query("DELETE FROM graph_checkpoints WHERE thread_id = ?")
264 .bind(thread_id)
265 .execute(&self.pool)
266 .await
267 .map_err(|e| GraphError::CheckpointError(e.to_string()))?;
268 Ok(())
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::state::State;
276
277 #[tokio::test]
278 async fn test_memory_checkpointer() {
279 let cp = MemoryCheckpointer::new();
280
281 let checkpoint = Checkpoint::new("thread_1", State::new(), 0, vec!["node_a".to_string()]);
283 let id = cp.save(&checkpoint).await.unwrap();
284 assert!(!id.is_empty());
285
286 let loaded = cp.load("thread_1").await.unwrap();
288 assert!(loaded.is_some());
289 assert_eq!(loaded.unwrap().step, 0);
290
291 let checkpoint2 = Checkpoint::new("thread_1", State::new(), 1, vec!["node_b".to_string()]);
293 cp.save(&checkpoint2).await.unwrap();
294
295 let loaded = cp.load("thread_1").await.unwrap();
297 assert_eq!(loaded.unwrap().step, 1);
298
299 let all = cp.list("thread_1").await.unwrap();
301 assert_eq!(all.len(), 2);
302
303 cp.delete("thread_1").await.unwrap();
305 let loaded = cp.load("thread_1").await.unwrap();
306 assert!(loaded.is_none());
307 }
308}