1use crate::error::LlmError;
2use rusqlite::Connection;
3use std::sync::{Arc, Mutex};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6#[derive(Debug, Clone, Default)]
8pub struct UsageStats {
9 pub total_requests: u64,
10 pub total_tokens: u64,
11 pub total_cost: f64,
12 pub avg_duration_ms: f64,
13 pub total_cache_read_tokens: u64,
14 pub total_cache_write_tokens: u64,
15}
16
17pub struct TrackingDb {
19 conn: Arc<Mutex<Connection>>,
20}
21
22impl TrackingDb {
23 pub fn new() -> Result<Self, LlmError> {
25 let mut db_path = dirs::home_dir()
26 .ok_or_else(|| LlmError::PersistenceError("Home directory not found".to_string()))?;
27 db_path.push(".limit");
28
29 std::fs::create_dir_all(&db_path).map_err(|e| {
31 LlmError::PersistenceError(format!("Failed to create .limit directory: {}", e))
32 })?;
33
34 db_path.push("tracking.db");
35
36 Self::open(&db_path)
37 }
38
39 pub fn open(db_path: &std::path::Path) -> Result<Self, LlmError> {
41 let conn = Connection::open(db_path)
42 .map_err(|e| LlmError::PersistenceError(format!("Failed to open database: {}", e)))?;
43
44 conn.busy_timeout(std::time::Duration::from_secs(5))
46 .map_err(|e| {
47 LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
48 })?;
49
50 let db = Self {
51 conn: Arc::new(Mutex::new(conn)),
52 };
53
54 db.init_tables()?;
55
56 Ok(db)
57 }
58
59 pub fn new_in_memory() -> Result<Self, LlmError> {
61 let conn = Connection::open_in_memory().map_err(|e| {
62 LlmError::PersistenceError(format!("Failed to open in-memory database: {}", e))
63 })?;
64
65 let db = Self {
66 conn: Arc::new(Mutex::new(conn)),
67 };
68
69 db.init_tables()?;
70
71 Ok(db)
72 }
73
74 fn init_tables(&self) -> Result<(), LlmError> {
76 let conn = self
77 .conn
78 .lock()
79 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
80
81 conn.execute(
82 "CREATE TABLE IF NOT EXISTS requests (
83 id INTEGER PRIMARY KEY AUTOINCREMENT,
84 timestamp INTEGER NOT NULL,
85 model TEXT NOT NULL,
86 input_tokens INTEGER NOT NULL,
87 output_tokens INTEGER NOT NULL,
88 cost REAL NOT NULL,
89 duration_ms INTEGER NOT NULL,
90 cache_read_tokens INTEGER NOT NULL DEFAULT 0,
91 cache_write_tokens INTEGER NOT NULL DEFAULT 0
92 )",
93 [],
94 )
95 .map_err(|e| LlmError::PersistenceError(format!("Failed to create table: {}", e)))?;
96
97 conn.execute(
99 "CREATE INDEX IF NOT EXISTS idx_timestamp ON requests (timestamp)",
100 [],
101 )
102 .map_err(|e| LlmError::PersistenceError(format!("Failed to create index: {}", e)))?;
103
104 let _ = conn.execute(
106 "ALTER TABLE requests ADD COLUMN cache_read_tokens INTEGER NOT NULL DEFAULT 0",
107 [],
108 );
109 let _ = conn.execute(
110 "ALTER TABLE requests ADD COLUMN cache_write_tokens INTEGER NOT NULL DEFAULT 0",
111 [],
112 );
113
114 Ok(())
115 }
116
117 #[allow(clippy::too_many_arguments)]
119 pub fn track_request(
120 &self,
121 model: &str,
122 input_tokens: u64,
123 output_tokens: u64,
124 cache_read_tokens: u64,
125 cache_write_tokens: u64,
126 cost: f64,
127 duration_ms: u64,
128 ) -> Result<(), LlmError> {
129 let timestamp = SystemTime::now()
130 .duration_since(UNIX_EPOCH)
131 .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
132 .as_secs() as i64;
133
134 let conn = self
135 .conn
136 .lock()
137 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
138
139 conn.execute(
140 "INSERT INTO requests (timestamp, model, input_tokens, output_tokens, cost, duration_ms, cache_read_tokens, cache_write_tokens)
141 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
142 rusqlite::params![timestamp, model, input_tokens as i64, output_tokens as i64, cost, duration_ms as i64, cache_read_tokens as i64, cache_write_tokens as i64],
143 )
144 .map_err(|e| LlmError::PersistenceError(format!("Failed to insert request: {}", e)))?;
145
146 Ok(())
147 }
148
149 pub fn get_usage_stats(&self, days: u32) -> Result<UsageStats, LlmError> {
151 let cutoff_timestamp = SystemTime::now()
152 .duration_since(UNIX_EPOCH)
153 .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
154 .as_secs() as i64
155 - (days as i64 * 86400);
156
157 let conn = self
158 .conn
159 .lock()
160 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
161
162 let mut stmt = conn
163 .prepare(
164 "SELECT
165 COUNT(*) as count,
166 SUM(input_tokens + output_tokens) as total_tokens,
167 SUM(cost) as total_cost,
168 AVG(duration_ms) as avg_duration,
169 COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
170 COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
171 FROM requests
172 WHERE timestamp >= ?1",
173 )
174 .map_err(|e| LlmError::PersistenceError(format!("Failed to prepare query: {}", e)))?;
175
176 let mut rows = stmt
177 .query([cutoff_timestamp])
178 .map_err(|e| LlmError::PersistenceError(format!("Failed to execute query: {}", e)))?;
179
180 if let Some(row) = rows
181 .next()
182 .map_err(|e| LlmError::PersistenceError(format!("Failed to read row: {}", e)))?
183 {
184 let total_requests: i64 = row.get(0).unwrap_or(0);
185 let total_tokens: i64 = row.get(1).unwrap_or(0);
186 let total_cost: f64 = row.get(2).unwrap_or(0.0);
187 let avg_duration_ms: f64 = row.get(3).unwrap_or(0.0);
188 let total_cache_read_tokens: i64 = row.get(4).unwrap_or(0);
189 let total_cache_write_tokens: i64 = row.get(5).unwrap_or(0);
190
191 Ok(UsageStats {
192 total_requests: total_requests.max(0) as u64,
193 total_tokens: total_tokens.max(0) as u64,
194 total_cost,
195 avg_duration_ms,
196 total_cache_read_tokens: total_cache_read_tokens.max(0) as u64,
197 total_cache_write_tokens: total_cache_write_tokens.max(0) as u64,
198 })
199 } else {
200 Ok(UsageStats::default())
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use tempfile::NamedTempFile;
209
210 fn create_test_db() -> Result<TrackingDb, LlmError> {
211 let temp_file = NamedTempFile::new().map_err(|e| {
212 LlmError::PersistenceError(format!("Failed to create temp file: {}", e))
213 })?;
214
215 let path = temp_file.path().to_path_buf();
216 drop(temp_file);
218
219 let conn = Connection::open(&path).map_err(|e| {
220 LlmError::PersistenceError(format!("Failed to open test database: {}", e))
221 })?;
222
223 conn.busy_timeout(std::time::Duration::from_secs(5))
224 .map_err(|e| {
225 LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
226 })?;
227
228 let db = TrackingDb {
229 conn: Arc::new(Mutex::new(conn)),
230 };
231
232 db.init_tables()?;
233
234 Ok(db)
235 }
236
237 #[test]
238 fn test_create_tables() {
239 let db = create_test_db().unwrap();
240
241 let result: rusqlite::Result<i64> =
243 db.conn
244 .lock()
245 .unwrap()
246 .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0));
247 assert!(result.is_ok());
248 }
249
250 #[test]
251 fn test_track_request() {
252 let db = create_test_db().unwrap();
253
254 db.track_request("claude-3-5-sonnet", 100, 50, 500, 100, 0.001, 1500)
255 .unwrap();
256 db.track_request("claude-3-opus", 200, 100, 0, 0, 0.005, 3000)
257 .unwrap();
258
259 let count: i64 = db
260 .conn
261 .lock()
262 .unwrap()
263 .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
264 .unwrap();
265 assert_eq!(count, 2);
266 }
267
268 #[test]
269 fn test_get_usage_stats() {
270 let db = create_test_db().unwrap();
271
272 db.track_request("model-1", 100, 50, 500, 100, 0.001, 1500)
273 .unwrap();
274 db.track_request("model-2", 200, 100, 0, 0, 0.005, 3000)
275 .unwrap();
276 db.track_request("model-3", 150, 75, 200, 50, 0.002, 2000)
277 .unwrap();
278
279 let stats = db.get_usage_stats(7).unwrap();
280
281 assert_eq!(stats.total_requests, 3);
282 assert_eq!(stats.total_tokens, 675);
283 assert!((stats.total_cost - 0.008).abs() < 0.0001);
284 assert!((stats.avg_duration_ms - 2166.6666666666665).abs() < 0.0001);
285 assert_eq!(stats.total_cache_read_tokens, 700);
286 assert_eq!(stats.total_cache_write_tokens, 150);
287 }
288
289 #[test]
290 fn test_empty_usage_stats() {
291 let db = create_test_db().unwrap();
292
293 let stats = db.get_usage_stats(7).unwrap();
294
295 assert_eq!(stats.total_requests, 0);
296 assert_eq!(stats.total_tokens, 0);
297 assert_eq!(stats.total_cost, 0.0);
298 assert_eq!(stats.avg_duration_ms, 0.0);
299 assert_eq!(stats.total_cache_read_tokens, 0);
300 assert_eq!(stats.total_cache_write_tokens, 0);
301 }
302
303 #[test]
304 fn test_concurrent_access() {
305 let db = create_test_db().unwrap();
306
307 let db_clone1 = TrackingDb {
308 conn: Arc::clone(&db.conn),
309 };
310 let db_clone2 = TrackingDb {
311 conn: Arc::clone(&db.conn),
312 };
313
314 let handle1 = std::thread::spawn(move || {
315 for i in 0..10 {
316 db_clone1
317 .track_request(
318 &format!("model-{}", i),
319 100 + i,
320 50 + i,
321 0,
322 0,
323 0.001,
324 1000 + i * 100,
325 )
326 .unwrap();
327 }
328 });
329
330 let handle2 = std::thread::spawn(move || {
331 for i in 0..10 {
332 db_clone2
333 .track_request(
334 &format!("model-{}", i + 10),
335 100 + i,
336 50 + i,
337 0,
338 0,
339 0.001,
340 1000 + i * 100,
341 )
342 .unwrap();
343 }
344 });
345
346 handle1.join().unwrap();
347 handle2.join().unwrap();
348
349 let stats = db.get_usage_stats(7).unwrap();
350 assert_eq!(stats.total_requests, 20);
351 }
352}