Skip to main content

langgraph_checkpoint_sqlite_rs/
saver.rs

1use std::collections::HashMap;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::Value as JsonValue;
7use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow};
8use sqlx::Row;
9
10use langgraph_checkpoint::checkpoint::base::{
11    get_checkpoint_id, get_checkpoint_metadata, writes_idx_map, BaseCheckpointSaver,
12};
13use langgraph_checkpoint::checkpoint::types::*;
14use langgraph_checkpoint::config::RunnableConfig;
15use langgraph_checkpoint::error::CheckpointError;
16use langgraph_checkpoint::serde::base::SerializerProtocol;
17use langgraph_checkpoint::serde::jsonplus::JsonPlusSerializer;
18
19use crate::queries::*;
20
21/// Async SQLite checkpoint saver using sqlx.
22///
23/// Uses a three-table schema (`checkpoints`, `checkpoint_blobs`,
24/// `checkpoint_writes`) consistent with the Postgres implementation.
25pub struct SqliteSaver {
26    pool: SqlitePool,
27    serde: Arc<dyn SerializerProtocol>,
28}
29
30impl SqliteSaver {
31    /// Create a new SqliteSaver from an existing connection pool.
32    pub fn new(pool: SqlitePool) -> Self {
33        Self {
34            pool,
35            serde: Arc::new(JsonPlusSerializer::new()),
36        }
37    }
38
39    /// Create a new SqliteSaver with a custom serializer.
40    pub fn with_serde(pool: SqlitePool, serde: Arc<dyn SerializerProtocol>) -> Self {
41        Self { pool, serde }
42    }
43
44    /// Create a SqliteSaver from a connection string.
45    ///
46    /// Accepts standard sqlx URIs such as `"sqlite::memory:"` or
47    /// `"sqlite:./checkpoints.db"`. The database file is created if
48    /// it does not exist, and WAL journal mode is enabled.
49    pub async fn from_conn_string(conn_string: &str) -> Result<Self, CheckpointError> {
50        let opts = SqliteConnectOptions::from_str(conn_string)
51            .map_err(|e| CheckpointError::Storage(e.to_string()))?
52            .create_if_missing(true)
53            .journal_mode(SqliteJournalMode::Wal);
54
55        let pool = SqlitePoolOptions::new()
56            .max_connections(5)
57            .connect_with(opts)
58            .await
59            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
60
61        Ok(Self::new(pool))
62    }
63
64    /// Run migrations to set up the checkpoint schema. Idempotent.
65    pub async fn setup(&self) -> Result<(), CheckpointError> {
66        // Bootstrap migrations table first
67        sqlx::query(MIGRATIONS[0])
68            .execute(&self.pool)
69            .await
70            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
71
72        let row: Option<(i64,)> = sqlx::query_as(
73            "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1",
74        )
75        .fetch_optional(&self.pool)
76        .await
77        .map_err(|e| CheckpointError::Storage(e.to_string()))?;
78
79        let version = row.map(|(v,)| v).unwrap_or(-1);
80
81        for (i, migration) in MIGRATIONS.iter().enumerate() {
82            let v = i as i64;
83            if v > version {
84                sqlx::query(migration)
85                    .execute(&self.pool)
86                    .await
87                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
88                sqlx::query("INSERT INTO checkpoint_migrations (v) VALUES (?1)")
89                    .bind(v)
90                    .execute(&self.pool)
91                    .await
92                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
93            }
94        }
95
96        Ok(())
97    }
98
99    /// Get the underlying connection pool.
100    pub fn pool(&self) -> &SqlitePool {
101        &self.pool
102    }
103
104    /// Helper: build a RunnableConfig referring to a specific checkpoint.
105    fn make_config(thread_id: &str, checkpoint_ns: &str, checkpoint_id: &str) -> RunnableConfig {
106        serde_json::from_value(serde_json::json!({
107            "configurable": {
108                "thread_id": thread_id,
109                "checkpoint_ns": checkpoint_ns,
110                "checkpoint_id": checkpoint_id,
111            }
112        }))
113        .unwrap_or_default()
114    }
115
116    /// Convert a checkpoint row into a `CheckpointTuple`. Channel values
117    /// from the row's JSON `checkpoint` column are kept as-is; the
118    /// authoritative blob storage is reconciled separately by the caller.
119    fn row_to_tuple(row: &SqliteRow) -> Result<CheckpointTuple, CheckpointError> {
120        let thread_id: String = row.get("thread_id");
121        let checkpoint_ns: String = row.get("checkpoint_ns");
122        let checkpoint_text: String = row.get("checkpoint");
123        let metadata_text: String = row.get("metadata");
124
125        let checkpoint: Checkpoint = serde_json::from_str(&checkpoint_text)
126            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
127        let metadata: CheckpointMetadata = serde_json::from_str(&metadata_text)
128            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
129
130        let parent_checkpoint_id: Option<String> = row.try_get("parent_checkpoint_id").ok();
131        let parent_config = parent_checkpoint_id.map(|pid| {
132            Self::make_config(&thread_id, &checkpoint_ns, &pid)
133        });
134
135        let tuple_config = Self::make_config(&thread_id, &checkpoint_ns, &checkpoint.id);
136
137        Ok(CheckpointTuple {
138            config: tuple_config,
139            checkpoint,
140            metadata,
141            parent_config,
142            pending_writes: None,
143        })
144    }
145
146    /// Load blobs for a checkpoint and merge them into the channel_values map.
147    async fn load_blobs(
148        &self,
149        thread_id: &str,
150        checkpoint_ns: &str,
151        checkpoint_id: &str,
152    ) -> Result<HashMap<String, JsonValue>, CheckpointError> {
153        let rows = sqlx::query(SELECT_BLOBS_SQL)
154            .bind(thread_id)
155            .bind(checkpoint_ns)
156            .bind(checkpoint_id)
157            .fetch_all(&self.pool)
158            .await
159            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
160
161        let mut values: HashMap<String, JsonValue> = HashMap::new();
162        for row in rows {
163            let channel: String = row.get("channel");
164            let type_tag: String = row.get("type");
165            let blob: Option<Vec<u8>> = row.try_get("blob").ok();
166
167            if type_tag == "empty" || blob.is_none() {
168                continue;
169            }
170            let bytes = blob.unwrap();
171            let val = match self.serde.loads_typed(&type_tag, &bytes) {
172                Ok(any_val) => any_to_json(any_val),
173                Err(_) => continue,
174            };
175            values.insert(channel, val);
176        }
177        Ok(values)
178    }
179
180    /// Load pending writes for a checkpoint.
181    async fn load_writes(
182        &self,
183        thread_id: &str,
184        checkpoint_ns: &str,
185        checkpoint_id: &str,
186    ) -> Result<Vec<PendingWrite>, CheckpointError> {
187        let rows = sqlx::query(SELECT_WRITES_SQL)
188            .bind(thread_id)
189            .bind(checkpoint_ns)
190            .bind(checkpoint_id)
191            .fetch_all(&self.pool)
192            .await
193            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
194
195        let mut writes = Vec::with_capacity(rows.len());
196        for row in rows {
197            let task_id: String = row.get("task_id");
198            let channel: String = row.get("channel");
199            let type_tag: Option<String> = row.try_get("type").ok();
200            let blob: Option<Vec<u8>> = row.try_get("blob").ok();
201
202            let value = match (type_tag.as_deref(), blob) {
203                (Some(tag), Some(bytes)) => match self.serde.loads_typed(tag, &bytes) {
204                    Ok(any_val) => any_to_json(any_val),
205                    Err(_) => JsonValue::Null,
206                },
207                _ => JsonValue::Null,
208            };
209            writes.push((task_id, channel, value));
210        }
211        Ok(writes)
212    }
213
214    /// Serialize new channel values into blob rows.
215    fn dump_blobs(
216        &self,
217        thread_id: &str,
218        checkpoint_ns: &str,
219        values: &HashMap<String, JsonValue>,
220        versions: &ChannelVersions,
221    ) -> Vec<(String, String, String, String, String, Option<Vec<u8>>)> {
222        let mut result = Vec::new();
223        for (channel, ver) in versions {
224            let ver_str = match ver {
225                JsonValue::String(s) => s.clone(),
226                JsonValue::Number(n) => n.to_string(),
227                _ => continue,
228            };
229            if let Some(val) = values.get(channel) {
230                if let Ok((type_tag, blob)) = self.serde.dumps_typed(val) {
231                    result.push((
232                        thread_id.to_string(),
233                        checkpoint_ns.to_string(),
234                        channel.clone(),
235                        ver_str,
236                        type_tag,
237                        Some(blob),
238                    ));
239                }
240            } else {
241                result.push((
242                    thread_id.to_string(),
243                    checkpoint_ns.to_string(),
244                    channel.clone(),
245                    ver_str,
246                    "empty".to_string(),
247                    None,
248                ));
249            }
250        }
251        result
252    }
253
254    /// Async list method.
255    pub async fn alist(
256        &self,
257        config: Option<&RunnableConfig>,
258        filter: Option<&HashMap<String, JsonValue>>,
259        before: Option<&RunnableConfig>,
260        limit: Option<usize>,
261    ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
262        // Build dynamic WHERE with positional parameters (?1, ?2, ...).
263        let mut conditions: Vec<String> = Vec::new();
264        let mut binds: Vec<String> = Vec::new();
265
266        if let Some(cfg) = config {
267            if let Some(thread_id) = cfg
268                .get("configurable")
269                .and_then(|c| c.get("thread_id"))
270                .and_then(|v| v.as_str())
271            {
272                conditions.push(format!("thread_id = ?{}", binds.len() + 1));
273                binds.push(thread_id.to_string());
274            }
275            if let Some(ns) = cfg
276                .get("configurable")
277                .and_then(|c| c.get("checkpoint_ns"))
278                .and_then(|v| v.as_str())
279            {
280                conditions.push(format!("checkpoint_ns = ?{}", binds.len() + 1));
281                binds.push(ns.to_string());
282            }
283            if let Some(cid) = get_checkpoint_id(cfg) {
284                conditions.push(format!("checkpoint_id = ?{}", binds.len() + 1));
285                binds.push(cid);
286            }
287        }
288
289        // Metadata filter: emits one `json_extract(metadata, '$.key') =
290        // json_extract(?, '$')` clause per key. Both sides go through
291        // `json_extract` so comparison is type-uniform regardless of
292        // whether the value is a string, number, bool, array, or object.
293        // Filter keys are validated against an allow-list to prevent
294        // SQL injection via the inlined `'$.{key}'` JSON path.
295        if let Some(meta_filter) = filter {
296            for (key, value) in meta_filter {
297                validate_filter_key(key)?;
298                conditions.push(format!(
299                    "json_extract(metadata, '$.{}') = json_extract(?{}, '$')",
300                    key,
301                    binds.len() + 1
302                ));
303                binds.push(serde_json::to_string(value).unwrap_or_else(|_| "null".to_string()));
304            }
305        }
306
307        if let Some(before_cfg) = before {
308            if let Some(before_id) = get_checkpoint_id(before_cfg) {
309                conditions.push(format!("checkpoint_id < ?{}", binds.len() + 1));
310                binds.push(before_id);
311            }
312        }
313
314        let where_clause = if conditions.is_empty() {
315            String::new()
316        } else {
317            format!("WHERE {}", conditions.join(" AND "))
318        };
319
320        let mut query = format!(
321            "{} {} ORDER BY checkpoint_id DESC",
322            SELECT_CHECKPOINT_SQL, where_clause
323        );
324        if let Some(lim) = limit {
325            query.push_str(&format!(" LIMIT {}", lim));
326        }
327
328        let mut q = sqlx::query(&query);
329        for b in &binds {
330            q = q.bind(b.as_str());
331        }
332
333        let rows = q
334            .fetch_all(&self.pool)
335            .await
336            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
337
338        let mut results = Vec::with_capacity(rows.len());
339        for row in rows {
340            let mut tuple = Self::row_to_tuple(&row)?;
341            // Reconcile channel values from blobs.
342            let thread_id = row.get::<String, _>("thread_id");
343            let ns = row.get::<String, _>("checkpoint_ns");
344            let cid = tuple.checkpoint.id.clone();
345            let blob_values = self.load_blobs(&thread_id, &ns, &cid).await?;
346            if !blob_values.is_empty() {
347                tuple.checkpoint.channel_values = blob_values;
348            }
349            tuple.pending_writes = Some(self.load_writes(&thread_id, &ns, &cid).await?);
350            results.push(tuple);
351        }
352        Ok(results)
353    }
354}
355
356/// Bridge an async future to a sync caller.
357///
358/// **Local triage** — see PR notes. The trait's sync methods (`get_tuple`,
359/// `put`, `put_writes`, `delete_thread`, `list`) get invoked from inside
360/// `langgraph::graph::state::run_pregel_inner`, which is itself an
361/// `async fn`. Calling `Handle::block_on` from within a runtime panics
362/// with *"Cannot start a runtime from within a runtime"*. The proper
363/// fix is to make the graph runner call `aget_tuple`/`aput`/etc.; that
364/// touches `langgraph` and is out of scope for this crate.
365///
366/// As a stopgap we use `block_in_place` to escape the worker thread and
367/// then drive the future via the existing handle. **This requires a
368/// multi-thread runtime** — calling sync saver methods from a
369/// `current_thread` runtime will still panic. The `langgraph-checkpoint-postgres`
370/// crate has the identical limitation today.
371fn block_on_in_runtime<F, T>(future: F) -> Result<T, CheckpointError>
372where
373    F: std::future::Future<Output = Result<T, CheckpointError>>,
374{
375    match tokio::runtime::Handle::try_current() {
376        Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)),
377        Err(_) => {
378            let rt = tokio::runtime::Runtime::new()
379                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
380            rt.block_on(future)
381        }
382    }
383}
384
385/// Validate a metadata filter key. Allowed: ASCII letters, digits,
386/// dot, underscore, hyphen. Empty strings rejected. The validated key
387/// is interpolated into the SQL JSON path (`'$.{key}'`) so anything
388/// that could break out of the literal must be rejected here.
389fn validate_filter_key(key: &str) -> Result<(), CheckpointError> {
390    if key.is_empty()
391        || key
392            .chars()
393            .any(|c| !(c.is_ascii_alphanumeric() || c == '.' || c == '_' || c == '-'))
394    {
395        return Err(CheckpointError::Config(format!(
396            "invalid metadata filter key: {:?}",
397            key
398        )));
399    }
400    Ok(())
401}
402
403/// Best-effort conversion of a deserialized value (Box<dyn Any>) back into JSON.
404fn any_to_json(val: Box<dyn std::any::Any>) -> JsonValue {
405    if val.is::<JsonValue>() {
406        *val.downcast::<JsonValue>().unwrap()
407    } else if val.is::<String>() {
408        JsonValue::String(*val.downcast::<String>().unwrap())
409    } else if val.is::<Vec<u8>>() {
410        let b = val.downcast::<Vec<u8>>().unwrap();
411        JsonValue::Array(b.into_iter().map(|byte: u8| JsonValue::Number(byte.into())).collect())
412    } else {
413        JsonValue::Null
414    }
415}
416
417#[async_trait]
418impl BaseCheckpointSaver for SqliteSaver {
419    fn get_tuple(
420        &self,
421        config: &RunnableConfig,
422    ) -> Result<Option<CheckpointTuple>, CheckpointError> {
423        block_on_in_runtime(self.aget_tuple(config))
424    }
425
426    fn list(
427        &self,
428        config: Option<&RunnableConfig>,
429        filter: Option<&HashMap<String, JsonValue>>,
430        before: Option<&RunnableConfig>,
431        limit: Option<usize>,
432    ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
433        block_on_in_runtime(self.alist(config, filter, before, limit))
434    }
435
436    fn put(
437        &self,
438        config: &RunnableConfig,
439        checkpoint: &Checkpoint,
440        metadata: &CheckpointMetadata,
441        new_versions: &ChannelVersions,
442    ) -> Result<RunnableConfig, CheckpointError> {
443        block_on_in_runtime(self.aput(config, checkpoint, metadata, new_versions))
444    }
445
446    fn put_writes(
447        &self,
448        config: &RunnableConfig,
449        writes: &[(String, String, JsonValue)],
450        task_id: &str,
451        task_path: &str,
452    ) -> Result<(), CheckpointError> {
453        block_on_in_runtime(self.aput_writes(
454            config,
455            writes.to_vec(),
456            task_id.to_string(),
457            task_path.to_string(),
458        ))
459    }
460
461    fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError> {
462        block_on_in_runtime(self.adelete_thread(thread_id.to_string()))
463    }
464
465    async fn aget_tuple(
466        &self,
467        config: &RunnableConfig,
468    ) -> Result<Option<CheckpointTuple>, CheckpointError> {
469        let thread_id = config
470            .get("configurable")
471            .and_then(|c| c.get("thread_id"))
472            .and_then(|v| v.as_str())
473            .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
474
475        let checkpoint_ns = config
476            .get("configurable")
477            .and_then(|c| c.get("checkpoint_ns"))
478            .and_then(|v| v.as_str())
479            .unwrap_or("");
480
481        let checkpoint_id = get_checkpoint_id(config);
482
483        let row = if let Some(cid) = &checkpoint_id {
484            sqlx::query(&format!(
485                "{} WHERE thread_id = ?1 AND checkpoint_ns = ?2 AND checkpoint_id = ?3",
486                SELECT_CHECKPOINT_SQL
487            ))
488            .bind(thread_id)
489            .bind(checkpoint_ns)
490            .bind(cid.as_str())
491            .fetch_optional(&self.pool)
492            .await
493            .map_err(|e| CheckpointError::Storage(e.to_string()))?
494        } else {
495            sqlx::query(&format!(
496                "{} WHERE thread_id = ?1 AND checkpoint_ns = ?2 ORDER BY checkpoint_id DESC LIMIT 1",
497                SELECT_CHECKPOINT_SQL
498            ))
499            .bind(thread_id)
500            .bind(checkpoint_ns)
501            .fetch_optional(&self.pool)
502            .await
503            .map_err(|e| CheckpointError::Storage(e.to_string()))?
504        };
505
506        let row = match row {
507            Some(r) => r,
508            None => return Ok(None),
509        };
510
511        let mut tuple = Self::row_to_tuple(&row)?;
512        let cid = tuple.checkpoint.id.clone();
513        let blob_values = self.load_blobs(thread_id, checkpoint_ns, &cid).await?;
514        if !blob_values.is_empty() {
515            tuple.checkpoint.channel_values = blob_values;
516        }
517        tuple.pending_writes = Some(self.load_writes(thread_id, checkpoint_ns, &cid).await?);
518        Ok(Some(tuple))
519    }
520
521    async fn aput(
522        &self,
523        config: &RunnableConfig,
524        checkpoint: &Checkpoint,
525        metadata: &CheckpointMetadata,
526        new_versions: &ChannelVersions,
527    ) -> Result<RunnableConfig, CheckpointError> {
528        let configurable = config.get("configurable").cloned().unwrap_or_default();
529        let thread_id = configurable
530            .get("thread_id")
531            .and_then(|v| v.as_str())
532            .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
533        let checkpoint_ns = configurable
534            .get("checkpoint_ns")
535            .and_then(|v| v.as_str())
536            .unwrap_or("");
537        let parent_checkpoint_id: Option<String> = configurable
538            .get("checkpoint_id")
539            .and_then(|v| v.as_str())
540            .map(|s| s.to_string());
541
542        let next_config = Self::make_config(thread_id, checkpoint_ns, &checkpoint.id);
543
544        // Strip channel_values from the JSON checkpoint payload to avoid
545        // duplicating them in the row body — they live in checkpoint_blobs.
546        let mut checkpoint_value = serde_json::to_value(checkpoint)
547            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
548        if let Some(obj) = checkpoint_value.as_object_mut() {
549            obj.insert("channel_values".to_string(), JsonValue::Object(Default::default()));
550        }
551        let checkpoint_text = serde_json::to_string(&checkpoint_value)
552            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
553        // Merge config-level fields (e.g. `langgraph_step`) into the
554        // metadata before persisting, so `list(filter=...)` over those
555        // fields can find the row. Mirrors Python's get_checkpoint_metadata.
556        let merged_metadata = get_checkpoint_metadata(config, metadata);
557        let metadata_text = serde_json::to_string(&merged_metadata)
558            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
559
560        let mut tx = self
561            .pool
562            .begin()
563            .await
564            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
565
566        // Upsert blobs
567        let blobs = self.dump_blobs(
568            thread_id,
569            checkpoint_ns,
570            &checkpoint.channel_values,
571            new_versions,
572        );
573        for (tid, cns, channel, version, type_tag, blob) in &blobs {
574            sqlx::query(UPSERT_CHECKPOINT_BLOBS_SQL)
575                .bind(tid.as_str())
576                .bind(cns.as_str())
577                .bind(channel.as_str())
578                .bind(version.as_str())
579                .bind(type_tag.as_str())
580                .bind(blob.as_deref())
581                .execute(&mut *tx)
582                .await
583                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
584        }
585
586        // Upsert checkpoint row
587        sqlx::query(UPSERT_CHECKPOINTS_SQL)
588            .bind(thread_id)
589            .bind(checkpoint_ns)
590            .bind(checkpoint.id.as_str())
591            .bind(parent_checkpoint_id.as_deref())
592            .bind(checkpoint_text.as_str())
593            .bind(metadata_text.as_str())
594            .execute(&mut *tx)
595            .await
596            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
597
598        tx.commit()
599            .await
600            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
601
602        Ok(next_config)
603    }
604
605    async fn aput_writes(
606        &self,
607        config: &RunnableConfig,
608        writes: Vec<(String, String, JsonValue)>,
609        task_id: String,
610        task_path: String,
611    ) -> Result<(), CheckpointError> {
612        let configurable = config.get("configurable").cloned().unwrap_or_default();
613        let thread_id = configurable
614            .get("thread_id")
615            .and_then(|v| v.as_str())
616            .ok_or_else(|| CheckpointError::Config("missing thread_id".into()))?;
617        let checkpoint_ns = configurable
618            .get("checkpoint_ns")
619            .and_then(|v| v.as_str())
620            .unwrap_or("");
621        // NOTE: align with langgraph-checkpoint-postgres which silently
622        // defaults to empty string when checkpoint_id is missing. The
623        // graph runner currently calls put_writes with the *input*
624        // config (not the new config returned by put), so on the first
625        // step checkpoint_id is absent. Erroring here would crash the
626        // run; defaulting matches existing behavior. This does NOT
627        // semantically fix interrupt/resume — pending writes still get
628        // attached to checkpoint_id="" and won't be reachable from
629        // get_tuple of the real latest checkpoint. Tracked as a
630        // cross-crate issue (graph runner should pass the post-put
631        // config, or saver should resolve latest checkpoint here).
632        let checkpoint_id = configurable
633            .get("checkpoint_id")
634            .and_then(|v| v.as_str())
635            .unwrap_or("");
636
637        let idx_map = writes_idx_map();
638        let use_upsert = writes
639            .iter()
640            .all(|(channel, _, _)| idx_map.contains_key(channel.as_str()));
641
642        let query = if use_upsert {
643            UPSERT_CHECKPOINT_WRITES_SQL
644        } else {
645            INSERT_CHECKPOINT_WRITES_SQL
646        };
647
648        let mut tx = self
649            .pool
650            .begin()
651            .await
652            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
653
654        for (idx, (_task_id_in_tuple, channel, value)) in writes.iter().enumerate() {
655            let idx_val: i64 = idx_map
656                .get(channel.as_str())
657                .copied()
658                .unwrap_or(idx as i64);
659
660            let (type_tag, blob) = match self.serde.dumps_typed(value) {
661                Ok(pair) => pair,
662                Err(_) => continue,
663            };
664
665            sqlx::query(query)
666                .bind(thread_id)
667                .bind(checkpoint_ns)
668                .bind(checkpoint_id)
669                .bind(task_id.as_str())
670                .bind(task_path.as_str())
671                .bind(idx_val)
672                .bind(channel.as_str())
673                .bind(type_tag.as_str())
674                .bind(blob.as_slice())
675                .execute(&mut *tx)
676                .await
677                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
678        }
679
680        tx.commit()
681            .await
682            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
683
684        Ok(())
685    }
686
687    async fn adelete_thread(&self, thread_id: String) -> Result<(), CheckpointError> {
688        let mut tx = self
689            .pool
690            .begin()
691            .await
692            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
693
694        sqlx::query("DELETE FROM checkpoints WHERE thread_id = ?1")
695            .bind(thread_id.as_str())
696            .execute(&mut *tx)
697            .await
698            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
699
700        sqlx::query("DELETE FROM checkpoint_blobs WHERE thread_id = ?1")
701            .bind(thread_id.as_str())
702            .execute(&mut *tx)
703            .await
704            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
705
706        sqlx::query("DELETE FROM checkpoint_writes WHERE thread_id = ?1")
707            .bind(thread_id.as_str())
708            .execute(&mut *tx)
709            .await
710            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
711
712        tx.commit()
713            .await
714            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
715
716        Ok(())
717    }
718}
719
720#[cfg(test)]
721mod tests {
722    use super::*;
723    use std::collections::HashMap;
724
725    async fn fresh_saver() -> SqliteSaver {
726        let saver = SqliteSaver::from_conn_string("sqlite::memory:")
727            .await
728            .expect("connect to in-memory sqlite");
729        saver.setup().await.expect("setup migrations");
730        saver
731    }
732
733    fn config_for(thread_id: &str) -> RunnableConfig {
734        serde_json::from_value(serde_json::json!({
735            "configurable": { "thread_id": thread_id, "checkpoint_ns": "" }
736        }))
737        .unwrap()
738    }
739
740    fn config_with_id(thread_id: &str, checkpoint_id: &str) -> RunnableConfig {
741        serde_json::from_value(serde_json::json!({
742            "configurable": {
743                "thread_id": thread_id,
744                "checkpoint_ns": "",
745                "checkpoint_id": checkpoint_id,
746            }
747        }))
748        .unwrap()
749    }
750
751    fn make_checkpoint(channel_values: Vec<(&str, JsonValue)>) -> (Checkpoint, ChannelVersions) {
752        let mut cp = Checkpoint::empty();
753        let mut versions: ChannelVersions = HashMap::new();
754        for (k, v) in channel_values {
755            cp.channel_values.insert(k.to_string(), v);
756            cp.channel_versions
757                .insert(k.to_string(), JsonValue::Number(1.into()));
758            versions.insert(k.to_string(), JsonValue::Number(1.into()));
759        }
760        (cp, versions)
761    }
762
763    #[tokio::test]
764    async fn test_setup_is_idempotent() {
765        let saver = fresh_saver().await;
766        // calling setup again should not error
767        saver.setup().await.expect("second setup");
768    }
769
770    #[tokio::test]
771    async fn test_get_tuple_returns_none_when_empty() {
772        let saver = fresh_saver().await;
773        let cfg = config_for("missing");
774        let result = saver.aget_tuple(&cfg).await.unwrap();
775        assert!(result.is_none());
776    }
777
778    #[tokio::test]
779    async fn test_put_then_get_roundtrip() {
780        let saver = fresh_saver().await;
781        let (cp, versions) = make_checkpoint(vec![
782            ("messages", serde_json::json!(["hello", "world"])),
783            ("counter", serde_json::json!(7)),
784        ]);
785        let cfg = config_for("thread-A");
786        let metadata = CheckpointMetadata {
787            source: Some(CheckpointSource::Loop),
788            step: Some(3),
789            ..Default::default()
790        };
791
792        let next = saver.aput(&cfg, &cp, &metadata, &versions).await.unwrap();
793
794        // The returned config should reference the new checkpoint id
795        let returned_cid = next
796            .get("configurable")
797            .and_then(|c| c.get("checkpoint_id"))
798            .and_then(|v| v.as_str())
799            .unwrap();
800        assert_eq!(returned_cid, cp.id);
801
802        // Fetch back and compare
803        let tuple = saver.aget_tuple(&cfg).await.unwrap().expect("tuple exists");
804        assert_eq!(tuple.checkpoint.id, cp.id);
805        assert_eq!(tuple.metadata.step, Some(3));
806        assert_eq!(
807            tuple.checkpoint.channel_values.get("messages"),
808            Some(&serde_json::json!(["hello", "world"]))
809        );
810        assert_eq!(
811            tuple.checkpoint.channel_values.get("counter"),
812            Some(&serde_json::json!(7))
813        );
814    }
815
816    #[tokio::test]
817    async fn test_put_writes_and_pending_writes_round_trip() {
818        let saver = fresh_saver().await;
819        let (cp, versions) = make_checkpoint(vec![("a", serde_json::json!(1))]);
820        let cfg = config_for("thread-W");
821        saver
822            .aput(&cfg, &cp, &CheckpointMetadata::default(), &versions)
823            .await
824            .unwrap();
825
826        let cfg_with_id = config_with_id("thread-W", &cp.id);
827        let writes = vec![
828            ("ch1".to_string(), "task-1".to_string(), serde_json::json!("v1")),
829            ("ch2".to_string(), "task-1".to_string(), serde_json::json!(42)),
830        ];
831        saver
832            .aput_writes(&cfg_with_id, writes, "task-1".into(), "".into())
833            .await
834            .unwrap();
835
836        let tuple = saver.aget_tuple(&cfg_with_id).await.unwrap().unwrap();
837        let pending = tuple.pending_writes.expect("pending writes loaded");
838        assert_eq!(pending.len(), 2);
839        // Order: by task_path, task_id, idx
840        assert_eq!(pending[0].1, "ch1");
841        assert_eq!(pending[1].1, "ch2");
842        assert_eq!(pending[1].2, serde_json::json!(42));
843    }
844
845    #[tokio::test]
846    async fn test_list_orders_descending_and_respects_limit() {
847        let saver = fresh_saver().await;
848        let cfg = config_for("thread-L");
849        let mut ids = Vec::new();
850        for i in 0..3 {
851            let (cp, versions) = make_checkpoint(vec![("x", serde_json::json!(i))]);
852            ids.push(cp.id.clone());
853            saver
854                .aput(&cfg, &cp, &CheckpointMetadata::default(), &versions)
855                .await
856                .unwrap();
857        }
858
859        let all = saver.alist(Some(&cfg), None, None, None).await.unwrap();
860        assert_eq!(all.len(), 3);
861        // ORDER BY checkpoint_id DESC — verify returned ids are in
862        // descending lexicographic order (not tied to insertion order,
863        // since UUIDv7 within the same millisecond is not monotonic).
864        for w in all.windows(2) {
865            assert!(w[0].checkpoint.id >= w[1].checkpoint.id);
866        }
867        // All three checkpoint ids should appear in the result set
868        let returned_ids: std::collections::HashSet<_> =
869            all.iter().map(|t| t.checkpoint.id.clone()).collect();
870        for id in &ids {
871            assert!(returned_ids.contains(id));
872        }
873
874        let limited = saver.alist(Some(&cfg), None, None, Some(2)).await.unwrap();
875        assert_eq!(limited.len(), 2);
876    }
877
878    #[tokio::test]
879    async fn test_delete_thread_removes_all_data() {
880        let saver = fresh_saver().await;
881        let (cp, versions) = make_checkpoint(vec![("x", serde_json::json!(1))]);
882        let cfg = config_for("thread-D");
883        saver
884            .aput(&cfg, &cp, &CheckpointMetadata::default(), &versions)
885            .await
886            .unwrap();
887        let cfg_with_id = config_with_id("thread-D", &cp.id);
888        saver
889            .aput_writes(
890                &cfg_with_id,
891                vec![("ch".into(), "task".into(), serde_json::json!("v"))],
892                "task".into(),
893                "".into(),
894            )
895            .await
896            .unwrap();
897
898        saver.adelete_thread("thread-D".into()).await.unwrap();
899        assert!(saver.aget_tuple(&cfg).await.unwrap().is_none());
900        let listed = saver.alist(Some(&cfg), None, None, None).await.unwrap();
901        assert!(listed.is_empty());
902    }
903
904    #[tokio::test]
905    async fn test_value_updates_when_version_increments() {
906        // Blob storage is keyed by (channel, version). Ensure that when
907        // a caller bumps the version, the new value is what's read back.
908        let saver = fresh_saver().await;
909        let cfg = config_for("thread-V");
910
911        // First put: counter=1 at version 1
912        let mut cp1 = Checkpoint::empty();
913        cp1.channel_values
914            .insert("counter".into(), JsonValue::Number(1.into()));
915        cp1.channel_versions
916            .insert("counter".into(), JsonValue::Number(1.into()));
917        let mut versions1: ChannelVersions = HashMap::new();
918        versions1.insert("counter".into(), JsonValue::Number(1.into()));
919        saver
920            .aput(&cfg, &cp1, &CheckpointMetadata::default(), &versions1)
921            .await
922            .unwrap();
923
924        tokio::time::sleep(std::time::Duration::from_millis(2)).await;
925
926        // Second put: counter=99 at version 2 — fresh blob row.
927        let mut cp2 = Checkpoint::empty();
928        cp2.channel_values
929            .insert("counter".into(), JsonValue::Number(99.into()));
930        cp2.channel_versions
931            .insert("counter".into(), JsonValue::Number(2.into()));
932        let mut versions2: ChannelVersions = HashMap::new();
933        versions2.insert("counter".into(), JsonValue::Number(2.into()));
934        saver
935            .aput(&cfg, &cp2, &CheckpointMetadata::default(), &versions2)
936            .await
937            .unwrap();
938
939        let cfg_cp2 = config_with_id("thread-V", &cp2.id);
940        let tuple = saver.aget_tuple(&cfg_cp2).await.unwrap().unwrap();
941        assert_eq!(
942            tuple.checkpoint.channel_values.get("counter"),
943            Some(&JsonValue::Number(99.into()))
944        );
945
946        // The earlier checkpoint should still see its own value.
947        let cfg_cp1 = config_with_id("thread-V", &cp1.id);
948        let earlier = saver.aget_tuple(&cfg_cp1).await.unwrap().unwrap();
949        assert_eq!(
950            earlier.checkpoint.channel_values.get("counter"),
951            Some(&JsonValue::Number(1.into()))
952        );
953    }
954
955    #[tokio::test]
956    async fn test_metadata_filter_returns_only_matching_rows() {
957        let saver = fresh_saver().await;
958        let cfg = config_for("thread-F");
959
960        // Three checkpoints with distinct `source` and `step` metadata.
961        for (source, step, val) in [
962            (CheckpointSource::Input, 0, "a"),
963            (CheckpointSource::Loop, 1, "b"),
964            (CheckpointSource::Loop, 2, "c"),
965        ] {
966            let (cp, vers) = make_checkpoint(vec![("x", serde_json::json!(val))]);
967            let meta = CheckpointMetadata {
968                source: Some(source),
969                step: Some(step),
970                ..Default::default()
971            };
972            saver.aput(&cfg, &cp, &meta, &vers).await.unwrap();
973            tokio::time::sleep(std::time::Duration::from_millis(2)).await;
974        }
975
976        // Filter source = "loop" → 2 results
977        let mut filter = HashMap::new();
978        filter.insert("source".into(), serde_json::json!("loop"));
979        let loop_only = saver
980            .alist(Some(&cfg), Some(&filter), None, None)
981            .await
982            .unwrap();
983        assert_eq!(loop_only.len(), 2);
984        for t in &loop_only {
985            assert_eq!(t.metadata.source, Some(CheckpointSource::Loop));
986        }
987
988        // Filter step = 1 → 1 result
989        let mut filter = HashMap::new();
990        filter.insert("step".into(), serde_json::json!(1));
991        let step_one = saver
992            .alist(Some(&cfg), Some(&filter), None, None)
993            .await
994            .unwrap();
995        assert_eq!(step_one.len(), 1);
996        assert_eq!(step_one[0].metadata.step, Some(1));
997
998        // Combined filter: source = "loop" AND step = 2 → 1 result
999        let mut filter = HashMap::new();
1000        filter.insert("source".into(), serde_json::json!("loop"));
1001        filter.insert("step".into(), serde_json::json!(2));
1002        let combined = saver
1003            .alist(Some(&cfg), Some(&filter), None, None)
1004            .await
1005            .unwrap();
1006        assert_eq!(combined.len(), 1);
1007        assert_eq!(combined[0].metadata.step, Some(2));
1008    }
1009
1010    #[test]
1011    fn test_validate_filter_key_rejects_injection_attempts() {
1012        // Valid keys
1013        assert!(validate_filter_key("source").is_ok());
1014        assert!(validate_filter_key("nested.field").is_ok());
1015        assert!(validate_filter_key("snake_case").is_ok());
1016        assert!(validate_filter_key("kebab-case").is_ok());
1017        assert!(validate_filter_key("Mixed123").is_ok());
1018
1019        // Invalid: empty, quotes, semicolons, brackets, spaces, unicode
1020        assert!(validate_filter_key("").is_err());
1021        assert!(validate_filter_key("source'; DROP TABLE--").is_err());
1022        assert!(validate_filter_key("a\"b").is_err());
1023        assert!(validate_filter_key("a b").is_err());
1024        assert!(validate_filter_key("[admin]").is_err());
1025        assert!(validate_filter_key("中文").is_err());
1026    }
1027
1028    #[tokio::test]
1029    async fn test_config_langgraph_step_merged_into_metadata() {
1030        // When the caller includes `langgraph_step` in configurable, the
1031        // saver should fold it into the persisted metadata (so list-with-filter
1032        // can later find rows by step).
1033        let saver = fresh_saver().await;
1034        let cfg: RunnableConfig = serde_json::from_value(serde_json::json!({
1035            "configurable": {
1036                "thread_id": "thread-M",
1037                "checkpoint_ns": "",
1038                "langgraph_step": 7
1039            }
1040        }))
1041        .unwrap();
1042
1043        let (cp, vers) = make_checkpoint(vec![("x", serde_json::json!(1))]);
1044        // Metadata passed in does NOT have step set — it should be filled
1045        // from the config.
1046        saver
1047            .aput(&cfg, &cp, &CheckpointMetadata::default(), &vers)
1048            .await
1049            .unwrap();
1050
1051        let cfg_with_id = config_with_id("thread-M", &cp.id);
1052        let tuple = saver.aget_tuple(&cfg_with_id).await.unwrap().unwrap();
1053        assert_eq!(tuple.metadata.step, Some(7));
1054    }
1055
1056    /// Verifies the sync wrapper does not panic when called from inside
1057    /// a multi-thread tokio runtime (the situation that occurs when the
1058    /// graph runner — itself an async fn — invokes `cp.get_tuple(...)`).
1059    /// See `block_on_in_runtime` doc comment.
1060    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1061    async fn test_sync_methods_work_inside_multi_thread_runtime() {
1062        let saver = fresh_saver().await;
1063        let saver = std::sync::Arc::new(saver);
1064        let cfg = config_for("thread-S");
1065
1066        // Drive sync `put` and `get_tuple` on a blocking task — this is
1067        // what `block_in_place` is designed to guard against.
1068        let (cp, vers) = make_checkpoint(vec![("k", serde_json::json!("v"))]);
1069        let s2 = saver.clone();
1070        let cfg2 = cfg.clone();
1071        let cp_clone = cp.clone();
1072        let vers_clone = vers.clone();
1073        let put_result = tokio::task::spawn_blocking(move || {
1074            s2.put(&cfg2, &cp_clone, &CheckpointMetadata::default(), &vers_clone)
1075        })
1076        .await
1077        .unwrap();
1078        assert!(put_result.is_ok());
1079
1080        let s3 = saver.clone();
1081        let cfg3 = cfg.clone();
1082        let get_result = tokio::task::spawn_blocking(move || s3.get_tuple(&cfg3))
1083            .await
1084            .unwrap()
1085            .unwrap();
1086        assert!(get_result.is_some());
1087        assert_eq!(get_result.unwrap().checkpoint.id, cp.id);
1088    }
1089
1090    #[tokio::test]
1091    async fn test_parent_config_links_checkpoints() {
1092        let saver = fresh_saver().await;
1093        let (cp1, vers1) = make_checkpoint(vec![("x", serde_json::json!("a"))]);
1094        let cfg = config_for("thread-P");
1095        let next1 = saver
1096            .aput(&cfg, &cp1, &CheckpointMetadata::default(), &vers1)
1097            .await
1098            .unwrap();
1099
1100        // Sleep briefly so UUIDv7 timestamps differ; otherwise two
1101        // checkpoints created within the same millisecond can sort
1102        // unpredictably.
1103        tokio::time::sleep(std::time::Duration::from_millis(2)).await;
1104
1105        // Second put using next1 as the parent config — its checkpoint_id
1106        // becomes the parent_checkpoint_id of cp2.
1107        let (cp2, vers2) = make_checkpoint(vec![("x", serde_json::json!("b"))]);
1108        saver
1109            .aput(&next1, &cp2, &CheckpointMetadata::default(), &vers2)
1110            .await
1111            .unwrap();
1112
1113        // Look up cp2 explicitly to avoid relying on lex ordering.
1114        let cfg_cp2 = config_with_id("thread-P", &cp2.id);
1115        let latest = saver.aget_tuple(&cfg_cp2).await.unwrap().unwrap();
1116        assert_eq!(latest.checkpoint.id, cp2.id);
1117        let parent = latest.parent_config.expect("parent_config present");
1118        let parent_id = parent
1119            .get("configurable")
1120            .and_then(|c| c.get("checkpoint_id"))
1121            .and_then(|v| v.as_str())
1122            .unwrap();
1123        assert_eq!(parent_id, cp1.id);
1124    }
1125}