1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value as JsonValue;
6use sqlx::postgres::{PgPool, PgPoolOptions, PgRow};
7use sqlx::Row;
8
9use langgraph_checkpoint::checkpoint::base::{get_checkpoint_id, writes_idx_map, BaseCheckpointSaver};
10use langgraph_checkpoint::checkpoint::types::*;
11use langgraph_checkpoint::config::RunnableConfig;
12use langgraph_checkpoint::error::CheckpointError;
13use langgraph_checkpoint::serde::base::SerializerProtocol;
14use langgraph_checkpoint::serde::jsonplus::JsonPlusSerializer;
15
16use crate::queries::*;
17
18type BlobRow = (String, String, String, String, String, Option<Vec<u8>>);
20
21type WriteRow = (String, String, String, String, String, i32, String, String, Vec<u8>);
23
24fn config_from_json(val: serde_json::Value) -> RunnableConfig {
26 serde_json::from_value(val).unwrap_or_default()
27}
28
29#[allow(dead_code)]
31fn any_to_json(val: Box<dyn std::any::Any + Send + Sync>) -> JsonValue {
32 if val.is::<JsonValue>() {
33 *val.downcast::<JsonValue>().unwrap()
34 } else if val.is::<String>() {
35 JsonValue::String(*val.downcast::<String>().unwrap())
36 } else if val.is::<Vec<u8>>() {
37 let b = val.downcast::<Vec<u8>>().unwrap();
38 JsonValue::Array(b.into_iter().map(|byte: u8| JsonValue::Number(byte.into())).collect())
39 } else {
40 JsonValue::Null
42 }
43}
44
45pub struct PostgresSaver {
47 pool: PgPool,
48 serde: Arc<dyn SerializerProtocol>,
49}
50
51impl PostgresSaver {
52 pub fn new(pool: PgPool) -> Self {
54 Self {
55 pool,
56 serde: Arc::new(JsonPlusSerializer::new()),
57 }
58 }
59
60 pub fn with_serde(pool: PgPool, serde: Arc<dyn SerializerProtocol>) -> Self {
62 Self { pool, serde }
63 }
64
65 pub async fn from_conn_string(conn_string: &str) -> Result<Self, CheckpointError> {
67 let pool = PgPoolOptions::new()
68 .max_connections(5)
69 .connect(conn_string)
70 .await
71 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
72 Ok(Self::new(pool))
73 }
74
75 pub async fn setup(&self) -> Result<(), CheckpointError> {
77 sqlx::query(MIGRATIONS[0])
78 .execute(&self.pool)
79 .await
80 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
81
82 let row: Option<(i32,)> = sqlx::query_as(
83 "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1",
84 )
85 .fetch_optional(&self.pool)
86 .await
87 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
88
89 let version = row.map(|(v,)| v).unwrap_or(-1);
90
91 for (i, migration) in MIGRATIONS.iter().enumerate() {
92 let v = i as i32;
93 if v > version {
94 sqlx::query(migration)
95 .execute(&self.pool)
96 .await
97 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
98 sqlx::query("INSERT INTO checkpoint_migrations (v) VALUES ($1)")
99 .bind(v)
100 .execute(&self.pool)
101 .await
102 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
103 }
104 }
105
106 Ok(())
107 }
108
109 pub fn pool(&self) -> &PgPool {
111 &self.pool
112 }
113
114 fn build_where_clause(
116 config: Option<&RunnableConfig>,
117 _filter: Option<&HashMap<String, JsonValue>>,
118 before: Option<&RunnableConfig>,
119 ) -> (String, Vec<String>) {
120 let mut wheres = Vec::new();
121 let mut params = Vec::new();
122
123 if let Some(config) = config {
124 if let Some(thread_id) = config
125 .get("configurable")
126 .and_then(|c| c.get("thread_id"))
127 .and_then(|v| v.as_str())
128 {
129 let idx = params.len() + 1;
130 wheres.push(format!("thread_id = ${}", idx));
131 params.push(thread_id.to_string());
132 }
133
134 if let Some(checkpoint_ns) = config
135 .get("configurable")
136 .and_then(|c| c.get("checkpoint_ns"))
137 .and_then(|v| v.as_str())
138 {
139 let idx = params.len() + 1;
140 wheres.push(format!("checkpoint_ns = ${}", idx));
141 params.push(checkpoint_ns.to_string());
142 }
143
144 if let Some(checkpoint_id) = get_checkpoint_id(config) {
145 let idx = params.len() + 1;
146 wheres.push(format!("checkpoint_id = ${}", idx));
147 params.push(checkpoint_id);
148 }
149 }
150
151 if let Some(before) = before {
152 if let Some(before_id) = get_checkpoint_id(before) {
153 let idx = params.len() + 1;
154 wheres.push(format!("checkpoint_id < ${}", idx));
155 params.push(before_id);
156 }
157 }
158
159 let where_clause = if wheres.is_empty() {
160 String::new()
161 } else {
162 format!("WHERE {}", wheres.join(" AND "))
163 };
164
165 (where_clause, params)
166 }
167
168 fn dump_blobs(
170 &self,
171 thread_id: &str,
172 checkpoint_ns: &str,
173 values: &HashMap<String, JsonValue>,
174 versions: &ChannelVersions,
175 ) -> Vec<BlobRow> {
176 let mut result = Vec::new();
177 for (k, ver) in versions {
178 let ver_str = match ver {
179 JsonValue::String(s) => s.clone(),
180 JsonValue::Number(n) => n.to_string(),
181 _ => continue,
182 };
183 if let Some(val) = values.get(k) {
184 if let Ok((type_tag, blob)) = self.serde.dumps_typed(val) {
185 result.push((
186 thread_id.to_string(),
187 checkpoint_ns.to_string(),
188 k.clone(),
189 ver_str,
190 type_tag,
191 Some(blob),
192 ));
193 }
194 } else {
195 result.push((
196 thread_id.to_string(),
197 checkpoint_ns.to_string(),
198 k.clone(),
199 ver_str,
200 "empty".to_string(),
201 None,
202 ));
203 }
204 }
205 result
206 }
207
208 fn dump_writes(
210 &self,
211 thread_id: &str,
212 checkpoint_ns: &str,
213 checkpoint_id: &str,
214 task_id: &str,
215 task_path: &str,
216 writes: &[(String, String, JsonValue)],
217 ) -> Vec<WriteRow> {
218 let idx_map = writes_idx_map();
219 writes
220 .iter()
221 .enumerate()
222 .filter_map(|(idx, (_task_id, channel, value))| {
223 let idx_val = idx_map
224 .get(channel.as_str())
225 .copied()
226 .unwrap_or(idx as i64) as i32;
227 if let Ok((type_tag, blob)) = self.serde.dumps_typed(value) {
228 Some((
229 thread_id.to_string(),
230 checkpoint_ns.to_string(),
231 checkpoint_id.to_string(),
232 task_id.to_string(),
233 task_path.to_string(),
234 idx_val,
235 channel.clone(),
236 type_tag,
237 blob,
238 ))
239 } else {
240 None
241 }
242 })
243 .collect()
244 }
245
246 fn row_to_tuple(row: &PgRow) -> Result<CheckpointTuple, CheckpointError> {
248 let checkpoint_json: JsonValue = row.get("checkpoint");
249 let metadata_json: JsonValue = row.get("metadata");
250
251 let checkpoint: Checkpoint = serde_json::from_value(checkpoint_json)
252 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
253 let metadata: CheckpointMetadata = serde_json::from_value(metadata_json)
254 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
255
256 let thread_id: String = row.get("thread_id");
257 let checkpoint_ns: String = row.get("checkpoint_ns");
258
259 let tuple_config = config_from_json(serde_json::json!({
260 "configurable": {
261 "thread_id": thread_id,
262 "checkpoint_ns": checkpoint_ns,
263 "checkpoint_id": checkpoint.id,
264 }
265 }));
266
267 let parent_config: Option<RunnableConfig> = row
268 .get::<Option<String>, _>("parent_checkpoint_id")
269 .map(|pid| {
270 config_from_json(serde_json::json!({
271 "configurable": {
272 "thread_id": thread_id,
273 "checkpoint_ns": checkpoint_ns,
274 "checkpoint_id": pid,
275 }
276 }))
277 });
278
279 Ok(CheckpointTuple {
280 config: tuple_config,
281 checkpoint,
282 metadata,
283 parent_config,
284 pending_writes: None,
285 })
286 }
287}
288
289#[async_trait]
290impl BaseCheckpointSaver for PostgresSaver {
291 fn get_tuple(
292 &self,
293 config: &RunnableConfig,
294 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
295 match tokio::runtime::Handle::try_current() {
297 Ok(handle) => handle.block_on(self.aget_tuple(config)),
298 Err(_) => {
299 let rt = tokio::runtime::Runtime::new()
300 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
301 rt.block_on(self.aget_tuple(config))
302 }
303 }
304 }
305
306 fn list(
307 &self,
308 config: Option<&RunnableConfig>,
309 filter: Option<&HashMap<String, JsonValue>>,
310 before: Option<&RunnableConfig>,
311 limit: Option<usize>,
312 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
313 match tokio::runtime::Handle::try_current() {
314 Ok(handle) => handle.block_on(self.alist(config, filter, before, limit)),
315 Err(_) => {
316 let rt = tokio::runtime::Runtime::new()
317 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
318 rt.block_on(self.alist(config, filter, before, limit))
319 }
320 }
321 }
322
323 fn put(
324 &self,
325 config: &RunnableConfig,
326 checkpoint: &Checkpoint,
327 metadata: &CheckpointMetadata,
328 new_versions: &ChannelVersions,
329 ) -> Result<RunnableConfig, CheckpointError> {
330 match tokio::runtime::Handle::try_current() {
331 Ok(handle) => handle.block_on(self.aput(config, checkpoint, metadata, new_versions)),
332 Err(_) => {
333 let rt = tokio::runtime::Runtime::new()
334 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
335 rt.block_on(self.aput(config, checkpoint, metadata, new_versions))
336 }
337 }
338 }
339
340 fn put_writes(
341 &self,
342 config: &RunnableConfig,
343 writes: &[(String, String, JsonValue)],
344 task_id: &str,
345 task_path: &str,
346 ) -> Result<(), CheckpointError> {
347 match tokio::runtime::Handle::try_current() {
348 Ok(handle) => handle.block_on(self.aput_writes(
349 config,
350 writes.to_vec(),
351 task_id.to_string(),
352 task_path.to_string(),
353 )),
354 Err(_) => {
355 let rt = tokio::runtime::Runtime::new()
356 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
357 rt.block_on(self.aput_writes(
358 config,
359 writes.to_vec(),
360 task_id.to_string(),
361 task_path.to_string(),
362 ))
363 }
364 }
365 }
366
367 fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError> {
368 match tokio::runtime::Handle::try_current() {
369 Ok(handle) => handle.block_on(self.adelete_thread(thread_id.to_string())),
370 Err(_) => {
371 let rt = tokio::runtime::Runtime::new()
372 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
373 rt.block_on(self.adelete_thread(thread_id.to_string()))
374 }
375 }
376 }
377
378 async fn aget_tuple(
379 &self,
380 config: &RunnableConfig,
381 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
382 let thread_id = config
383 .get("configurable")
384 .and_then(|c| c.get("thread_id"))
385 .and_then(|v| v.as_str())
386 .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
387
388 let checkpoint_ns = config
389 .get("configurable")
390 .and_then(|c| c.get("checkpoint_ns"))
391 .and_then(|v| v.as_str())
392 .unwrap_or("");
393
394 let checkpoint_id = get_checkpoint_id(config);
395
396 let row = if let Some(cid) = &checkpoint_id {
397 sqlx::query(&format!(
398 "{} WHERE thread_id = $1 AND checkpoint_ns = $2 AND checkpoint_id = $3",
399 SELECT_SQL
400 ))
401 .bind(thread_id)
402 .bind(checkpoint_ns)
403 .bind(cid.as_str())
404 .fetch_optional(&self.pool)
405 .await
406 .map_err(|e| CheckpointError::Storage(e.to_string()))?
407 } else {
408 sqlx::query(&format!(
409 "{} WHERE thread_id = $1 AND checkpoint_ns = $2 ORDER BY checkpoint_id DESC LIMIT 1",
410 SELECT_SQL
411 ))
412 .bind(thread_id)
413 .bind(checkpoint_ns)
414 .fetch_optional(&self.pool)
415 .await
416 .map_err(|e| CheckpointError::Storage(e.to_string()))?
417 };
418
419 match row {
420 Some(row) => Ok(Some(Self::row_to_tuple(&row)?)),
421 None => Ok(None),
422 }
423 }
424
425 async fn aput(
426 &self,
427 config: &RunnableConfig,
428 checkpoint: &Checkpoint,
429 metadata: &CheckpointMetadata,
430 new_versions: &ChannelVersions,
431 ) -> Result<RunnableConfig, CheckpointError> {
432 let configurable = config.get("configurable").cloned().unwrap_or_default();
433 let thread_id = configurable
434 .get("thread_id")
435 .and_then(|v| v.as_str())
436 .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
437 let checkpoint_ns = configurable
438 .get("checkpoint_ns")
439 .and_then(|v| v.as_str())
440 .unwrap_or("");
441 let parent_checkpoint_id: Option<String> = configurable
442 .get("checkpoint_id")
443 .and_then(|v| v.as_str())
444 .map(|s| s.to_string());
445
446 let next_config = config_from_json(serde_json::json!({
447 "configurable": {
448 "thread_id": thread_id,
449 "checkpoint_ns": checkpoint_ns,
450 "checkpoint_id": checkpoint.id,
451 }
452 }));
453
454 let checkpoint_json = serde_json::to_value(checkpoint)
455 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
456 let metadata_json = serde_json::to_value(metadata)
457 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
458
459 let blobs = self.dump_blobs(
461 thread_id,
462 checkpoint_ns,
463 &checkpoint.channel_values,
464 new_versions,
465 );
466 for (tid, cns, channel, version, type_tag, blob) in &blobs {
467 sqlx::query(UPSERT_CHECKPOINT_BLOBS_SQL)
468 .bind(tid.as_str())
469 .bind(cns.as_str())
470 .bind(channel.as_str())
471 .bind(version.as_str())
472 .bind(type_tag.as_str())
473 .bind(blob.as_deref())
474 .execute(&self.pool)
475 .await
476 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
477 }
478
479 sqlx::query(UPSERT_CHECKPOINTS_SQL)
481 .bind(thread_id)
482 .bind(checkpoint_ns)
483 .bind(checkpoint.id.as_str())
484 .bind(parent_checkpoint_id.as_deref())
485 .bind(&checkpoint_json)
486 .bind(&metadata_json)
487 .execute(&self.pool)
488 .await
489 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
490
491 Ok(next_config)
492 }
493
494 async fn aput_writes(
495 &self,
496 config: &RunnableConfig,
497 writes: Vec<(String, String, JsonValue)>,
498 task_id: String,
499 task_path: String,
500 ) -> Result<(), CheckpointError> {
501 let configurable = config.get("configurable").cloned().unwrap_or_default();
502 let thread_id = configurable
503 .get("thread_id")
504 .and_then(|v| v.as_str())
505 .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
506 let checkpoint_ns = configurable
507 .get("checkpoint_ns")
508 .and_then(|v| v.as_str())
509 .unwrap_or("");
510 let checkpoint_id = configurable
511 .get("checkpoint_id")
512 .and_then(|v| v.as_str())
513 .unwrap_or("");
514
515 let idx_map = writes_idx_map();
516 let use_upsert = writes
517 .iter()
518 .all(|(channel, _, _)| idx_map.contains_key(channel.as_str()));
519
520 let query = if use_upsert {
521 UPSERT_CHECKPOINT_WRITES_SQL
522 } else {
523 INSERT_CHECKPOINT_WRITES_SQL
524 };
525
526 let dump = self.dump_writes(
527 thread_id,
528 checkpoint_ns,
529 checkpoint_id,
530 &task_id,
531 &task_path,
532 &writes,
533 );
534
535 for (tid, cns, cid, tid2, tpath, idx, channel, type_tag, blob) in &dump {
536 sqlx::query(query)
537 .bind(tid.as_str())
538 .bind(cns.as_str())
539 .bind(cid.as_str())
540 .bind(tid2.as_str())
541 .bind(tpath.as_str())
542 .bind(*idx)
543 .bind(channel.as_str())
544 .bind(type_tag.as_str())
545 .bind(blob.as_slice())
546 .execute(&self.pool)
547 .await
548 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
549 }
550
551 Ok(())
552 }
553
554 async fn adelete_thread(&self, thread_id: String) -> Result<(), CheckpointError> {
555 let mut tx = self
556 .pool
557 .begin()
558 .await
559 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
560
561 sqlx::query("DELETE FROM checkpoints WHERE thread_id = $1")
562 .bind(thread_id.as_str())
563 .execute(&mut *tx)
564 .await
565 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
566
567 sqlx::query("DELETE FROM checkpoint_blobs WHERE thread_id = $1")
568 .bind(thread_id.as_str())
569 .execute(&mut *tx)
570 .await
571 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
572
573 sqlx::query("DELETE FROM checkpoint_writes WHERE thread_id = $1")
574 .bind(thread_id.as_str())
575 .execute(&mut *tx)
576 .await
577 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
578
579 tx.commit()
580 .await
581 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
582
583 Ok(())
584 }
585}
586
587impl PostgresSaver {
589 pub async fn alist(
590 &self,
591 config: Option<&RunnableConfig>,
592 filter: Option<&HashMap<String, JsonValue>>,
593 before: Option<&RunnableConfig>,
594 limit: Option<usize>,
595 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
596 let (where_clause, _params) = Self::build_where_clause(config, filter, before);
597 let mut query = format!(
598 "{} {} ORDER BY checkpoint_id DESC",
599 SELECT_SQL, where_clause
600 );
601
602 if let Some(limit) = limit {
603 query.push_str(&format!(" LIMIT {}", limit));
604 }
605
606 let mut q = sqlx::query(&query);
608 if let Some(config) = config {
609 if let Some(thread_id) = config
610 .get("configurable")
611 .and_then(|c| c.get("thread_id"))
612 .and_then(|v| v.as_str())
613 {
614 q = q.bind(thread_id);
615 }
616 if let Some(checkpoint_ns) = config
617 .get("configurable")
618 .and_then(|c| c.get("checkpoint_ns"))
619 .and_then(|v| v.as_str())
620 {
621 q = q.bind(checkpoint_ns);
622 }
623 if let Some(checkpoint_id) = get_checkpoint_id(config) {
624 q = q.bind(checkpoint_id);
625 }
626 }
627 if let Some(before) = before {
628 if let Some(before_id) = get_checkpoint_id(before) {
629 q = q.bind(before_id);
630 }
631 }
632
633 let rows = q
634 .fetch_all(&self.pool)
635 .await
636 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
637
638 let mut results = Vec::new();
639 for row in rows {
640 results.push(Self::row_to_tuple(&row)?);
641 }
642
643 Ok(results)
644 }
645}
646
647#[allow(dead_code)]
650impl PostgresSaver {
651 fn config_error(msg: &str) -> CheckpointError {
653 CheckpointError::Storage(format!("config error: {}", msg))
654 }
655}