use llm_shield_core::Error;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelTask {
PromptInjection,
Toxicity,
Sentiment,
NamedEntityRecognition,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelVariant {
FP16,
FP32,
INT8,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub id: String,
pub task: ModelTask,
pub variant: ModelVariant,
pub url: String,
pub checksum: String,
pub size_bytes: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct RegistryData {
cache_dir: Option<String>,
models: Vec<ModelMetadata>,
}
#[derive(Debug, Clone)]
pub struct ModelRegistry {
models: Arc<HashMap<String, ModelMetadata>>,
cache_dir: Arc<PathBuf>,
}
impl ModelRegistry {
pub fn new() -> Self {
let cache_dir = Self::default_cache_dir();
Self {
models: Arc::new(HashMap::new()),
cache_dir: Arc::new(cache_dir),
}
}
pub fn from_file(path: &str) -> Result<Self> {
tracing::info!("Loading model registry from: {}", path);
let json = std::fs::read_to_string(path).map_err(|e| {
Error::model(format!("Failed to read registry file '{}': {}", path, e))
})?;
let data: RegistryData = serde_json::from_str(&json).map_err(|e| {
Error::model(format!("Failed to parse registry JSON: {}", e))
})?;
let mut models = HashMap::new();
for model in data.models {
let key = Self::model_key(&model.task, &model.variant);
tracing::debug!(
"Registered model: {} ({:?}/{:?})",
model.id,
model.task,
model.variant
);
models.insert(key, model);
}
let cache_dir = if let Some(dir) = data.cache_dir {
PathBuf::from(shellexpand::tilde(&dir).to_string())
} else {
Self::default_cache_dir()
};
tracing::info!(
"Registry loaded with {} models, cache_dir: {}",
models.len(),
cache_dir.display()
);
Ok(Self {
models: Arc::new(models),
cache_dir: Arc::new(cache_dir)
})
}
pub fn get_model_metadata(
&self,
task: ModelTask,
variant: ModelVariant,
) -> Result<&ModelMetadata> {
let key = Self::model_key(&task, &variant);
self.models.get(&key).ok_or_else(|| {
Error::model(format!(
"Model not found in registry: {:?}/{:?}",
task, variant
))
})
}
pub fn list_models(&self) -> Vec<&ModelMetadata> {
self.models.values().collect()
}
pub fn list_models_for_task(&self, task: ModelTask) -> Vec<&ModelMetadata> {
self.models
.values()
.filter(|m| m.task == task)
.collect()
}
pub fn get_available_variants(&self, task: ModelTask) -> Vec<ModelVariant> {
self.models
.values()
.filter(|m| m.task == task)
.map(|m| m.variant)
.collect()
}
pub fn has_model(&self, task: ModelTask, variant: ModelVariant) -> bool {
let key = Self::model_key(&task, &variant);
self.models.contains_key(&key)
}
pub fn model_count(&self) -> usize {
self.models.len()
}
pub fn is_empty(&self) -> bool {
self.models.is_empty()
}
pub async fn ensure_model_available(
&self,
task: ModelTask,
variant: ModelVariant,
) -> Result<PathBuf> {
let metadata = self.get_model_metadata(task, variant)?;
let model_path = self.cache_dir.join(&metadata.id).join("model.onnx");
if model_path.exists() {
tracing::debug!("Model found in cache: {:?}", model_path);
if self.verify_checksum(&model_path, &metadata.checksum)? {
tracing::debug!("Checksum verified, using cached model");
return Ok(model_path);
} else {
tracing::warn!("Cached model checksum mismatch, re-downloading");
}
}
tracing::info!(
"Downloading model: {} from {}",
metadata.id,
metadata.url
);
self.download_model(metadata, &model_path).await?;
if !self.verify_checksum(&model_path, &metadata.checksum)? {
let _ = std::fs::remove_file(&model_path);
return Err(Error::model(format!(
"Checksum verification failed for model: {}",
metadata.id
)));
}
tracing::info!("Model downloaded and verified: {:?}", model_path);
Ok(model_path)
}
async fn download_model(&self, metadata: &ModelMetadata, dest: &Path) -> Result<()> {
if let Some(parent) = dest.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
Error::model(format!(
"Failed to create cache directory '{}': {}",
parent.display(),
e
))
})?;
}
if metadata.url.starts_with("file://") {
let src_path = metadata.url.strip_prefix("file://").unwrap();
std::fs::copy(src_path, dest).map_err(|e| {
Error::model(format!(
"Failed to copy model from '{}' to '{}': {}",
src_path,
dest.display(),
e
))
})?;
return Ok(());
}
let response = reqwest::get(&metadata.url).await.map_err(|e| {
Error::model(format!(
"Failed to download model from '{}': {}",
metadata.url, e
))
})?;
if !response.status().is_success() {
return Err(Error::model(format!(
"HTTP error downloading model: {}",
response.status()
)));
}
let bytes = response.bytes().await.map_err(|e| {
Error::model(format!("Failed to read response body: {}", e))
})?;
std::fs::write(dest, bytes).map_err(|e| {
Error::model(format!("Failed to write model to '{}': {}", dest.display(), e))
})?;
Ok(())
}
fn verify_checksum(&self, path: &Path, expected: &str) -> Result<bool> {
let bytes = std::fs::read(path).map_err(|e| {
Error::model(format!("Failed to read file '{}' for checksum: {}", path.display(), e))
})?;
let mut hasher = Sha256::new();
hasher.update(&bytes);
let hash = format!("{:x}", hasher.finalize());
Ok(hash == expected)
}
fn model_key(task: &ModelTask, variant: &ModelVariant) -> String {
format!("{:?}/{:?}", task, variant)
}
fn default_cache_dir() -> PathBuf {
dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from(".cache"))
.join("llm-shield")
.join("models")
}
}
impl Default for ModelRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_model_key_generation() {
let key1 = ModelRegistry::model_key(&ModelTask::PromptInjection, &ModelVariant::FP16);
let key2 = ModelRegistry::model_key(&ModelTask::Toxicity, &ModelVariant::FP32);
assert_eq!(key1, "PromptInjection/FP16");
assert_eq!(key2, "Toxicity/FP32");
assert_ne!(key1, key2);
}
#[test]
fn test_default_cache_dir() {
let cache_dir = ModelRegistry::default_cache_dir();
assert!(cache_dir.to_string_lossy().contains("llm-shield"));
assert!(cache_dir.to_string_lossy().contains("models"));
}
#[test]
fn test_registry_creation() {
let registry = ModelRegistry::new();
assert_eq!(registry.models.len(), 0);
assert!(registry.cache_dir.to_string_lossy().contains("llm-shield"));
}
#[test]
fn test_registry_from_file() {
let temp_dir = TempDir::new().unwrap();
let registry_path = temp_dir.path().join("registry.json");
let content = r#"{
"cache_dir": "/tmp/test-cache",
"models": [
{
"id": "test-model",
"task": "PromptInjection",
"variant": "FP16",
"url": "https://example.com/model.onnx",
"checksum": "abc123",
"size_bytes": 1024
}
]
}"#;
std::fs::write(®istry_path, content).unwrap();
let registry = ModelRegistry::from_file(registry_path.to_str().unwrap()).unwrap();
assert_eq!(registry.models.len(), 1);
let metadata = registry
.get_model_metadata(ModelTask::PromptInjection, ModelVariant::FP16)
.unwrap();
assert_eq!(metadata.id, "test-model");
assert_eq!(metadata.url, "https://example.com/model.onnx");
}
#[test]
fn test_get_missing_model() {
let registry = ModelRegistry::new();
let result = registry.get_model_metadata(ModelTask::PromptInjection, ModelVariant::FP16);
assert!(result.is_err());
}
#[test]
fn test_checksum_verification() {
let temp_dir = TempDir::new().unwrap();
let test_file = temp_dir.path().join("test.txt");
let content = b"Hello, World!";
std::fs::write(&test_file, content).unwrap();
let mut hasher = Sha256::new();
hasher.update(content);
let correct_checksum = format!("{:x}", hasher.finalize());
let registry = ModelRegistry::new();
assert!(registry
.verify_checksum(&test_file, &correct_checksum)
.unwrap());
assert!(!registry
.verify_checksum(&test_file, "wrong_checksum")
.unwrap());
}
#[tokio::test]
async fn test_download_local_file() {
let temp_dir = TempDir::new().unwrap();
let src_file = temp_dir.path().join("source.onnx");
let content = b"fake model data";
std::fs::write(&src_file, content).unwrap();
let mut hasher = Sha256::new();
hasher.update(content);
let checksum = format!("{:x}", hasher.finalize());
let metadata = ModelMetadata {
id: "test".to_string(),
task: ModelTask::PromptInjection,
variant: ModelVariant::FP16,
url: format!("file://{}", src_file.display()),
checksum,
size_bytes: content.len(),
};
let dest_file = temp_dir.path().join("dest.onnx");
let registry = ModelRegistry::new();
registry.download_model(&metadata, &dest_file).await.unwrap();
assert!(dest_file.exists());
let downloaded = std::fs::read(&dest_file).unwrap();
assert_eq!(downloaded, content);
}
#[test]
fn test_model_task_serialization() {
let task = ModelTask::PromptInjection;
let json = serde_json::to_string(&task).unwrap();
let deserialized: ModelTask = serde_json::from_str(&json).unwrap();
assert_eq!(task, deserialized);
}
#[test]
fn test_model_variant_serialization() {
let variant = ModelVariant::FP16;
let json = serde_json::to_string(&variant).unwrap();
let deserialized: ModelVariant = serde_json::from_str(&json).unwrap();
assert_eq!(variant, deserialized);
}
}