bob_adapters/
checkpoint_file.rs1use std::{
4 path::{Path, PathBuf},
5 time::{SystemTime, UNIX_EPOCH},
6};
7
8use bob_core::{
9 error::StoreError,
10 ports::TurnCheckpointStorePort,
11 types::{SessionId, TurnCheckpoint},
12};
13
14#[derive(Debug)]
16pub struct FileCheckpointStore {
17 root: PathBuf,
18 cache: scc::HashMap<SessionId, TurnCheckpoint>,
19 write_guard: tokio::sync::Mutex<()>,
20}
21
22impl FileCheckpointStore {
23 pub fn new(root: PathBuf) -> Result<Self, StoreError> {
28 std::fs::create_dir_all(&root).map_err(|err| {
29 StoreError::Backend(format!("failed to create checkpoint dir: {err}"))
30 })?;
31 Ok(Self { root, cache: scc::HashMap::new(), write_guard: tokio::sync::Mutex::new(()) })
32 }
33
34 fn checkpoint_path(&self, session_id: &SessionId) -> PathBuf {
35 self.root.join(format!("{}.json", encode_session_id(session_id)))
36 }
37
38 fn temp_path_for(final_path: &Path) -> PathBuf {
39 let nanos = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
40 final_path.with_extension(format!("json.tmp.{}.{}", std::process::id(), nanos))
41 }
42
43 fn quarantine_path_for(path: &Path) -> PathBuf {
44 let nanos = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
45 let filename = path.file_name().and_then(std::ffi::OsStr::to_str).unwrap_or("checkpoint");
46 path.with_file_name(format!("{filename}.corrupt.{}.{}", std::process::id(), nanos))
47 }
48
49 async fn quarantine_corrupt_file(path: &Path) -> Result<PathBuf, StoreError> {
50 let quarantine_path = Self::quarantine_path_for(path);
51 tokio::fs::rename(path, &quarantine_path).await.map_err(|err| {
52 StoreError::Backend(format!(
53 "failed to quarantine corrupted checkpoint '{}': {err}",
54 path.display()
55 ))
56 })?;
57 Ok(quarantine_path)
58 }
59
60 async fn load_from_disk(
61 &self,
62 session_id: &SessionId,
63 ) -> Result<Option<TurnCheckpoint>, StoreError> {
64 let path = self.checkpoint_path(session_id);
65 let raw = match tokio::fs::read(&path).await {
66 Ok(raw) => raw,
67 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
68 Err(err) => {
69 return Err(StoreError::Backend(format!(
70 "failed to read checkpoint '{}': {err}",
71 path.display()
72 )));
73 }
74 };
75
76 let checkpoint = if let Ok(value) = serde_json::from_slice::<TurnCheckpoint>(&raw) {
77 value
78 } else {
79 let _ = Self::quarantine_corrupt_file(&path).await?;
80 return Ok(None);
81 };
82 Ok(Some(checkpoint))
83 }
84
85 async fn save_to_disk(
86 &self,
87 session_id: &SessionId,
88 checkpoint: &TurnCheckpoint,
89 ) -> Result<(), StoreError> {
90 let final_path = self.checkpoint_path(session_id);
91 let temp_path = Self::temp_path_for(&final_path);
92 let bytes = serde_json::to_vec_pretty(checkpoint)
93 .map_err(|err| StoreError::Serialization(err.to_string()))?;
94
95 tokio::fs::write(&temp_path, bytes).await.map_err(|err| {
96 StoreError::Backend(format!(
97 "failed to write temp checkpoint '{}': {err}",
98 temp_path.display()
99 ))
100 })?;
101
102 if let Err(rename_err) = tokio::fs::rename(&temp_path, &final_path).await {
103 if path_exists(&final_path).await {
104 tokio::fs::remove_file(&final_path).await.map_err(|remove_err| {
105 StoreError::Backend(format!(
106 "failed to replace existing checkpoint '{}' after rename error '{rename_err}': {remove_err}",
107 final_path.display()
108 ))
109 })?;
110 tokio::fs::rename(&temp_path, &final_path).await.map_err(|err| {
111 StoreError::Backend(format!(
112 "failed to replace checkpoint '{}' after fallback remove: {err}",
113 final_path.display()
114 ))
115 })?;
116 } else {
117 return Err(StoreError::Backend(format!(
118 "failed to persist checkpoint '{}': {rename_err}",
119 final_path.display()
120 )));
121 }
122 }
123 Ok(())
124 }
125}
126
127#[async_trait::async_trait]
128impl TurnCheckpointStorePort for FileCheckpointStore {
129 async fn save_checkpoint(&self, checkpoint: &TurnCheckpoint) -> Result<(), StoreError> {
130 let _lock = self.write_guard.lock().await;
131 self.save_to_disk(&checkpoint.session_id, checkpoint).await?;
132 let entry = self.cache.entry_async(checkpoint.session_id.clone()).await;
133 match entry {
134 scc::hash_map::Entry::Occupied(mut occ) => occ.get_mut().clone_from(checkpoint),
135 scc::hash_map::Entry::Vacant(vac) => {
136 let _ = vac.insert_entry(checkpoint.clone());
137 }
138 }
139 Ok(())
140 }
141
142 async fn load_latest(
143 &self,
144 session_id: &SessionId,
145 ) -> Result<Option<TurnCheckpoint>, StoreError> {
146 if let Some(value) = self.cache.read_async(session_id, |_k, v| v.clone()).await {
147 return Ok(Some(value));
148 }
149
150 let loaded = self.load_from_disk(session_id).await?;
151 if let Some(ref checkpoint) = loaded {
152 let entry = self.cache.entry_async(session_id.clone()).await;
153 match entry {
154 scc::hash_map::Entry::Occupied(mut occ) => occ.get_mut().clone_from(checkpoint),
155 scc::hash_map::Entry::Vacant(vac) => {
156 let _ = vac.insert_entry(checkpoint.clone());
157 }
158 }
159 }
160 Ok(loaded)
161 }
162}
163
164fn encode_session_id(session_id: &str) -> String {
165 if session_id.is_empty() {
166 return "session".to_string();
167 }
168
169 let mut encoded = String::with_capacity(session_id.len().saturating_mul(2));
170 for byte in session_id.as_bytes() {
171 use std::fmt::Write as _;
172 let _ = write!(&mut encoded, "{byte:02x}");
173 }
174 encoded
175}
176
177async fn path_exists(path: &Path) -> bool {
178 tokio::fs::metadata(path).await.is_ok()
179}
180
181#[cfg(test)]
182mod tests {
183 use bob_core::types::TokenUsage;
184
185 use super::*;
186
187 #[tokio::test]
188 async fn missing_checkpoint_returns_none() {
189 let dir = tempfile::tempdir();
190 assert!(dir.is_ok());
191 let dir = match dir {
192 Ok(value) => value,
193 Err(_) => return,
194 };
195 let store = FileCheckpointStore::new(dir.path().to_path_buf());
196 assert!(store.is_ok());
197 let store = match store {
198 Ok(value) => value,
199 Err(_) => return,
200 };
201
202 let loaded = store.load_latest(&"missing".to_string()).await;
203 assert!(loaded.is_ok());
204 assert!(loaded.ok().flatten().is_none());
205 }
206
207 #[tokio::test]
208 async fn roundtrip_persists_across_store_recreation() {
209 let dir = tempfile::tempdir();
210 assert!(dir.is_ok());
211 let dir = match dir {
212 Ok(value) => value,
213 Err(_) => return,
214 };
215
216 let checkpoint = TurnCheckpoint {
217 session_id: "s/1".to_string(),
218 step: 3,
219 tool_calls: 2,
220 usage: TokenUsage { prompt_tokens: 5, completion_tokens: 6 },
221 };
222
223 let first = FileCheckpointStore::new(dir.path().to_path_buf());
224 assert!(first.is_ok());
225 let first = match first {
226 Ok(value) => value,
227 Err(_) => return,
228 };
229 let saved = first.save_checkpoint(&checkpoint).await;
230 assert!(saved.is_ok());
231
232 let second = FileCheckpointStore::new(dir.path().to_path_buf());
233 assert!(second.is_ok());
234 let second = match second {
235 Ok(value) => value,
236 Err(_) => return,
237 };
238 let loaded = second.load_latest(&"s/1".to_string()).await;
239 assert!(loaded.is_ok());
240 let loaded = loaded.ok().flatten();
241 assert!(loaded.is_some());
242 assert_eq!(loaded.as_ref().map(|cp| cp.step), Some(3));
243 }
244
245 #[tokio::test]
246 async fn corrupted_checkpoint_is_quarantined_and_treated_as_missing() {
247 let dir = tempfile::tempdir();
248 assert!(dir.is_ok());
249 let dir = match dir {
250 Ok(value) => value,
251 Err(_) => return,
252 };
253
254 let session_id = "broken-checkpoint".to_string();
255 let encoded = encode_session_id(&session_id);
256 let checkpoint_path = dir.path().join(format!("{encoded}.json"));
257 let write = tokio::fs::write(&checkpoint_path, b"{not-json").await;
258 assert!(write.is_ok());
259
260 let store = FileCheckpointStore::new(dir.path().to_path_buf());
261 assert!(store.is_ok());
262 let store = match store {
263 Ok(value) => value,
264 Err(_) => return,
265 };
266
267 let loaded = store.load_latest(&session_id).await;
268 assert!(loaded.is_ok());
269 assert!(loaded.ok().flatten().is_none());
270 assert!(!checkpoint_path.exists());
271
272 let mut has_quarantine = false;
273 let read_dir = std::fs::read_dir(dir.path());
274 assert!(read_dir.is_ok());
275 if let Ok(entries) = read_dir {
276 for entry in entries.flatten() {
277 let name = entry.file_name().to_string_lossy().to_string();
278 if name.contains(".corrupt.") {
279 has_quarantine = true;
280 break;
281 }
282 }
283 }
284 assert!(has_quarantine);
285 }
286}