1use std::{
4 path::{Path, PathBuf},
5 time::{SystemTime, UNIX_EPOCH},
6};
7
8use bob_core::{
9 error::{CostError, StoreError},
10 ports::CostMeterPort,
11 types::{SessionId, TokenUsage, ToolResult},
12};
13
14#[derive(Debug, Clone, Copy, Default)]
15struct SessionCost {
16 total_tokens: u64,
17 tool_calls: u64,
18}
19
20impl SessionCost {
21 fn from_json_slice(raw: &[u8]) -> Result<Self, StoreError> {
22 let value = serde_json::from_slice::<serde_json::Value>(raw)
23 .map_err(|err| StoreError::Serialization(err.to_string()))?;
24 let object = value
25 .as_object()
26 .ok_or_else(|| StoreError::Serialization("expected JSON object".to_string()))?;
27 let total_tokens =
28 object.get("total_tokens").and_then(serde_json::Value::as_u64).unwrap_or(0);
29 let tool_calls = object.get("tool_calls").and_then(serde_json::Value::as_u64).unwrap_or(0);
30 Ok(Self { total_tokens, tool_calls })
31 }
32
33 fn to_json_vec(self) -> Result<Vec<u8>, StoreError> {
34 serde_json::to_vec_pretty(&serde_json::json!({
35 "total_tokens": self.total_tokens,
36 "tool_calls": self.tool_calls,
37 }))
38 .map_err(|err| StoreError::Serialization(err.to_string()))
39 }
40}
41
42#[derive(Debug)]
44pub struct FileCostMeter {
45 root: PathBuf,
46 session_token_budget: Option<u64>,
47 cache: scc::HashMap<SessionId, SessionCost>,
48 write_guard: tokio::sync::Mutex<()>,
49}
50
51impl FileCostMeter {
52 pub fn new(root: PathBuf, session_token_budget: Option<u64>) -> Result<Self, CostError> {
57 std::fs::create_dir_all(&root)
58 .map_err(|err| CostError::Backend(format!("failed to create cost dir: {err}")))?;
59 Ok(Self {
60 root,
61 session_token_budget,
62 cache: scc::HashMap::new(),
63 write_guard: tokio::sync::Mutex::new(()),
64 })
65 }
66
67 fn cost_path(&self, session_id: &SessionId) -> PathBuf {
68 self.root.join(format!("{}.json", encode_session_id(session_id)))
69 }
70
71 fn temp_path_for(final_path: &Path) -> PathBuf {
72 let nanos = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
73 final_path.with_extension(format!("json.tmp.{}.{}", std::process::id(), nanos))
74 }
75
76 fn quarantine_path_for(path: &Path) -> PathBuf {
77 let nanos = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
78 let filename = path.file_name().and_then(std::ffi::OsStr::to_str).unwrap_or("cost");
79 path.with_file_name(format!("{filename}.corrupt.{}.{}", std::process::id(), nanos))
80 }
81
82 async fn quarantine_corrupt_file(path: &Path) -> Result<PathBuf, CostError> {
83 let quarantine_path = Self::quarantine_path_for(path);
84 tokio::fs::rename(path, &quarantine_path).await.map_err(|err| {
85 CostError::Backend(format!(
86 "failed to quarantine corrupted cost snapshot '{}': {err}",
87 path.display()
88 ))
89 })?;
90 Ok(quarantine_path)
91 }
92
93 async fn load_from_disk(
94 &self,
95 session_id: &SessionId,
96 ) -> Result<Option<SessionCost>, CostError> {
97 let path = self.cost_path(session_id);
98 let raw = match tokio::fs::read(&path).await {
99 Ok(raw) => raw,
100 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
101 Err(err) => {
102 return Err(CostError::Backend(format!(
103 "failed to read cost snapshot '{}': {err}",
104 path.display()
105 )));
106 }
107 };
108
109 if let Ok(cost) = SessionCost::from_json_slice(&raw) {
110 return Ok(Some(cost));
111 }
112
113 let _ = Self::quarantine_corrupt_file(&path).await?;
114 Ok(None)
115 }
116
117 async fn save_to_disk(
118 &self,
119 session_id: &SessionId,
120 cost: SessionCost,
121 ) -> Result<(), CostError> {
122 let final_path = self.cost_path(session_id);
123 let temp_path = Self::temp_path_for(&final_path);
124 let bytes = cost.to_json_vec().map_err(|err| {
125 CostError::Backend(format!("failed to serialize cost snapshot: {err}"))
126 })?;
127
128 tokio::fs::write(&temp_path, bytes).await.map_err(|err| {
129 CostError::Backend(format!(
130 "failed to write temp cost snapshot '{}': {err}",
131 temp_path.display()
132 ))
133 })?;
134
135 if let Err(rename_err) = tokio::fs::rename(&temp_path, &final_path).await {
136 if path_exists(&final_path).await {
137 tokio::fs::remove_file(&final_path).await.map_err(|remove_err| {
138 CostError::Backend(format!(
139 "failed to replace existing cost snapshot '{}' after rename error '{rename_err}': {remove_err}",
140 final_path.display()
141 ))
142 })?;
143 tokio::fs::rename(&temp_path, &final_path).await.map_err(|err| {
144 CostError::Backend(format!(
145 "failed to replace cost snapshot '{}' after fallback remove: {err}",
146 final_path.display()
147 ))
148 })?;
149 } else {
150 return Err(CostError::Backend(format!(
151 "failed to persist cost snapshot '{}': {rename_err}",
152 final_path.display()
153 )));
154 }
155 }
156 Ok(())
157 }
158
159 fn ensure_session_budget(
160 &self,
161 session_id: &SessionId,
162 total_tokens: u64,
163 ) -> Result<(), CostError> {
164 let Some(limit) = self.session_token_budget else {
165 return Ok(());
166 };
167 if total_tokens > limit {
168 return Err(CostError::BudgetExceeded(format!(
169 "session '{session_id}' exceeded token budget ({total_tokens}>{limit})"
170 )));
171 }
172 Ok(())
173 }
174
175 async fn read_session_cost(&self, session_id: &SessionId) -> Result<SessionCost, CostError> {
176 if let Some(cost) = self.cache.read_async(session_id, |_k, value| *value).await {
177 return Ok(cost);
178 }
179
180 let loaded = self.load_from_disk(session_id).await?.unwrap_or_default();
181 let entry = self.cache.entry_async(session_id.clone()).await;
182 match entry {
183 scc::hash_map::Entry::Occupied(mut occ) => {
184 *occ.get_mut() = loaded;
185 }
186 scc::hash_map::Entry::Vacant(vac) => {
187 let _ = vac.insert_entry(loaded);
188 }
189 }
190 Ok(loaded)
191 }
192
193 async fn write_session_cost(
194 &self,
195 session_id: &SessionId,
196 session_cost: SessionCost,
197 ) -> Result<(), CostError> {
198 self.save_to_disk(session_id, session_cost).await?;
199 let entry = self.cache.entry_async(session_id.clone()).await;
200 match entry {
201 scc::hash_map::Entry::Occupied(mut occ) => {
202 *occ.get_mut() = session_cost;
203 }
204 scc::hash_map::Entry::Vacant(vac) => {
205 let _ = vac.insert_entry(session_cost);
206 }
207 }
208 Ok(())
209 }
210}
211
212#[async_trait::async_trait]
213impl CostMeterPort for FileCostMeter {
214 async fn check_budget(&self, session_id: &SessionId) -> Result<(), CostError> {
215 let Some(limit) = self.session_token_budget else {
216 return Ok(());
217 };
218 let session_cost = self.read_session_cost(session_id).await?;
219 if session_cost.total_tokens >= limit {
220 return Err(CostError::BudgetExceeded(format!(
221 "session '{session_id}' reached token budget ({}>={limit})",
222 session_cost.total_tokens
223 )));
224 }
225 Ok(())
226 }
227
228 async fn record_llm_usage(
229 &self,
230 session_id: &SessionId,
231 _model: &str,
232 usage: &TokenUsage,
233 ) -> Result<(), CostError> {
234 let _lock = self.write_guard.lock().await;
235 let mut session_cost = self.read_session_cost(session_id).await?;
236 session_cost.total_tokens =
237 session_cost.total_tokens.saturating_add(u64::from(usage.total()));
238 self.write_session_cost(session_id, session_cost).await?;
239 self.ensure_session_budget(session_id, session_cost.total_tokens)
240 }
241
242 async fn record_tool_result(
243 &self,
244 session_id: &SessionId,
245 _tool_result: &ToolResult,
246 ) -> Result<(), CostError> {
247 let _lock = self.write_guard.lock().await;
248 let mut session_cost = self.read_session_cost(session_id).await?;
249 session_cost.tool_calls = session_cost.tool_calls.saturating_add(1);
250 self.write_session_cost(session_id, session_cost).await
251 }
252}
253
254fn encode_session_id(session_id: &str) -> String {
255 if session_id.is_empty() {
256 return "session".to_string();
257 }
258
259 let mut encoded = String::with_capacity(session_id.len().saturating_mul(2));
260 for byte in session_id.as_bytes() {
261 use std::fmt::Write as _;
262 let _ = write!(&mut encoded, "{byte:02x}");
263 }
264 encoded
265}
266
267async fn path_exists(path: &Path) -> bool {
268 tokio::fs::metadata(path).await.is_ok()
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[tokio::test]
276 async fn no_budget_never_blocks() {
277 let dir = tempfile::tempdir();
278 assert!(dir.is_ok());
279 let dir = match dir {
280 Ok(value) => value,
281 Err(_) => return,
282 };
283
284 let meter = FileCostMeter::new(dir.path().to_path_buf(), None);
285 assert!(meter.is_ok());
286 let meter = match meter {
287 Ok(value) => value,
288 Err(_) => return,
289 };
290 let session = "s1".to_string();
291 assert!(meter.check_budget(&session).await.is_ok());
292 assert!(
293 meter
294 .record_llm_usage(
295 &session,
296 "test-model",
297 &TokenUsage { prompt_tokens: 10, completion_tokens: 5 }
298 )
299 .await
300 .is_ok()
301 );
302 assert!(meter.check_budget(&session).await.is_ok());
303 }
304
305 #[tokio::test]
306 async fn usage_persists_across_recreation() {
307 let dir = tempfile::tempdir();
308 assert!(dir.is_ok());
309 let dir = match dir {
310 Ok(value) => value,
311 Err(_) => return,
312 };
313 let session = "s1".to_string();
314
315 let first = FileCostMeter::new(dir.path().to_path_buf(), Some(50));
316 assert!(first.is_ok());
317 let first = match first {
318 Ok(value) => value,
319 Err(_) => return,
320 };
321 let usage = first
322 .record_llm_usage(
323 &session,
324 "test-model",
325 &TokenUsage { prompt_tokens: 30, completion_tokens: 0 },
326 )
327 .await;
328 assert!(usage.is_ok());
329
330 let second = FileCostMeter::new(dir.path().to_path_buf(), Some(50));
331 assert!(second.is_ok());
332 let second = match second {
333 Ok(value) => value,
334 Err(_) => return,
335 };
336 let budget = second.check_budget(&session).await;
337 assert!(budget.is_ok(), "persisted usage below budget should pass");
338 let overflow = second
339 .record_llm_usage(
340 &session,
341 "test-model",
342 &TokenUsage { prompt_tokens: 25, completion_tokens: 0 },
343 )
344 .await;
345 assert!(overflow.is_err(), "persisted and new usage should trigger budget");
346
347 let third = FileCostMeter::new(dir.path().to_path_buf(), Some(50));
348 assert!(third.is_ok());
349 let third = match third {
350 Ok(value) => value,
351 Err(_) => return,
352 };
353 let budget = third.check_budget(&session).await;
354 assert!(budget.is_err(), "budget state should survive process restart");
355 }
356
357 #[tokio::test]
358 async fn corrupted_snapshot_is_quarantined_and_treated_as_empty() {
359 let dir = tempfile::tempdir();
360 assert!(dir.is_ok());
361 let dir = match dir {
362 Ok(value) => value,
363 Err(_) => return,
364 };
365 let session = "broken-cost".to_string();
366 let encoded = encode_session_id(&session);
367 let path = dir.path().join(format!("{encoded}.json"));
368 let write = tokio::fs::write(&path, b"{not-json").await;
369 assert!(write.is_ok());
370
371 let meter = FileCostMeter::new(dir.path().to_path_buf(), Some(10));
372 assert!(meter.is_ok());
373 let meter = match meter {
374 Ok(value) => value,
375 Err(_) => return,
376 };
377 let budget = meter.check_budget(&session).await;
378 assert!(budget.is_ok(), "corrupt snapshot should not block runtime start");
379 assert!(!path.exists(), "corrupt file should be quarantined");
380
381 let mut has_quarantine = false;
382 let read_dir = std::fs::read_dir(dir.path());
383 assert!(read_dir.is_ok());
384 if let Ok(entries) = read_dir {
385 for entry in entries.flatten() {
386 let name = entry.file_name().to_string_lossy().to_string();
387 if name.contains(".corrupt.") {
388 has_quarantine = true;
389 break;
390 }
391 }
392 }
393 assert!(has_quarantine);
394 }
395}