1use std::collections::HashMap;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::Value as JsonValue;
7use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow};
8use sqlx::Row;
9
10use langgraph_checkpoint::checkpoint::base::{
11 get_checkpoint_id, get_checkpoint_metadata, writes_idx_map, BaseCheckpointSaver,
12};
13use langgraph_checkpoint::checkpoint::types::*;
14use langgraph_checkpoint::config::RunnableConfig;
15use langgraph_checkpoint::error::CheckpointError;
16use langgraph_checkpoint::serde::base::SerializerProtocol;
17use langgraph_checkpoint::serde::jsonplus::JsonPlusSerializer;
18
19use crate::queries::*;
20
21pub struct SqliteSaver {
26 pool: SqlitePool,
27 serde: Arc<dyn SerializerProtocol>,
28}
29
30impl SqliteSaver {
31 pub fn new(pool: SqlitePool) -> Self {
33 Self {
34 pool,
35 serde: Arc::new(JsonPlusSerializer::new()),
36 }
37 }
38
39 pub fn with_serde(pool: SqlitePool, serde: Arc<dyn SerializerProtocol>) -> Self {
41 Self { pool, serde }
42 }
43
44 pub async fn from_conn_string(conn_string: &str) -> Result<Self, CheckpointError> {
50 let opts = SqliteConnectOptions::from_str(conn_string)
51 .map_err(|e| CheckpointError::Storage(e.to_string()))?
52 .create_if_missing(true)
53 .journal_mode(SqliteJournalMode::Wal);
54
55 let pool = SqlitePoolOptions::new()
56 .max_connections(5)
57 .connect_with(opts)
58 .await
59 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
60
61 Ok(Self::new(pool))
62 }
63
64 pub async fn setup(&self) -> Result<(), CheckpointError> {
66 sqlx::query(MIGRATIONS[0])
68 .execute(&self.pool)
69 .await
70 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
71
72 let row: Option<(i64,)> = sqlx::query_as(
73 "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1",
74 )
75 .fetch_optional(&self.pool)
76 .await
77 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
78
79 let version = row.map(|(v,)| v).unwrap_or(-1);
80
81 for (i, migration) in MIGRATIONS.iter().enumerate() {
82 let v = i as i64;
83 if v > version {
84 sqlx::query(migration)
85 .execute(&self.pool)
86 .await
87 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
88 sqlx::query("INSERT INTO checkpoint_migrations (v) VALUES (?1)")
89 .bind(v)
90 .execute(&self.pool)
91 .await
92 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
93 }
94 }
95
96 Ok(())
97 }
98
99 pub fn pool(&self) -> &SqlitePool {
101 &self.pool
102 }
103
104 fn make_config(thread_id: &str, checkpoint_ns: &str, checkpoint_id: &str) -> RunnableConfig {
106 serde_json::from_value(serde_json::json!({
107 "configurable": {
108 "thread_id": thread_id,
109 "checkpoint_ns": checkpoint_ns,
110 "checkpoint_id": checkpoint_id,
111 }
112 }))
113 .unwrap_or_default()
114 }
115
116 fn row_to_tuple(row: &SqliteRow) -> Result<CheckpointTuple, CheckpointError> {
120 let thread_id: String = row.get("thread_id");
121 let checkpoint_ns: String = row.get("checkpoint_ns");
122 let checkpoint_text: String = row.get("checkpoint");
123 let metadata_text: String = row.get("metadata");
124
125 let checkpoint: Checkpoint = serde_json::from_str(&checkpoint_text)
126 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
127 let metadata: CheckpointMetadata = serde_json::from_str(&metadata_text)
128 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
129
130 let parent_checkpoint_id: Option<String> = row.try_get("parent_checkpoint_id").ok();
131 let parent_config = parent_checkpoint_id.map(|pid| {
132 Self::make_config(&thread_id, &checkpoint_ns, &pid)
133 });
134
135 let tuple_config = Self::make_config(&thread_id, &checkpoint_ns, &checkpoint.id);
136
137 Ok(CheckpointTuple {
138 config: tuple_config,
139 checkpoint,
140 metadata,
141 parent_config,
142 pending_writes: None,
143 })
144 }
145
146 async fn load_blobs(
148 &self,
149 thread_id: &str,
150 checkpoint_ns: &str,
151 checkpoint_id: &str,
152 ) -> Result<HashMap<String, JsonValue>, CheckpointError> {
153 let rows = sqlx::query(SELECT_BLOBS_SQL)
154 .bind(thread_id)
155 .bind(checkpoint_ns)
156 .bind(checkpoint_id)
157 .fetch_all(&self.pool)
158 .await
159 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
160
161 let mut values: HashMap<String, JsonValue> = HashMap::new();
162 for row in rows {
163 let channel: String = row.get("channel");
164 let type_tag: String = row.get("type");
165 let blob: Option<Vec<u8>> = row.try_get("blob").ok();
166
167 if type_tag == "empty" || blob.is_none() {
168 continue;
169 }
170 let bytes = blob.unwrap();
171 let val = match self.serde.loads_typed(&type_tag, &bytes) {
172 Ok(any_val) => any_to_json(any_val),
173 Err(_) => continue,
174 };
175 values.insert(channel, val);
176 }
177 Ok(values)
178 }
179
180 async fn load_writes(
182 &self,
183 thread_id: &str,
184 checkpoint_ns: &str,
185 checkpoint_id: &str,
186 ) -> Result<Vec<PendingWrite>, CheckpointError> {
187 let rows = sqlx::query(SELECT_WRITES_SQL)
188 .bind(thread_id)
189 .bind(checkpoint_ns)
190 .bind(checkpoint_id)
191 .fetch_all(&self.pool)
192 .await
193 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
194
195 let mut writes = Vec::with_capacity(rows.len());
196 for row in rows {
197 let task_id: String = row.get("task_id");
198 let channel: String = row.get("channel");
199 let type_tag: Option<String> = row.try_get("type").ok();
200 let blob: Option<Vec<u8>> = row.try_get("blob").ok();
201
202 let value = match (type_tag.as_deref(), blob) {
203 (Some(tag), Some(bytes)) => match self.serde.loads_typed(tag, &bytes) {
204 Ok(any_val) => any_to_json(any_val),
205 Err(_) => JsonValue::Null,
206 },
207 _ => JsonValue::Null,
208 };
209 writes.push((task_id, channel, value));
210 }
211 Ok(writes)
212 }
213
214 fn dump_blobs(
216 &self,
217 thread_id: &str,
218 checkpoint_ns: &str,
219 values: &HashMap<String, JsonValue>,
220 versions: &ChannelVersions,
221 ) -> Vec<(String, String, String, String, String, Option<Vec<u8>>)> {
222 let mut result = Vec::new();
223 for (channel, ver) in versions {
224 let ver_str = match ver {
225 JsonValue::String(s) => s.clone(),
226 JsonValue::Number(n) => n.to_string(),
227 _ => continue,
228 };
229 if let Some(val) = values.get(channel) {
230 if let Ok((type_tag, blob)) = self.serde.dumps_typed(val) {
231 result.push((
232 thread_id.to_string(),
233 checkpoint_ns.to_string(),
234 channel.clone(),
235 ver_str,
236 type_tag,
237 Some(blob),
238 ));
239 }
240 } else {
241 result.push((
242 thread_id.to_string(),
243 checkpoint_ns.to_string(),
244 channel.clone(),
245 ver_str,
246 "empty".to_string(),
247 None,
248 ));
249 }
250 }
251 result
252 }
253
254 pub async fn alist(
256 &self,
257 config: Option<&RunnableConfig>,
258 filter: Option<&HashMap<String, JsonValue>>,
259 before: Option<&RunnableConfig>,
260 limit: Option<usize>,
261 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
262 let mut conditions: Vec<String> = Vec::new();
264 let mut binds: Vec<String> = Vec::new();
265
266 if let Some(cfg) = config {
267 if let Some(thread_id) = cfg
268 .get("configurable")
269 .and_then(|c| c.get("thread_id"))
270 .and_then(|v| v.as_str())
271 {
272 conditions.push(format!("thread_id = ?{}", binds.len() + 1));
273 binds.push(thread_id.to_string());
274 }
275 if let Some(ns) = cfg
276 .get("configurable")
277 .and_then(|c| c.get("checkpoint_ns"))
278 .and_then(|v| v.as_str())
279 {
280 conditions.push(format!("checkpoint_ns = ?{}", binds.len() + 1));
281 binds.push(ns.to_string());
282 }
283 if let Some(cid) = get_checkpoint_id(cfg) {
284 conditions.push(format!("checkpoint_id = ?{}", binds.len() + 1));
285 binds.push(cid);
286 }
287 }
288
289 if let Some(meta_filter) = filter {
296 for (key, value) in meta_filter {
297 validate_filter_key(key)?;
298 conditions.push(format!(
299 "json_extract(metadata, '$.{}') = json_extract(?{}, '$')",
300 key,
301 binds.len() + 1
302 ));
303 binds.push(serde_json::to_string(value).unwrap_or_else(|_| "null".to_string()));
304 }
305 }
306
307 if let Some(before_cfg) = before {
308 if let Some(before_id) = get_checkpoint_id(before_cfg) {
309 conditions.push(format!("checkpoint_id < ?{}", binds.len() + 1));
310 binds.push(before_id);
311 }
312 }
313
314 let where_clause = if conditions.is_empty() {
315 String::new()
316 } else {
317 format!("WHERE {}", conditions.join(" AND "))
318 };
319
320 let mut query = format!(
321 "{} {} ORDER BY checkpoint_id DESC",
322 SELECT_CHECKPOINT_SQL, where_clause
323 );
324 if let Some(lim) = limit {
325 query.push_str(&format!(" LIMIT {}", lim));
326 }
327
328 let mut q = sqlx::query(&query);
329 for b in &binds {
330 q = q.bind(b.as_str());
331 }
332
333 let rows = q
334 .fetch_all(&self.pool)
335 .await
336 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
337
338 let mut results = Vec::with_capacity(rows.len());
339 for row in rows {
340 let mut tuple = Self::row_to_tuple(&row)?;
341 let thread_id = row.get::<String, _>("thread_id");
343 let ns = row.get::<String, _>("checkpoint_ns");
344 let cid = tuple.checkpoint.id.clone();
345 let blob_values = self.load_blobs(&thread_id, &ns, &cid).await?;
346 if !blob_values.is_empty() {
347 tuple.checkpoint.channel_values = blob_values;
348 }
349 tuple.pending_writes = Some(self.load_writes(&thread_id, &ns, &cid).await?);
350 results.push(tuple);
351 }
352 Ok(results)
353 }
354}
355
356fn block_on_in_runtime<F, T>(future: F) -> Result<T, CheckpointError>
372where
373 F: std::future::Future<Output = Result<T, CheckpointError>>,
374{
375 match tokio::runtime::Handle::try_current() {
376 Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)),
377 Err(_) => {
378 let rt = tokio::runtime::Runtime::new()
379 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
380 rt.block_on(future)
381 }
382 }
383}
384
385fn validate_filter_key(key: &str) -> Result<(), CheckpointError> {
390 if key.is_empty()
391 || key
392 .chars()
393 .any(|c| !(c.is_ascii_alphanumeric() || c == '.' || c == '_' || c == '-'))
394 {
395 return Err(CheckpointError::Config(format!(
396 "invalid metadata filter key: {:?}",
397 key
398 )));
399 }
400 Ok(())
401}
402
403fn any_to_json(val: Box<dyn std::any::Any>) -> JsonValue {
405 if val.is::<JsonValue>() {
406 *val.downcast::<JsonValue>().unwrap()
407 } else if val.is::<String>() {
408 JsonValue::String(*val.downcast::<String>().unwrap())
409 } else if val.is::<Vec<u8>>() {
410 let b = val.downcast::<Vec<u8>>().unwrap();
411 JsonValue::Array(b.into_iter().map(|byte: u8| JsonValue::Number(byte.into())).collect())
412 } else {
413 JsonValue::Null
414 }
415}
416
417#[async_trait]
418impl BaseCheckpointSaver for SqliteSaver {
419 fn get_tuple(
420 &self,
421 config: &RunnableConfig,
422 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
423 block_on_in_runtime(self.aget_tuple(config))
424 }
425
426 fn list(
427 &self,
428 config: Option<&RunnableConfig>,
429 filter: Option<&HashMap<String, JsonValue>>,
430 before: Option<&RunnableConfig>,
431 limit: Option<usize>,
432 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
433 block_on_in_runtime(self.alist(config, filter, before, limit))
434 }
435
436 fn put(
437 &self,
438 config: &RunnableConfig,
439 checkpoint: &Checkpoint,
440 metadata: &CheckpointMetadata,
441 new_versions: &ChannelVersions,
442 ) -> Result<RunnableConfig, CheckpointError> {
443 block_on_in_runtime(self.aput(config, checkpoint, metadata, new_versions))
444 }
445
446 fn put_writes(
447 &self,
448 config: &RunnableConfig,
449 writes: &[(String, String, JsonValue)],
450 task_id: &str,
451 task_path: &str,
452 ) -> Result<(), CheckpointError> {
453 block_on_in_runtime(self.aput_writes(
454 config,
455 writes.to_vec(),
456 task_id.to_string(),
457 task_path.to_string(),
458 ))
459 }
460
461 fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError> {
462 block_on_in_runtime(self.adelete_thread(thread_id.to_string()))
463 }
464
465 async fn aget_tuple(
466 &self,
467 config: &RunnableConfig,
468 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
469 let thread_id = config
470 .get("configurable")
471 .and_then(|c| c.get("thread_id"))
472 .and_then(|v| v.as_str())
473 .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
474
475 let checkpoint_ns = config
476 .get("configurable")
477 .and_then(|c| c.get("checkpoint_ns"))
478 .and_then(|v| v.as_str())
479 .unwrap_or("");
480
481 let checkpoint_id = get_checkpoint_id(config);
482
483 let row = if let Some(cid) = &checkpoint_id {
484 sqlx::query(&format!(
485 "{} WHERE thread_id = ?1 AND checkpoint_ns = ?2 AND checkpoint_id = ?3",
486 SELECT_CHECKPOINT_SQL
487 ))
488 .bind(thread_id)
489 .bind(checkpoint_ns)
490 .bind(cid.as_str())
491 .fetch_optional(&self.pool)
492 .await
493 .map_err(|e| CheckpointError::Storage(e.to_string()))?
494 } else {
495 sqlx::query(&format!(
496 "{} WHERE thread_id = ?1 AND checkpoint_ns = ?2 ORDER BY checkpoint_id DESC LIMIT 1",
497 SELECT_CHECKPOINT_SQL
498 ))
499 .bind(thread_id)
500 .bind(checkpoint_ns)
501 .fetch_optional(&self.pool)
502 .await
503 .map_err(|e| CheckpointError::Storage(e.to_string()))?
504 };
505
506 let row = match row {
507 Some(r) => r,
508 None => return Ok(None),
509 };
510
511 let mut tuple = Self::row_to_tuple(&row)?;
512 let cid = tuple.checkpoint.id.clone();
513 let blob_values = self.load_blobs(thread_id, checkpoint_ns, &cid).await?;
514 if !blob_values.is_empty() {
515 tuple.checkpoint.channel_values = blob_values;
516 }
517 tuple.pending_writes = Some(self.load_writes(thread_id, checkpoint_ns, &cid).await?);
518 Ok(Some(tuple))
519 }
520
521 async fn aput(
522 &self,
523 config: &RunnableConfig,
524 checkpoint: &Checkpoint,
525 metadata: &CheckpointMetadata,
526 new_versions: &ChannelVersions,
527 ) -> Result<RunnableConfig, CheckpointError> {
528 let configurable = config.get("configurable").cloned().unwrap_or_default();
529 let thread_id = configurable
530 .get("thread_id")
531 .and_then(|v| v.as_str())
532 .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
533 let checkpoint_ns = configurable
534 .get("checkpoint_ns")
535 .and_then(|v| v.as_str())
536 .unwrap_or("");
537 let parent_checkpoint_id: Option<String> = configurable
538 .get("checkpoint_id")
539 .and_then(|v| v.as_str())
540 .map(|s| s.to_string());
541
542 let next_config = Self::make_config(thread_id, checkpoint_ns, &checkpoint.id);
543
544 let mut checkpoint_value = serde_json::to_value(checkpoint)
547 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
548 if let Some(obj) = checkpoint_value.as_object_mut() {
549 obj.insert("channel_values".to_string(), JsonValue::Object(Default::default()));
550 }
551 let checkpoint_text = serde_json::to_string(&checkpoint_value)
552 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
553 let merged_metadata = get_checkpoint_metadata(config, metadata);
557 let metadata_text = serde_json::to_string(&merged_metadata)
558 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
559
560 let mut tx = self
561 .pool
562 .begin()
563 .await
564 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
565
566 let blobs = self.dump_blobs(
568 thread_id,
569 checkpoint_ns,
570 &checkpoint.channel_values,
571 new_versions,
572 );
573 for (tid, cns, channel, version, type_tag, blob) in &blobs {
574 sqlx::query(UPSERT_CHECKPOINT_BLOBS_SQL)
575 .bind(tid.as_str())
576 .bind(cns.as_str())
577 .bind(channel.as_str())
578 .bind(version.as_str())
579 .bind(type_tag.as_str())
580 .bind(blob.as_deref())
581 .execute(&mut *tx)
582 .await
583 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
584 }
585
586 sqlx::query(UPSERT_CHECKPOINTS_SQL)
588 .bind(thread_id)
589 .bind(checkpoint_ns)
590 .bind(checkpoint.id.as_str())
591 .bind(parent_checkpoint_id.as_deref())
592 .bind(checkpoint_text.as_str())
593 .bind(metadata_text.as_str())
594 .execute(&mut *tx)
595 .await
596 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
597
598 tx.commit()
599 .await
600 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
601
602 Ok(next_config)
603 }
604
605 async fn aput_writes(
606 &self,
607 config: &RunnableConfig,
608 writes: Vec<(String, String, JsonValue)>,
609 task_id: String,
610 task_path: String,
611 ) -> Result<(), CheckpointError> {
612 let configurable = config.get("configurable").cloned().unwrap_or_default();
613 let thread_id = configurable
614 .get("thread_id")
615 .and_then(|v| v.as_str())
616 .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
617 let checkpoint_ns = configurable
618 .get("checkpoint_ns")
619 .and_then(|v| v.as_str())
620 .unwrap_or("");
621 let checkpoint_id = configurable
633 .get("checkpoint_id")
634 .and_then(|v| v.as_str())
635 .unwrap_or("");
636
637 let idx_map = writes_idx_map();
638 let use_upsert = writes
639 .iter()
640 .all(|(channel, _, _)| idx_map.contains_key(channel.as_str()));
641
642 let query = if use_upsert {
643 UPSERT_CHECKPOINT_WRITES_SQL
644 } else {
645 INSERT_CHECKPOINT_WRITES_SQL
646 };
647
648 let mut tx = self
649 .pool
650 .begin()
651 .await
652 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
653
654 for (idx, (_task_id_in_tuple, channel, value)) in writes.iter().enumerate() {
655 let idx_val: i64 = idx_map
656 .get(channel.as_str())
657 .copied()
658 .unwrap_or(idx as i64);
659
660 let (type_tag, blob) = match self.serde.dumps_typed(value) {
661 Ok(pair) => pair,
662 Err(_) => continue,
663 };
664
665 sqlx::query(query)
666 .bind(thread_id)
667 .bind(checkpoint_ns)
668 .bind(checkpoint_id)
669 .bind(task_id.as_str())
670 .bind(task_path.as_str())
671 .bind(idx_val)
672 .bind(channel.as_str())
673 .bind(type_tag.as_str())
674 .bind(blob.as_slice())
675 .execute(&mut *tx)
676 .await
677 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
678 }
679
680 tx.commit()
681 .await
682 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
683
684 Ok(())
685 }
686
687 async fn adelete_thread(&self, thread_id: String) -> Result<(), CheckpointError> {
688 let mut tx = self
689 .pool
690 .begin()
691 .await
692 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
693
694 sqlx::query("DELETE FROM checkpoints WHERE thread_id = ?1")
695 .bind(thread_id.as_str())
696 .execute(&mut *tx)
697 .await
698 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
699
700 sqlx::query("DELETE FROM checkpoint_blobs WHERE thread_id = ?1")
701 .bind(thread_id.as_str())
702 .execute(&mut *tx)
703 .await
704 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
705
706 sqlx::query("DELETE FROM checkpoint_writes WHERE thread_id = ?1")
707 .bind(thread_id.as_str())
708 .execute(&mut *tx)
709 .await
710 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
711
712 tx.commit()
713 .await
714 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
715
716 Ok(())
717 }
718}
719
720#[cfg(test)]
721mod tests {
722 use super::*;
723 use std::collections::HashMap;
724
725 async fn fresh_saver() -> SqliteSaver {
726 let saver = SqliteSaver::from_conn_string("sqlite::memory:")
727 .await
728 .expect("connect to in-memory sqlite");
729 saver.setup().await.expect("setup migrations");
730 saver
731 }
732
733 fn config_for(thread_id: &str) -> RunnableConfig {
734 serde_json::from_value(serde_json::json!({
735 "configurable": { "thread_id": thread_id, "checkpoint_ns": "" }
736 }))
737 .unwrap()
738 }
739
740 fn config_with_id(thread_id: &str, checkpoint_id: &str) -> RunnableConfig {
741 serde_json::from_value(serde_json::json!({
742 "configurable": {
743 "thread_id": thread_id,
744 "checkpoint_ns": "",
745 "checkpoint_id": checkpoint_id,
746 }
747 }))
748 .unwrap()
749 }
750
751 fn make_checkpoint(channel_values: Vec<(&str, JsonValue)>) -> (Checkpoint, ChannelVersions) {
752 let mut cp = Checkpoint::empty();
753 let mut versions: ChannelVersions = HashMap::new();
754 for (k, v) in channel_values {
755 cp.channel_values.insert(k.to_string(), v);
756 cp.channel_versions
757 .insert(k.to_string(), JsonValue::Number(1.into()));
758 versions.insert(k.to_string(), JsonValue::Number(1.into()));
759 }
760 (cp, versions)
761 }
762
763 #[tokio::test]
764 async fn test_setup_is_idempotent() {
765 let saver = fresh_saver().await;
766 saver.setup().await.expect("second setup");
768 }
769
770 #[tokio::test]
771 async fn test_get_tuple_returns_none_when_empty() {
772 let saver = fresh_saver().await;
773 let cfg = config_for("missing");
774 let result = saver.aget_tuple(&cfg).await.unwrap();
775 assert!(result.is_none());
776 }
777
778 #[tokio::test]
779 async fn test_put_then_get_roundtrip() {
780 let saver = fresh_saver().await;
781 let (cp, versions) = make_checkpoint(vec![
782 ("messages", serde_json::json!(["hello", "world"])),
783 ("counter", serde_json::json!(7)),
784 ]);
785 let cfg = config_for("thread-A");
786 let metadata = CheckpointMetadata {
787 source: Some(CheckpointSource::Loop),
788 step: Some(3),
789 ..Default::default()
790 };
791
792 let next = saver.aput(&cfg, &cp, &metadata, &versions).await.unwrap();
793
794 let returned_cid = next
796 .get("configurable")
797 .and_then(|c| c.get("checkpoint_id"))
798 .and_then(|v| v.as_str())
799 .unwrap();
800 assert_eq!(returned_cid, cp.id);
801
802 let tuple = saver.aget_tuple(&cfg).await.unwrap().expect("tuple exists");
804 assert_eq!(tuple.checkpoint.id, cp.id);
805 assert_eq!(tuple.metadata.step, Some(3));
806 assert_eq!(
807 tuple.checkpoint.channel_values.get("messages"),
808 Some(&serde_json::json!(["hello", "world"]))
809 );
810 assert_eq!(
811 tuple.checkpoint.channel_values.get("counter"),
812 Some(&serde_json::json!(7))
813 );
814 }
815
816 #[tokio::test]
817 async fn test_put_writes_and_pending_writes_round_trip() {
818 let saver = fresh_saver().await;
819 let (cp, versions) = make_checkpoint(vec![("a", serde_json::json!(1))]);
820 let cfg = config_for("thread-W");
821 saver
822 .aput(&cfg, &cp, &CheckpointMetadata::default(), &versions)
823 .await
824 .unwrap();
825
826 let cfg_with_id = config_with_id("thread-W", &cp.id);
827 let writes = vec![
828 ("ch1".to_string(), "task-1".to_string(), serde_json::json!("v1")),
829 ("ch2".to_string(), "task-1".to_string(), serde_json::json!(42)),
830 ];
831 saver
832 .aput_writes(&cfg_with_id, writes, "task-1".into(), "".into())
833 .await
834 .unwrap();
835
836 let tuple = saver.aget_tuple(&cfg_with_id).await.unwrap().unwrap();
837 let pending = tuple.pending_writes.expect("pending writes loaded");
838 assert_eq!(pending.len(), 2);
839 assert_eq!(pending[0].1, "ch1");
841 assert_eq!(pending[1].1, "ch2");
842 assert_eq!(pending[1].2, serde_json::json!(42));
843 }
844
845 #[tokio::test]
846 async fn test_list_orders_descending_and_respects_limit() {
847 let saver = fresh_saver().await;
848 let cfg = config_for("thread-L");
849 let mut ids = Vec::new();
850 for i in 0..3 {
851 let (cp, versions) = make_checkpoint(vec![("x", serde_json::json!(i))]);
852 ids.push(cp.id.clone());
853 saver
854 .aput(&cfg, &cp, &CheckpointMetadata::default(), &versions)
855 .await
856 .unwrap();
857 }
858
859 let all = saver.alist(Some(&cfg), None, None, None).await.unwrap();
860 assert_eq!(all.len(), 3);
861 for w in all.windows(2) {
865 assert!(w[0].checkpoint.id >= w[1].checkpoint.id);
866 }
867 let returned_ids: std::collections::HashSet<_> =
869 all.iter().map(|t| t.checkpoint.id.clone()).collect();
870 for id in &ids {
871 assert!(returned_ids.contains(id));
872 }
873
874 let limited = saver.alist(Some(&cfg), None, None, Some(2)).await.unwrap();
875 assert_eq!(limited.len(), 2);
876 }
877
878 #[tokio::test]
879 async fn test_delete_thread_removes_all_data() {
880 let saver = fresh_saver().await;
881 let (cp, versions) = make_checkpoint(vec![("x", serde_json::json!(1))]);
882 let cfg = config_for("thread-D");
883 saver
884 .aput(&cfg, &cp, &CheckpointMetadata::default(), &versions)
885 .await
886 .unwrap();
887 let cfg_with_id = config_with_id("thread-D", &cp.id);
888 saver
889 .aput_writes(
890 &cfg_with_id,
891 vec![("ch".into(), "task".into(), serde_json::json!("v"))],
892 "task".into(),
893 "".into(),
894 )
895 .await
896 .unwrap();
897
898 saver.adelete_thread("thread-D".into()).await.unwrap();
899 assert!(saver.aget_tuple(&cfg).await.unwrap().is_none());
900 let listed = saver.alist(Some(&cfg), None, None, None).await.unwrap();
901 assert!(listed.is_empty());
902 }
903
904 #[tokio::test]
905 async fn test_value_updates_when_version_increments() {
906 let saver = fresh_saver().await;
909 let cfg = config_for("thread-V");
910
911 let mut cp1 = Checkpoint::empty();
913 cp1.channel_values
914 .insert("counter".into(), JsonValue::Number(1.into()));
915 cp1.channel_versions
916 .insert("counter".into(), JsonValue::Number(1.into()));
917 let mut versions1: ChannelVersions = HashMap::new();
918 versions1.insert("counter".into(), JsonValue::Number(1.into()));
919 saver
920 .aput(&cfg, &cp1, &CheckpointMetadata::default(), &versions1)
921 .await
922 .unwrap();
923
924 tokio::time::sleep(std::time::Duration::from_millis(2)).await;
925
926 let mut cp2 = Checkpoint::empty();
928 cp2.channel_values
929 .insert("counter".into(), JsonValue::Number(99.into()));
930 cp2.channel_versions
931 .insert("counter".into(), JsonValue::Number(2.into()));
932 let mut versions2: ChannelVersions = HashMap::new();
933 versions2.insert("counter".into(), JsonValue::Number(2.into()));
934 saver
935 .aput(&cfg, &cp2, &CheckpointMetadata::default(), &versions2)
936 .await
937 .unwrap();
938
939 let cfg_cp2 = config_with_id("thread-V", &cp2.id);
940 let tuple = saver.aget_tuple(&cfg_cp2).await.unwrap().unwrap();
941 assert_eq!(
942 tuple.checkpoint.channel_values.get("counter"),
943 Some(&JsonValue::Number(99.into()))
944 );
945
946 let cfg_cp1 = config_with_id("thread-V", &cp1.id);
948 let earlier = saver.aget_tuple(&cfg_cp1).await.unwrap().unwrap();
949 assert_eq!(
950 earlier.checkpoint.channel_values.get("counter"),
951 Some(&JsonValue::Number(1.into()))
952 );
953 }
954
955 #[tokio::test]
956 async fn test_metadata_filter_returns_only_matching_rows() {
957 let saver = fresh_saver().await;
958 let cfg = config_for("thread-F");
959
960 for (source, step, val) in [
962 (CheckpointSource::Input, 0, "a"),
963 (CheckpointSource::Loop, 1, "b"),
964 (CheckpointSource::Loop, 2, "c"),
965 ] {
966 let (cp, vers) = make_checkpoint(vec![("x", serde_json::json!(val))]);
967 let meta = CheckpointMetadata {
968 source: Some(source),
969 step: Some(step),
970 ..Default::default()
971 };
972 saver.aput(&cfg, &cp, &meta, &vers).await.unwrap();
973 tokio::time::sleep(std::time::Duration::from_millis(2)).await;
974 }
975
976 let mut filter = HashMap::new();
978 filter.insert("source".into(), serde_json::json!("loop"));
979 let loop_only = saver
980 .alist(Some(&cfg), Some(&filter), None, None)
981 .await
982 .unwrap();
983 assert_eq!(loop_only.len(), 2);
984 for t in &loop_only {
985 assert_eq!(t.metadata.source, Some(CheckpointSource::Loop));
986 }
987
988 let mut filter = HashMap::new();
990 filter.insert("step".into(), serde_json::json!(1));
991 let step_one = saver
992 .alist(Some(&cfg), Some(&filter), None, None)
993 .await
994 .unwrap();
995 assert_eq!(step_one.len(), 1);
996 assert_eq!(step_one[0].metadata.step, Some(1));
997
998 let mut filter = HashMap::new();
1000 filter.insert("source".into(), serde_json::json!("loop"));
1001 filter.insert("step".into(), serde_json::json!(2));
1002 let combined = saver
1003 .alist(Some(&cfg), Some(&filter), None, None)
1004 .await
1005 .unwrap();
1006 assert_eq!(combined.len(), 1);
1007 assert_eq!(combined[0].metadata.step, Some(2));
1008 }
1009
1010 #[test]
1011 fn test_validate_filter_key_rejects_injection_attempts() {
1012 assert!(validate_filter_key("source").is_ok());
1014 assert!(validate_filter_key("nested.field").is_ok());
1015 assert!(validate_filter_key("snake_case").is_ok());
1016 assert!(validate_filter_key("kebab-case").is_ok());
1017 assert!(validate_filter_key("Mixed123").is_ok());
1018
1019 assert!(validate_filter_key("").is_err());
1021 assert!(validate_filter_key("source'; DROP TABLE--").is_err());
1022 assert!(validate_filter_key("a\"b").is_err());
1023 assert!(validate_filter_key("a b").is_err());
1024 assert!(validate_filter_key("[admin]").is_err());
1025 assert!(validate_filter_key("中文").is_err());
1026 }
1027
1028 #[tokio::test]
1029 async fn test_config_langgraph_step_merged_into_metadata() {
1030 let saver = fresh_saver().await;
1034 let cfg: RunnableConfig = serde_json::from_value(serde_json::json!({
1035 "configurable": {
1036 "thread_id": "thread-M",
1037 "checkpoint_ns": "",
1038 "langgraph_step": 7
1039 }
1040 }))
1041 .unwrap();
1042
1043 let (cp, vers) = make_checkpoint(vec![("x", serde_json::json!(1))]);
1044 saver
1047 .aput(&cfg, &cp, &CheckpointMetadata::default(), &vers)
1048 .await
1049 .unwrap();
1050
1051 let cfg_with_id = config_with_id("thread-M", &cp.id);
1052 let tuple = saver.aget_tuple(&cfg_with_id).await.unwrap().unwrap();
1053 assert_eq!(tuple.metadata.step, Some(7));
1054 }
1055
1056 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1061 async fn test_sync_methods_work_inside_multi_thread_runtime() {
1062 let saver = fresh_saver().await;
1063 let saver = std::sync::Arc::new(saver);
1064 let cfg = config_for("thread-S");
1065
1066 let (cp, vers) = make_checkpoint(vec![("k", serde_json::json!("v"))]);
1069 let s2 = saver.clone();
1070 let cfg2 = cfg.clone();
1071 let cp_clone = cp.clone();
1072 let vers_clone = vers.clone();
1073 let put_result = tokio::task::spawn_blocking(move || {
1074 s2.put(&cfg2, &cp_clone, &CheckpointMetadata::default(), &vers_clone)
1075 })
1076 .await
1077 .unwrap();
1078 assert!(put_result.is_ok());
1079
1080 let s3 = saver.clone();
1081 let cfg3 = cfg.clone();
1082 let get_result = tokio::task::spawn_blocking(move || s3.get_tuple(&cfg3))
1083 .await
1084 .unwrap()
1085 .unwrap();
1086 assert!(get_result.is_some());
1087 assert_eq!(get_result.unwrap().checkpoint.id, cp.id);
1088 }
1089
1090 #[tokio::test]
1091 async fn test_parent_config_links_checkpoints() {
1092 let saver = fresh_saver().await;
1093 let (cp1, vers1) = make_checkpoint(vec![("x", serde_json::json!("a"))]);
1094 let cfg = config_for("thread-P");
1095 let next1 = saver
1096 .aput(&cfg, &cp1, &CheckpointMetadata::default(), &vers1)
1097 .await
1098 .unwrap();
1099
1100 tokio::time::sleep(std::time::Duration::from_millis(2)).await;
1104
1105 let (cp2, vers2) = make_checkpoint(vec![("x", serde_json::json!("b"))]);
1108 saver
1109 .aput(&next1, &cp2, &CheckpointMetadata::default(), &vers2)
1110 .await
1111 .unwrap();
1112
1113 let cfg_cp2 = config_with_id("thread-P", &cp2.id);
1115 let latest = saver.aget_tuple(&cfg_cp2).await.unwrap().unwrap();
1116 assert_eq!(latest.checkpoint.id, cp2.id);
1117 let parent = latest.parent_config.expect("parent_config present");
1118 let parent_id = parent
1119 .get("configurable")
1120 .and_then(|c| c.get("checkpoint_id"))
1121 .and_then(|v| v.as_str())
1122 .unwrap();
1123 assert_eq!(parent_id, cp1.id);
1124 }
1125}