Skip to main content

langgraph_checkpoint_postgres_rs/
saver.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value as JsonValue;
6use sqlx::postgres::{PgPool, PgPoolOptions, PgRow};
7use sqlx::Row;
8
9use langgraph_checkpoint::checkpoint::base::{get_checkpoint_id, writes_idx_map, BaseCheckpointSaver};
10use langgraph_checkpoint::checkpoint::types::*;
11use langgraph_checkpoint::config::RunnableConfig;
12use langgraph_checkpoint::error::CheckpointError;
13use langgraph_checkpoint::serde::base::SerializerProtocol;
14use langgraph_checkpoint::serde::jsonplus::JsonPlusSerializer;
15
16use crate::queries::*;
17
18/// Blob row: (thread_id, checkpoint_ns, channel, version, type_tag, blob)
19type BlobRow = (String, String, String, String, String, Option<Vec<u8>>);
20
21/// Write row: (thread_id, checkpoint_ns, checkpoint_id, task_id, task_path, idx, channel, type_tag, blob)
22type WriteRow = (String, String, String, String, String, i32, String, String, Vec<u8>);
23
24/// Helper: create a RunnableConfig from a JSON value.
25fn config_from_json(val: serde_json::Value) -> RunnableConfig {
26    serde_json::from_value(val).unwrap_or_default()
27}
28
29/// Helper: downcast Box<dyn Any> from loads_typed to JsonValue.
30#[allow(dead_code)]
31fn any_to_json(val: Box<dyn std::any::Any + Send + Sync>) -> JsonValue {
32    if val.is::<JsonValue>() {
33        *val.downcast::<JsonValue>().unwrap()
34    } else if val.is::<String>() {
35        JsonValue::String(*val.downcast::<String>().unwrap())
36    } else if val.is::<Vec<u8>>() {
37        let b = val.downcast::<Vec<u8>>().unwrap();
38        JsonValue::Array(b.into_iter().map(|byte: u8| JsonValue::Number(byte.into())).collect())
39    } else {
40        // () and unknown types both map to Null
41        JsonValue::Null
42    }
43}
44
45/// Async Postgres checkpoint saver using sqlx.
46pub struct PostgresSaver {
47    pool: PgPool,
48    serde: Arc<dyn SerializerProtocol>,
49}
50
51impl PostgresSaver {
52    /// Create a new PostgresSaver from a connection pool.
53    pub fn new(pool: PgPool) -> Self {
54        Self {
55            pool,
56            serde: Arc::new(JsonPlusSerializer::new()),
57        }
58    }
59
60    /// Create a new PostgresSaver with a custom serializer.
61    pub fn with_serde(pool: PgPool, serde: Arc<dyn SerializerProtocol>) -> Self {
62        Self { pool, serde }
63    }
64
65    /// Create a PostgresSaver from a connection string.
66    pub async fn from_conn_string(conn_string: &str) -> Result<Self, CheckpointError> {
67        let pool = PgPoolOptions::new()
68            .max_connections(5)
69            .connect(conn_string)
70            .await
71            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
72        Ok(Self::new(pool))
73    }
74
75    /// Run migrations to set up the checkpoint schema.
76    pub async fn setup(&self) -> Result<(), CheckpointError> {
77        sqlx::query(MIGRATIONS[0])
78            .execute(&self.pool)
79            .await
80            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
81
82        let row: Option<(i32,)> = sqlx::query_as(
83            "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1",
84        )
85        .fetch_optional(&self.pool)
86        .await
87        .map_err(|e| CheckpointError::Storage(e.to_string()))?;
88
89        let version = row.map(|(v,)| v).unwrap_or(-1);
90
91        for (i, migration) in MIGRATIONS.iter().enumerate() {
92            let v = i as i32;
93            if v > version {
94                sqlx::query(migration)
95                    .execute(&self.pool)
96                    .await
97                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
98                sqlx::query("INSERT INTO checkpoint_migrations (v) VALUES ($1)")
99                    .bind(v)
100                    .execute(&self.pool)
101                    .await
102                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
103            }
104        }
105
106        Ok(())
107    }
108
109    /// Get the underlying connection pool.
110    pub fn pool(&self) -> &PgPool {
111        &self.pool
112    }
113
114    /// Build a WHERE clause from config, filter, and before parameters.
115    fn build_where_clause(
116        config: Option<&RunnableConfig>,
117        _filter: Option<&HashMap<String, JsonValue>>,
118        before: Option<&RunnableConfig>,
119    ) -> (String, Vec<String>) {
120        let mut wheres = Vec::new();
121        let mut params = Vec::new();
122
123        if let Some(config) = config {
124            if let Some(thread_id) = config
125                .get("configurable")
126                .and_then(|c| c.get("thread_id"))
127                .and_then(|v| v.as_str())
128            {
129                let idx = params.len() + 1;
130                wheres.push(format!("thread_id = ${}", idx));
131                params.push(thread_id.to_string());
132            }
133
134            if let Some(checkpoint_ns) = config
135                .get("configurable")
136                .and_then(|c| c.get("checkpoint_ns"))
137                .and_then(|v| v.as_str())
138            {
139                let idx = params.len() + 1;
140                wheres.push(format!("checkpoint_ns = ${}", idx));
141                params.push(checkpoint_ns.to_string());
142            }
143
144            if let Some(checkpoint_id) = get_checkpoint_id(config) {
145                let idx = params.len() + 1;
146                wheres.push(format!("checkpoint_id = ${}", idx));
147                params.push(checkpoint_id);
148            }
149        }
150
151        if let Some(before) = before {
152            if let Some(before_id) = get_checkpoint_id(before) {
153                let idx = params.len() + 1;
154                wheres.push(format!("checkpoint_id < ${}", idx));
155                params.push(before_id);
156            }
157        }
158
159        let where_clause = if wheres.is_empty() {
160            String::new()
161        } else {
162            format!("WHERE {}", wheres.join(" AND "))
163        };
164
165        (where_clause, params)
166    }
167
168    /// Serialize blobs for storage.
169    fn dump_blobs(
170        &self,
171        thread_id: &str,
172        checkpoint_ns: &str,
173        values: &HashMap<String, JsonValue>,
174        versions: &ChannelVersions,
175    ) -> Vec<BlobRow> {
176        let mut result = Vec::new();
177        for (k, ver) in versions {
178            let ver_str = match ver {
179                JsonValue::String(s) => s.clone(),
180                JsonValue::Number(n) => n.to_string(),
181                _ => continue,
182            };
183            if let Some(val) = values.get(k) {
184                if let Ok((type_tag, blob)) = self.serde.dumps_typed(val) {
185                    result.push((
186                        thread_id.to_string(),
187                        checkpoint_ns.to_string(),
188                        k.clone(),
189                        ver_str,
190                        type_tag,
191                        Some(blob),
192                    ));
193                }
194            } else {
195                result.push((
196                    thread_id.to_string(),
197                    checkpoint_ns.to_string(),
198                    k.clone(),
199                    ver_str,
200                    "empty".to_string(),
201                    None,
202                ));
203            }
204        }
205        result
206    }
207
208    /// Serialize writes for storage.
209    fn dump_writes(
210        &self,
211        thread_id: &str,
212        checkpoint_ns: &str,
213        checkpoint_id: &str,
214        task_id: &str,
215        task_path: &str,
216        writes: &[(String, String, JsonValue)],
217    ) -> Vec<WriteRow> {
218        let idx_map = writes_idx_map();
219        writes
220            .iter()
221            .enumerate()
222            .filter_map(|(idx, (_task_id, channel, value))| {
223                let idx_val = idx_map
224                    .get(channel.as_str())
225                    .copied()
226                    .unwrap_or(idx as i64) as i32;
227                if let Ok((type_tag, blob)) = self.serde.dumps_typed(value) {
228                    Some((
229                        thread_id.to_string(),
230                        checkpoint_ns.to_string(),
231                        checkpoint_id.to_string(),
232                        task_id.to_string(),
233                        task_path.to_string(),
234                        idx_val,
235                        channel.clone(),
236                        type_tag,
237                        blob,
238                    ))
239                } else {
240                    None
241                }
242            })
243            .collect()
244    }
245
246    /// Parse a checkpoint from a row and build a CheckpointTuple.
247    fn row_to_tuple(row: &PgRow) -> Result<CheckpointTuple, CheckpointError> {
248        let checkpoint_json: JsonValue = row.get("checkpoint");
249        let metadata_json: JsonValue = row.get("metadata");
250
251        let checkpoint: Checkpoint = serde_json::from_value(checkpoint_json)
252            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
253        let metadata: CheckpointMetadata = serde_json::from_value(metadata_json)
254            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
255
256        let thread_id: String = row.get("thread_id");
257        let checkpoint_ns: String = row.get("checkpoint_ns");
258
259        let tuple_config = config_from_json(serde_json::json!({
260            "configurable": {
261                "thread_id": thread_id,
262                "checkpoint_ns": checkpoint_ns,
263                "checkpoint_id": checkpoint.id,
264            }
265        }));
266
267        let parent_config: Option<RunnableConfig> = row
268            .get::<Option<String>, _>("parent_checkpoint_id")
269            .map(|pid| {
270                config_from_json(serde_json::json!({
271                    "configurable": {
272                        "thread_id": thread_id,
273                        "checkpoint_ns": checkpoint_ns,
274                        "checkpoint_id": pid,
275                    }
276                }))
277            });
278
279        Ok(CheckpointTuple {
280            config: tuple_config,
281            checkpoint,
282            metadata,
283            parent_config,
284            pending_writes: None,
285        })
286    }
287}
288
289#[async_trait]
290impl BaseCheckpointSaver for PostgresSaver {
291    fn get_tuple(
292        &self,
293        config: &RunnableConfig,
294    ) -> Result<Option<CheckpointTuple>, CheckpointError> {
295        // For sync calls, try to use existing runtime or create one
296        match tokio::runtime::Handle::try_current() {
297            Ok(handle) => handle.block_on(self.aget_tuple(config)),
298            Err(_) => {
299                let rt = tokio::runtime::Runtime::new()
300                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
301                rt.block_on(self.aget_tuple(config))
302            }
303        }
304    }
305
306    fn list(
307        &self,
308        config: Option<&RunnableConfig>,
309        filter: Option<&HashMap<String, JsonValue>>,
310        before: Option<&RunnableConfig>,
311        limit: Option<usize>,
312    ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
313        match tokio::runtime::Handle::try_current() {
314            Ok(handle) => handle.block_on(self.alist(config, filter, before, limit)),
315            Err(_) => {
316                let rt = tokio::runtime::Runtime::new()
317                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
318                rt.block_on(self.alist(config, filter, before, limit))
319            }
320        }
321    }
322
323    fn put(
324        &self,
325        config: &RunnableConfig,
326        checkpoint: &Checkpoint,
327        metadata: &CheckpointMetadata,
328        new_versions: &ChannelVersions,
329    ) -> Result<RunnableConfig, CheckpointError> {
330        match tokio::runtime::Handle::try_current() {
331            Ok(handle) => handle.block_on(self.aput(config, checkpoint, metadata, new_versions)),
332            Err(_) => {
333                let rt = tokio::runtime::Runtime::new()
334                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
335                rt.block_on(self.aput(config, checkpoint, metadata, new_versions))
336            }
337        }
338    }
339
340    fn put_writes(
341        &self,
342        config: &RunnableConfig,
343        writes: &[(String, String, JsonValue)],
344        task_id: &str,
345        task_path: &str,
346    ) -> Result<(), CheckpointError> {
347        match tokio::runtime::Handle::try_current() {
348            Ok(handle) => handle.block_on(self.aput_writes(
349                config,
350                writes.to_vec(),
351                task_id.to_string(),
352                task_path.to_string(),
353            )),
354            Err(_) => {
355                let rt = tokio::runtime::Runtime::new()
356                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
357                rt.block_on(self.aput_writes(
358                    config,
359                    writes.to_vec(),
360                    task_id.to_string(),
361                    task_path.to_string(),
362                ))
363            }
364        }
365    }
366
367    fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError> {
368        match tokio::runtime::Handle::try_current() {
369            Ok(handle) => handle.block_on(self.adelete_thread(thread_id.to_string())),
370            Err(_) => {
371                let rt = tokio::runtime::Runtime::new()
372                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
373                rt.block_on(self.adelete_thread(thread_id.to_string()))
374            }
375        }
376    }
377
378    async fn aget_tuple(
379        &self,
380        config: &RunnableConfig,
381    ) -> Result<Option<CheckpointTuple>, CheckpointError> {
382        let thread_id = config
383            .get("configurable")
384            .and_then(|c| c.get("thread_id"))
385            .and_then(|v| v.as_str())
386            .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
387
388        let checkpoint_ns = config
389            .get("configurable")
390            .and_then(|c| c.get("checkpoint_ns"))
391            .and_then(|v| v.as_str())
392            .unwrap_or("");
393
394        let checkpoint_id = get_checkpoint_id(config);
395
396        let row = if let Some(cid) = &checkpoint_id {
397            sqlx::query(&format!(
398                "{} WHERE thread_id = $1 AND checkpoint_ns = $2 AND checkpoint_id = $3",
399                SELECT_SQL
400            ))
401            .bind(thread_id)
402            .bind(checkpoint_ns)
403            .bind(cid.as_str())
404            .fetch_optional(&self.pool)
405            .await
406            .map_err(|e| CheckpointError::Storage(e.to_string()))?
407        } else {
408            sqlx::query(&format!(
409                "{} WHERE thread_id = $1 AND checkpoint_ns = $2 ORDER BY checkpoint_id DESC LIMIT 1",
410                SELECT_SQL
411            ))
412            .bind(thread_id)
413            .bind(checkpoint_ns)
414            .fetch_optional(&self.pool)
415            .await
416            .map_err(|e| CheckpointError::Storage(e.to_string()))?
417        };
418
419        match row {
420            Some(row) => Ok(Some(Self::row_to_tuple(&row)?)),
421            None => Ok(None),
422        }
423    }
424
425    async fn aput(
426        &self,
427        config: &RunnableConfig,
428        checkpoint: &Checkpoint,
429        metadata: &CheckpointMetadata,
430        new_versions: &ChannelVersions,
431    ) -> Result<RunnableConfig, CheckpointError> {
432        let configurable = config.get("configurable").cloned().unwrap_or_default();
433        let thread_id = configurable
434            .get("thread_id")
435            .and_then(|v| v.as_str())
436            .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
437        let checkpoint_ns = configurable
438            .get("checkpoint_ns")
439            .and_then(|v| v.as_str())
440            .unwrap_or("");
441        let parent_checkpoint_id: Option<String> = configurable
442            .get("checkpoint_id")
443            .and_then(|v| v.as_str())
444            .map(|s| s.to_string());
445
446        let next_config = config_from_json(serde_json::json!({
447            "configurable": {
448                "thread_id": thread_id,
449                "checkpoint_ns": checkpoint_ns,
450                "checkpoint_id": checkpoint.id,
451            }
452        }));
453
454        let checkpoint_json = serde_json::to_value(checkpoint)
455            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
456        let metadata_json = serde_json::to_value(metadata)
457            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
458
459        // Upsert blobs
460        let blobs = self.dump_blobs(
461            thread_id,
462            checkpoint_ns,
463            &checkpoint.channel_values,
464            new_versions,
465        );
466        for (tid, cns, channel, version, type_tag, blob) in &blobs {
467            sqlx::query(UPSERT_CHECKPOINT_BLOBS_SQL)
468                .bind(tid.as_str())
469                .bind(cns.as_str())
470                .bind(channel.as_str())
471                .bind(version.as_str())
472                .bind(type_tag.as_str())
473                .bind(blob.as_deref())
474                .execute(&self.pool)
475                .await
476                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
477        }
478
479        // Upsert checkpoint
480        sqlx::query(UPSERT_CHECKPOINTS_SQL)
481            .bind(thread_id)
482            .bind(checkpoint_ns)
483            .bind(checkpoint.id.as_str())
484            .bind(parent_checkpoint_id.as_deref())
485            .bind(&checkpoint_json)
486            .bind(&metadata_json)
487            .execute(&self.pool)
488            .await
489            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
490
491        Ok(next_config)
492    }
493
494    async fn aput_writes(
495        &self,
496        config: &RunnableConfig,
497        writes: Vec<(String, String, JsonValue)>,
498        task_id: String,
499        task_path: String,
500    ) -> Result<(), CheckpointError> {
501        let configurable = config.get("configurable").cloned().unwrap_or_default();
502        let thread_id = configurable
503            .get("thread_id")
504            .and_then(|v| v.as_str())
505            .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
506        let checkpoint_ns = configurable
507            .get("checkpoint_ns")
508            .and_then(|v| v.as_str())
509            .unwrap_or("");
510        let checkpoint_id = configurable
511            .get("checkpoint_id")
512            .and_then(|v| v.as_str())
513            .unwrap_or("");
514
515        let idx_map = writes_idx_map();
516        let use_upsert = writes
517            .iter()
518            .all(|(channel, _, _)| idx_map.contains_key(channel.as_str()));
519
520        let query = if use_upsert {
521            UPSERT_CHECKPOINT_WRITES_SQL
522        } else {
523            INSERT_CHECKPOINT_WRITES_SQL
524        };
525
526        let dump = self.dump_writes(
527            thread_id,
528            checkpoint_ns,
529            checkpoint_id,
530            &task_id,
531            &task_path,
532            &writes,
533        );
534
535        for (tid, cns, cid, tid2, tpath, idx, channel, type_tag, blob) in &dump {
536            sqlx::query(query)
537                .bind(tid.as_str())
538                .bind(cns.as_str())
539                .bind(cid.as_str())
540                .bind(tid2.as_str())
541                .bind(tpath.as_str())
542                .bind(*idx)
543                .bind(channel.as_str())
544                .bind(type_tag.as_str())
545                .bind(blob.as_slice())
546                .execute(&self.pool)
547                .await
548                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
549        }
550
551        Ok(())
552    }
553
554    async fn adelete_thread(&self, thread_id: String) -> Result<(), CheckpointError> {
555        let mut tx = self
556            .pool
557            .begin()
558            .await
559            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
560
561        sqlx::query("DELETE FROM checkpoints WHERE thread_id = $1")
562            .bind(thread_id.as_str())
563            .execute(&mut *tx)
564            .await
565            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
566
567        sqlx::query("DELETE FROM checkpoint_blobs WHERE thread_id = $1")
568            .bind(thread_id.as_str())
569            .execute(&mut *tx)
570            .await
571            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
572
573        sqlx::query("DELETE FROM checkpoint_writes WHERE thread_id = $1")
574            .bind(thread_id.as_str())
575            .execute(&mut *tx)
576            .await
577            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
578
579        tx.commit()
580            .await
581            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
582
583        Ok(())
584    }
585}
586
587/// Async list method for PostgresSaver.
588impl PostgresSaver {
589    pub async fn alist(
590        &self,
591        config: Option<&RunnableConfig>,
592        filter: Option<&HashMap<String, JsonValue>>,
593        before: Option<&RunnableConfig>,
594        limit: Option<usize>,
595    ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
596        let (where_clause, _params) = Self::build_where_clause(config, filter, before);
597        let mut query = format!(
598            "{} {} ORDER BY checkpoint_id DESC",
599            SELECT_SQL, where_clause
600        );
601
602        if let Some(limit) = limit {
603            query.push_str(&format!(" LIMIT {}", limit));
604        }
605
606        // Build the query with bound params
607        let mut q = sqlx::query(&query);
608        if let Some(config) = config {
609            if let Some(thread_id) = config
610                .get("configurable")
611                .and_then(|c| c.get("thread_id"))
612                .and_then(|v| v.as_str())
613            {
614                q = q.bind(thread_id);
615            }
616            if let Some(checkpoint_ns) = config
617                .get("configurable")
618                .and_then(|c| c.get("checkpoint_ns"))
619                .and_then(|v| v.as_str())
620            {
621                q = q.bind(checkpoint_ns);
622            }
623            if let Some(checkpoint_id) = get_checkpoint_id(config) {
624                q = q.bind(checkpoint_id);
625            }
626        }
627        if let Some(before) = before {
628            if let Some(before_id) = get_checkpoint_id(before) {
629                q = q.bind(before_id);
630            }
631        }
632
633        let rows = q
634            .fetch_all(&self.pool)
635            .await
636            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
637
638        let mut results = Vec::new();
639        for row in rows {
640            results.push(Self::row_to_tuple(&row)?);
641        }
642
643        Ok(results)
644    }
645}
646
647/// CheckpointError needs a Config variant for missing config fields.
648/// We add it here since the base error type doesn't have one.
649#[allow(dead_code)]
650impl PostgresSaver {
651    /// Wrap a config error message into a CheckpointError::Storage.
652    fn config_error(msg: &str) -> CheckpointError {
653        CheckpointError::Storage(format!("config error: {}", msg))
654    }
655}