Skip to main content

openhawk_core/
token_tracker.rs

1use rusqlite::{Connection, params};
2use thiserror::Error;
3
4#[derive(Debug, Error)]
5pub enum TokenError {
6    #[error("database error: {0}")]
7    Database(String),
8}
9
10impl From<rusqlite::Error> for TokenError {
11    fn from(e: rusqlite::Error) -> Self {
12        TokenError::Database(e.to_string())
13    }
14}
15
16pub struct AgentTokenStats {
17    pub agent_pid: u32,
18    pub total_prompt_tokens: u64,
19    pub total_completion_tokens: u64,
20    pub total_tokens: u64,
21    pub estimated_cost: f64,
22}
23
24pub struct DailyTokenUsage {
25    pub date: String, // YYYY-MM-DD
26    pub tokens: u64,
27    pub cost: f64,
28}
29
30pub struct TokenTracker {
31    db: Connection,
32}
33
34impl TokenTracker {
35    pub fn new(db: Connection) -> Self {
36        Self { db }
37    }
38
39    pub fn record(
40        &self,
41        agent_pid: u32,
42        provider: &str,
43        prompt_tokens: u32,
44        completion_tokens: u32,
45        pricing: Option<f64>,
46    ) -> Result<(), TokenError> {
47        let estimated_cost = pricing.map(|price| (prompt_tokens + completion_tokens) as f64 * price);
48        self.db.execute(
49            "INSERT INTO token_usage (agent_pid, timestamp, provider, prompt_tokens, completion_tokens, estimated_cost) \
50             VALUES (?1, datetime('now'), ?2, ?3, ?4, ?5)",
51            params![agent_pid, provider, prompt_tokens, completion_tokens, estimated_cost],
52        )?;
53        Ok(())
54    }
55
56    pub fn get_agent_stats(&self, agent_pid: u32) -> Result<AgentTokenStats, TokenError> {
57        let (prompt, completion, cost) = self.db.query_row(
58            "SELECT COALESCE(SUM(prompt_tokens), 0), COALESCE(SUM(completion_tokens), 0), COALESCE(SUM(estimated_cost), 0.0) \
59             FROM token_usage WHERE agent_pid = ?1",
60            params![agent_pid],
61            |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?, row.get::<_, f64>(2)?)),
62        )?;
63        let total_prompt = prompt as u64;
64        let total_completion = completion as u64;
65        Ok(AgentTokenStats {
66            agent_pid,
67            total_prompt_tokens: total_prompt,
68            total_completion_tokens: total_completion,
69            total_tokens: total_prompt + total_completion,
70            estimated_cost: cost,
71        })
72    }
73
74    pub fn get_all_stats(&self) -> Result<Vec<(u32, AgentTokenStats)>, TokenError> {
75        let mut stmt = self.db.prepare(
76            "SELECT agent_pid, COALESCE(SUM(prompt_tokens), 0), COALESCE(SUM(completion_tokens), 0), COALESCE(SUM(estimated_cost), 0.0) \
77             FROM token_usage GROUP BY agent_pid ORDER BY agent_pid",
78        )?;
79        let rows = stmt.query_map([], |row| {
80            let pid = row.get::<_, i64>(0)? as u32;
81            let prompt = row.get::<_, i64>(1)? as u64;
82            let completion = row.get::<_, i64>(2)? as u64;
83            let cost = row.get::<_, f64>(3)?;
84            Ok((pid, prompt, completion, cost))
85        })?;
86        let mut result = Vec::new();
87        for row in rows {
88            let (pid, prompt, completion, cost) = row?;
89            result.push((pid, AgentTokenStats {
90                agent_pid: pid,
91                total_prompt_tokens: prompt,
92                total_completion_tokens: completion,
93                total_tokens: prompt + completion,
94                estimated_cost: cost,
95            }));
96        }
97        Ok(result)
98    }
99
100    pub fn get_7day_trend(&self, agent_pid: u32) -> Result<Vec<DailyTokenUsage>, TokenError> {
101        let mut stmt = self.db.prepare(
102            "SELECT date(timestamp) AS day, COALESCE(SUM(prompt_tokens + completion_tokens), 0), COALESCE(SUM(estimated_cost), 0.0) \
103             FROM token_usage WHERE agent_pid = ?1 AND timestamp >= datetime('now', '-7 days') \
104             GROUP BY day ORDER BY day ASC",
105        )?;
106        let rows = stmt.query_map(params![agent_pid], |row| {
107            Ok(DailyTokenUsage {
108                date: row.get(0)?,
109                tokens: row.get::<_, i64>(1)? as u64,
110                cost: row.get(2)?,
111            })
112        })?;
113        let mut trend = Vec::new();
114        for row in rows {
115            trend.push(row?);
116        }
117        Ok(trend)
118    }
119
120    pub fn check_budget(&self, agent_pid: u32, budget_tokens: u64) -> Result<bool, TokenError> {
121        let total: i64 = self.db.query_row(
122            "SELECT COALESCE(SUM(prompt_tokens + completion_tokens), 0) FROM token_usage WHERE agent_pid = ?1",
123            params![agent_pid],
124            |row| row.get(0),
125        )?;
126        Ok(total as u64 > budget_tokens)
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::db::init_database;
134    use tempfile::NamedTempFile;
135
136    fn make_tracker() -> (NamedTempFile, TokenTracker) {
137        let f = NamedTempFile::new().unwrap();
138        let conn = init_database(f.path()).unwrap();
139        (f, TokenTracker::new(conn))
140    }
141
142    #[test]
143    fn record_stores_correct_values() {
144        let (_f, tracker) = make_tracker();
145        tracker.record(1, "openai", 100, 50, Some(0.00003)).unwrap();
146        let stats = tracker.get_agent_stats(1).unwrap();
147        assert_eq!(stats.total_prompt_tokens, 100);
148        assert_eq!(stats.total_completion_tokens, 50);
149        assert_eq!(stats.total_tokens, 150);
150    }
151
152    #[test]
153    fn record_without_pricing_stores_zero_cost() {
154        let (_f, tracker) = make_tracker();
155        tracker.record(2, "ollama", 200, 100, None).unwrap();
156        let stats = tracker.get_agent_stats(2).unwrap();
157        assert_eq!(stats.total_tokens, 300);
158        assert!((stats.estimated_cost - 0.0).abs() < f64::EPSILON);
159    }
160
161    #[test]
162    fn multiple_records_accumulate() {
163        let (_f, tracker) = make_tracker();
164        tracker.record(3, "openai", 100, 50, Some(0.00003)).unwrap();
165        tracker.record(3, "openai", 200, 100, Some(0.00003)).unwrap();
166        let stats = tracker.get_agent_stats(3).unwrap();
167        assert_eq!(stats.total_tokens, 450);
168    }
169
170    #[test]
171    fn cost_calculation_uses_pricing() {
172        let (_f, tracker) = make_tracker();
173        let price = 0.00003_f64;
174        tracker.record(4, "openai", 1000, 500, Some(price)).unwrap();
175        let stats = tracker.get_agent_stats(4).unwrap();
176        let expected = 1500.0 * price;
177        assert!((stats.estimated_cost - expected).abs() < 1e-9);
178    }
179
180    #[test]
181    fn get_all_stats_returns_per_agent() {
182        let (_f, tracker) = make_tracker();
183        tracker.record(10, "openai", 100, 50, Some(0.00003)).unwrap();
184        tracker.record(11, "ollama", 200, 100, None).unwrap();
185        let all = tracker.get_all_stats().unwrap();
186        assert_eq!(all.len(), 2);
187        let pids: Vec<u32> = all.iter().map(|(p, _)| *p).collect();
188        assert!(pids.contains(&10));
189        assert!(pids.contains(&11));
190    }
191
192    #[test]
193    fn check_budget_returns_false_when_under() {
194        let (_f, tracker) = make_tracker();
195        tracker.record(20, "openai", 100, 50, None).unwrap();
196        assert!(!tracker.check_budget(20, 1000).unwrap());
197    }
198
199    #[test]
200    fn check_budget_returns_true_when_over() {
201        let (_f, tracker) = make_tracker();
202        tracker.record(21, "openai", 600, 500, None).unwrap();
203        assert!(tracker.check_budget(21, 1000).unwrap());
204    }
205
206    #[test]
207    fn check_budget_exact_limit_not_exceeded() {
208        let (_f, tracker) = make_tracker();
209        tracker.record(22, "openai", 500, 500, None).unwrap();
210        assert!(!tracker.check_budget(22, 1000).unwrap());
211    }
212
213    #[test]
214    fn check_budget_no_records_not_exceeded() {
215        let (_f, tracker) = make_tracker();
216        assert!(!tracker.check_budget(99, 100).unwrap());
217    }
218
219    #[test]
220    fn get_7day_trend_returns_empty_for_no_records() {
221        let (_f, tracker) = make_tracker();
222        assert!(tracker.get_7day_trend(50).unwrap().is_empty());
223    }
224
225    #[test]
226    fn get_7day_trend_includes_recent_records() {
227        let (_f, tracker) = make_tracker();
228        tracker.record(30, "openai", 100, 50, Some(0.00003)).unwrap();
229        tracker.record(30, "openai", 200, 100, Some(0.00003)).unwrap();
230        let trend = tracker.get_7day_trend(30).unwrap();
231        assert!(!trend.is_empty());
232        let total_tokens: u64 = trend.iter().map(|d| d.tokens).sum();
233        assert_eq!(total_tokens, 450);
234    }
235
236    #[test]
237    fn get_7day_trend_date_format_is_yyyy_mm_dd() {
238        let (_f, tracker) = make_tracker();
239        tracker.record(31, "openai", 10, 5, None).unwrap();
240        let trend = tracker.get_7day_trend(31).unwrap();
241        assert!(!trend.is_empty());
242        let date = &trend[0].date;
243        assert_eq!(date.len(), 10, "date should be YYYY-MM-DD: {date}");
244        assert_eq!(&date[4..5], "-");
245        assert_eq!(&date[7..8], "-");
246    }
247
248    #[test]
249    fn get_7day_trend_cost_matches_records() {
250        let (_f, tracker) = make_tracker();
251        let price = 0.00003_f64;
252        tracker.record(32, "openai", 100, 50, Some(price)).unwrap();
253        let trend = tracker.get_7day_trend(32).unwrap();
254        assert!(!trend.is_empty());
255        let total_cost: f64 = trend.iter().map(|d| d.cost).sum();
256        let expected = 150.0 * price;
257        assert!((total_cost - expected).abs() < 1e-9);
258    }
259}