synwire_checkpoint/
memory.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7
8use crate::base::BaseCheckpointSaver;
9use crate::types::{
10 Checkpoint, CheckpointConfig, CheckpointError, CheckpointMetadata, CheckpointTuple,
11};
12
13#[derive(Debug, Clone, Default)]
18pub struct InMemoryCheckpointSaver {
19 storage: Arc<RwLock<HashMap<String, Vec<CheckpointTuple>>>>,
20}
21
22impl InMemoryCheckpointSaver {
23 pub fn new() -> Self {
25 Self::default()
26 }
27}
28
29#[allow(clippy::significant_drop_tightening)]
30impl BaseCheckpointSaver for InMemoryCheckpointSaver {
31 fn get_tuple<'a>(
32 &'a self,
33 config: &'a CheckpointConfig,
34 ) -> synwire_core::BoxFuture<'a, Result<Option<CheckpointTuple>, CheckpointError>> {
35 Box::pin(async move {
36 let storage = self.storage.read().await;
37 let Some(tuples) = storage.get(&config.thread_id) else {
38 return Ok(None);
39 };
40 Ok(config.checkpoint_id.as_ref().map_or_else(
41 || tuples.last().cloned(),
42 |checkpoint_id| {
43 tuples
44 .iter()
45 .find(|t| t.checkpoint.id == *checkpoint_id)
46 .cloned()
47 },
48 ))
49 })
50 }
51
52 fn list<'a>(
53 &'a self,
54 config: &'a CheckpointConfig,
55 limit: Option<usize>,
56 ) -> synwire_core::BoxFuture<'a, Result<Vec<CheckpointTuple>, CheckpointError>> {
57 Box::pin(async move {
58 let storage = self.storage.read().await;
59 let Some(tuples) = storage.get(&config.thread_id) else {
60 return Ok(Vec::new());
61 };
62 let mut result: Vec<CheckpointTuple> = tuples.iter().rev().cloned().collect();
63 if let Some(limit) = limit {
64 result.truncate(limit);
65 }
66 Ok(result)
67 })
68 }
69
70 fn put<'a>(
71 &'a self,
72 config: &'a CheckpointConfig,
73 checkpoint: Checkpoint,
74 metadata: CheckpointMetadata,
75 ) -> synwire_core::BoxFuture<'a, Result<CheckpointConfig, CheckpointError>> {
76 Box::pin(async move {
77 let new_config = CheckpointConfig {
78 thread_id: config.thread_id.clone(),
79 checkpoint_id: Some(checkpoint.id.clone()),
80 };
81
82 let mut storage = self.storage.write().await;
83 let tuples = storage.entry(config.thread_id.clone()).or_default();
84 let parent_config = tuples.last().map(|t| t.config.clone());
85
86 tuples.push(CheckpointTuple {
87 config: new_config.clone(),
88 checkpoint,
89 metadata,
90 parent_config,
91 });
92
93 Ok(new_config)
94 })
95 }
96}
97
98#[cfg(test)]
99#[allow(clippy::unwrap_used)]
100mod tests {
101 use super::*;
102 use crate::types::CheckpointSource;
103 use serde_json::json;
104
105 fn make_checkpoint(id: &str, step: i64) -> (Checkpoint, CheckpointMetadata) {
106 let mut cp = Checkpoint::new(id.to_owned());
107 let _prev = cp.channel_values.insert("messages".into(), json!([]));
108 let metadata = CheckpointMetadata {
109 source: CheckpointSource::Loop,
110 step,
111 writes: HashMap::new(),
112 parents: HashMap::new(),
113 };
114 (cp, metadata)
115 }
116
117 #[tokio::test]
119 async fn put_and_get_round_trip() {
120 let saver = InMemoryCheckpointSaver::new();
121 let config = CheckpointConfig {
122 thread_id: "thread-1".into(),
123 checkpoint_id: None,
124 };
125 let (cp, meta) = make_checkpoint("cp-1", 0);
126 let result_config = saver.put(&config, cp, meta).await.unwrap();
127 assert_eq!(result_config.checkpoint_id.as_deref(), Some("cp-1"));
128
129 let tuple = saver.get_tuple(&config).await.unwrap().unwrap();
131 assert_eq!(tuple.checkpoint.id, "cp-1");
132 assert_eq!(tuple.checkpoint.channel_values["messages"], json!([]));
133
134 let specific = CheckpointConfig {
136 thread_id: "thread-1".into(),
137 checkpoint_id: Some("cp-1".into()),
138 };
139 let tuple = saver.get_tuple(&specific).await.unwrap().unwrap();
140 assert_eq!(tuple.checkpoint.id, "cp-1");
141
142 let missing = CheckpointConfig {
144 thread_id: "no-such-thread".into(),
145 checkpoint_id: None,
146 };
147 assert!(saver.get_tuple(&missing).await.unwrap().is_none());
148 }
149
150 #[tokio::test]
152 async fn list_returns_in_order() {
153 let saver = InMemoryCheckpointSaver::new();
154 let config = CheckpointConfig {
155 thread_id: "thread-1".into(),
156 checkpoint_id: None,
157 };
158
159 for i in 0..5 {
160 let (cp, meta) = make_checkpoint(&format!("cp-{i}"), i64::from(i));
161 let _cfg = saver.put(&config, cp, meta).await.unwrap();
162 }
163
164 let all = saver.list(&config, None).await.unwrap();
166 assert_eq!(all.len(), 5);
167 assert_eq!(all[0].checkpoint.id, "cp-4");
168 assert_eq!(all[4].checkpoint.id, "cp-0");
169
170 let limited = saver.list(&config, Some(2)).await.unwrap();
172 assert_eq!(limited.len(), 2);
173 assert_eq!(limited[0].checkpoint.id, "cp-4");
174 assert_eq!(limited[1].checkpoint.id, "cp-3");
175
176 assert!(all[0].parent_config.is_some());
178 assert_eq!(
179 all[0]
180 .parent_config
181 .as_ref()
182 .unwrap()
183 .checkpoint_id
184 .as_deref(),
185 Some("cp-3")
186 );
187 }
188
189 #[tokio::test]
191 async fn format_version_default() {
192 let cp = Checkpoint::new("test".into());
193 assert_eq!(cp.format_version, "1.0");
194 }
195}