1use rusqlite::{Connection, params};
2use thiserror::Error;
3
4#[derive(Debug, Error)]
5pub enum TokenError {
6 #[error("database error: {0}")]
7 Database(String),
8}
9
10impl From<rusqlite::Error> for TokenError {
11 fn from(e: rusqlite::Error) -> Self {
12 TokenError::Database(e.to_string())
13 }
14}
15
16pub struct AgentTokenStats {
17 pub agent_pid: u32,
18 pub total_prompt_tokens: u64,
19 pub total_completion_tokens: u64,
20 pub total_tokens: u64,
21 pub estimated_cost: f64,
22}
23
24pub struct DailyTokenUsage {
25 pub date: String, pub tokens: u64,
27 pub cost: f64,
28}
29
30pub struct TokenTracker {
31 db: Connection,
32}
33
34impl TokenTracker {
35 pub fn new(db: Connection) -> Self {
36 Self { db }
37 }
38
39 pub fn record(
40 &self,
41 agent_pid: u32,
42 provider: &str,
43 prompt_tokens: u32,
44 completion_tokens: u32,
45 pricing: Option<f64>,
46 ) -> Result<(), TokenError> {
47 let estimated_cost = pricing.map(|price| (prompt_tokens + completion_tokens) as f64 * price);
48 self.db.execute(
49 "INSERT INTO token_usage (agent_pid, timestamp, provider, prompt_tokens, completion_tokens, estimated_cost) \
50 VALUES (?1, datetime('now'), ?2, ?3, ?4, ?5)",
51 params![agent_pid, provider, prompt_tokens, completion_tokens, estimated_cost],
52 )?;
53 Ok(())
54 }
55
56 pub fn get_agent_stats(&self, agent_pid: u32) -> Result<AgentTokenStats, TokenError> {
57 let (prompt, completion, cost) = self.db.query_row(
58 "SELECT COALESCE(SUM(prompt_tokens), 0), COALESCE(SUM(completion_tokens), 0), COALESCE(SUM(estimated_cost), 0.0) \
59 FROM token_usage WHERE agent_pid = ?1",
60 params![agent_pid],
61 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?, row.get::<_, f64>(2)?)),
62 )?;
63 let total_prompt = prompt as u64;
64 let total_completion = completion as u64;
65 Ok(AgentTokenStats {
66 agent_pid,
67 total_prompt_tokens: total_prompt,
68 total_completion_tokens: total_completion,
69 total_tokens: total_prompt + total_completion,
70 estimated_cost: cost,
71 })
72 }
73
74 pub fn get_all_stats(&self) -> Result<Vec<(u32, AgentTokenStats)>, TokenError> {
75 let mut stmt = self.db.prepare(
76 "SELECT agent_pid, COALESCE(SUM(prompt_tokens), 0), COALESCE(SUM(completion_tokens), 0), COALESCE(SUM(estimated_cost), 0.0) \
77 FROM token_usage GROUP BY agent_pid ORDER BY agent_pid",
78 )?;
79 let rows = stmt.query_map([], |row| {
80 let pid = row.get::<_, i64>(0)? as u32;
81 let prompt = row.get::<_, i64>(1)? as u64;
82 let completion = row.get::<_, i64>(2)? as u64;
83 let cost = row.get::<_, f64>(3)?;
84 Ok((pid, prompt, completion, cost))
85 })?;
86 let mut result = Vec::new();
87 for row in rows {
88 let (pid, prompt, completion, cost) = row?;
89 result.push((pid, AgentTokenStats {
90 agent_pid: pid,
91 total_prompt_tokens: prompt,
92 total_completion_tokens: completion,
93 total_tokens: prompt + completion,
94 estimated_cost: cost,
95 }));
96 }
97 Ok(result)
98 }
99
100 pub fn get_7day_trend(&self, agent_pid: u32) -> Result<Vec<DailyTokenUsage>, TokenError> {
101 let mut stmt = self.db.prepare(
102 "SELECT date(timestamp) AS day, COALESCE(SUM(prompt_tokens + completion_tokens), 0), COALESCE(SUM(estimated_cost), 0.0) \
103 FROM token_usage WHERE agent_pid = ?1 AND timestamp >= datetime('now', '-7 days') \
104 GROUP BY day ORDER BY day ASC",
105 )?;
106 let rows = stmt.query_map(params![agent_pid], |row| {
107 Ok(DailyTokenUsage {
108 date: row.get(0)?,
109 tokens: row.get::<_, i64>(1)? as u64,
110 cost: row.get(2)?,
111 })
112 })?;
113 let mut trend = Vec::new();
114 for row in rows {
115 trend.push(row?);
116 }
117 Ok(trend)
118 }
119
120 pub fn check_budget(&self, agent_pid: u32, budget_tokens: u64) -> Result<bool, TokenError> {
121 let total: i64 = self.db.query_row(
122 "SELECT COALESCE(SUM(prompt_tokens + completion_tokens), 0) FROM token_usage WHERE agent_pid = ?1",
123 params![agent_pid],
124 |row| row.get(0),
125 )?;
126 Ok(total as u64 > budget_tokens)
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::db::init_database;
134 use tempfile::NamedTempFile;
135
136 fn make_tracker() -> (NamedTempFile, TokenTracker) {
137 let f = NamedTempFile::new().unwrap();
138 let conn = init_database(f.path()).unwrap();
139 (f, TokenTracker::new(conn))
140 }
141
142 #[test]
143 fn record_stores_correct_values() {
144 let (_f, tracker) = make_tracker();
145 tracker.record(1, "openai", 100, 50, Some(0.00003)).unwrap();
146 let stats = tracker.get_agent_stats(1).unwrap();
147 assert_eq!(stats.total_prompt_tokens, 100);
148 assert_eq!(stats.total_completion_tokens, 50);
149 assert_eq!(stats.total_tokens, 150);
150 }
151
152 #[test]
153 fn record_without_pricing_stores_zero_cost() {
154 let (_f, tracker) = make_tracker();
155 tracker.record(2, "ollama", 200, 100, None).unwrap();
156 let stats = tracker.get_agent_stats(2).unwrap();
157 assert_eq!(stats.total_tokens, 300);
158 assert!((stats.estimated_cost - 0.0).abs() < f64::EPSILON);
159 }
160
161 #[test]
162 fn multiple_records_accumulate() {
163 let (_f, tracker) = make_tracker();
164 tracker.record(3, "openai", 100, 50, Some(0.00003)).unwrap();
165 tracker.record(3, "openai", 200, 100, Some(0.00003)).unwrap();
166 let stats = tracker.get_agent_stats(3).unwrap();
167 assert_eq!(stats.total_tokens, 450);
168 }
169
170 #[test]
171 fn cost_calculation_uses_pricing() {
172 let (_f, tracker) = make_tracker();
173 let price = 0.00003_f64;
174 tracker.record(4, "openai", 1000, 500, Some(price)).unwrap();
175 let stats = tracker.get_agent_stats(4).unwrap();
176 let expected = 1500.0 * price;
177 assert!((stats.estimated_cost - expected).abs() < 1e-9);
178 }
179
180 #[test]
181 fn get_all_stats_returns_per_agent() {
182 let (_f, tracker) = make_tracker();
183 tracker.record(10, "openai", 100, 50, Some(0.00003)).unwrap();
184 tracker.record(11, "ollama", 200, 100, None).unwrap();
185 let all = tracker.get_all_stats().unwrap();
186 assert_eq!(all.len(), 2);
187 let pids: Vec<u32> = all.iter().map(|(p, _)| *p).collect();
188 assert!(pids.contains(&10));
189 assert!(pids.contains(&11));
190 }
191
192 #[test]
193 fn check_budget_returns_false_when_under() {
194 let (_f, tracker) = make_tracker();
195 tracker.record(20, "openai", 100, 50, None).unwrap();
196 assert!(!tracker.check_budget(20, 1000).unwrap());
197 }
198
199 #[test]
200 fn check_budget_returns_true_when_over() {
201 let (_f, tracker) = make_tracker();
202 tracker.record(21, "openai", 600, 500, None).unwrap();
203 assert!(tracker.check_budget(21, 1000).unwrap());
204 }
205
206 #[test]
207 fn check_budget_exact_limit_not_exceeded() {
208 let (_f, tracker) = make_tracker();
209 tracker.record(22, "openai", 500, 500, None).unwrap();
210 assert!(!tracker.check_budget(22, 1000).unwrap());
211 }
212
213 #[test]
214 fn check_budget_no_records_not_exceeded() {
215 let (_f, tracker) = make_tracker();
216 assert!(!tracker.check_budget(99, 100).unwrap());
217 }
218
219 #[test]
220 fn get_7day_trend_returns_empty_for_no_records() {
221 let (_f, tracker) = make_tracker();
222 assert!(tracker.get_7day_trend(50).unwrap().is_empty());
223 }
224
225 #[test]
226 fn get_7day_trend_includes_recent_records() {
227 let (_f, tracker) = make_tracker();
228 tracker.record(30, "openai", 100, 50, Some(0.00003)).unwrap();
229 tracker.record(30, "openai", 200, 100, Some(0.00003)).unwrap();
230 let trend = tracker.get_7day_trend(30).unwrap();
231 assert!(!trend.is_empty());
232 let total_tokens: u64 = trend.iter().map(|d| d.tokens).sum();
233 assert_eq!(total_tokens, 450);
234 }
235
236 #[test]
237 fn get_7day_trend_date_format_is_yyyy_mm_dd() {
238 let (_f, tracker) = make_tracker();
239 tracker.record(31, "openai", 10, 5, None).unwrap();
240 let trend = tracker.get_7day_trend(31).unwrap();
241 assert!(!trend.is_empty());
242 let date = &trend[0].date;
243 assert_eq!(date.len(), 10, "date should be YYYY-MM-DD: {date}");
244 assert_eq!(&date[4..5], "-");
245 assert_eq!(&date[7..8], "-");
246 }
247
248 #[test]
249 fn get_7day_trend_cost_matches_records() {
250 let (_f, tracker) = make_tracker();
251 let price = 0.00003_f64;
252 tracker.record(32, "openai", 100, 50, Some(price)).unwrap();
253 let trend = tracker.get_7day_trend(32).unwrap();
254 assert!(!trend.is_empty());
255 let total_cost: f64 = trend.iter().map(|d| d.cost).sum();
256 let expected = 150.0 * price;
257 assert!((total_cost - expected).abs() < 1e-9);
258 }
259}