use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
pub trait StorageBackend: Send + Sync {
fn get_model_stats(&self, alias: &str) -> Option<PersistedModelStats>;
fn update_model_stats(
&self,
alias: &str,
requests_delta: u64,
tokens_generated_delta: u64,
tokens_prompt_delta: u64,
avg_tps_x100: u64,
);
fn record_model_load(&self, alias: &str, load_time_ms: u64);
fn all_model_stats(&self) -> Vec<(String, PersistedModelStats)>;
fn get_model_meta(&self, path: &str, mtime: u64) -> Option<CachedModelMeta>;
fn set_model_meta(&self, path: &str, meta: &CachedModelMeta);
fn get_prompt_tokens(&self, hash: &[u8]) -> Option<Vec<i32>>;
fn set_prompt_tokens(&self, hash: &[u8], tokens: &[i32]);
fn save_session(&self, session: &StoredSession);
fn get_session(&self, id: &str) -> Option<StoredSession>;
fn list_sessions(&self) -> Vec<StoredSession>;
fn delete_session(&self, id: &str);
fn get_config(&self, key: &str) -> Option<String>;
fn set_config(&self, key: &str, value: &str);
fn flush(&self);
}
pub struct DaemonStore {
db: sled::Db,
model_stats: sled::Tree,
model_meta: sled::Tree,
prompt_cache: sled::Tree,
sessions: sled::Tree,
config: sled::Tree,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PersistedModelStats {
pub requests_total: u64,
pub tokens_generated: u64,
pub tokens_prompt: u64,
pub avg_tokens_per_sec_x100: u64,
pub load_count: u64,
pub last_used: u64,
pub total_load_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedModelMeta {
pub n_params: u64,
pub n_layers: u32,
pub n_embd: u32,
pub n_vocab: u32,
pub file_size: u64,
pub quant_type: Option<String>,
pub mtime: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredSession {
pub id: String,
pub model: String,
pub messages: Vec<StoredMessage>,
pub created_at: u64,
pub updated_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredMessage {
pub role: String,
pub content: String,
}
impl DaemonStore {
pub fn open_default() -> Result<Self, sled::Error> {
let db_path = Self::default_path();
Self::open(&db_path)
}
pub fn open(path: &PathBuf) -> Result<Self, sled::Error> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).ok();
}
let db = sled::open(path)?;
let model_stats = db.open_tree("model_stats")?;
let model_meta = db.open_tree("model_meta")?;
let prompt_cache = db.open_tree("prompt_cache")?;
let sessions = db.open_tree("sessions")?;
let config = db.open_tree("config")?;
Ok(Self {
db,
model_stats,
model_meta,
prompt_cache,
sessions,
config,
})
}
pub fn default_path() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".mullama")
.join("db")
}
pub fn get_model_stats(&self, alias: &str) -> Option<PersistedModelStats> {
self.model_stats
.get(alias.as_bytes())
.ok()
.flatten()
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
}
pub fn update_model_stats(
&self,
alias: &str,
requests_delta: u64,
tokens_generated_delta: u64,
tokens_prompt_delta: u64,
avg_tps_x100: u64,
) {
let mut stats = self.get_model_stats(alias).unwrap_or_default();
stats.requests_total += requests_delta;
stats.tokens_generated += tokens_generated_delta;
stats.tokens_prompt += tokens_prompt_delta;
if avg_tps_x100 > 0 {
if stats.avg_tokens_per_sec_x100 == 0 {
stats.avg_tokens_per_sec_x100 = avg_tps_x100;
} else {
stats.avg_tokens_per_sec_x100 =
(stats.avg_tokens_per_sec_x100 * 3 + avg_tps_x100) / 4;
}
}
stats.last_used = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if let Ok(bytes) = serde_json::to_vec(&stats) {
let _ = self.model_stats.insert(alias.as_bytes(), bytes);
}
}
pub fn record_model_load(&self, alias: &str, load_time_ms: u64) {
let mut stats = self.get_model_stats(alias).unwrap_or_default();
stats.load_count += 1;
stats.total_load_time_ms += load_time_ms;
stats.last_used = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if let Ok(bytes) = serde_json::to_vec(&stats) {
let _ = self.model_stats.insert(alias.as_bytes(), bytes);
}
}
pub fn all_model_stats(&self) -> Vec<(String, PersistedModelStats)> {
self.model_stats
.iter()
.filter_map(|item| {
let (key, value) = item.ok()?;
let alias = String::from_utf8(key.to_vec()).ok()?;
let stats: PersistedModelStats = serde_json::from_slice(&value).ok()?;
Some((alias, stats))
})
.collect()
}
pub fn get_model_meta(&self, path: &str, mtime: u64) -> Option<CachedModelMeta> {
let key = format!("{}:{}", path, mtime);
self.model_meta
.get(key.as_bytes())
.ok()
.flatten()
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
}
pub fn set_model_meta(&self, path: &str, meta: &CachedModelMeta) {
let key = format!("{}:{}", path, meta.mtime);
if let Ok(bytes) = serde_json::to_vec(meta) {
let _ = self.model_meta.insert(key.as_bytes(), bytes);
}
}
pub fn get_prompt_tokens(&self, hash: &[u8]) -> Option<Vec<i32>> {
self.prompt_cache
.get(hash)
.ok()
.flatten()
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
}
pub fn set_prompt_tokens(&self, hash: &[u8], tokens: &[i32]) {
if let Ok(bytes) = serde_json::to_vec(tokens) {
let _ = self.prompt_cache.insert(hash, bytes);
}
}
pub fn save_session(&self, session: &StoredSession) {
if let Ok(bytes) = serde_json::to_vec(session) {
let _ = self.sessions.insert(session.id.as_bytes(), bytes);
}
}
pub fn get_session(&self, id: &str) -> Option<StoredSession> {
self.sessions
.get(id.as_bytes())
.ok()
.flatten()
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
}
pub fn list_sessions(&self) -> Vec<StoredSession> {
let mut sessions: Vec<StoredSession> = self
.sessions
.iter()
.filter_map(|item| {
let (_key, value) = item.ok()?;
serde_json::from_slice(&value).ok()
})
.collect();
sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
sessions
}
pub fn delete_session(&self, id: &str) {
let _ = self.sessions.remove(id.as_bytes());
}
pub fn get_config(&self, key: &str) -> Option<String> {
self.config
.get(key.as_bytes())
.ok()
.flatten()
.and_then(|bytes| String::from_utf8(bytes.to_vec()).ok())
}
pub fn set_config(&self, key: &str, value: &str) {
let _ = self.config.insert(key.as_bytes(), value.as_bytes());
}
pub fn flush(&self) {
let _ = self.db.flush();
}
}
impl StorageBackend for DaemonStore {
fn get_model_stats(&self, alias: &str) -> Option<PersistedModelStats> {
self.get_model_stats(alias)
}
fn update_model_stats(
&self,
alias: &str,
requests_delta: u64,
tokens_generated_delta: u64,
tokens_prompt_delta: u64,
avg_tps_x100: u64,
) {
self.update_model_stats(
alias,
requests_delta,
tokens_generated_delta,
tokens_prompt_delta,
avg_tps_x100,
);
}
fn record_model_load(&self, alias: &str, load_time_ms: u64) {
self.record_model_load(alias, load_time_ms);
}
fn all_model_stats(&self) -> Vec<(String, PersistedModelStats)> {
self.all_model_stats()
}
fn get_model_meta(&self, path: &str, mtime: u64) -> Option<CachedModelMeta> {
self.get_model_meta(path, mtime)
}
fn set_model_meta(&self, path: &str, meta: &CachedModelMeta) {
self.set_model_meta(path, meta);
}
fn get_prompt_tokens(&self, hash: &[u8]) -> Option<Vec<i32>> {
self.get_prompt_tokens(hash)
}
fn set_prompt_tokens(&self, hash: &[u8], tokens: &[i32]) {
self.set_prompt_tokens(hash, tokens);
}
fn save_session(&self, session: &StoredSession) {
self.save_session(session);
}
fn get_session(&self, id: &str) -> Option<StoredSession> {
self.get_session(id)
}
fn list_sessions(&self) -> Vec<StoredSession> {
self.list_sessions()
}
fn delete_session(&self, id: &str) {
self.delete_session(id);
}
fn get_config(&self, key: &str) -> Option<String> {
self.get_config(key)
}
fn set_config(&self, key: &str, value: &str) {
self.set_config(key, value);
}
fn flush(&self) {
self.flush();
}
}
impl Drop for DaemonStore {
fn drop(&mut self) {
self.flush();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_store_model_stats() {
let dir = tempfile::tempdir().unwrap();
let store = DaemonStore::open(&dir.path().join("test.db")).unwrap();
assert!(store.get_model_stats("test").is_none());
store.update_model_stats("test", 10, 500, 100, 4500);
let stats = store.get_model_stats("test").unwrap();
assert_eq!(stats.requests_total, 10);
assert_eq!(stats.tokens_generated, 500);
assert_eq!(stats.tokens_prompt, 100);
store.update_model_stats("test", 5, 200, 50, 5000);
let stats = store.get_model_stats("test").unwrap();
assert_eq!(stats.requests_total, 15);
assert_eq!(stats.tokens_generated, 700);
}
#[test]
fn test_store_sessions() {
let dir = tempfile::tempdir().unwrap();
let store = DaemonStore::open(&dir.path().join("test.db")).unwrap();
let session = StoredSession {
id: "sess_1".to_string(),
model: "llama3".to_string(),
messages: vec![
StoredMessage {
role: "user".to_string(),
content: "Hello".to_string(),
},
StoredMessage {
role: "assistant".to_string(),
content: "Hi there!".to_string(),
},
],
created_at: 1000,
updated_at: 2000,
};
store.save_session(&session);
let loaded = store.get_session("sess_1").unwrap();
assert_eq!(loaded.model, "llama3");
assert_eq!(loaded.messages.len(), 2);
}
#[test]
fn test_store_config() {
let dir = tempfile::tempdir().unwrap();
let store = DaemonStore::open(&dir.path().join("test.db")).unwrap();
store.set_config("last_gpu_layers", "33");
assert_eq!(store.get_config("last_gpu_layers").unwrap(), "33");
}
#[test]
fn test_trait_dispatch_through_dyn_storage_backend() {
let dir = tempfile::tempdir().unwrap();
let store = DaemonStore::open(&dir.path().join("trait_test.db")).unwrap();
let backend: Arc<dyn StorageBackend> = Arc::new(store);
assert!(backend.get_model_stats("dyn_test").is_none());
backend.update_model_stats("dyn_test", 5, 100, 50, 3000);
let stats = backend.get_model_stats("dyn_test").unwrap();
assert_eq!(stats.requests_total, 5);
assert_eq!(stats.tokens_generated, 100);
assert_eq!(stats.tokens_prompt, 50);
backend.record_model_load("dyn_test", 1500);
let stats = backend.get_model_stats("dyn_test").unwrap();
assert_eq!(stats.load_count, 1);
assert_eq!(stats.total_load_time_ms, 1500);
let all = backend.all_model_stats();
assert!(!all.is_empty());
assert!(all.iter().any(|(alias, _)| alias == "dyn_test"));
backend.set_config("trait_key", "trait_value");
assert_eq!(backend.get_config("trait_key").unwrap(), "trait_value");
let session = StoredSession {
id: "dyn_sess".to_string(),
model: "test_model".to_string(),
messages: vec![StoredMessage {
role: "user".to_string(),
content: "via trait".to_string(),
}],
created_at: 100,
updated_at: 200,
};
backend.save_session(&session);
let loaded = backend.get_session("dyn_sess").unwrap();
assert_eq!(loaded.messages[0].content, "via trait");
backend.delete_session("dyn_sess");
assert!(backend.get_session("dyn_sess").is_none());
backend.flush();
}
}