Skip to main content

limit_llm/
tracking.rs

1use crate::error::LlmError;
2use rusqlite::Connection;
3use std::sync::{Arc, Mutex};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6/// Usage statistics for LLM requests
7#[derive(Debug, Clone, Default)]
8pub struct UsageStats {
9    pub total_requests: u64,
10    pub total_tokens: u64,
11    pub total_cost: f64,
12    pub avg_duration_ms: f64,
13    pub total_cache_read_tokens: u64,
14    pub total_cache_write_tokens: u64,
15}
16
17/// SQLite database for tracking LLM usage
18pub struct TrackingDb {
19    conn: Arc<Mutex<Connection>>,
20}
21
22impl TrackingDb {
23    /// Initialize tracking database at ~/.limit/tracking.db
24    pub fn new() -> Result<Self, LlmError> {
25        let mut db_path = dirs::home_dir()
26            .ok_or_else(|| LlmError::PersistenceError("Home directory not found".to_string()))?;
27        db_path.push(".limit");
28
29        // Create .limit directory if it doesn't exist
30        std::fs::create_dir_all(&db_path).map_err(|e| {
31            LlmError::PersistenceError(format!("Failed to create .limit directory: {}", e))
32        })?;
33
34        db_path.push("tracking.db");
35
36        Self::open(&db_path)
37    }
38
39    /// Open tracking database at a specific path (for testing)
40    pub fn open(db_path: &std::path::Path) -> Result<Self, LlmError> {
41        let conn = Connection::open(db_path)
42            .map_err(|e| LlmError::PersistenceError(format!("Failed to open database: {}", e)))?;
43
44        // Set busy timeout for concurrent access (5 seconds)
45        conn.busy_timeout(std::time::Duration::from_secs(5))
46            .map_err(|e| {
47                LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
48            })?;
49
50        let db = Self {
51            conn: Arc::new(Mutex::new(conn)),
52        };
53
54        db.init_tables()?;
55
56        Ok(db)
57    }
58
59    /// Create an in-memory database (for testing)
60    pub fn new_in_memory() -> Result<Self, LlmError> {
61        let conn = Connection::open_in_memory().map_err(|e| {
62            LlmError::PersistenceError(format!("Failed to open in-memory database: {}", e))
63        })?;
64
65        let db = Self {
66            conn: Arc::new(Mutex::new(conn)),
67        };
68
69        db.init_tables()?;
70
71        Ok(db)
72    }
73
74    /// Create tables if they don't exist
75    fn init_tables(&self) -> Result<(), LlmError> {
76        let conn = self
77            .conn
78            .lock()
79            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
80
81        conn.execute(
82            "CREATE TABLE IF NOT EXISTS requests (
83                id INTEGER PRIMARY KEY AUTOINCREMENT,
84                timestamp INTEGER NOT NULL,
85                model TEXT NOT NULL,
86                input_tokens INTEGER NOT NULL,
87                output_tokens INTEGER NOT NULL,
88                cost REAL NOT NULL,
89                duration_ms INTEGER NOT NULL,
90                cache_read_tokens INTEGER NOT NULL DEFAULT 0,
91                cache_write_tokens INTEGER NOT NULL DEFAULT 0
92            )",
93            [],
94        )
95        .map_err(|e| LlmError::PersistenceError(format!("Failed to create table: {}", e)))?;
96
97        // Create index on timestamp for faster queries
98        conn.execute(
99            "CREATE INDEX IF NOT EXISTS idx_timestamp ON requests (timestamp)",
100            [],
101        )
102        .map_err(|e| LlmError::PersistenceError(format!("Failed to create index: {}", e)))?;
103
104        // Migration: Add cache columns to existing databases
105        let _ = conn.execute(
106            "ALTER TABLE requests ADD COLUMN cache_read_tokens INTEGER NOT NULL DEFAULT 0",
107            [],
108        );
109        let _ = conn.execute(
110            "ALTER TABLE requests ADD COLUMN cache_write_tokens INTEGER NOT NULL DEFAULT 0",
111            [],
112        );
113
114        Ok(())
115    }
116
117    /// Track a new LLM request
118    #[allow(clippy::too_many_arguments)]
119    pub fn track_request(
120        &self,
121        model: &str,
122        input_tokens: u64,
123        output_tokens: u64,
124        cache_read_tokens: u64,
125        cache_write_tokens: u64,
126        cost: f64,
127        duration_ms: u64,
128    ) -> Result<(), LlmError> {
129        let timestamp = SystemTime::now()
130            .duration_since(UNIX_EPOCH)
131            .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
132            .as_secs() as i64;
133
134        let conn = self
135            .conn
136            .lock()
137            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
138
139        conn.execute(
140            "INSERT INTO requests (timestamp, model, input_tokens, output_tokens, cost, duration_ms, cache_read_tokens, cache_write_tokens)
141             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
142            rusqlite::params![timestamp, model, input_tokens as i64, output_tokens as i64, cost, duration_ms as i64, cache_read_tokens as i64, cache_write_tokens as i64],
143        )
144        .map_err(|e| LlmError::PersistenceError(format!("Failed to insert request: {}", e)))?;
145
146        Ok(())
147    }
148
149    /// Get usage statistics for the last N days
150    pub fn get_usage_stats(&self, days: u32) -> Result<UsageStats, LlmError> {
151        let cutoff_timestamp = SystemTime::now()
152            .duration_since(UNIX_EPOCH)
153            .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
154            .as_secs() as i64
155            - (days as i64 * 86400);
156
157        let conn = self
158            .conn
159            .lock()
160            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
161
162        let mut stmt = conn
163            .prepare(
164                "SELECT 
165                COUNT(*) as count,
166                SUM(input_tokens + output_tokens) as total_tokens,
167                SUM(cost) as total_cost,
168                AVG(duration_ms) as avg_duration,
169                COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
170                COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
171             FROM requests
172             WHERE timestamp >= ?1",
173            )
174            .map_err(|e| LlmError::PersistenceError(format!("Failed to prepare query: {}", e)))?;
175
176        let mut rows = stmt
177            .query([cutoff_timestamp])
178            .map_err(|e| LlmError::PersistenceError(format!("Failed to execute query: {}", e)))?;
179
180        if let Some(row) = rows
181            .next()
182            .map_err(|e| LlmError::PersistenceError(format!("Failed to read row: {}", e)))?
183        {
184            let total_requests: i64 = row.get(0).unwrap_or(0);
185            let total_tokens: i64 = row.get(1).unwrap_or(0);
186            let total_cost: f64 = row.get(2).unwrap_or(0.0);
187            let avg_duration_ms: f64 = row.get(3).unwrap_or(0.0);
188            let total_cache_read_tokens: i64 = row.get(4).unwrap_or(0);
189            let total_cache_write_tokens: i64 = row.get(5).unwrap_or(0);
190
191            Ok(UsageStats {
192                total_requests: total_requests.max(0) as u64,
193                total_tokens: total_tokens.max(0) as u64,
194                total_cost,
195                avg_duration_ms,
196                total_cache_read_tokens: total_cache_read_tokens.max(0) as u64,
197                total_cache_write_tokens: total_cache_write_tokens.max(0) as u64,
198            })
199        } else {
200            Ok(UsageStats::default())
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use tempfile::NamedTempFile;
209
210    fn create_test_db() -> Result<TrackingDb, LlmError> {
211        let temp_file = NamedTempFile::new().map_err(|e| {
212            LlmError::PersistenceError(format!("Failed to create temp file: {}", e))
213        })?;
214
215        let path = temp_file.path().to_path_buf();
216        // Drop temp file to free the path for SQLite
217        drop(temp_file);
218
219        let conn = Connection::open(&path).map_err(|e| {
220            LlmError::PersistenceError(format!("Failed to open test database: {}", e))
221        })?;
222
223        conn.busy_timeout(std::time::Duration::from_secs(5))
224            .map_err(|e| {
225                LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
226            })?;
227
228        let db = TrackingDb {
229            conn: Arc::new(Mutex::new(conn)),
230        };
231
232        db.init_tables()?;
233
234        Ok(db)
235    }
236
237    #[test]
238    fn test_create_tables() {
239        let db = create_test_db().unwrap();
240
241        // Verify table exists by querying it
242        let result: rusqlite::Result<i64> =
243            db.conn
244                .lock()
245                .unwrap()
246                .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0));
247        assert!(result.is_ok());
248    }
249
250    #[test]
251    fn test_track_request() {
252        let db = create_test_db().unwrap();
253
254        db.track_request("claude-3-5-sonnet", 100, 50, 500, 100, 0.001, 1500)
255            .unwrap();
256        db.track_request("claude-3-opus", 200, 100, 0, 0, 0.005, 3000)
257            .unwrap();
258
259        let count: i64 = db
260            .conn
261            .lock()
262            .unwrap()
263            .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
264            .unwrap();
265        assert_eq!(count, 2);
266    }
267
268    #[test]
269    fn test_get_usage_stats() {
270        let db = create_test_db().unwrap();
271
272        db.track_request("model-1", 100, 50, 500, 100, 0.001, 1500)
273            .unwrap();
274        db.track_request("model-2", 200, 100, 0, 0, 0.005, 3000)
275            .unwrap();
276        db.track_request("model-3", 150, 75, 200, 50, 0.002, 2000)
277            .unwrap();
278
279        let stats = db.get_usage_stats(7).unwrap();
280
281        assert_eq!(stats.total_requests, 3);
282        assert_eq!(stats.total_tokens, 675);
283        assert!((stats.total_cost - 0.008).abs() < 0.0001);
284        assert!((stats.avg_duration_ms - 2166.6666666666665).abs() < 0.0001);
285        assert_eq!(stats.total_cache_read_tokens, 700);
286        assert_eq!(stats.total_cache_write_tokens, 150);
287    }
288
289    #[test]
290    fn test_empty_usage_stats() {
291        let db = create_test_db().unwrap();
292
293        let stats = db.get_usage_stats(7).unwrap();
294
295        assert_eq!(stats.total_requests, 0);
296        assert_eq!(stats.total_tokens, 0);
297        assert_eq!(stats.total_cost, 0.0);
298        assert_eq!(stats.avg_duration_ms, 0.0);
299        assert_eq!(stats.total_cache_read_tokens, 0);
300        assert_eq!(stats.total_cache_write_tokens, 0);
301    }
302
303    #[test]
304    fn test_concurrent_access() {
305        let db = create_test_db().unwrap();
306
307        let db_clone1 = TrackingDb {
308            conn: Arc::clone(&db.conn),
309        };
310        let db_clone2 = TrackingDb {
311            conn: Arc::clone(&db.conn),
312        };
313
314        let handle1 = std::thread::spawn(move || {
315            for i in 0..10 {
316                db_clone1
317                    .track_request(
318                        &format!("model-{}", i),
319                        100 + i,
320                        50 + i,
321                        0,
322                        0,
323                        0.001,
324                        1000 + i * 100,
325                    )
326                    .unwrap();
327            }
328        });
329
330        let handle2 = std::thread::spawn(move || {
331            for i in 0..10 {
332                db_clone2
333                    .track_request(
334                        &format!("model-{}", i + 10),
335                        100 + i,
336                        50 + i,
337                        0,
338                        0,
339                        0.001,
340                        1000 + i * 100,
341                    )
342                    .unwrap();
343            }
344        });
345
346        handle1.join().unwrap();
347        handle2.join().unwrap();
348
349        let stats = db.get_usage_stats(7).unwrap();
350        assert_eq!(stats.total_requests, 20);
351    }
352}