use crate::error::{Result, TinyDancerError};
use crate::types::Candidate;
use parking_lot::Mutex;
use rusqlite::{params, Connection};
use std::path::Path;
use std::sync::Arc;
pub struct Storage {
conn: Arc<Mutex<Connection>>,
}
impl Storage {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let conn = Connection::open(path)?;
conn.execute_batch(
"PRAGMA journal_mode=WAL;
PRAGMA synchronous=NORMAL;
PRAGMA cache_size=1000000000;
PRAGMA temp_store=memory;",
)?;
let storage = Self {
conn: Arc::new(Mutex::new(conn)),
};
storage.init_schema()?;
Ok(storage)
}
pub fn in_memory() -> Result<Self> {
let conn = Connection::open_in_memory()?;
let storage = Self {
conn: Arc::new(Mutex::new(conn)),
};
storage.init_schema()?;
Ok(storage)
}
fn init_schema(&self) -> Result<()> {
let conn = self.conn.lock();
conn.execute(
"CREATE TABLE IF NOT EXISTS candidates (
id TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
metadata TEXT NOT NULL,
created_at INTEGER NOT NULL,
access_count INTEGER DEFAULT 0,
success_rate REAL DEFAULT 0.0,
last_accessed INTEGER
)",
[],
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS routing_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
candidate_id TEXT NOT NULL,
query_embedding BLOB NOT NULL,
confidence REAL NOT NULL,
use_lightweight INTEGER NOT NULL,
uncertainty REAL NOT NULL,
timestamp INTEGER NOT NULL,
inference_time_us INTEGER NOT NULL,
FOREIGN KEY(candidate_id) REFERENCES candidates(id)
)",
[],
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_candidates_created_at ON candidates(created_at)",
[],
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_routing_timestamp ON routing_history(timestamp)",
[],
)?;
Ok(())
}
pub fn insert_candidate(&self, candidate: &Candidate) -> Result<()> {
let conn = self.conn.lock();
let embedding_bytes = bytemuck::cast_slice::<f32, u8>(&candidate.embedding);
let metadata_json = serde_json::to_string(&candidate.metadata)?;
conn.execute(
"INSERT OR REPLACE INTO candidates
(id, embedding, metadata, created_at, access_count, success_rate, last_accessed)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
params![
&candidate.id,
embedding_bytes,
metadata_json,
candidate.created_at,
candidate.access_count,
candidate.success_rate,
chrono::Utc::now().timestamp()
],
)?;
Ok(())
}
pub fn get_candidate(&self, id: &str) -> Result<Option<Candidate>> {
let conn = self.conn.lock();
let mut stmt = conn.prepare(
"SELECT id, embedding, metadata, created_at, access_count, success_rate
FROM candidates WHERE id = ?1",
)?;
let mut rows = stmt.query(params![id])?;
if let Some(row) = rows.next()? {
let id: String = row.get(0)?;
let embedding_bytes: Vec<u8> = row.get(1)?;
let metadata_json: String = row.get(2)?;
let created_at: i64 = row.get(3)?;
let access_count: u64 = row.get(4)?;
let success_rate: f32 = row.get(5)?;
let embedding = bytemuck::cast_slice::<u8, f32>(&embedding_bytes).to_vec();
let metadata = serde_json::from_str(&metadata_json)?;
Ok(Some(Candidate {
id,
embedding,
metadata,
created_at,
access_count,
success_rate,
}))
} else {
Ok(None)
}
}
pub fn query_candidates(&self, limit: usize) -> Result<Vec<Candidate>> {
let conn = self.conn.lock();
let mut stmt = conn.prepare(
"SELECT id, embedding, metadata, created_at, access_count, success_rate
FROM candidates
ORDER BY created_at DESC
LIMIT ?1",
)?;
let rows = stmt.query_map(params![limit], |row| {
let id: String = row.get(0)?;
let embedding_bytes: Vec<u8> = row.get(1)?;
let metadata_json: String = row.get(2)?;
let created_at: i64 = row.get(3)?;
let access_count: u64 = row.get(4)?;
let success_rate: f32 = row.get(5)?;
let embedding = bytemuck::cast_slice::<u8, f32>(&embedding_bytes).to_vec();
let metadata = serde_json::from_str(&metadata_json).unwrap_or_default();
Ok(Candidate {
id,
embedding,
metadata,
created_at,
access_count,
success_rate,
})
})?;
let candidates: Result<Vec<Candidate>> = rows
.map(|r| r.map_err(|e| TinyDancerError::DatabaseError(e)))
.collect();
candidates
}
pub fn increment_access_count(&self, id: &str) -> Result<()> {
let conn = self.conn.lock();
conn.execute(
"UPDATE candidates
SET access_count = access_count + 1,
last_accessed = ?1
WHERE id = ?2",
params![chrono::Utc::now().timestamp(), id],
)?;
Ok(())
}
pub fn record_routing(
&self,
candidate_id: &str,
query_embedding: &[f32],
confidence: f32,
use_lightweight: bool,
uncertainty: f32,
inference_time_us: u64,
) -> Result<()> {
let conn = self.conn.lock();
let query_bytes = bytemuck::cast_slice::<f32, u8>(query_embedding);
conn.execute(
"INSERT INTO routing_history
(candidate_id, query_embedding, confidence, use_lightweight, uncertainty, timestamp, inference_time_us)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
params![
candidate_id,
query_bytes,
confidence,
use_lightweight as i32,
uncertainty,
chrono::Utc::now().timestamp(),
inference_time_us as i64
],
)?;
Ok(())
}
pub fn get_statistics(&self) -> Result<RoutingStatistics> {
let conn = self.conn.lock();
let total_routes: i64 =
conn.query_row("SELECT COUNT(*) FROM routing_history", [], |row| row.get(0))?;
let lightweight_routes: i64 = conn.query_row(
"SELECT COUNT(*) FROM routing_history WHERE use_lightweight = 1",
[],
|row| row.get(0),
)?;
let avg_inference_time: f64 = conn
.query_row(
"SELECT AVG(inference_time_us) FROM routing_history",
[],
|row| row.get(0),
)
.unwrap_or(0.0);
Ok(RoutingStatistics {
total_routes: total_routes as u64,
lightweight_routes: lightweight_routes as u64,
powerful_routes: (total_routes - lightweight_routes) as u64,
avg_inference_time_us: avg_inference_time,
})
}
}
#[derive(Debug, Clone)]
pub struct RoutingStatistics {
pub total_routes: u64,
pub lightweight_routes: u64,
pub powerful_routes: u64,
pub avg_inference_time_us: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_storage_creation() {
let storage = Storage::in_memory().unwrap();
let stats = storage.get_statistics().unwrap();
assert_eq!(stats.total_routes, 0);
}
#[test]
fn test_candidate_insertion() {
let storage = Storage::in_memory().unwrap();
let candidate = Candidate {
id: "test-1".to_string(),
embedding: vec![0.5; 384],
metadata: HashMap::new(),
created_at: chrono::Utc::now().timestamp(),
access_count: 0,
success_rate: 0.0,
};
storage.insert_candidate(&candidate).unwrap();
let retrieved = storage.get_candidate("test-1").unwrap();
assert!(retrieved.is_some());
}
}