Skip to main content

langgraph_checkpoint_postgres/
postgres_saver.rs

1use langgraph_checkpoint::{Checkpoint, CheckpointError, CheckpointSaver};
2use postgres::error::SqlState;
3use postgres::{Client, NoTls};
4
5const CREATE_TABLE_SQL: &str = r#"
6CREATE TABLE IF NOT EXISTS checkpoints (
7    thread_id TEXT NOT NULL,
8    checkpoint_id TEXT NOT NULL,
9    payload JSONB NOT NULL,
10    created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
11    PRIMARY KEY (thread_id, checkpoint_id)
12);
13"#;
14
15#[derive(Debug, Clone)]
16pub struct PostgresSaver {
17    connection_string: String,
18}
19
20impl PostgresSaver {
21    pub fn new(connection_string: impl Into<String>) -> Result<Self, CheckpointError> {
22        let saver = Self { connection_string: connection_string.into() };
23        saver.init_schema()?;
24        Ok(saver)
25    }
26
27    fn init_schema(&self) -> Result<(), CheckpointError> {
28        let mut client = self.open()?;
29        client.batch_execute(CREATE_TABLE_SQL).map_err(map_pg_error)?;
30        Ok(())
31    }
32
33    fn open(&self) -> Result<Client, CheckpointError> {
34        Client::connect(&self.connection_string, NoTls).map_err(map_pg_error)
35    }
36}
37
38impl CheckpointSaver for PostgresSaver {
39    fn put(&self, checkpoint: Checkpoint) -> Result<(), CheckpointError> {
40        let payload = serde_json::to_value(&checkpoint)
41            .map_err(|err| CheckpointError::Storage(err.to_string()))?;
42        let mut client = self.open()?;
43        let result = client.execute(
44            "INSERT INTO checkpoints (thread_id, checkpoint_id, payload) VALUES ($1, $2, $3)",
45            &[&checkpoint.thread_id, &checkpoint.checkpoint_id, &payload],
46        );
47        match result {
48            Ok(_) => Ok(()),
49            Err(err) => {
50                if let Some(code) = err.code() {
51                    if *code == SqlState::UNIQUE_VIOLATION {
52                        return Err(CheckpointError::Conflict(format!(
53                            "thread `{}` already has checkpoint `{}`",
54                            checkpoint.thread_id, checkpoint.checkpoint_id
55                        )));
56                    }
57                }
58                Err(map_pg_error(err))
59            }
60        }
61    }
62
63    fn get(
64        &self,
65        thread_id: &str,
66        checkpoint_id: &str,
67    ) -> Result<Option<Checkpoint>, CheckpointError> {
68        let mut client = self.open()?;
69        let row = client
70            .query_opt(
71                "SELECT payload FROM checkpoints WHERE thread_id = $1 AND checkpoint_id = $2",
72                &[&thread_id, &checkpoint_id],
73            )
74            .map_err(map_pg_error)?;
75        match row {
76            Some(row) => {
77                let payload: serde_json::Value = row.get(0);
78                let checkpoint = serde_json::from_value(payload)
79                    .map_err(|err| CheckpointError::Storage(err.to_string()))?;
80                Ok(Some(checkpoint))
81            }
82            None => Ok(None),
83        }
84    }
85
86    fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>, CheckpointError> {
87        let mut client = self.open()?;
88        let rows = client
89            .query(
90                "SELECT payload FROM checkpoints WHERE thread_id = $1 ORDER BY created_at ASC, checkpoint_id ASC",
91                &[&thread_id],
92            )
93            .map_err(map_pg_error)?;
94        rows.into_iter()
95            .map(|row| {
96                let payload: serde_json::Value = row.get(0);
97                serde_json::from_value(payload)
98                    .map_err(|err| CheckpointError::Storage(err.to_string()))
99            })
100            .collect()
101    }
102}
103
104fn map_pg_error(err: postgres::Error) -> CheckpointError {
105    CheckpointError::Storage(err.to_string())
106}