1use crate::error::LlmError;
2use rusqlite::Connection;
3use std::sync::{Arc, Mutex};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6#[derive(Debug, Clone, Default)]
8pub struct UsageStats {
9 pub total_requests: u64,
10 pub total_tokens: u64,
11 pub total_cost: f64,
12 pub avg_duration_ms: f64,
13 pub total_cache_read_tokens: u64,
14 pub total_cache_write_tokens: u64,
15}
16
17pub struct TrackingDb {
19 conn: Arc<Mutex<Connection>>,
20}
21
22impl TrackingDb {
23 pub fn new() -> Result<Self, LlmError> {
25 let mut db_path = dirs::home_dir()
26 .ok_or_else(|| LlmError::PersistenceError("Home directory not found".to_string()))?;
27 db_path.push(".limit");
28
29 std::fs::create_dir_all(&db_path).map_err(|e| {
31 LlmError::PersistenceError(format!("Failed to create .limit directory: {}", e))
32 })?;
33
34 db_path.push("tracking.db");
35
36 Self::open(&db_path)
37 }
38
39 pub fn open(db_path: &std::path::Path) -> Result<Self, LlmError> {
41 let conn = Connection::open(db_path)
42 .map_err(|e| LlmError::PersistenceError(format!("Failed to open database: {}", e)))?;
43
44 conn.busy_timeout(std::time::Duration::from_secs(5))
46 .map_err(|e| {
47 LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
48 })?;
49
50 let db = Self {
51 conn: Arc::new(Mutex::new(conn)),
52 };
53
54 db.init_tables()?;
55
56 Ok(db)
57 }
58
59 pub fn new_in_memory() -> Result<Self, LlmError> {
61 let conn = Connection::open_in_memory().map_err(|e| {
62 LlmError::PersistenceError(format!("Failed to open in-memory database: {}", e))
63 })?;
64
65 let db = Self {
66 conn: Arc::new(Mutex::new(conn)),
67 };
68
69 db.init_tables()?;
70
71 Ok(db)
72 }
73
74 fn init_tables(&self) -> Result<(), LlmError> {
76 let conn = self
77 .conn
78 .lock()
79 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
80
81 conn.execute(
82 "CREATE TABLE IF NOT EXISTS requests (
83 id INTEGER PRIMARY KEY AUTOINCREMENT,
84 timestamp INTEGER NOT NULL,
85 model TEXT NOT NULL,
86 input_tokens INTEGER NOT NULL,
87 output_tokens INTEGER NOT NULL,
88 cost REAL NOT NULL,
89 duration_ms INTEGER NOT NULL,
90 cache_read_tokens INTEGER NOT NULL DEFAULT 0,
91 cache_write_tokens INTEGER NOT NULL DEFAULT 0
92 )",
93 [],
94 )
95 .map_err(|e| LlmError::PersistenceError(format!("Failed to create table: {}", e)))?;
96
97 conn.execute(
99 "CREATE INDEX IF NOT EXISTS idx_timestamp ON requests (timestamp)",
100 [],
101 )
102 .map_err(|e| LlmError::PersistenceError(format!("Failed to create index: {}", e)))?;
103
104 let _ = conn.execute(
106 "ALTER TABLE requests ADD COLUMN cache_read_tokens INTEGER NOT NULL DEFAULT 0",
107 [],
108 );
109 let _ = conn.execute(
110 "ALTER TABLE requests ADD COLUMN cache_write_tokens INTEGER NOT NULL DEFAULT 0",
111 [],
112 );
113
114 Ok(())
115 }
116
117 pub fn track_request(
119 &self,
120 model: &str,
121 input_tokens: u64,
122 output_tokens: u64,
123 cache_read_tokens: u64,
124 cache_write_tokens: u64,
125 cost: f64,
126 duration_ms: u64,
127 ) -> Result<(), LlmError> {
128 let timestamp = SystemTime::now()
129 .duration_since(UNIX_EPOCH)
130 .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
131 .as_secs() as i64;
132
133 let conn = self
134 .conn
135 .lock()
136 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
137
138 conn.execute(
139 "INSERT INTO requests (timestamp, model, input_tokens, output_tokens, cost, duration_ms, cache_read_tokens, cache_write_tokens)
140 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
141 rusqlite::params![timestamp, model, input_tokens as i64, output_tokens as i64, cost, duration_ms as i64, cache_read_tokens as i64, cache_write_tokens as i64],
142 )
143 .map_err(|e| LlmError::PersistenceError(format!("Failed to insert request: {}", e)))?;
144
145 Ok(())
146 }
147
148 pub fn get_usage_stats(&self, days: u32) -> Result<UsageStats, LlmError> {
150 let cutoff_timestamp = SystemTime::now()
151 .duration_since(UNIX_EPOCH)
152 .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
153 .as_secs() as i64
154 - (days as i64 * 86400);
155
156 let conn = self
157 .conn
158 .lock()
159 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
160
161 let mut stmt = conn
162 .prepare(
163 "SELECT
164 COUNT(*) as count,
165 SUM(input_tokens + output_tokens) as total_tokens,
166 SUM(cost) as total_cost,
167 AVG(duration_ms) as avg_duration,
168 COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
169 COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
170 FROM requests
171 WHERE timestamp >= ?1",
172 )
173 .map_err(|e| LlmError::PersistenceError(format!("Failed to prepare query: {}", e)))?;
174
175 let mut rows = stmt
176 .query([cutoff_timestamp])
177 .map_err(|e| LlmError::PersistenceError(format!("Failed to execute query: {}", e)))?;
178
179 if let Some(row) = rows
180 .next()
181 .map_err(|e| LlmError::PersistenceError(format!("Failed to read row: {}", e)))?
182 {
183 let total_requests: i64 = row.get(0).unwrap_or(0);
184 let total_tokens: i64 = row.get(1).unwrap_or(0);
185 let total_cost: f64 = row.get(2).unwrap_or(0.0);
186 let avg_duration_ms: f64 = row.get(3).unwrap_or(0.0);
187 let total_cache_read_tokens: i64 = row.get(4).unwrap_or(0);
188 let total_cache_write_tokens: i64 = row.get(5).unwrap_or(0);
189
190 Ok(UsageStats {
191 total_requests: total_requests.max(0) as u64,
192 total_tokens: total_tokens.max(0) as u64,
193 total_cost,
194 avg_duration_ms,
195 total_cache_read_tokens: total_cache_read_tokens.max(0) as u64,
196 total_cache_write_tokens: total_cache_write_tokens.max(0) as u64,
197 })
198 } else {
199 Ok(UsageStats::default())
200 }
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use tempfile::NamedTempFile;
208
209 fn create_test_db() -> Result<TrackingDb, LlmError> {
210 let temp_file = NamedTempFile::new().map_err(|e| {
211 LlmError::PersistenceError(format!("Failed to create temp file: {}", e))
212 })?;
213
214 let path = temp_file.path().to_path_buf();
215 drop(temp_file);
217
218 let conn = Connection::open(&path).map_err(|e| {
219 LlmError::PersistenceError(format!("Failed to open test database: {}", e))
220 })?;
221
222 conn.busy_timeout(std::time::Duration::from_secs(5))
223 .map_err(|e| {
224 LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
225 })?;
226
227 let db = TrackingDb {
228 conn: Arc::new(Mutex::new(conn)),
229 };
230
231 db.init_tables()?;
232
233 Ok(db)
234 }
235
236 #[test]
237 fn test_create_tables() {
238 let db = create_test_db().unwrap();
239
240 let result: rusqlite::Result<i64> =
242 db.conn
243 .lock()
244 .unwrap()
245 .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0));
246 assert!(result.is_ok());
247 }
248
249 #[test]
250 fn test_track_request() {
251 let db = create_test_db().unwrap();
252
253 db.track_request("claude-3-5-sonnet", 100, 50, 500, 100, 0.001, 1500)
254 .unwrap();
255 db.track_request("claude-3-opus", 200, 100, 0, 0, 0.005, 3000)
256 .unwrap();
257
258 let count: i64 = db
259 .conn
260 .lock()
261 .unwrap()
262 .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
263 .unwrap();
264 assert_eq!(count, 2);
265 }
266
267 #[test]
268 fn test_get_usage_stats() {
269 let db = create_test_db().unwrap();
270
271 db.track_request("model-1", 100, 50, 500, 100, 0.001, 1500)
272 .unwrap();
273 db.track_request("model-2", 200, 100, 0, 0, 0.005, 3000)
274 .unwrap();
275 db.track_request("model-3", 150, 75, 200, 50, 0.002, 2000)
276 .unwrap();
277
278 let stats = db.get_usage_stats(7).unwrap();
279
280 assert_eq!(stats.total_requests, 3);
281 assert_eq!(stats.total_tokens, 675);
282 assert!((stats.total_cost - 0.008).abs() < 0.0001);
283 assert!((stats.avg_duration_ms - 2166.6666666666665).abs() < 0.0001);
284 assert_eq!(stats.total_cache_read_tokens, 700);
285 assert_eq!(stats.total_cache_write_tokens, 150);
286 }
287
288 #[test]
289 fn test_empty_usage_stats() {
290 let db = create_test_db().unwrap();
291
292 let stats = db.get_usage_stats(7).unwrap();
293
294 assert_eq!(stats.total_requests, 0);
295 assert_eq!(stats.total_tokens, 0);
296 assert_eq!(stats.total_cost, 0.0);
297 assert_eq!(stats.avg_duration_ms, 0.0);
298 assert_eq!(stats.total_cache_read_tokens, 0);
299 assert_eq!(stats.total_cache_write_tokens, 0);
300 }
301
302 #[test]
303 fn test_concurrent_access() {
304 let db = create_test_db().unwrap();
305
306 let db_clone1 = TrackingDb {
307 conn: Arc::clone(&db.conn),
308 };
309 let db_clone2 = TrackingDb {
310 conn: Arc::clone(&db.conn),
311 };
312
313 let handle1 = std::thread::spawn(move || {
314 for i in 0..10 {
315 db_clone1
316 .track_request(
317 &format!("model-{}", i),
318 100 + i,
319 50 + i,
320 0,
321 0,
322 0.001,
323 1000 + i * 100,
324 )
325 .unwrap();
326 }
327 });
328
329 let handle2 = std::thread::spawn(move || {
330 for i in 0..10 {
331 db_clone2
332 .track_request(
333 &format!("model-{}", i + 10),
334 100 + i,
335 50 + i,
336 0,
337 0,
338 0.001,
339 1000 + i * 100,
340 )
341 .unwrap();
342 }
343 });
344
345 handle1.join().unwrap();
346 handle2.join().unwrap();
347
348 let stats = db.get_usage_stats(7).unwrap();
349 assert_eq!(stats.total_requests, 20);
350 }
351}