use crate::error::LlmError;
use rusqlite::Connection;
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Default)]
pub struct UsageStats {
pub total_requests: u64,
pub total_tokens: u64,
pub total_cost: f64,
pub avg_duration_ms: f64,
pub total_cache_read_tokens: u64,
pub total_cache_write_tokens: u64,
}
pub struct TrackingDb {
conn: Arc<Mutex<Connection>>,
}
impl TrackingDb {
pub fn new() -> Result<Self, LlmError> {
let mut db_path = dirs::home_dir()
.ok_or_else(|| LlmError::PersistenceError("Home directory not found".to_string()))?;
db_path.push(".limit");
std::fs::create_dir_all(&db_path).map_err(|e| {
LlmError::PersistenceError(format!("Failed to create .limit directory: {}", e))
})?;
db_path.push("tracking.db");
Self::open(&db_path)
}
pub fn open(db_path: &std::path::Path) -> Result<Self, LlmError> {
let conn = Connection::open(db_path)
.map_err(|e| LlmError::PersistenceError(format!("Failed to open database: {}", e)))?;
conn.busy_timeout(std::time::Duration::from_secs(5))
.map_err(|e| {
LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
})?;
let db = Self {
conn: Arc::new(Mutex::new(conn)),
};
db.init_tables()?;
Ok(db)
}
pub fn new_in_memory() -> Result<Self, LlmError> {
let conn = Connection::open_in_memory().map_err(|e| {
LlmError::PersistenceError(format!("Failed to open in-memory database: {}", e))
})?;
let db = Self {
conn: Arc::new(Mutex::new(conn)),
};
db.init_tables()?;
Ok(db)
}
fn init_tables(&self) -> Result<(), LlmError> {
let conn = self
.conn
.lock()
.map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
conn.execute(
"CREATE TABLE IF NOT EXISTS requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp INTEGER NOT NULL,
model TEXT NOT NULL,
input_tokens INTEGER NOT NULL,
output_tokens INTEGER NOT NULL,
cost REAL NOT NULL,
duration_ms INTEGER NOT NULL,
cache_read_tokens INTEGER NOT NULL DEFAULT 0,
cache_write_tokens INTEGER NOT NULL DEFAULT 0
)",
[],
)
.map_err(|e| LlmError::PersistenceError(format!("Failed to create table: {}", e)))?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_timestamp ON requests (timestamp)",
[],
)
.map_err(|e| LlmError::PersistenceError(format!("Failed to create index: {}", e)))?;
let _ = conn.execute(
"ALTER TABLE requests ADD COLUMN cache_read_tokens INTEGER NOT NULL DEFAULT 0",
[],
);
let _ = conn.execute(
"ALTER TABLE requests ADD COLUMN cache_write_tokens INTEGER NOT NULL DEFAULT 0",
[],
);
Ok(())
}
pub fn track_request(
&self,
model: &str,
input_tokens: u64,
output_tokens: u64,
cache_read_tokens: u64,
cache_write_tokens: u64,
cost: f64,
duration_ms: u64,
) -> Result<(), LlmError> {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
.as_secs() as i64;
let conn = self
.conn
.lock()
.map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
conn.execute(
"INSERT INTO requests (timestamp, model, input_tokens, output_tokens, cost, duration_ms, cache_read_tokens, cache_write_tokens)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
rusqlite::params![timestamp, model, input_tokens as i64, output_tokens as i64, cost, duration_ms as i64, cache_read_tokens as i64, cache_write_tokens as i64],
)
.map_err(|e| LlmError::PersistenceError(format!("Failed to insert request: {}", e)))?;
Ok(())
}
pub fn get_usage_stats(&self, days: u32) -> Result<UsageStats, LlmError> {
let cutoff_timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| LlmError::PersistenceError(format!("Failed to get timestamp: {}", e)))?
.as_secs() as i64
- (days as i64 * 86400);
let conn = self
.conn
.lock()
.map_err(|e| LlmError::PersistenceError(format!("Failed to acquire lock: {}", e)))?;
let mut stmt = conn
.prepare(
"SELECT
COUNT(*) as count,
SUM(input_tokens + output_tokens) as total_tokens,
SUM(cost) as total_cost,
AVG(duration_ms) as avg_duration,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
FROM requests
WHERE timestamp >= ?1",
)
.map_err(|e| LlmError::PersistenceError(format!("Failed to prepare query: {}", e)))?;
let mut rows = stmt
.query([cutoff_timestamp])
.map_err(|e| LlmError::PersistenceError(format!("Failed to execute query: {}", e)))?;
if let Some(row) = rows
.next()
.map_err(|e| LlmError::PersistenceError(format!("Failed to read row: {}", e)))?
{
let total_requests: i64 = row.get(0).unwrap_or(0);
let total_tokens: i64 = row.get(1).unwrap_or(0);
let total_cost: f64 = row.get(2).unwrap_or(0.0);
let avg_duration_ms: f64 = row.get(3).unwrap_or(0.0);
let total_cache_read_tokens: i64 = row.get(4).unwrap_or(0);
let total_cache_write_tokens: i64 = row.get(5).unwrap_or(0);
Ok(UsageStats {
total_requests: total_requests.max(0) as u64,
total_tokens: total_tokens.max(0) as u64,
total_cost,
avg_duration_ms,
total_cache_read_tokens: total_cache_read_tokens.max(0) as u64,
total_cache_write_tokens: total_cache_write_tokens.max(0) as u64,
})
} else {
Ok(UsageStats::default())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
fn create_test_db() -> Result<TrackingDb, LlmError> {
let temp_file = NamedTempFile::new().map_err(|e| {
LlmError::PersistenceError(format!("Failed to create temp file: {}", e))
})?;
let path = temp_file.path().to_path_buf();
drop(temp_file);
let conn = Connection::open(&path).map_err(|e| {
LlmError::PersistenceError(format!("Failed to open test database: {}", e))
})?;
conn.busy_timeout(std::time::Duration::from_secs(5))
.map_err(|e| {
LlmError::PersistenceError(format!("Failed to set busy timeout: {}", e))
})?;
let db = TrackingDb {
conn: Arc::new(Mutex::new(conn)),
};
db.init_tables()?;
Ok(db)
}
#[test]
fn test_create_tables() {
let db = create_test_db().unwrap();
let result: rusqlite::Result<i64> =
db.conn
.lock()
.unwrap()
.query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0));
assert!(result.is_ok());
}
#[test]
fn test_track_request() {
let db = create_test_db().unwrap();
db.track_request("claude-3-5-sonnet", 100, 50, 500, 100, 0.001, 1500)
.unwrap();
db.track_request("claude-3-opus", 200, 100, 0, 0, 0.005, 3000)
.unwrap();
let count: i64 = db
.conn
.lock()
.unwrap()
.query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 2);
}
#[test]
fn test_get_usage_stats() {
let db = create_test_db().unwrap();
db.track_request("model-1", 100, 50, 500, 100, 0.001, 1500)
.unwrap();
db.track_request("model-2", 200, 100, 0, 0, 0.005, 3000)
.unwrap();
db.track_request("model-3", 150, 75, 200, 50, 0.002, 2000)
.unwrap();
let stats = db.get_usage_stats(7).unwrap();
assert_eq!(stats.total_requests, 3);
assert_eq!(stats.total_tokens, 675);
assert!((stats.total_cost - 0.008).abs() < 0.0001);
assert!((stats.avg_duration_ms - 2166.6666666666665).abs() < 0.0001);
assert_eq!(stats.total_cache_read_tokens, 700);
assert_eq!(stats.total_cache_write_tokens, 150);
}
#[test]
fn test_empty_usage_stats() {
let db = create_test_db().unwrap();
let stats = db.get_usage_stats(7).unwrap();
assert_eq!(stats.total_requests, 0);
assert_eq!(stats.total_tokens, 0);
assert_eq!(stats.total_cost, 0.0);
assert_eq!(stats.avg_duration_ms, 0.0);
assert_eq!(stats.total_cache_read_tokens, 0);
assert_eq!(stats.total_cache_write_tokens, 0);
}
#[test]
fn test_concurrent_access() {
let db = create_test_db().unwrap();
let db_clone1 = TrackingDb {
conn: Arc::clone(&db.conn),
};
let db_clone2 = TrackingDb {
conn: Arc::clone(&db.conn),
};
let handle1 = std::thread::spawn(move || {
for i in 0..10 {
db_clone1
.track_request(
&format!("model-{}", i),
100 + i,
50 + i,
0,
0,
0.001,
1000 + i * 100,
)
.unwrap();
}
});
let handle2 = std::thread::spawn(move || {
for i in 0..10 {
db_clone2
.track_request(
&format!("model-{}", i + 10),
100 + i,
50 + i,
0,
0,
0.001,
1000 + i * 100,
)
.unwrap();
}
});
handle1.join().unwrap();
handle2.join().unwrap();
let stats = db.get_usage_stats(7).unwrap();
assert_eq!(stats.total_requests, 20);
}
}