Skip to main content

langgraph_checkpoint_sqlite/
sqlite.rs

1use std::path::{Path, PathBuf};
2
3use langgraph_checkpoint::{Checkpoint, CheckpointError, CheckpointSaver};
4use rusqlite::{params, Connection, Error as SqlError, ErrorCode};
5
6const CREATE_TABLE_SQL: &str = r#"
7CREATE TABLE IF NOT EXISTS checkpoints (
8    thread_id TEXT NOT NULL,
9    checkpoint_id TEXT NOT NULL,
10    payload TEXT NOT NULL,
11    created_at INTEGER NOT NULL DEFAULT (unixepoch()),
12    PRIMARY KEY (thread_id, checkpoint_id)
13);
14"#;
15
16#[derive(Debug, Clone)]
17pub struct SqliteSaver {
18    db_path: PathBuf,
19}
20
21impl SqliteSaver {
22    pub fn new(path: impl AsRef<Path>) -> Result<Self, CheckpointError> {
23        let saver = Self { db_path: path.as_ref().to_path_buf() };
24        saver.init_schema()?;
25        Ok(saver)
26    }
27
28    fn init_schema(&self) -> Result<(), CheckpointError> {
29        let conn = self.open()?;
30        conn.execute_batch(CREATE_TABLE_SQL).map_err(map_sql_error)?;
31        Ok(())
32    }
33
34    fn open(&self) -> Result<Connection, CheckpointError> {
35        Connection::open(&self.db_path).map_err(map_sql_error)
36    }
37}
38
39impl CheckpointSaver for SqliteSaver {
40    fn put(&self, checkpoint: Checkpoint) -> Result<(), CheckpointError> {
41        let payload = serde_json::to_string(&checkpoint)
42            .map_err(|err| CheckpointError::Storage(err.to_string()))?;
43
44        let conn = self.open()?;
45        let result = conn.execute(
46            "INSERT INTO checkpoints (thread_id, checkpoint_id, payload) VALUES (?1, ?2, ?3)",
47            params![checkpoint.thread_id, checkpoint.checkpoint_id, payload],
48        );
49
50        match result {
51            Ok(_) => Ok(()),
52            Err(err) => {
53                if is_unique_violation(&err) {
54                    Err(CheckpointError::Conflict(format!(
55                        "thread `{}` already has checkpoint `{}`",
56                        checkpoint.thread_id, checkpoint.checkpoint_id
57                    )))
58                } else {
59                    Err(map_sql_error(err))
60                }
61            }
62        }
63    }
64
65    fn get(
66        &self,
67        thread_id: &str,
68        checkpoint_id: &str,
69    ) -> Result<Option<Checkpoint>, CheckpointError> {
70        let conn = self.open()?;
71        let mut stmt = conn
72            .prepare("SELECT payload FROM checkpoints WHERE thread_id = ?1 AND checkpoint_id = ?2")
73            .map_err(map_sql_error)?;
74        let mut rows = stmt.query(params![thread_id, checkpoint_id]).map_err(map_sql_error)?;
75
76        if let Some(row) = rows.next().map_err(map_sql_error)? {
77            let payload: String = row.get(0).map_err(map_sql_error)?;
78            let checkpoint = serde_json::from_str(&payload)
79                .map_err(|err| CheckpointError::Storage(err.to_string()))?;
80            Ok(Some(checkpoint))
81        } else {
82            Ok(None)
83        }
84    }
85
86    fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>, CheckpointError> {
87        let conn = self.open()?;
88        let mut stmt = conn
89            .prepare(
90                "SELECT payload FROM checkpoints WHERE thread_id = ?1 ORDER BY created_at ASC, checkpoint_id ASC",
91            )
92            .map_err(map_sql_error)?;
93        let mut rows = stmt.query(params![thread_id]).map_err(map_sql_error)?;
94        let mut checkpoints = Vec::new();
95
96        while let Some(row) = rows.next().map_err(map_sql_error)? {
97            let payload: String = row.get(0).map_err(map_sql_error)?;
98            let checkpoint = serde_json::from_str(&payload)
99                .map_err(|err| CheckpointError::Storage(err.to_string()))?;
100            checkpoints.push(checkpoint);
101        }
102
103        Ok(checkpoints)
104    }
105}
106
107fn map_sql_error(err: SqlError) -> CheckpointError {
108    CheckpointError::Storage(err.to_string())
109}
110
111fn is_unique_violation(err: &SqlError) -> bool {
112    match err {
113        SqlError::SqliteFailure(info, _) => info.code == ErrorCode::ConstraintViolation,
114        _ => false,
115    }
116}