Skip to main content

limit_llm/
tracking.rs

1use crate::error::LlmError;
2use rusqlite::Connection;
3use std::sync::{Arc, Mutex};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6/// Usage statistics for LLM requests
7#[derive(Debug, Clone, Default)]
8pub struct UsageStats {
9    pub total_requests: u64,
10    pub total_tokens: u64,
11    pub total_cost: f64,
12    pub avg_duration_ms: f64,
13}
14
15/// SQLite database for tracking LLM usage
16pub struct TrackingDb {
17    conn: Arc<Mutex<Connection>>,
18}
19
20impl TrackingDb {
21    /// Initialize tracking database at ~/.limit/tracking.db
22    pub fn new() -> Result<Self, LlmError> {
23        let mut db_path = dirs::home_dir()
24            .ok_or_else(|| LlmError::PersistenceError("Home directory not found".to_string()))?;
25        db_path.push(".limit");
26
27        // Create .limit directory if it doesn't exist
28        std::fs::create_dir_all(&db_path).map_err(|e| {
29            LlmError::PersistenceError(format!("Failed to create .limit directory: {}", e))
30        })?;
31
32        db_path.push("tracking.db");
33
34        Self::open(&db_path)
35    }
36
37    /// Open tracking database at a specific path (for testing)
38    pub fn open(db_path: &std::path::Path) -> Result<Self, LlmError> {
39        let conn = Connection::open(db_path)
40            .map_err(|e| LlmError::PersistenceError(format!("Failed to open database: {}", e)))?;
41
42        // Set busy timeout for concurrent access (5 seconds)
43        conn.busy_timeout(std::time::Duration::from_secs(5))
44            .map_err(|e| {
45                LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
46            })?;
47
48        let db = Self {
49            conn: Arc::new(Mutex::new(conn)),
50        };
51
52        db.init_tables()?;
53
54        Ok(db)
55    }
56
57    /// Create an in-memory database (for testing)
58    pub fn new_in_memory() -> Result<Self, LlmError> {
59        let conn = Connection::open_in_memory().map_err(|e| {
60            LlmError::PersistenceError(format!("Failed to open in-memory database: {}", e))
61        })?;
62
63        let db = Self {
64            conn: Arc::new(Mutex::new(conn)),
65        };
66
67        db.init_tables()?;
68
69        Ok(db)
70    }
71
72    /// Create tables if they don't exist
73    fn init_tables(&self) -> Result<(), LlmError> {
74        let conn = self
75            .conn
76            .lock()
77            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
78
79        conn.execute(
80            "CREATE TABLE IF NOT EXISTS requests (
81                id INTEGER PRIMARY KEY AUTOINCREMENT,
82                timestamp INTEGER NOT NULL,
83                model TEXT NOT NULL,
84                input_tokens INTEGER NOT NULL,
85                output_tokens INTEGER NOT NULL,
86                cost REAL NOT NULL,
87                duration_ms INTEGER NOT NULL
88            )",
89            [],
90        )
91        .map_err(|e| LlmError::PersistenceError(format!("Failed to create table: {}", e)))?;
92
93        // Create index on timestamp for faster queries
94        conn.execute(
95            "CREATE INDEX IF NOT EXISTS idx_timestamp ON requests (timestamp)",
96            [],
97        )
98        .map_err(|e| LlmError::PersistenceError(format!("Failed to create index: {}", e)))?;
99
100        Ok(())
101    }
102
103    /// Track a new LLM request
104    pub fn track_request(
105        &self,
106        model: &str,
107        input_tokens: u64,
108        output_tokens: u64,
109        cost: f64,
110        duration_ms: u64,
111    ) -> Result<(), LlmError> {
112        let timestamp = SystemTime::now()
113            .duration_since(UNIX_EPOCH)
114            .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
115            .as_secs() as i64;
116
117        let conn = self
118            .conn
119            .lock()
120            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
121
122        conn.execute(
123            "INSERT INTO requests (timestamp, model, input_tokens, output_tokens, cost, duration_ms)
124             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
125            rusqlite::params![timestamp, model, input_tokens as i64, output_tokens as i64, cost, duration_ms as i64],
126        )
127        .map_err(|e| LlmError::PersistenceError(format!("Failed to insert request: {}", e)))?;
128
129        Ok(())
130    }
131
132    /// Get usage statistics for the last N days
133    pub fn get_usage_stats(&self, days: u32) -> Result<UsageStats, LlmError> {
134        let cutoff_timestamp = SystemTime::now()
135            .duration_since(UNIX_EPOCH)
136            .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
137            .as_secs() as i64
138            - (days as i64 * 86400);
139
140        let conn = self
141            .conn
142            .lock()
143            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
144
145        let mut stmt = conn
146            .prepare(
147                "SELECT 
148                COUNT(*) as count,
149                SUM(input_tokens + output_tokens) as total_tokens,
150                SUM(cost) as total_cost,
151                AVG(duration_ms) as avg_duration
152             FROM requests
153             WHERE timestamp >= ?1",
154            )
155            .map_err(|e| LlmError::PersistenceError(format!("Failed to prepare query: {}", e)))?;
156
157        let mut rows = stmt
158            .query([cutoff_timestamp])
159            .map_err(|e| LlmError::PersistenceError(format!("Failed to execute query: {}", e)))?;
160
161        if let Some(row) = rows
162            .next()
163            .map_err(|e| LlmError::PersistenceError(format!("Failed to read row: {}", e)))?
164        {
165            let total_requests: i64 = row.get(0).unwrap_or(0);
166            let total_tokens: i64 = row.get(1).unwrap_or(0);
167            let total_cost: f64 = row.get(2).unwrap_or(0.0);
168            let avg_duration_ms: f64 = row.get(3).unwrap_or(0.0);
169
170            Ok(UsageStats {
171                total_requests: total_requests.max(0) as u64,
172                total_tokens: total_tokens.max(0) as u64,
173                total_cost,
174                avg_duration_ms,
175            })
176        } else {
177            Ok(UsageStats::default())
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use tempfile::NamedTempFile;
186
187    fn create_test_db() -> Result<TrackingDb, LlmError> {
188        let temp_file = NamedTempFile::new().map_err(|e| {
189            LlmError::PersistenceError(format!("Failed to create temp file: {}", e))
190        })?;
191
192        let path = temp_file.path().to_path_buf();
193        // Drop temp file to free the path for SQLite
194        drop(temp_file);
195
196        let conn = Connection::open(&path).map_err(|e| {
197            LlmError::PersistenceError(format!("Failed to open test database: {}", e))
198        })?;
199
200        conn.busy_timeout(std::time::Duration::from_secs(5))
201            .map_err(|e| {
202                LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
203            })?;
204
205        let db = TrackingDb {
206            conn: Arc::new(Mutex::new(conn)),
207        };
208
209        db.init_tables()?;
210
211        Ok(db)
212    }
213
214    #[test]
215    fn test_create_tables() {
216        let db = create_test_db().unwrap();
217
218        // Verify table exists by querying it
219        let result: rusqlite::Result<i64> =
220            db.conn
221                .lock()
222                .unwrap()
223                .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0));
224        assert!(result.is_ok());
225    }
226
227    #[test]
228    fn test_track_request() {
229        let db = create_test_db().unwrap();
230
231        db.track_request("claude-3-5-sonnet", 100, 50, 0.001, 1500)
232            .unwrap();
233        db.track_request("claude-3-opus", 200, 100, 0.005, 3000)
234            .unwrap();
235
236        let count: i64 = db
237            .conn
238            .lock()
239            .unwrap()
240            .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
241            .unwrap();
242        assert_eq!(count, 2);
243    }
244
245    #[test]
246    fn test_get_usage_stats() {
247        let db = create_test_db().unwrap();
248
249        db.track_request("model-1", 100, 50, 0.001, 1500).unwrap();
250        db.track_request("model-2", 200, 100, 0.005, 3000).unwrap();
251        db.track_request("model-3", 150, 75, 0.002, 2000).unwrap();
252
253        let stats = db.get_usage_stats(7).unwrap();
254
255        assert_eq!(stats.total_requests, 3);
256        assert_eq!(stats.total_tokens, 675); // (100+50) + (200+100) + (150+75)
257        assert!((stats.total_cost - 0.008).abs() < 0.0001);
258        assert!((stats.avg_duration_ms - 2166.6666666666665).abs() < 0.0001);
259    }
260
261    #[test]
262    fn test_empty_usage_stats() {
263        let db = create_test_db().unwrap();
264
265        let stats = db.get_usage_stats(7).unwrap();
266
267        assert_eq!(stats.total_requests, 0);
268        assert_eq!(stats.total_tokens, 0);
269        assert_eq!(stats.total_cost, 0.0);
270        assert_eq!(stats.avg_duration_ms, 0.0);
271    }
272
273    #[test]
274    fn test_concurrent_access() {
275        let db = create_test_db().unwrap();
276
277        let db_clone1 = TrackingDb {
278            conn: Arc::clone(&db.conn),
279        };
280        let db_clone2 = TrackingDb {
281            conn: Arc::clone(&db.conn),
282        };
283
284        let handle1 = std::thread::spawn(move || {
285            for i in 0..10 {
286                db_clone1
287                    .track_request(
288                        &format!("model-{}", i),
289                        100 + i,
290                        50 + i,
291                        0.001,
292                        1000 + i * 100,
293                    )
294                    .unwrap();
295            }
296        });
297
298        let handle2 = std::thread::spawn(move || {
299            for i in 0..10 {
300                db_clone2
301                    .track_request(
302                        &format!("model-{}", i + 10),
303                        100 + i,
304                        50 + i,
305                        0.001,
306                        1000 + i * 100,
307                    )
308                    .unwrap();
309            }
310        });
311
312        handle1.join().unwrap();
313        handle2.join().unwrap();
314
315        let stats = db.get_usage_stats(7).unwrap();
316        assert_eq!(stats.total_requests, 20);
317    }
318}