1use crate::error::LlmError;
2use rusqlite::Connection;
3use std::sync::{Arc, Mutex};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6#[derive(Debug, Clone, Default)]
8pub struct UsageStats {
9 pub total_requests: u64,
10 pub total_tokens: u64,
11 pub total_cost: f64,
12 pub avg_duration_ms: f64,
13}
14
15pub struct TrackingDb {
17 conn: Arc<Mutex<Connection>>,
18}
19
20impl TrackingDb {
21 pub fn new() -> Result<Self, LlmError> {
23 let mut db_path = dirs::home_dir()
24 .ok_or_else(|| LlmError::PersistenceError("Home directory not found".to_string()))?;
25 db_path.push(".limit");
26
27 std::fs::create_dir_all(&db_path).map_err(|e| {
29 LlmError::PersistenceError(format!("Failed to create .limit directory: {}", e))
30 })?;
31
32 db_path.push("tracking.db");
33
34 Self::open(&db_path)
35 }
36
37 pub fn open(db_path: &std::path::Path) -> Result<Self, LlmError> {
39 let conn = Connection::open(db_path)
40 .map_err(|e| LlmError::PersistenceError(format!("Failed to open database: {}", e)))?;
41
42 conn.busy_timeout(std::time::Duration::from_secs(5))
44 .map_err(|e| {
45 LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
46 })?;
47
48 let db = Self {
49 conn: Arc::new(Mutex::new(conn)),
50 };
51
52 db.init_tables()?;
53
54 Ok(db)
55 }
56
57 pub fn new_in_memory() -> Result<Self, LlmError> {
59 let conn = Connection::open_in_memory().map_err(|e| {
60 LlmError::PersistenceError(format!("Failed to open in-memory database: {}", e))
61 })?;
62
63 let db = Self {
64 conn: Arc::new(Mutex::new(conn)),
65 };
66
67 db.init_tables()?;
68
69 Ok(db)
70 }
71
72 fn init_tables(&self) -> Result<(), LlmError> {
74 let conn = self
75 .conn
76 .lock()
77 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
78
79 conn.execute(
80 "CREATE TABLE IF NOT EXISTS requests (
81 id INTEGER PRIMARY KEY AUTOINCREMENT,
82 timestamp INTEGER NOT NULL,
83 model TEXT NOT NULL,
84 input_tokens INTEGER NOT NULL,
85 output_tokens INTEGER NOT NULL,
86 cost REAL NOT NULL,
87 duration_ms INTEGER NOT NULL
88 )",
89 [],
90 )
91 .map_err(|e| LlmError::PersistenceError(format!("Failed to create table: {}", e)))?;
92
93 conn.execute(
95 "CREATE INDEX IF NOT EXISTS idx_timestamp ON requests (timestamp)",
96 [],
97 )
98 .map_err(|e| LlmError::PersistenceError(format!("Failed to create index: {}", e)))?;
99
100 Ok(())
101 }
102
103 pub fn track_request(
105 &self,
106 model: &str,
107 input_tokens: u64,
108 output_tokens: u64,
109 cost: f64,
110 duration_ms: u64,
111 ) -> Result<(), LlmError> {
112 let timestamp = SystemTime::now()
113 .duration_since(UNIX_EPOCH)
114 .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
115 .as_secs() as i64;
116
117 let conn = self
118 .conn
119 .lock()
120 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
121
122 conn.execute(
123 "INSERT INTO requests (timestamp, model, input_tokens, output_tokens, cost, duration_ms)
124 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
125 rusqlite::params![timestamp, model, input_tokens as i64, output_tokens as i64, cost, duration_ms as i64],
126 )
127 .map_err(|e| LlmError::PersistenceError(format!("Failed to insert request: {}", e)))?;
128
129 Ok(())
130 }
131
132 pub fn get_usage_stats(&self, days: u32) -> Result<UsageStats, LlmError> {
134 let cutoff_timestamp = SystemTime::now()
135 .duration_since(UNIX_EPOCH)
136 .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
137 .as_secs() as i64
138 - (days as i64 * 86400);
139
140 let conn = self
141 .conn
142 .lock()
143 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
144
145 let mut stmt = conn
146 .prepare(
147 "SELECT
148 COUNT(*) as count,
149 SUM(input_tokens + output_tokens) as total_tokens,
150 SUM(cost) as total_cost,
151 AVG(duration_ms) as avg_duration
152 FROM requests
153 WHERE timestamp >= ?1",
154 )
155 .map_err(|e| LlmError::PersistenceError(format!("Failed to prepare query: {}", e)))?;
156
157 let mut rows = stmt
158 .query([cutoff_timestamp])
159 .map_err(|e| LlmError::PersistenceError(format!("Failed to execute query: {}", e)))?;
160
161 if let Some(row) = rows
162 .next()
163 .map_err(|e| LlmError::PersistenceError(format!("Failed to read row: {}", e)))?
164 {
165 let total_requests: i64 = row.get(0).unwrap_or(0);
166 let total_tokens: i64 = row.get(1).unwrap_or(0);
167 let total_cost: f64 = row.get(2).unwrap_or(0.0);
168 let avg_duration_ms: f64 = row.get(3).unwrap_or(0.0);
169
170 Ok(UsageStats {
171 total_requests: total_requests.max(0) as u64,
172 total_tokens: total_tokens.max(0) as u64,
173 total_cost,
174 avg_duration_ms,
175 })
176 } else {
177 Ok(UsageStats::default())
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use tempfile::NamedTempFile;
186
187 fn create_test_db() -> Result<TrackingDb, LlmError> {
188 let temp_file = NamedTempFile::new().map_err(|e| {
189 LlmError::PersistenceError(format!("Failed to create temp file: {}", e))
190 })?;
191
192 let path = temp_file.path().to_path_buf();
193 drop(temp_file);
195
196 let conn = Connection::open(&path).map_err(|e| {
197 LlmError::PersistenceError(format!("Failed to open test database: {}", e))
198 })?;
199
200 conn.busy_timeout(std::time::Duration::from_secs(5))
201 .map_err(|e| {
202 LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
203 })?;
204
205 let db = TrackingDb {
206 conn: Arc::new(Mutex::new(conn)),
207 };
208
209 db.init_tables()?;
210
211 Ok(db)
212 }
213
214 #[test]
215 fn test_create_tables() {
216 let db = create_test_db().unwrap();
217
218 let result: rusqlite::Result<i64> =
220 db.conn
221 .lock()
222 .unwrap()
223 .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0));
224 assert!(result.is_ok());
225 }
226
227 #[test]
228 fn test_track_request() {
229 let db = create_test_db().unwrap();
230
231 db.track_request("claude-3-5-sonnet", 100, 50, 0.001, 1500)
232 .unwrap();
233 db.track_request("claude-3-opus", 200, 100, 0.005, 3000)
234 .unwrap();
235
236 let count: i64 = db
237 .conn
238 .lock()
239 .unwrap()
240 .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
241 .unwrap();
242 assert_eq!(count, 2);
243 }
244
245 #[test]
246 fn test_get_usage_stats() {
247 let db = create_test_db().unwrap();
248
249 db.track_request("model-1", 100, 50, 0.001, 1500).unwrap();
250 db.track_request("model-2", 200, 100, 0.005, 3000).unwrap();
251 db.track_request("model-3", 150, 75, 0.002, 2000).unwrap();
252
253 let stats = db.get_usage_stats(7).unwrap();
254
255 assert_eq!(stats.total_requests, 3);
256 assert_eq!(stats.total_tokens, 675); assert!((stats.total_cost - 0.008).abs() < 0.0001);
258 assert!((stats.avg_duration_ms - 2166.6666666666665).abs() < 0.0001);
259 }
260
261 #[test]
262 fn test_empty_usage_stats() {
263 let db = create_test_db().unwrap();
264
265 let stats = db.get_usage_stats(7).unwrap();
266
267 assert_eq!(stats.total_requests, 0);
268 assert_eq!(stats.total_tokens, 0);
269 assert_eq!(stats.total_cost, 0.0);
270 assert_eq!(stats.avg_duration_ms, 0.0);
271 }
272
273 #[test]
274 fn test_concurrent_access() {
275 let db = create_test_db().unwrap();
276
277 let db_clone1 = TrackingDb {
278 conn: Arc::clone(&db.conn),
279 };
280 let db_clone2 = TrackingDb {
281 conn: Arc::clone(&db.conn),
282 };
283
284 let handle1 = std::thread::spawn(move || {
285 for i in 0..10 {
286 db_clone1
287 .track_request(
288 &format!("model-{}", i),
289 100 + i,
290 50 + i,
291 0.001,
292 1000 + i * 100,
293 )
294 .unwrap();
295 }
296 });
297
298 let handle2 = std::thread::spawn(move || {
299 for i in 0..10 {
300 db_clone2
301 .track_request(
302 &format!("model-{}", i + 10),
303 100 + i,
304 50 + i,
305 0.001,
306 1000 + i * 100,
307 )
308 .unwrap();
309 }
310 });
311
312 handle1.join().unwrap();
313 handle2.join().unwrap();
314
315 let stats = db.get_usage_stats(7).unwrap();
316 assert_eq!(stats.total_requests, 20);
317 }
318}