langgraph_checkpoint_sqlite/
sqlite.rs1use std::path::{Path, PathBuf};
2
3use langgraph_checkpoint::{Checkpoint, CheckpointError, CheckpointSaver};
4use rusqlite::{params, Connection, Error as SqlError, ErrorCode};
5
6const CREATE_TABLE_SQL: &str = r#"
7CREATE TABLE IF NOT EXISTS checkpoints (
8 thread_id TEXT NOT NULL,
9 checkpoint_id TEXT NOT NULL,
10 payload TEXT NOT NULL,
11 created_at INTEGER NOT NULL DEFAULT (unixepoch()),
12 PRIMARY KEY (thread_id, checkpoint_id)
13);
14"#;
15
16#[derive(Debug, Clone)]
17pub struct SqliteSaver {
18 db_path: PathBuf,
19}
20
21impl SqliteSaver {
22 pub fn new(path: impl AsRef<Path>) -> Result<Self, CheckpointError> {
23 let saver = Self { db_path: path.as_ref().to_path_buf() };
24 saver.init_schema()?;
25 Ok(saver)
26 }
27
28 fn init_schema(&self) -> Result<(), CheckpointError> {
29 let conn = self.open()?;
30 conn.execute_batch(CREATE_TABLE_SQL).map_err(map_sql_error)?;
31 Ok(())
32 }
33
34 fn open(&self) -> Result<Connection, CheckpointError> {
35 Connection::open(&self.db_path).map_err(map_sql_error)
36 }
37}
38
39impl CheckpointSaver for SqliteSaver {
40 fn put(&self, checkpoint: Checkpoint) -> Result<(), CheckpointError> {
41 let payload = serde_json::to_string(&checkpoint)
42 .map_err(|err| CheckpointError::Storage(err.to_string()))?;
43
44 let conn = self.open()?;
45 let result = conn.execute(
46 "INSERT INTO checkpoints (thread_id, checkpoint_id, payload) VALUES (?1, ?2, ?3)",
47 params![checkpoint.thread_id, checkpoint.checkpoint_id, payload],
48 );
49
50 match result {
51 Ok(_) => Ok(()),
52 Err(err) => {
53 if is_unique_violation(&err) {
54 Err(CheckpointError::Conflict(format!(
55 "thread `{}` already has checkpoint `{}`",
56 checkpoint.thread_id, checkpoint.checkpoint_id
57 )))
58 } else {
59 Err(map_sql_error(err))
60 }
61 }
62 }
63 }
64
65 fn get(
66 &self,
67 thread_id: &str,
68 checkpoint_id: &str,
69 ) -> Result<Option<Checkpoint>, CheckpointError> {
70 let conn = self.open()?;
71 let mut stmt = conn
72 .prepare("SELECT payload FROM checkpoints WHERE thread_id = ?1 AND checkpoint_id = ?2")
73 .map_err(map_sql_error)?;
74 let mut rows = stmt.query(params![thread_id, checkpoint_id]).map_err(map_sql_error)?;
75
76 if let Some(row) = rows.next().map_err(map_sql_error)? {
77 let payload: String = row.get(0).map_err(map_sql_error)?;
78 let checkpoint = serde_json::from_str(&payload)
79 .map_err(|err| CheckpointError::Storage(err.to_string()))?;
80 Ok(Some(checkpoint))
81 } else {
82 Ok(None)
83 }
84 }
85
86 fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>, CheckpointError> {
87 let conn = self.open()?;
88 let mut stmt = conn
89 .prepare(
90 "SELECT payload FROM checkpoints WHERE thread_id = ?1 ORDER BY created_at ASC, checkpoint_id ASC",
91 )
92 .map_err(map_sql_error)?;
93 let mut rows = stmt.query(params![thread_id]).map_err(map_sql_error)?;
94 let mut checkpoints = Vec::new();
95
96 while let Some(row) = rows.next().map_err(map_sql_error)? {
97 let payload: String = row.get(0).map_err(map_sql_error)?;
98 let checkpoint = serde_json::from_str(&payload)
99 .map_err(|err| CheckpointError::Storage(err.to_string()))?;
100 checkpoints.push(checkpoint);
101 }
102
103 Ok(checkpoints)
104 }
105}
106
107fn map_sql_error(err: SqlError) -> CheckpointError {
108 CheckpointError::Storage(err.to_string())
109}
110
111fn is_unique_violation(err: &SqlError) -> bool {
112 match err {
113 SqlError::SqliteFailure(info, _) => info.code == ErrorCode::ConstraintViolation,
114 _ => false,
115 }
116}