Skip to main content

langgraph_checkpoint/
memory.rs

1use std::collections::{BTreeMap, HashMap};
2use std::sync::RwLock;
3
4use crate::traits::{CheckpointError, CheckpointSaver};
5use crate::types::{Checkpoint, CheckpointId, ThreadId};
6
7#[derive(Debug, Default)]
8pub struct InMemorySaver {
9    inner: RwLock<HashMap<ThreadId, BTreeMap<CheckpointId, Checkpoint>>>,
10}
11
12impl InMemorySaver {
13    #[must_use]
14    pub fn new() -> Self {
15        Self::default()
16    }
17}
18
19impl CheckpointSaver for InMemorySaver {
20    fn put(&self, checkpoint: Checkpoint) -> Result<(), CheckpointError> {
21        let mut guard = self.inner.write().map_err(|e| CheckpointError::Storage(e.to_string()))?;
22        let entries = guard.entry(checkpoint.thread_id.clone()).or_insert_with(BTreeMap::new);
23        if entries.contains_key(&checkpoint.checkpoint_id) {
24            return Err(CheckpointError::Conflict(format!(
25                "thread `{}` already has checkpoint `{}`",
26                checkpoint.thread_id, checkpoint.checkpoint_id
27            )));
28        }
29        entries.insert(checkpoint.checkpoint_id.clone(), checkpoint);
30        Ok(())
31    }
32
33    fn get(
34        &self,
35        thread_id: &str,
36        checkpoint_id: &str,
37    ) -> Result<Option<Checkpoint>, CheckpointError> {
38        let guard = self.inner.read().map_err(|e| CheckpointError::Storage(e.to_string()))?;
39        Ok(guard.get(thread_id).and_then(|checkpoints| checkpoints.get(checkpoint_id).cloned()))
40    }
41
42    fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>, CheckpointError> {
43        let guard = self.inner.read().map_err(|e| CheckpointError::Storage(e.to_string()))?;
44        let result = guard
45            .get(thread_id)
46            .map(|checkpoints| checkpoints.values().cloned().collect())
47            .unwrap_or_default();
48        Ok(result)
49    }
50}