1use crate::error::LlmError;
2use rusqlite::Connection;
3use std::sync::{Arc, Mutex};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6#[derive(Debug, Clone, Default)]
8pub struct UsageStats {
9 pub total_requests: u64,
10 pub total_tokens: u64,
11 pub total_cost: f64,
12 pub avg_duration_ms: f64,
13}
14
15pub struct TrackingDb {
17 conn: Arc<Mutex<Connection>>,
18}
19
20impl TrackingDb {
21 pub fn new() -> Result<Self, LlmError> {
23 let mut db_path = dirs::home_dir()
24 .ok_or_else(|| LlmError::PersistenceError("Home directory not found".to_string()))?;
25 db_path.push(".limit");
26
27 std::fs::create_dir_all(&db_path).map_err(|e| {
29 LlmError::PersistenceError(format!("Failed to create .limit directory: {}", e))
30 })?;
31
32 db_path.push("tracking.db");
33
34 let conn = Connection::open(&db_path)
35 .map_err(|e| LlmError::PersistenceError(format!("Failed to open database: {}", e)))?;
36
37 conn.busy_timeout(std::time::Duration::from_secs(5))
39 .map_err(|e| {
40 LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
41 })?;
42
43 let db = Self {
44 conn: Arc::new(Mutex::new(conn)),
45 };
46
47 db.init_tables()?;
48
49 Ok(db)
50 }
51
52 fn init_tables(&self) -> Result<(), LlmError> {
54 let conn = self
55 .conn
56 .lock()
57 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
58
59 conn.execute(
60 "CREATE TABLE IF NOT EXISTS requests (
61 id INTEGER PRIMARY KEY AUTOINCREMENT,
62 timestamp INTEGER NOT NULL,
63 model TEXT NOT NULL,
64 input_tokens INTEGER NOT NULL,
65 output_tokens INTEGER NOT NULL,
66 cost REAL NOT NULL,
67 duration_ms INTEGER NOT NULL
68 )",
69 [],
70 )
71 .map_err(|e| LlmError::PersistenceError(format!("Failed to create table: {}", e)))?;
72
73 conn.execute(
75 "CREATE INDEX IF NOT EXISTS idx_timestamp ON requests (timestamp)",
76 [],
77 )
78 .map_err(|e| LlmError::PersistenceError(format!("Failed to create index: {}", e)))?;
79
80 Ok(())
81 }
82
83 pub fn track_request(
85 &self,
86 model: &str,
87 input_tokens: u64,
88 output_tokens: u64,
89 cost: f64,
90 duration_ms: u64,
91 ) -> Result<(), LlmError> {
92 let timestamp = SystemTime::now()
93 .duration_since(UNIX_EPOCH)
94 .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
95 .as_secs() as i64;
96
97 let conn = self
98 .conn
99 .lock()
100 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
101
102 conn.execute(
103 "INSERT INTO requests (timestamp, model, input_tokens, output_tokens, cost, duration_ms)
104 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
105 rusqlite::params![timestamp, model, input_tokens as i64, output_tokens as i64, cost, duration_ms as i64],
106 )
107 .map_err(|e| LlmError::PersistenceError(format!("Failed to insert request: {}", e)))?;
108
109 Ok(())
110 }
111
112 pub fn get_usage_stats(&self, days: u32) -> Result<UsageStats, LlmError> {
114 let cutoff_timestamp = SystemTime::now()
115 .duration_since(UNIX_EPOCH)
116 .map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
117 .as_secs() as i64
118 - (days as i64 * 86400);
119
120 let conn = self
121 .conn
122 .lock()
123 .map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
124
125 let mut stmt = conn
126 .prepare(
127 "SELECT
128 COUNT(*) as count,
129 SUM(input_tokens + output_tokens) as total_tokens,
130 SUM(cost) as total_cost,
131 AVG(duration_ms) as avg_duration
132 FROM requests
133 WHERE timestamp >= ?1",
134 )
135 .map_err(|e| LlmError::PersistenceError(format!("Failed to prepare query: {}", e)))?;
136
137 let mut rows = stmt
138 .query([cutoff_timestamp])
139 .map_err(|e| LlmError::PersistenceError(format!("Failed to execute query: {}", e)))?;
140
141 if let Some(row) = rows
142 .next()
143 .map_err(|e| LlmError::PersistenceError(format!("Failed to read row: {}", e)))?
144 {
145 let total_requests: i64 = row.get(0).unwrap_or(0);
146 let total_tokens: i64 = row.get(1).unwrap_or(0);
147 let total_cost: f64 = row.get(2).unwrap_or(0.0);
148 let avg_duration_ms: f64 = row.get(3).unwrap_or(0.0);
149
150 Ok(UsageStats {
151 total_requests: total_requests.max(0) as u64,
152 total_tokens: total_tokens.max(0) as u64,
153 total_cost,
154 avg_duration_ms,
155 })
156 } else {
157 Ok(UsageStats::default())
158 }
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use tempfile::NamedTempFile;
166
167 fn create_test_db() -> Result<TrackingDb, LlmError> {
168 let temp_file = NamedTempFile::new().map_err(|e| {
169 LlmError::PersistenceError(format!("Failed to create temp file: {}", e))
170 })?;
171
172 let path = temp_file.path().to_path_buf();
173 drop(temp_file);
175
176 let conn = Connection::open(&path).map_err(|e| {
177 LlmError::PersistenceError(format!("Failed to open test database: {}", e))
178 })?;
179
180 conn.busy_timeout(std::time::Duration::from_secs(5))
181 .map_err(|e| {
182 LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
183 })?;
184
185 let db = TrackingDb {
186 conn: Arc::new(Mutex::new(conn)),
187 };
188
189 db.init_tables()?;
190
191 Ok(db)
192 }
193
194 #[test]
195 fn test_create_tables() {
196 let db = create_test_db().unwrap();
197
198 let result: rusqlite::Result<i64> =
200 db.conn
201 .lock()
202 .unwrap()
203 .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0));
204 assert!(result.is_ok());
205 }
206
207 #[test]
208 fn test_track_request() {
209 let db = create_test_db().unwrap();
210
211 db.track_request("claude-3-5-sonnet", 100, 50, 0.001, 1500)
212 .unwrap();
213 db.track_request("claude-3-opus", 200, 100, 0.005, 3000)
214 .unwrap();
215
216 let count: i64 = db
217 .conn
218 .lock()
219 .unwrap()
220 .query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
221 .unwrap();
222 assert_eq!(count, 2);
223 }
224
225 #[test]
226 fn test_get_usage_stats() {
227 let db = create_test_db().unwrap();
228
229 db.track_request("model-1", 100, 50, 0.001, 1500).unwrap();
230 db.track_request("model-2", 200, 100, 0.005, 3000).unwrap();
231 db.track_request("model-3", 150, 75, 0.002, 2000).unwrap();
232
233 let stats = db.get_usage_stats(7).unwrap();
234
235 assert_eq!(stats.total_requests, 3);
236 assert_eq!(stats.total_tokens, 675); assert!((stats.total_cost - 0.008).abs() < 0.0001);
238 assert!((stats.avg_duration_ms - 2166.6666666666665).abs() < 0.0001);
239 }
240
241 #[test]
242 fn test_empty_usage_stats() {
243 let db = create_test_db().unwrap();
244
245 let stats = db.get_usage_stats(7).unwrap();
246
247 assert_eq!(stats.total_requests, 0);
248 assert_eq!(stats.total_tokens, 0);
249 assert_eq!(stats.total_cost, 0.0);
250 assert_eq!(stats.avg_duration_ms, 0.0);
251 }
252
253 #[test]
254 fn test_concurrent_access() {
255 let db = create_test_db().unwrap();
256
257 let db_clone1 = TrackingDb {
258 conn: Arc::clone(&db.conn),
259 };
260 let db_clone2 = TrackingDb {
261 conn: Arc::clone(&db.conn),
262 };
263
264 let handle1 = std::thread::spawn(move || {
265 for i in 0..10 {
266 db_clone1
267 .track_request(
268 &format!("model-{}", i),
269 100 + i,
270 50 + i,
271 0.001,
272 1000 + i * 100,
273 )
274 .unwrap();
275 }
276 });
277
278 let handle2 = std::thread::spawn(move || {
279 for i in 0..10 {
280 db_clone2
281 .track_request(
282 &format!("model-{}", i + 10),
283 100 + i,
284 50 + i,
285 0.001,
286 1000 + i * 100,
287 )
288 .unwrap();
289 }
290 });
291
292 handle1.join().unwrap();
293 handle2.join().unwrap();
294
295 let stats = db.get_usage_stats(7).unwrap();
296 assert_eq!(stats.total_requests, 20);
297 }
298}