Skip to main content

bob_adapters/
cost_simple.rs

1//! Simple in-memory cost meter adapter.
2
3use bob_core::{
4    error::CostError,
5    ports::CostMeterPort,
6    types::{SessionId, TokenUsage, ToolResult},
7};
8
9#[derive(Debug, Clone, Copy, Default)]
10struct SessionCost {
11    total_tokens: u64,
12    tool_calls: u64,
13}
14
15/// In-memory cost meter with optional per-session token budget.
16#[derive(Debug)]
17pub struct SimpleCostMeter {
18    session_token_budget: Option<u64>,
19    sessions: scc::HashMap<SessionId, SessionCost>,
20}
21
22impl SimpleCostMeter {
23    #[must_use]
24    pub fn new(session_token_budget: Option<u64>) -> Self {
25        Self { session_token_budget, sessions: scc::HashMap::new() }
26    }
27
28    fn ensure_session_budget(
29        &self,
30        session_id: &SessionId,
31        total_tokens: u64,
32    ) -> Result<(), CostError> {
33        let Some(limit) = self.session_token_budget else {
34            return Ok(());
35        };
36        if total_tokens > limit {
37            return Err(CostError::BudgetExceeded(format!(
38                "session '{session_id}' exceeded token budget ({total_tokens}>{limit})"
39            )));
40        }
41        Ok(())
42    }
43}
44
45#[async_trait::async_trait]
46impl CostMeterPort for SimpleCostMeter {
47    async fn check_budget(&self, session_id: &SessionId) -> Result<(), CostError> {
48        let Some(limit) = self.session_token_budget else {
49            return Ok(());
50        };
51        let total = self.sessions.read_async(session_id, |_k, v| v.total_tokens).await.unwrap_or(0);
52        if total >= limit {
53            return Err(CostError::BudgetExceeded(format!(
54                "session '{session_id}' reached token budget ({total}>={limit})"
55            )));
56        }
57        Ok(())
58    }
59
60    async fn record_llm_usage(
61        &self,
62        session_id: &SessionId,
63        _model: &str,
64        usage: &TokenUsage,
65    ) -> Result<(), CostError> {
66        let usage_tokens = u64::from(usage.total());
67        let entry = self.sessions.entry_async(session_id.clone()).await;
68        let total_after = match entry {
69            scc::hash_map::Entry::Occupied(mut occ) => {
70                occ.get_mut().total_tokens += usage_tokens;
71                occ.get().total_tokens
72            }
73            scc::hash_map::Entry::Vacant(vac) => {
74                let inserted =
75                    vac.insert_entry(SessionCost { total_tokens: usage_tokens, tool_calls: 0 });
76                inserted.get().total_tokens
77            }
78        };
79        self.ensure_session_budget(session_id, total_after)
80    }
81
82    async fn record_tool_result(
83        &self,
84        session_id: &SessionId,
85        _tool_result: &ToolResult,
86    ) -> Result<(), CostError> {
87        let entry = self.sessions.entry_async(session_id.clone()).await;
88        match entry {
89            scc::hash_map::Entry::Occupied(mut occ) => {
90                occ.get_mut().tool_calls += 1;
91            }
92            scc::hash_map::Entry::Vacant(vac) => {
93                let _ = vac.insert_entry(SessionCost { total_tokens: 0, tool_calls: 1 });
94            }
95        }
96        Ok(())
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[tokio::test]
105    async fn no_budget_never_blocks() {
106        let meter = SimpleCostMeter::new(None);
107        let session = "s1".to_string();
108        assert!(meter.check_budget(&session).await.is_ok());
109        assert!(
110            meter
111                .record_llm_usage(
112                    &session,
113                    "test-model",
114                    &TokenUsage { prompt_tokens: 30, completion_tokens: 20 }
115                )
116                .await
117                .is_ok()
118        );
119        assert!(meter.check_budget(&session).await.is_ok());
120    }
121
122    #[tokio::test]
123    async fn check_budget_blocks_after_limit_is_reached() {
124        let meter = SimpleCostMeter::new(Some(100));
125        let session = "s1".to_string();
126
127        assert!(
128            meter
129                .record_llm_usage(
130                    &session,
131                    "test-model",
132                    &TokenUsage { prompt_tokens: 60, completion_tokens: 40 }
133                )
134                .await
135                .is_ok()
136        );
137        let result = meter.check_budget(&session).await;
138        assert!(result.is_err());
139        let message = result.err().map(|err| err.to_string()).unwrap_or_default();
140        assert!(message.contains("budget"));
141    }
142
143    #[tokio::test]
144    async fn record_usage_fails_when_exceeding_limit() {
145        let meter = SimpleCostMeter::new(Some(50));
146        let session = "s1".to_string();
147
148        let result = meter
149            .record_llm_usage(
150                &session,
151                "test-model",
152                &TokenUsage { prompt_tokens: 40, completion_tokens: 20 },
153            )
154            .await;
155        assert!(result.is_err());
156        let message = result.err().map(|err| err.to_string()).unwrap_or_default();
157        assert!(message.contains("exceeded"));
158    }
159}