bob_adapters/
cost_simple.rs1use bob_core::{
4 error::CostError,
5 ports::CostMeterPort,
6 types::{SessionId, TokenUsage, ToolResult},
7};
8
9#[derive(Debug, Clone, Copy, Default)]
10struct SessionCost {
11 total_tokens: u64,
12 tool_calls: u64,
13}
14
15#[derive(Debug)]
17pub struct SimpleCostMeter {
18 session_token_budget: Option<u64>,
19 sessions: scc::HashMap<SessionId, SessionCost>,
20}
21
22impl SimpleCostMeter {
23 #[must_use]
24 pub fn new(session_token_budget: Option<u64>) -> Self {
25 Self { session_token_budget, sessions: scc::HashMap::new() }
26 }
27
28 fn ensure_session_budget(
29 &self,
30 session_id: &SessionId,
31 total_tokens: u64,
32 ) -> Result<(), CostError> {
33 let Some(limit) = self.session_token_budget else {
34 return Ok(());
35 };
36 if total_tokens > limit {
37 return Err(CostError::BudgetExceeded(format!(
38 "session '{session_id}' exceeded token budget ({total_tokens}>{limit})"
39 )));
40 }
41 Ok(())
42 }
43}
44
45#[async_trait::async_trait]
46impl CostMeterPort for SimpleCostMeter {
47 async fn check_budget(&self, session_id: &SessionId) -> Result<(), CostError> {
48 let Some(limit) = self.session_token_budget else {
49 return Ok(());
50 };
51 let total = self.sessions.read_async(session_id, |_k, v| v.total_tokens).await.unwrap_or(0);
52 if total >= limit {
53 return Err(CostError::BudgetExceeded(format!(
54 "session '{session_id}' reached token budget ({total}>={limit})"
55 )));
56 }
57 Ok(())
58 }
59
60 async fn record_llm_usage(
61 &self,
62 session_id: &SessionId,
63 _model: &str,
64 usage: &TokenUsage,
65 ) -> Result<(), CostError> {
66 let usage_tokens = u64::from(usage.total());
67 let entry = self.sessions.entry_async(session_id.clone()).await;
68 let total_after = match entry {
69 scc::hash_map::Entry::Occupied(mut occ) => {
70 occ.get_mut().total_tokens += usage_tokens;
71 occ.get().total_tokens
72 }
73 scc::hash_map::Entry::Vacant(vac) => {
74 let inserted =
75 vac.insert_entry(SessionCost { total_tokens: usage_tokens, tool_calls: 0 });
76 inserted.get().total_tokens
77 }
78 };
79 self.ensure_session_budget(session_id, total_after)
80 }
81
82 async fn record_tool_result(
83 &self,
84 session_id: &SessionId,
85 _tool_result: &ToolResult,
86 ) -> Result<(), CostError> {
87 let entry = self.sessions.entry_async(session_id.clone()).await;
88 match entry {
89 scc::hash_map::Entry::Occupied(mut occ) => {
90 occ.get_mut().tool_calls += 1;
91 }
92 scc::hash_map::Entry::Vacant(vac) => {
93 let _ = vac.insert_entry(SessionCost { total_tokens: 0, tool_calls: 1 });
94 }
95 }
96 Ok(())
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[tokio::test]
105 async fn no_budget_never_blocks() {
106 let meter = SimpleCostMeter::new(None);
107 let session = "s1".to_string();
108 assert!(meter.check_budget(&session).await.is_ok());
109 assert!(
110 meter
111 .record_llm_usage(
112 &session,
113 "test-model",
114 &TokenUsage { prompt_tokens: 30, completion_tokens: 20 }
115 )
116 .await
117 .is_ok()
118 );
119 assert!(meter.check_budget(&session).await.is_ok());
120 }
121
122 #[tokio::test]
123 async fn check_budget_blocks_after_limit_is_reached() {
124 let meter = SimpleCostMeter::new(Some(100));
125 let session = "s1".to_string();
126
127 assert!(
128 meter
129 .record_llm_usage(
130 &session,
131 "test-model",
132 &TokenUsage { prompt_tokens: 60, completion_tokens: 40 }
133 )
134 .await
135 .is_ok()
136 );
137 let result = meter.check_budget(&session).await;
138 assert!(result.is_err());
139 let message = result.err().map(|err| err.to_string()).unwrap_or_default();
140 assert!(message.contains("budget"));
141 }
142
143 #[tokio::test]
144 async fn record_usage_fails_when_exceeding_limit() {
145 let meter = SimpleCostMeter::new(Some(50));
146 let session = "s1".to_string();
147
148 let result = meter
149 .record_llm_usage(
150 &session,
151 "test-model",
152 &TokenUsage { prompt_tokens: 40, completion_tokens: 20 },
153 )
154 .await;
155 assert!(result.is_err());
156 let message = result.err().map(|err| err.to_string()).unwrap_or_default();
157 assert!(message.contains("exceeded"));
158 }
159}