langgraph_checkpoint_postgres/
postgres_saver.rs1use langgraph_checkpoint::{Checkpoint, CheckpointError, CheckpointSaver};
2use postgres::error::SqlState;
3use postgres::{Client, NoTls};
4
5const CREATE_TABLE_SQL: &str = r#"
6CREATE TABLE IF NOT EXISTS checkpoints (
7 thread_id TEXT NOT NULL,
8 checkpoint_id TEXT NOT NULL,
9 payload JSONB NOT NULL,
10 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
11 PRIMARY KEY (thread_id, checkpoint_id)
12);
13"#;
14
15#[derive(Debug, Clone)]
16pub struct PostgresSaver {
17 connection_string: String,
18}
19
20impl PostgresSaver {
21 pub fn new(connection_string: impl Into<String>) -> Result<Self, CheckpointError> {
22 let saver = Self { connection_string: connection_string.into() };
23 saver.init_schema()?;
24 Ok(saver)
25 }
26
27 fn init_schema(&self) -> Result<(), CheckpointError> {
28 let mut client = self.open()?;
29 client.batch_execute(CREATE_TABLE_SQL).map_err(map_pg_error)?;
30 Ok(())
31 }
32
33 fn open(&self) -> Result<Client, CheckpointError> {
34 Client::connect(&self.connection_string, NoTls).map_err(map_pg_error)
35 }
36}
37
38impl CheckpointSaver for PostgresSaver {
39 fn put(&self, checkpoint: Checkpoint) -> Result<(), CheckpointError> {
40 let payload = serde_json::to_value(&checkpoint)
41 .map_err(|err| CheckpointError::Storage(err.to_string()))?;
42 let mut client = self.open()?;
43 let result = client.execute(
44 "INSERT INTO checkpoints (thread_id, checkpoint_id, payload) VALUES ($1, $2, $3)",
45 &[&checkpoint.thread_id, &checkpoint.checkpoint_id, &payload],
46 );
47 match result {
48 Ok(_) => Ok(()),
49 Err(err) => {
50 if let Some(code) = err.code() {
51 if *code == SqlState::UNIQUE_VIOLATION {
52 return Err(CheckpointError::Conflict(format!(
53 "thread `{}` already has checkpoint `{}`",
54 checkpoint.thread_id, checkpoint.checkpoint_id
55 )));
56 }
57 }
58 Err(map_pg_error(err))
59 }
60 }
61 }
62
63 fn get(
64 &self,
65 thread_id: &str,
66 checkpoint_id: &str,
67 ) -> Result<Option<Checkpoint>, CheckpointError> {
68 let mut client = self.open()?;
69 let row = client
70 .query_opt(
71 "SELECT payload FROM checkpoints WHERE thread_id = $1 AND checkpoint_id = $2",
72 &[&thread_id, &checkpoint_id],
73 )
74 .map_err(map_pg_error)?;
75 match row {
76 Some(row) => {
77 let payload: serde_json::Value = row.get(0);
78 let checkpoint = serde_json::from_value(payload)
79 .map_err(|err| CheckpointError::Storage(err.to_string()))?;
80 Ok(Some(checkpoint))
81 }
82 None => Ok(None),
83 }
84 }
85
86 fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>, CheckpointError> {
87 let mut client = self.open()?;
88 let rows = client
89 .query(
90 "SELECT payload FROM checkpoints WHERE thread_id = $1 ORDER BY created_at ASC, checkpoint_id ASC",
91 &[&thread_id],
92 )
93 .map_err(map_pg_error)?;
94 rows.into_iter()
95 .map(|row| {
96 let payload: serde_json::Value = row.get(0);
97 serde_json::from_value(payload)
98 .map_err(|err| CheckpointError::Storage(err.to_string()))
99 })
100 .collect()
101 }
102}
103
104fn map_pg_error(err: postgres::Error) -> CheckpointError {
105 CheckpointError::Storage(err.to_string())
106}