Skip to main content

limit_llm/
tracking.rs

1use crate::error::LlmError;
2use rusqlite::Connection;
3use std::sync::{Arc, Mutex};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6/// Usage statistics for LLM requests
7#[derive(Debug, Clone, Default)]
8pub struct UsageStats {
9    pub total_requests: u64,
10    pub total_tokens: u64,
11    pub total_cost: f64,
12    pub avg_duration_ms: f64,
13    pub total_cache_read_tokens: u64,
14    pub total_cache_write_tokens: u64,
15}
16
17/// SQLite database for tracking LLM usage
18pub struct TrackingDb {
19    conn: Arc<Mutex<Connection>>,
20}
21
22impl TrackingDb {
23    /// Initialize tracking database at ~/.limit/tracking.db
24    pub fn new() -> Result<Self, LlmError> {
25        let mut db_path = dirs::home_dir()
26            .ok_or_else(|| LlmError::PersistenceError("Home directory not found".to_string()))?;
27        db_path.push(".limit");
28
29        // Create .limit directory if it doesn't exist
30        std::fs::create_dir_all(&db_path).map_err(|e| {
31            LlmError::PersistenceError(format!("Failed to create .limit directory: {}", e))
32        })?;
33
34        db_path.push("tracking.db");
35
36        Self::open(&db_path)
37    }
38
39    /// Open tracking database at a specific path (for testing)
40    pub fn open(db_path: &std::path::Path) -> Result<Self, LlmError> {
41        let conn = Connection::open(db_path)
42            .map_err(|e| LlmError::PersistenceError(format!("Failed to open database: {}", e)))?;
43
44        // Set busy timeout for concurrent access (5 seconds)
45        conn.busy_timeout(std::time::Duration::from_secs(5))
46            .map_err(|e| {
47                LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
48            })?;
49
50        let db = Self {
51            conn: Arc::new(Mutex::new(conn)),
52        };
53
54        db.init_tables()?;
55
56        Ok(db)
57    }
58
59    /// Create an in-memory database (for testing)
60    pub fn new_in_memory() -> Result<Self, LlmError> {
61        let conn = Connection::open_in_memory().map_err(|e| {
62            LlmError::PersistenceError(format!("Failed to open in-memory database: {}", e))
63        })?;
64
65        let db = Self {
66            conn: Arc::new(Mutex::new(conn)),
67        };
68
69        db.init_tables()?;
70
71        Ok(db)
72    }
73
74    /// Create tables if they don't exist
75    fn init_tables(&self) -> Result<(), LlmError> {
76        let conn = self
77            .conn
78            .lock()
79            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
80
81        conn.execute(
82            "CREATE TABLE IF NOT EXISTS requests (
83                id INTEGER PRIMARY KEY AUTOINCREMENT,
84                timestamp INTEGER NOT NULL,
85                model TEXT NOT NULL,
86                input_tokens INTEGER NOT NULL,
87                output_tokens INTEGER NOT NULL,
88                cost REAL NOT NULL,
89                duration_ms INTEGER NOT NULL,
90                cache_read_tokens INTEGER NOT NULL DEFAULT 0,
91                cache_write_tokens INTEGER NOT NULL DEFAULT 0
92            )",
93            [],
94        )
95        .map_err(|e| LlmError::PersistenceError(format!("Failed to create table: {}", e)))?;
96
97        // Create index on timestamp for faster queries
98        conn.execute(
99            "CREATE INDEX IF NOT EXISTS idx_timestamp ON requests (timestamp)",
100            [],
101        )
102        .map_err(|e| LlmError::PersistenceError(format!("Failed to create index: {}", e)))?;
103
104        // Migration: Add cache columns to existing databases
105        let _ = conn.execute(
106            "ALTER TABLE requests ADD COLUMN cache_read_tokens INTEGER NOT NULL DEFAULT 0",
107            [],
108        );
109        let _ = conn.execute(
110            "ALTER TABLE requests ADD COLUMN cache_write_tokens INTEGER NOT NULL DEFAULT 0",
111            [],
112        );
113
114        Ok(())
115    }
116
117    /// Track a new LLM request
118    pub fn track_request(
119        &self,
120        model: &str,
121        input_tokens: u64,
122        output_tokens: u64,
123        cache_read_tokens: u64,
124        cache_write_tokens: u64,
125        cost: f64,
126        duration_ms: u64,
127    ) -> Result<(), LlmError> {
128        let timestamp = SystemTime::now()
129            .duration_since(UNIX_EPOCH)
130            .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
131            .as_secs() as i64;
132
133        let conn = self
134            .conn
135            .lock()
136            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
137
138        conn.execute(
139            "INSERT INTO requests (timestamp, model, input_tokens, output_tokens, cost, duration_ms, cache_read_tokens, cache_write_tokens)
140             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
141            rusqlite::params![timestamp, model, input_tokens as i64, output_tokens as i64, cost, duration_ms as i64, cache_read_tokens as i64, cache_write_tokens as i64],
142        )
143        .map_err(|e| LlmError::PersistenceError(format!("Failed to insert request: {}", e)))?;
144
145        Ok(())
146    }
147
148    /// Get usage statistics for the last N days
149    pub fn get_usage_stats(&self, days: u32) -> Result<UsageStats, LlmError> {
150        let cutoff_timestamp = SystemTime::now()
151            .duration_since(UNIX_EPOCH)
152            .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
153            .as_secs() as i64
154            - (days as i64 * 86400);
155
156        let conn = self
157            .conn
158            .lock()
159            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
160
161        let mut stmt = conn
162            .prepare(
163                "SELECT 
164                COUNT(*) as count,
165                SUM(input_tokens + output_tokens) as total_tokens,
166                SUM(cost) as total_cost,
167                AVG(duration_ms) as avg_duration,
168                COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
169                COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
170             FROM requests
171             WHERE timestamp >= ?1",
172            )
173            .map_err(|e| LlmError::PersistenceError(format!("Failed to prepare query: {}", e)))?;
174
175        let mut rows = stmt
176            .query([cutoff_timestamp])
177            .map_err(|e| LlmError::PersistenceError(format!("Failed to execute query: {}", e)))?;
178
179        if let Some(row) = rows
180            .next()
181            .map_err(|e| LlmError::PersistenceError(format!("Failed to read row: {}", e)))?
182        {
183            let total_requests: i64 = row.get(0).unwrap_or(0);
184            let total_tokens: i64 = row.get(1).unwrap_or(0);
185            let total_cost: f64 = row.get(2).unwrap_or(0.0);
186            let avg_duration_ms: f64 = row.get(3).unwrap_or(0.0);
187            let total_cache_read_tokens: i64 = row.get(4).unwrap_or(0);
188            let total_cache_write_tokens: i64 = row.get(5).unwrap_or(0);
189
190            Ok(UsageStats {
191                total_requests: total_requests.max(0) as u64,
192                total_tokens: total_tokens.max(0) as u64,
193                total_cost,
194                avg_duration_ms,
195                total_cache_read_tokens: total_cache_read_tokens.max(0) as u64,
196                total_cache_write_tokens: total_cache_write_tokens.max(0) as u64,
197            })
198        } else {
199            Ok(UsageStats::default())
200        }
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use tempfile::NamedTempFile;
208
209    fn create_test_db() -> Result<TrackingDb, LlmError> {
210        let temp_file = NamedTempFile::new().map_err(|e| {
211            LlmError::PersistenceError(format!("Failed to create temp file: {}", e))
212        })?;
213
214        let path = temp_file.path().to_path_buf();
215        // Drop temp file to free the path for SQLite
216        drop(temp_file);
217
218        let conn = Connection::open(&path).map_err(|e| {
219            LlmError::PersistenceError(format!("Failed to open test database: {}", e))
220        })?;
221
222        conn.busy_timeout(std::time::Duration::from_secs(5))
223            .map_err(|e| {
224                LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
225            })?;
226
227        let db = TrackingDb {
228            conn: Arc::new(Mutex::new(conn)),
229        };
230
231        db.init_tables()?;
232
233        Ok(db)
234    }
235
236    #[test]
237    fn test_create_tables() {
238        let db = create_test_db().unwrap();
239
240        // Verify table exists by querying it
241        let result: rusqlite::Result<i64> =
242            db.conn
243                .lock()
244                .unwrap()
245                .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0));
246        assert!(result.is_ok());
247    }
248
249    #[test]
250    fn test_track_request() {
251        let db = create_test_db().unwrap();
252
253        db.track_request("claude-3-5-sonnet", 100, 50, 500, 100, 0.001, 1500)
254            .unwrap();
255        db.track_request("claude-3-opus", 200, 100, 0, 0, 0.005, 3000)
256            .unwrap();
257
258        let count: i64 = db
259            .conn
260            .lock()
261            .unwrap()
262            .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
263            .unwrap();
264        assert_eq!(count, 2);
265    }
266
267    #[test]
268    fn test_get_usage_stats() {
269        let db = create_test_db().unwrap();
270
271        db.track_request("model-1", 100, 50, 500, 100, 0.001, 1500)
272            .unwrap();
273        db.track_request("model-2", 200, 100, 0, 0, 0.005, 3000)
274            .unwrap();
275        db.track_request("model-3", 150, 75, 200, 50, 0.002, 2000)
276            .unwrap();
277
278        let stats = db.get_usage_stats(7).unwrap();
279
280        assert_eq!(stats.total_requests, 3);
281        assert_eq!(stats.total_tokens, 675);
282        assert!((stats.total_cost - 0.008).abs() < 0.0001);
283        assert!((stats.avg_duration_ms - 2166.6666666666665).abs() < 0.0001);
284        assert_eq!(stats.total_cache_read_tokens, 700);
285        assert_eq!(stats.total_cache_write_tokens, 150);
286    }
287
288    #[test]
289    fn test_empty_usage_stats() {
290        let db = create_test_db().unwrap();
291
292        let stats = db.get_usage_stats(7).unwrap();
293
294        assert_eq!(stats.total_requests, 0);
295        assert_eq!(stats.total_tokens, 0);
296        assert_eq!(stats.total_cost, 0.0);
297        assert_eq!(stats.avg_duration_ms, 0.0);
298        assert_eq!(stats.total_cache_read_tokens, 0);
299        assert_eq!(stats.total_cache_write_tokens, 0);
300    }
301
302    #[test]
303    fn test_concurrent_access() {
304        let db = create_test_db().unwrap();
305
306        let db_clone1 = TrackingDb {
307            conn: Arc::clone(&db.conn),
308        };
309        let db_clone2 = TrackingDb {
310            conn: Arc::clone(&db.conn),
311        };
312
313        let handle1 = std::thread::spawn(move || {
314            for i in 0..10 {
315                db_clone1
316                    .track_request(
317                        &format!("model-{}", i),
318                        100 + i,
319                        50 + i,
320                        0,
321                        0,
322                        0.001,
323                        1000 + i * 100,
324                    )
325                    .unwrap();
326            }
327        });
328
329        let handle2 = std::thread::spawn(move || {
330            for i in 0..10 {
331                db_clone2
332                    .track_request(
333                        &format!("model-{}", i + 10),
334                        100 + i,
335                        50 + i,
336                        0,
337                        0,
338                        0.001,
339                        1000 + i * 100,
340                    )
341                    .unwrap();
342            }
343        });
344
345        handle1.join().unwrap();
346        handle2.join().unwrap();
347
348        let stats = db.get_usage_stats(7).unwrap();
349        assert_eq!(stats.total_requests, 20);
350    }
351}