langgraph_checkpoint/
memory.rs1use std::collections::{BTreeMap, HashMap};
2use std::sync::RwLock;
3
4use crate::traits::{CheckpointError, CheckpointSaver};
5use crate::types::{Checkpoint, CheckpointId, ThreadId};
6
7#[derive(Debug, Default)]
8pub struct InMemorySaver {
9 inner: RwLock<HashMap<ThreadId, BTreeMap<CheckpointId, Checkpoint>>>,
10}
11
12impl InMemorySaver {
13 #[must_use]
14 pub fn new() -> Self {
15 Self::default()
16 }
17}
18
19impl CheckpointSaver for InMemorySaver {
20 fn put(&self, checkpoint: Checkpoint) -> Result<(), CheckpointError> {
21 let mut guard = self.inner.write().map_err(|e| CheckpointError::Storage(e.to_string()))?;
22 let entries = guard.entry(checkpoint.thread_id.clone()).or_insert_with(BTreeMap::new);
23 if entries.contains_key(&checkpoint.checkpoint_id) {
24 return Err(CheckpointError::Conflict(format!(
25 "thread `{}` already has checkpoint `{}`",
26 checkpoint.thread_id, checkpoint.checkpoint_id
27 )));
28 }
29 entries.insert(checkpoint.checkpoint_id.clone(), checkpoint);
30 Ok(())
31 }
32
33 fn get(
34 &self,
35 thread_id: &str,
36 checkpoint_id: &str,
37 ) -> Result<Option<Checkpoint>, CheckpointError> {
38 let guard = self.inner.read().map_err(|e| CheckpointError::Storage(e.to_string()))?;
39 Ok(guard.get(thread_id).and_then(|checkpoints| checkpoints.get(checkpoint_id).cloned()))
40 }
41
42 fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>, CheckpointError> {
43 let guard = self.inner.read().map_err(|e| CheckpointError::Storage(e.to_string()))?;
44 let result = guard
45 .get(thread_id)
46 .map(|checkpoints| checkpoints.values().cloned().collect())
47 .unwrap_or_default();
48 Ok(result)
49 }
50}