Skip to main content

limit_llm/
tracking.rs

1use crate::error::LlmError;
2use rusqlite::Connection;
3use std::sync::{Arc, Mutex};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6/// Usage statistics for LLM requests
7#[derive(Debug, Clone, Default)]
8pub struct UsageStats {
9    pub total_requests: u64,
10    pub total_tokens: u64,
11    pub total_cost: f64,
12    pub avg_duration_ms: f64,
13}
14
15/// SQLite database for tracking LLM usage
16pub struct TrackingDb {
17    conn: Arc<Mutex<Connection>>,
18}
19
20impl TrackingDb {
21    /// Initialize tracking database at ~/.limit/tracking.db
22    pub fn new() -> Result<Self, LlmError> {
23        let mut db_path = dirs::home_dir()
24            .ok_or_else(|| LlmError::PersistenceError("Home directory not found".to_string()))?;
25        db_path.push(".limit");
26
27        // Create .limit directory if it doesn't exist
28        std::fs::create_dir_all(&db_path).map_err(|e| {
29            LlmError::PersistenceError(format!("Failed to create .limit directory: {}", e))
30        })?;
31
32        db_path.push("tracking.db");
33
34        let conn = Connection::open(&db_path)
35            .map_err(|e| LlmError::PersistenceError(format!("Failed to open database: {}", e)))?;
36
37        // Set busy timeout for concurrent access (5 seconds)
38        conn.busy_timeout(std::time::Duration::from_secs(5))
39            .map_err(|e| {
40                LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
41            })?;
42
43        let db = Self {
44            conn: Arc::new(Mutex::new(conn)),
45        };
46
47        db.init_tables()?;
48
49        Ok(db)
50    }
51
52    /// Create tables if they don't exist
53    fn init_tables(&self) -> Result<(), LlmError> {
54        let conn = self
55            .conn
56            .lock()
57            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
58
59        conn.execute(
60            "CREATE TABLE IF NOT EXISTS requests (
61                id INTEGER PRIMARY KEY AUTOINCREMENT,
62                timestamp INTEGER NOT NULL,
63                model TEXT NOT NULL,
64                input_tokens INTEGER NOT NULL,
65                output_tokens INTEGER NOT NULL,
66                cost REAL NOT NULL,
67                duration_ms INTEGER NOT NULL
68            )",
69            [],
70        )
71        .map_err(|e| LlmError::PersistenceError(format!("Failed to create table: {}", e)))?;
72
73        // Create index on timestamp for faster queries
74        conn.execute(
75            "CREATE INDEX IF NOT EXISTS idx_timestamp ON requests (timestamp)",
76            [],
77        )
78        .map_err(|e| LlmError::PersistenceError(format!("Failed to create index: {}", e)))?;
79
80        Ok(())
81    }
82
83    /// Track a new LLM request
84    pub fn track_request(
85        &self,
86        model: &str,
87        input_tokens: u64,
88        output_tokens: u64,
89        cost: f64,
90        duration_ms: u64,
91    ) -> Result<(), LlmError> {
92        let timestamp = SystemTime::now()
93            .duration_since(UNIX_EPOCH)
94            .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
95            .as_secs() as i64;
96
97        let conn = self
98            .conn
99            .lock()
100            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
101
102        conn.execute(
103            "INSERT INTO requests (timestamp, model, input_tokens, output_tokens, cost, duration_ms)
104             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
105            rusqlite::params![timestamp, model, input_tokens as i64, output_tokens as i64, cost, duration_ms as i64],
106        )
107        .map_err(|e| LlmError::PersistenceError(format!("Failed to insert request: {}", e)))?;
108
109        Ok(())
110    }
111
112    /// Get usage statistics for the last N days
113    pub fn get_usage_stats(&self, days: u32) -> Result<UsageStats, LlmError> {
114        let cutoff_timestamp = SystemTime::now()
115            .duration_since(UNIX_EPOCH)
116            .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
117            .as_secs() as i64
118            - (days as i64 * 86400);
119
120        let conn = self
121            .conn
122            .lock()
123            .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
124
125        let mut stmt = conn
126            .prepare(
127                "SELECT 
128                COUNT(*) as count,
129                SUM(input_tokens + output_tokens) as total_tokens,
130                SUM(cost) as total_cost,
131                AVG(duration_ms) as avg_duration
132             FROM requests
133             WHERE timestamp >= ?1",
134            )
135            .map_err(|e| LlmError::PersistenceError(format!("Failed to prepare query: {}", e)))?;
136
137        let mut rows = stmt
138            .query([cutoff_timestamp])
139            .map_err(|e| LlmError::PersistenceError(format!("Failed to execute query: {}", e)))?;
140
141        if let Some(row) = rows
142            .next()
143            .map_err(|e| LlmError::PersistenceError(format!("Failed to read row: {}", e)))?
144        {
145            let total_requests: i64 = row.get(0).unwrap_or(0);
146            let total_tokens: i64 = row.get(1).unwrap_or(0);
147            let total_cost: f64 = row.get(2).unwrap_or(0.0);
148            let avg_duration_ms: f64 = row.get(3).unwrap_or(0.0);
149
150            Ok(UsageStats {
151                total_requests: total_requests.max(0) as u64,
152                total_tokens: total_tokens.max(0) as u64,
153                total_cost,
154                avg_duration_ms,
155            })
156        } else {
157            Ok(UsageStats::default())
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use tempfile::NamedTempFile;
166
167    fn create_test_db() -> Result<TrackingDb, LlmError> {
168        let temp_file = NamedTempFile::new().map_err(|e| {
169            LlmError::PersistenceError(format!("Failed to create temp file: {}", e))
170        })?;
171
172        let path = temp_file.path().to_path_buf();
173        // Drop temp file to free the path for SQLite
174        drop(temp_file);
175
176        let conn = Connection::open(&path).map_err(|e| {
177            LlmError::PersistenceError(format!("Failed to open test database: {}", e))
178        })?;
179
180        conn.busy_timeout(std::time::Duration::from_secs(5))
181            .map_err(|e| {
182                LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
183            })?;
184
185        let db = TrackingDb {
186            conn: Arc::new(Mutex::new(conn)),
187        };
188
189        db.init_tables()?;
190
191        Ok(db)
192    }
193
194    #[test]
195    fn test_create_tables() {
196        let db = create_test_db().unwrap();
197
198        // Verify table exists by querying it
199        let result: rusqlite::Result<i64> =
200            db.conn
201                .lock()
202                .unwrap()
203                .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0));
204        assert!(result.is_ok());
205    }
206
207    #[test]
208    fn test_track_request() {
209        let db = create_test_db().unwrap();
210
211        db.track_request("claude-3-5-sonnet", 100, 50, 0.001, 1500)
212            .unwrap();
213        db.track_request("claude-3-opus", 200, 100, 0.005, 3000)
214            .unwrap();
215
216        let count: i64 = db
217            .conn
218            .lock()
219            .unwrap()
220            .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
221            .unwrap();
222        assert_eq!(count, 2);
223    }
224
225    #[test]
226    fn test_get_usage_stats() {
227        let db = create_test_db().unwrap();
228
229        db.track_request("model-1", 100, 50, 0.001, 1500).unwrap();
230        db.track_request("model-2", 200, 100, 0.005, 3000).unwrap();
231        db.track_request("model-3", 150, 75, 0.002, 2000).unwrap();
232
233        let stats = db.get_usage_stats(7).unwrap();
234
235        assert_eq!(stats.total_requests, 3);
236        assert_eq!(stats.total_tokens, 675); // (100+50) + (200+100) + (150+75)
237        assert!((stats.total_cost - 0.008).abs() < 0.0001);
238        assert!((stats.avg_duration_ms - 2166.6666666666665).abs() < 0.0001);
239    }
240
241    #[test]
242    fn test_empty_usage_stats() {
243        let db = create_test_db().unwrap();
244
245        let stats = db.get_usage_stats(7).unwrap();
246
247        assert_eq!(stats.total_requests, 0);
248        assert_eq!(stats.total_tokens, 0);
249        assert_eq!(stats.total_cost, 0.0);
250        assert_eq!(stats.avg_duration_ms, 0.0);
251    }
252
253    #[test]
254    fn test_concurrent_access() {
255        let db = create_test_db().unwrap();
256
257        let db_clone1 = TrackingDb {
258            conn: Arc::clone(&db.conn),
259        };
260        let db_clone2 = TrackingDb {
261            conn: Arc::clone(&db.conn),
262        };
263
264        let handle1 = std::thread::spawn(move || {
265            for i in 0..10 {
266                db_clone1
267                    .track_request(
268                        &format!("model-{}", i),
269                        100 + i,
270                        50 + i,
271                        0.001,
272                        1000 + i * 100,
273                    )
274                    .unwrap();
275            }
276        });
277
278        let handle2 = std::thread::spawn(move || {
279            for i in 0..10 {
280                db_clone2
281                    .track_request(
282                        &format!("model-{}", i + 10),
283                        100 + i,
284                        50 + i,
285                        0.001,
286                        1000 + i * 100,
287                    )
288                    .unwrap();
289            }
290        });
291
292        handle1.join().unwrap();
293        handle2.join().unwrap();
294
295        let stats = db.get_usage_stats(7).unwrap();
296        assert_eq!(stats.total_requests, 20);
297    }
298}