Skip to main content

bob_adapters/
checkpoint_file.rs

1//! File-backed turn checkpoint store adapter.
2
3use std::{
4    path::{Path, PathBuf},
5    time::{SystemTime, UNIX_EPOCH},
6};
7
8use bob_core::{
9    error::StoreError,
10    ports::TurnCheckpointStorePort,
11    types::{SessionId, TurnCheckpoint},
12};
13
14/// Durable checkpoint store backed by per-session JSON snapshots.
15#[derive(Debug)]
16pub struct FileCheckpointStore {
17    root: PathBuf,
18    cache: scc::HashMap<SessionId, TurnCheckpoint>,
19    write_guard: tokio::sync::Mutex<()>,
20}
21
22impl FileCheckpointStore {
23    /// Create a file-backed checkpoint store rooted at `root`.
24    ///
25    /// # Errors
26    /// Returns a backend error when the root directory cannot be created.
27    pub fn new(root: PathBuf) -> Result<Self, StoreError> {
28        std::fs::create_dir_all(&root).map_err(|err| {
29            StoreError::Backend(format!("failed to create checkpoint dir: {err}"))
30        })?;
31        Ok(Self { root, cache: scc::HashMap::new(), write_guard: tokio::sync::Mutex::new(()) })
32    }
33
34    fn checkpoint_path(&self, session_id: &SessionId) -> PathBuf {
35        self.root.join(format!("{}.json", encode_session_id(session_id)))
36    }
37
38    fn temp_path_for(final_path: &Path) -> PathBuf {
39        let nanos = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
40        final_path.with_extension(format!("json.tmp.{}.{}", std::process::id(), nanos))
41    }
42
43    fn quarantine_path_for(path: &Path) -> PathBuf {
44        let nanos = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
45        let filename = path.file_name().and_then(std::ffi::OsStr::to_str).unwrap_or("checkpoint");
46        path.with_file_name(format!("{filename}.corrupt.{}.{}", std::process::id(), nanos))
47    }
48
49    async fn quarantine_corrupt_file(path: &Path) -> Result<PathBuf, StoreError> {
50        let quarantine_path = Self::quarantine_path_for(path);
51        tokio::fs::rename(path, &quarantine_path).await.map_err(|err| {
52            StoreError::Backend(format!(
53                "failed to quarantine corrupted checkpoint '{}': {err}",
54                path.display()
55            ))
56        })?;
57        Ok(quarantine_path)
58    }
59
60    async fn load_from_disk(
61        &self,
62        session_id: &SessionId,
63    ) -> Result<Option<TurnCheckpoint>, StoreError> {
64        let path = self.checkpoint_path(session_id);
65        let raw = match tokio::fs::read(&path).await {
66            Ok(raw) => raw,
67            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
68            Err(err) => {
69                return Err(StoreError::Backend(format!(
70                    "failed to read checkpoint '{}': {err}",
71                    path.display()
72                )));
73            }
74        };
75
76        let checkpoint = if let Ok(value) = serde_json::from_slice::<TurnCheckpoint>(&raw) {
77            value
78        } else {
79            let _ = Self::quarantine_corrupt_file(&path).await?;
80            return Ok(None);
81        };
82        Ok(Some(checkpoint))
83    }
84
85    async fn save_to_disk(
86        &self,
87        session_id: &SessionId,
88        checkpoint: &TurnCheckpoint,
89    ) -> Result<(), StoreError> {
90        let final_path = self.checkpoint_path(session_id);
91        let temp_path = Self::temp_path_for(&final_path);
92        let bytes = serde_json::to_vec_pretty(checkpoint)
93            .map_err(|err| StoreError::Serialization(err.to_string()))?;
94
95        tokio::fs::write(&temp_path, bytes).await.map_err(|err| {
96            StoreError::Backend(format!(
97                "failed to write temp checkpoint '{}': {err}",
98                temp_path.display()
99            ))
100        })?;
101
102        if let Err(rename_err) = tokio::fs::rename(&temp_path, &final_path).await {
103            if path_exists(&final_path).await {
104                tokio::fs::remove_file(&final_path).await.map_err(|remove_err| {
105                    StoreError::Backend(format!(
106                        "failed to replace existing checkpoint '{}' after rename error '{rename_err}': {remove_err}",
107                        final_path.display()
108                    ))
109                })?;
110                tokio::fs::rename(&temp_path, &final_path).await.map_err(|err| {
111                    StoreError::Backend(format!(
112                        "failed to replace checkpoint '{}' after fallback remove: {err}",
113                        final_path.display()
114                    ))
115                })?;
116            } else {
117                return Err(StoreError::Backend(format!(
118                    "failed to persist checkpoint '{}': {rename_err}",
119                    final_path.display()
120                )));
121            }
122        }
123        Ok(())
124    }
125}
126
127#[async_trait::async_trait]
128impl TurnCheckpointStorePort for FileCheckpointStore {
129    async fn save_checkpoint(&self, checkpoint: &TurnCheckpoint) -> Result<(), StoreError> {
130        let _lock = self.write_guard.lock().await;
131        self.save_to_disk(&checkpoint.session_id, checkpoint).await?;
132        let entry = self.cache.entry_async(checkpoint.session_id.clone()).await;
133        match entry {
134            scc::hash_map::Entry::Occupied(mut occ) => occ.get_mut().clone_from(checkpoint),
135            scc::hash_map::Entry::Vacant(vac) => {
136                let _ = vac.insert_entry(checkpoint.clone());
137            }
138        }
139        Ok(())
140    }
141
142    async fn load_latest(
143        &self,
144        session_id: &SessionId,
145    ) -> Result<Option<TurnCheckpoint>, StoreError> {
146        if let Some(value) = self.cache.read_async(session_id, |_k, v| v.clone()).await {
147            return Ok(Some(value));
148        }
149
150        let loaded = self.load_from_disk(session_id).await?;
151        if let Some(ref checkpoint) = loaded {
152            let entry = self.cache.entry_async(session_id.clone()).await;
153            match entry {
154                scc::hash_map::Entry::Occupied(mut occ) => occ.get_mut().clone_from(checkpoint),
155                scc::hash_map::Entry::Vacant(vac) => {
156                    let _ = vac.insert_entry(checkpoint.clone());
157                }
158            }
159        }
160        Ok(loaded)
161    }
162}
163
164fn encode_session_id(session_id: &str) -> String {
165    if session_id.is_empty() {
166        return "session".to_string();
167    }
168
169    let mut encoded = String::with_capacity(session_id.len().saturating_mul(2));
170    for byte in session_id.as_bytes() {
171        use std::fmt::Write as _;
172        let _ = write!(&mut encoded, "{byte:02x}");
173    }
174    encoded
175}
176
177async fn path_exists(path: &Path) -> bool {
178    tokio::fs::metadata(path).await.is_ok()
179}
180
181#[cfg(test)]
182mod tests {
183    use bob_core::types::TokenUsage;
184
185    use super::*;
186
187    #[tokio::test]
188    async fn missing_checkpoint_returns_none() {
189        let dir = tempfile::tempdir();
190        assert!(dir.is_ok());
191        let dir = match dir {
192            Ok(value) => value,
193            Err(_) => return,
194        };
195        let store = FileCheckpointStore::new(dir.path().to_path_buf());
196        assert!(store.is_ok());
197        let store = match store {
198            Ok(value) => value,
199            Err(_) => return,
200        };
201
202        let loaded = store.load_latest(&"missing".to_string()).await;
203        assert!(loaded.is_ok());
204        assert!(loaded.ok().flatten().is_none());
205    }
206
207    #[tokio::test]
208    async fn roundtrip_persists_across_store_recreation() {
209        let dir = tempfile::tempdir();
210        assert!(dir.is_ok());
211        let dir = match dir {
212            Ok(value) => value,
213            Err(_) => return,
214        };
215
216        let checkpoint = TurnCheckpoint {
217            session_id: "s/1".to_string(),
218            step: 3,
219            tool_calls: 2,
220            usage: TokenUsage { prompt_tokens: 5, completion_tokens: 6 },
221        };
222
223        let first = FileCheckpointStore::new(dir.path().to_path_buf());
224        assert!(first.is_ok());
225        let first = match first {
226            Ok(value) => value,
227            Err(_) => return,
228        };
229        let saved = first.save_checkpoint(&checkpoint).await;
230        assert!(saved.is_ok());
231
232        let second = FileCheckpointStore::new(dir.path().to_path_buf());
233        assert!(second.is_ok());
234        let second = match second {
235            Ok(value) => value,
236            Err(_) => return,
237        };
238        let loaded = second.load_latest(&"s/1".to_string()).await;
239        assert!(loaded.is_ok());
240        let loaded = loaded.ok().flatten();
241        assert!(loaded.is_some());
242        assert_eq!(loaded.as_ref().map(|cp| cp.step), Some(3));
243    }
244
245    #[tokio::test]
246    async fn corrupted_checkpoint_is_quarantined_and_treated_as_missing() {
247        let dir = tempfile::tempdir();
248        assert!(dir.is_ok());
249        let dir = match dir {
250            Ok(value) => value,
251            Err(_) => return,
252        };
253
254        let session_id = "broken-checkpoint".to_string();
255        let encoded = encode_session_id(&session_id);
256        let checkpoint_path = dir.path().join(format!("{encoded}.json"));
257        let write = tokio::fs::write(&checkpoint_path, b"{not-json").await;
258        assert!(write.is_ok());
259
260        let store = FileCheckpointStore::new(dir.path().to_path_buf());
261        assert!(store.is_ok());
262        let store = match store {
263            Ok(value) => value,
264            Err(_) => return,
265        };
266
267        let loaded = store.load_latest(&session_id).await;
268        assert!(loaded.is_ok());
269        assert!(loaded.ok().flatten().is_none());
270        assert!(!checkpoint_path.exists());
271
272        let mut has_quarantine = false;
273        let read_dir = std::fs::read_dir(dir.path());
274        assert!(read_dir.is_ok());
275        if let Ok(entries) = read_dir {
276            for entry in entries.flatten() {
277                let name = entry.file_name().to_string_lossy().to_string();
278                if name.contains(".corrupt.") {
279                    has_quarantine = true;
280                    break;
281                }
282            }
283        }
284        assert!(has_quarantine);
285    }
286}