Skip to main content

bob_adapters/
cost_file.rs

1//! File-backed cost meter adapter.
2
3use std::{
4    path::{Path, PathBuf},
5    time::{SystemTime, UNIX_EPOCH},
6};
7
8use bob_core::{
9    error::{CostError, StoreError},
10    ports::CostMeterPort,
11    types::{SessionId, TokenUsage, ToolResult},
12};
13
14#[derive(Debug, Clone, Copy, Default)]
15struct SessionCost {
16    total_tokens: u64,
17    tool_calls: u64,
18}
19
20impl SessionCost {
21    fn from_json_slice(raw: &[u8]) -> Result<Self, StoreError> {
22        let value = serde_json::from_slice::<serde_json::Value>(raw)
23            .map_err(|err| StoreError::Serialization(err.to_string()))?;
24        let object = value
25            .as_object()
26            .ok_or_else(|| StoreError::Serialization("expected JSON object".to_string()))?;
27        let total_tokens =
28            object.get("total_tokens").and_then(serde_json::Value::as_u64).unwrap_or(0);
29        let tool_calls = object.get("tool_calls").and_then(serde_json::Value::as_u64).unwrap_or(0);
30        Ok(Self { total_tokens, tool_calls })
31    }
32
33    fn to_json_vec(self) -> Result<Vec<u8>, StoreError> {
34        serde_json::to_vec_pretty(&serde_json::json!({
35            "total_tokens": self.total_tokens,
36            "tool_calls": self.tool_calls,
37        }))
38        .map_err(|err| StoreError::Serialization(err.to_string()))
39    }
40}
41
42/// Durable cost meter with optional per-session token budget.
43#[derive(Debug)]
44pub struct FileCostMeter {
45    root: PathBuf,
46    session_token_budget: Option<u64>,
47    cache: scc::HashMap<SessionId, SessionCost>,
48    write_guard: tokio::sync::Mutex<()>,
49}
50
51impl FileCostMeter {
52    /// Create a file-backed cost meter rooted at `root`.
53    ///
54    /// # Errors
55    /// Returns a backend error when the root directory cannot be created.
56    pub fn new(root: PathBuf, session_token_budget: Option<u64>) -> Result<Self, CostError> {
57        std::fs::create_dir_all(&root)
58            .map_err(|err| CostError::Backend(format!("failed to create cost dir: {err}")))?;
59        Ok(Self {
60            root,
61            session_token_budget,
62            cache: scc::HashMap::new(),
63            write_guard: tokio::sync::Mutex::new(()),
64        })
65    }
66
67    fn cost_path(&self, session_id: &SessionId) -> PathBuf {
68        self.root.join(format!("{}.json", encode_session_id(session_id)))
69    }
70
71    fn temp_path_for(final_path: &Path) -> PathBuf {
72        let nanos = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
73        final_path.with_extension(format!("json.tmp.{}.{}", std::process::id(), nanos))
74    }
75
76    fn quarantine_path_for(path: &Path) -> PathBuf {
77        let nanos = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
78        let filename = path.file_name().and_then(std::ffi::OsStr::to_str).unwrap_or("cost");
79        path.with_file_name(format!("{filename}.corrupt.{}.{}", std::process::id(), nanos))
80    }
81
82    async fn quarantine_corrupt_file(path: &Path) -> Result<PathBuf, CostError> {
83        let quarantine_path = Self::quarantine_path_for(path);
84        tokio::fs::rename(path, &quarantine_path).await.map_err(|err| {
85            CostError::Backend(format!(
86                "failed to quarantine corrupted cost snapshot '{}': {err}",
87                path.display()
88            ))
89        })?;
90        Ok(quarantine_path)
91    }
92
93    async fn load_from_disk(
94        &self,
95        session_id: &SessionId,
96    ) -> Result<Option<SessionCost>, CostError> {
97        let path = self.cost_path(session_id);
98        let raw = match tokio::fs::read(&path).await {
99            Ok(raw) => raw,
100            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
101            Err(err) => {
102                return Err(CostError::Backend(format!(
103                    "failed to read cost snapshot '{}': {err}",
104                    path.display()
105                )));
106            }
107        };
108
109        if let Ok(cost) = SessionCost::from_json_slice(&raw) {
110            return Ok(Some(cost));
111        }
112
113        let _ = Self::quarantine_corrupt_file(&path).await?;
114        Ok(None)
115    }
116
117    async fn save_to_disk(
118        &self,
119        session_id: &SessionId,
120        cost: SessionCost,
121    ) -> Result<(), CostError> {
122        let final_path = self.cost_path(session_id);
123        let temp_path = Self::temp_path_for(&final_path);
124        let bytes = cost.to_json_vec().map_err(|err| {
125            CostError::Backend(format!("failed to serialize cost snapshot: {err}"))
126        })?;
127
128        tokio::fs::write(&temp_path, bytes).await.map_err(|err| {
129            CostError::Backend(format!(
130                "failed to write temp cost snapshot '{}': {err}",
131                temp_path.display()
132            ))
133        })?;
134
135        if let Err(rename_err) = tokio::fs::rename(&temp_path, &final_path).await {
136            if path_exists(&final_path).await {
137                tokio::fs::remove_file(&final_path).await.map_err(|remove_err| {
138                    CostError::Backend(format!(
139                        "failed to replace existing cost snapshot '{}' after rename error '{rename_err}': {remove_err}",
140                        final_path.display()
141                    ))
142                })?;
143                tokio::fs::rename(&temp_path, &final_path).await.map_err(|err| {
144                    CostError::Backend(format!(
145                        "failed to replace cost snapshot '{}' after fallback remove: {err}",
146                        final_path.display()
147                    ))
148                })?;
149            } else {
150                return Err(CostError::Backend(format!(
151                    "failed to persist cost snapshot '{}': {rename_err}",
152                    final_path.display()
153                )));
154            }
155        }
156        Ok(())
157    }
158
159    fn ensure_session_budget(
160        &self,
161        session_id: &SessionId,
162        total_tokens: u64,
163    ) -> Result<(), CostError> {
164        let Some(limit) = self.session_token_budget else {
165            return Ok(());
166        };
167        if total_tokens > limit {
168            return Err(CostError::BudgetExceeded(format!(
169                "session '{session_id}' exceeded token budget ({total_tokens}>{limit})"
170            )));
171        }
172        Ok(())
173    }
174
175    async fn read_session_cost(&self, session_id: &SessionId) -> Result<SessionCost, CostError> {
176        if let Some(cost) = self.cache.read_async(session_id, |_k, value| *value).await {
177            return Ok(cost);
178        }
179
180        let loaded = self.load_from_disk(session_id).await?.unwrap_or_default();
181        let entry = self.cache.entry_async(session_id.clone()).await;
182        match entry {
183            scc::hash_map::Entry::Occupied(mut occ) => {
184                *occ.get_mut() = loaded;
185            }
186            scc::hash_map::Entry::Vacant(vac) => {
187                let _ = vac.insert_entry(loaded);
188            }
189        }
190        Ok(loaded)
191    }
192
193    async fn write_session_cost(
194        &self,
195        session_id: &SessionId,
196        session_cost: SessionCost,
197    ) -> Result<(), CostError> {
198        self.save_to_disk(session_id, session_cost).await?;
199        let entry = self.cache.entry_async(session_id.clone()).await;
200        match entry {
201            scc::hash_map::Entry::Occupied(mut occ) => {
202                *occ.get_mut() = session_cost;
203            }
204            scc::hash_map::Entry::Vacant(vac) => {
205                let _ = vac.insert_entry(session_cost);
206            }
207        }
208        Ok(())
209    }
210}
211
212#[async_trait::async_trait]
213impl CostMeterPort for FileCostMeter {
214    async fn check_budget(&self, session_id: &SessionId) -> Result<(), CostError> {
215        let Some(limit) = self.session_token_budget else {
216            return Ok(());
217        };
218        let session_cost = self.read_session_cost(session_id).await?;
219        if session_cost.total_tokens >= limit {
220            return Err(CostError::BudgetExceeded(format!(
221                "session '{session_id}' reached token budget ({}>={limit})",
222                session_cost.total_tokens
223            )));
224        }
225        Ok(())
226    }
227
228    async fn record_llm_usage(
229        &self,
230        session_id: &SessionId,
231        _model: &str,
232        usage: &TokenUsage,
233    ) -> Result<(), CostError> {
234        let _lock = self.write_guard.lock().await;
235        let mut session_cost = self.read_session_cost(session_id).await?;
236        session_cost.total_tokens =
237            session_cost.total_tokens.saturating_add(u64::from(usage.total()));
238        self.write_session_cost(session_id, session_cost).await?;
239        self.ensure_session_budget(session_id, session_cost.total_tokens)
240    }
241
242    async fn record_tool_result(
243        &self,
244        session_id: &SessionId,
245        _tool_result: &ToolResult,
246    ) -> Result<(), CostError> {
247        let _lock = self.write_guard.lock().await;
248        let mut session_cost = self.read_session_cost(session_id).await?;
249        session_cost.tool_calls = session_cost.tool_calls.saturating_add(1);
250        self.write_session_cost(session_id, session_cost).await
251    }
252}
253
254fn encode_session_id(session_id: &str) -> String {
255    if session_id.is_empty() {
256        return "session".to_string();
257    }
258
259    let mut encoded = String::with_capacity(session_id.len().saturating_mul(2));
260    for byte in session_id.as_bytes() {
261        use std::fmt::Write as _;
262        let _ = write!(&mut encoded, "{byte:02x}");
263    }
264    encoded
265}
266
267async fn path_exists(path: &Path) -> bool {
268    tokio::fs::metadata(path).await.is_ok()
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[tokio::test]
276    async fn no_budget_never_blocks() {
277        let dir = tempfile::tempdir();
278        assert!(dir.is_ok());
279        let dir = match dir {
280            Ok(value) => value,
281            Err(_) => return,
282        };
283
284        let meter = FileCostMeter::new(dir.path().to_path_buf(), None);
285        assert!(meter.is_ok());
286        let meter = match meter {
287            Ok(value) => value,
288            Err(_) => return,
289        };
290        let session = "s1".to_string();
291        assert!(meter.check_budget(&session).await.is_ok());
292        assert!(
293            meter
294                .record_llm_usage(
295                    &session,
296                    "test-model",
297                    &TokenUsage { prompt_tokens: 10, completion_tokens: 5 }
298                )
299                .await
300                .is_ok()
301        );
302        assert!(meter.check_budget(&session).await.is_ok());
303    }
304
305    #[tokio::test]
306    async fn usage_persists_across_recreation() {
307        let dir = tempfile::tempdir();
308        assert!(dir.is_ok());
309        let dir = match dir {
310            Ok(value) => value,
311            Err(_) => return,
312        };
313        let session = "s1".to_string();
314
315        let first = FileCostMeter::new(dir.path().to_path_buf(), Some(50));
316        assert!(first.is_ok());
317        let first = match first {
318            Ok(value) => value,
319            Err(_) => return,
320        };
321        let usage = first
322            .record_llm_usage(
323                &session,
324                "test-model",
325                &TokenUsage { prompt_tokens: 30, completion_tokens: 0 },
326            )
327            .await;
328        assert!(usage.is_ok());
329
330        let second = FileCostMeter::new(dir.path().to_path_buf(), Some(50));
331        assert!(second.is_ok());
332        let second = match second {
333            Ok(value) => value,
334            Err(_) => return,
335        };
336        let budget = second.check_budget(&session).await;
337        assert!(budget.is_ok(), "persisted usage below budget should pass");
338        let overflow = second
339            .record_llm_usage(
340                &session,
341                "test-model",
342                &TokenUsage { prompt_tokens: 25, completion_tokens: 0 },
343            )
344            .await;
345        assert!(overflow.is_err(), "persisted and new usage should trigger budget");
346
347        let third = FileCostMeter::new(dir.path().to_path_buf(), Some(50));
348        assert!(third.is_ok());
349        let third = match third {
350            Ok(value) => value,
351            Err(_) => return,
352        };
353        let budget = third.check_budget(&session).await;
354        assert!(budget.is_err(), "budget state should survive process restart");
355    }
356
357    #[tokio::test]
358    async fn corrupted_snapshot_is_quarantined_and_treated_as_empty() {
359        let dir = tempfile::tempdir();
360        assert!(dir.is_ok());
361        let dir = match dir {
362            Ok(value) => value,
363            Err(_) => return,
364        };
365        let session = "broken-cost".to_string();
366        let encoded = encode_session_id(&session);
367        let path = dir.path().join(format!("{encoded}.json"));
368        let write = tokio::fs::write(&path, b"{not-json").await;
369        assert!(write.is_ok());
370
371        let meter = FileCostMeter::new(dir.path().to_path_buf(), Some(10));
372        assert!(meter.is_ok());
373        let meter = match meter {
374            Ok(value) => value,
375            Err(_) => return,
376        };
377        let budget = meter.check_budget(&session).await;
378        assert!(budget.is_ok(), "corrupt snapshot should not block runtime start");
379        assert!(!path.exists(), "corrupt file should be quarantined");
380
381        let mut has_quarantine = false;
382        let read_dir = std::fs::read_dir(dir.path());
383        assert!(read_dir.is_ok());
384        if let Ok(entries) = read_dir {
385            for entry in entries.flatten() {
386                let name = entry.file_name().to_string_lossy().to_string();
387                if name.contains(".corrupt.") {
388                    has_quarantine = true;
389                    break;
390                }
391            }
392        }
393        assert!(has_quarantine);
394    }
395}