Skip to main content

langgraph_checkpoint_rs/checkpoint/
base.rs

1use std::collections::HashMap;
2use async_trait::async_trait;
3use serde_json::Value as JsonValue;
4use crate::config::RunnableConfig;
5use crate::error::CheckpointError;
6use super::types::*;
7
8/// Default WRITES_IDX_MAP for special write channels
9pub fn writes_idx_map() -> HashMap<&'static str, i64> {
10    let mut m = HashMap::new();
11    m.insert("__error__", -1i64);
12    m.insert("__scheduled__", -2i64);
13    m.insert("__interrupt__", -3i64);
14    m.insert("__resume__", -4i64);
15    m
16}
17
18/// Metadata keys excluded from checkpoint metadata
19pub fn excluded_metadata_keys() -> &'static [&'static str] {
20    &[
21        "thread_id",
22        "checkpoint_id",
23        "checkpoint_ns",
24        "checkpoint_map",
25        "langgraph_step",
26        "langgraph_node",
27        "langgraph_triggers",
28        "langgraph_path",
29        "langgraph_checkpoint_ns",
30    ]
31}
32
33/// Base checkpoint saver trait. Mirrors Python's BaseCheckpointSaver.
34///
35/// All methods that store/retrieve checkpoints must be implemented.
36/// Async versions default to wrapping sync versions via spawn_blocking.
37#[async_trait]
38pub trait BaseCheckpointSaver: Send + Sync {
39    /// Get a checkpoint tuple by config.
40    fn get_tuple(&self, config: &RunnableConfig) -> Result<Option<CheckpointTuple>, CheckpointError>;
41
42    /// List checkpoint tuples.
43    fn list(
44        &self,
45        config: Option<&RunnableConfig>,
46        filter: Option<&HashMap<String, JsonValue>>,
47        before: Option<&RunnableConfig>,
48        limit: Option<usize>,
49    ) -> Result<Vec<CheckpointTuple>, CheckpointError>;
50
51    /// Store a checkpoint.
52    fn put(
53        &self,
54        config: &RunnableConfig,
55        checkpoint: &Checkpoint,
56        metadata: &CheckpointMetadata,
57        new_versions: &ChannelVersions,
58    ) -> Result<RunnableConfig, CheckpointError>;
59
60    /// Store pending writes for a checkpoint.
61    fn put_writes(
62        &self,
63        config: &RunnableConfig,
64        writes: &[(String, String, JsonValue)],
65        task_id: &str,
66        task_path: &str,
67    ) -> Result<(), CheckpointError>;
68
69    /// Delete all checkpoints for a thread.
70    fn delete_thread(&self, thread_id: &str) -> Result<(), CheckpointError>;
71
72    /// Get the next version for a channel.
73    fn get_next_version(&self, current: Option<&ChannelVersion>) -> ChannelVersion {
74        match current {
75            Some(JsonValue::Number(n)) => {
76                let v = n.as_i64().unwrap_or(0) + 1;
77                JsonValue::Number(v.into())
78            }
79            Some(JsonValue::String(s)) => {
80                // Parse "NNN.random" format
81                let num: i64 = s.split('.').next().unwrap_or("0").parse().unwrap_or(0);
82                JsonValue::String(format!("{:032}.{:016}", num + 1, rand::random::<u64>()))
83            }
84            _ => JsonValue::Number(1i64.into()),
85        }
86    }
87
88    // Async mirrors with default implementations
89
90    async fn aget_tuple(&self, config: &RunnableConfig) -> Result<Option<CheckpointTuple>, CheckpointError> {
91        let config = config.clone();
92        let this = self;
93        // Use blocking for default impl
94        tokio::task::block_in_place(|| this.get_tuple(&config))
95    }
96
97    async fn aput(
98        &self,
99        config: &RunnableConfig,
100        checkpoint: &Checkpoint,
101        metadata: &CheckpointMetadata,
102        new_versions: &ChannelVersions,
103    ) -> Result<RunnableConfig, CheckpointError> {
104        let config = config.clone();
105        let checkpoint = checkpoint.clone();
106        let metadata = metadata.clone();
107        let new_versions = new_versions.clone();
108        tokio::task::block_in_place(|| {
109            self.put(&config, &checkpoint, &metadata, &new_versions)
110        })
111    }
112
113    async fn aput_writes(
114        &self,
115        config: &RunnableConfig,
116        writes: Vec<(String, String, JsonValue)>,
117        task_id: String,
118        task_path: String,
119    ) -> Result<(), CheckpointError> {
120        let config = config.clone();
121        tokio::task::block_in_place(|| {
122            self.put_writes(&config, &writes, &task_id, &task_path)
123        })
124    }
125
126    async fn adelete_thread(&self, thread_id: String) -> Result<(), CheckpointError> {
127        let this = self;
128        tokio::task::block_in_place(|| this.delete_thread(&thread_id))
129    }
130}
131
132/// Helper to extract checkpoint_id from config
133pub fn get_checkpoint_id(config: &RunnableConfig) -> Option<String> {
134    config
135        .get("configurable")
136        .and_then(|c| c.get("checkpoint_id"))
137        .and_then(|v| v.as_str())
138        .map(|s| s.to_string())
139}
140
141/// Helper to extract checkpoint metadata from config
142pub fn get_checkpoint_metadata(
143    config: &RunnableConfig,
144    metadata: &CheckpointMetadata,
145) -> CheckpointMetadata {
146    let mut meta = metadata.clone();
147    if let Some(step) = config
148        .get("configurable")
149        .and_then(|c| c.get("langgraph_step"))
150        .and_then(|v| v.as_i64())
151    {
152        meta.step = Some(step);
153    }
154    meta
155}
156
157/// Copy a checkpoint
158pub fn copy_checkpoint(checkpoint: &Checkpoint) -> Checkpoint {
159    checkpoint.copy()
160}
161
162/// Create an empty checkpoint
163pub fn empty_checkpoint() -> Checkpoint {
164    Checkpoint::empty()
165}
166
167/// Create a checkpoint from current channel state
168pub fn create_checkpoint(
169    checkpoint: &Checkpoint,
170    channel_values: HashMap<String, JsonValue>,
171    _step: i64,
172) -> Checkpoint {
173    use chrono::Utc;
174    use crate::checkpoint::id::uuid6;
175
176    Checkpoint {
177        v: LATEST_VERSION,
178        id: uuid6(),
179        ts: Utc::now().to_rfc3339(),
180        channel_values,
181        channel_versions: checkpoint.channel_versions.clone(),
182        versions_seen: checkpoint.versions_seen.clone(),
183        updated_channels: checkpoint.updated_channels.clone(),
184    }
185}
186
187// Add rand dependency for version generation
188mod rand {
189    use std::collections::hash_map::RandomState;
190    use std::hash::{BuildHasher, Hasher};
191
192    pub fn random<T: From<u64>>() -> T {
193        let s = RandomState::new();
194        let mut hasher = s.build_hasher();
195        hasher.write_u64(42); // Fixed seed for determinism in tests
196        T::from(hasher.finish())
197    }
198}