1use llm_shield_core::Error;
28use serde::{Deserialize, Serialize};
29use sha2::{Digest, Sha256};
30use std::collections::HashMap;
31use std::path::{Path, PathBuf};
32use std::sync::Arc;
33
34pub type Result<T> = std::result::Result<T, Error>;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
39pub enum ModelTask {
40 PromptInjection,
42 Toxicity,
44 Sentiment,
46 NamedEntityRecognition,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
52pub enum ModelVariant {
53 FP16,
55 FP32,
57 INT8,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ModelMetadata {
64 pub id: String,
66 pub task: ModelTask,
68 pub variant: ModelVariant,
70 pub url: String,
72 pub checksum: String,
74 pub size_bytes: usize,
76}
77
78#[derive(Debug, Serialize, Deserialize)]
80struct RegistryData {
81 cache_dir: Option<String>,
83 models: Vec<ModelMetadata>,
85}
86
87#[derive(Debug, Clone)]
95pub struct ModelRegistry {
96 models: Arc<HashMap<String, ModelMetadata>>,
98 cache_dir: Arc<PathBuf>,
100}
101
102impl ModelRegistry {
103 pub fn new() -> Self {
105 let cache_dir = Self::default_cache_dir();
106
107 Self {
108 models: Arc::new(HashMap::new()),
109 cache_dir: Arc::new(cache_dir),
110 }
111 }
112
113 pub fn from_file(path: &str) -> Result<Self> {
127 tracing::info!("Loading model registry from: {}", path);
128
129 let json = std::fs::read_to_string(path).map_err(|e| {
130 Error::model(format!("Failed to read registry file '{}': {}", path, e))
131 })?;
132
133 let data: RegistryData = serde_json::from_str(&json).map_err(|e| {
134 Error::model(format!("Failed to parse registry JSON: {}", e))
135 })?;
136
137 let mut models = HashMap::new();
138 for model in data.models {
139 let key = Self::model_key(&model.task, &model.variant);
140 tracing::debug!(
141 "Registered model: {} ({:?}/{:?})",
142 model.id,
143 model.task,
144 model.variant
145 );
146 models.insert(key, model);
147 }
148
149 let cache_dir = if let Some(dir) = data.cache_dir {
150 PathBuf::from(shellexpand::tilde(&dir).to_string())
151 } else {
152 Self::default_cache_dir()
153 };
154
155 tracing::info!(
156 "Registry loaded with {} models, cache_dir: {}",
157 models.len(),
158 cache_dir.display()
159 );
160
161 Ok(Self {
162 models: Arc::new(models),
163 cache_dir: Arc::new(cache_dir)
164 })
165 }
166
167 pub fn get_model_metadata(
193 &self,
194 task: ModelTask,
195 variant: ModelVariant,
196 ) -> Result<&ModelMetadata> {
197 let key = Self::model_key(&task, &variant);
198 self.models.get(&key).ok_or_else(|| {
199 Error::model(format!(
200 "Model not found in registry: {:?}/{:?}",
201 task, variant
202 ))
203 })
204 }
205
206 pub fn list_models(&self) -> Vec<&ModelMetadata> {
226 self.models.values().collect()
227 }
228
229 pub fn list_models_for_task(&self, task: ModelTask) -> Vec<&ModelMetadata> {
251 self.models
252 .values()
253 .filter(|m| m.task == task)
254 .collect()
255 }
256
257 pub fn get_available_variants(&self, task: ModelTask) -> Vec<ModelVariant> {
281 self.models
282 .values()
283 .filter(|m| m.task == task)
284 .map(|m| m.variant)
285 .collect()
286 }
287
288 pub fn has_model(&self, task: ModelTask, variant: ModelVariant) -> bool {
314 let key = Self::model_key(&task, &variant);
315 self.models.contains_key(&key)
316 }
317
318 pub fn model_count(&self) -> usize {
335 self.models.len()
336 }
337
338 pub fn is_empty(&self) -> bool {
357 self.models.is_empty()
358 }
359
360 pub async fn ensure_model_available(
377 &self,
378 task: ModelTask,
379 variant: ModelVariant,
380 ) -> Result<PathBuf> {
381 let metadata = self.get_model_metadata(task, variant)?;
382 let model_path = self.cache_dir.join(&metadata.id).join("model.onnx");
383
384 if model_path.exists() {
386 tracing::debug!("Model found in cache: {:?}", model_path);
387
388 if self.verify_checksum(&model_path, &metadata.checksum)? {
389 tracing::debug!("Checksum verified, using cached model");
390 return Ok(model_path);
391 } else {
392 tracing::warn!("Cached model checksum mismatch, re-downloading");
393 }
394 }
395
396 tracing::info!(
398 "Downloading model: {} from {}",
399 metadata.id,
400 metadata.url
401 );
402 self.download_model(metadata, &model_path).await?;
403
404 if !self.verify_checksum(&model_path, &metadata.checksum)? {
406 let _ = std::fs::remove_file(&model_path);
408 return Err(Error::model(format!(
409 "Checksum verification failed for model: {}",
410 metadata.id
411 )));
412 }
413
414 tracing::info!("Model downloaded and verified: {:?}", model_path);
415 Ok(model_path)
416 }
417
418 async fn download_model(&self, metadata: &ModelMetadata, dest: &Path) -> Result<()> {
420 if let Some(parent) = dest.parent() {
422 std::fs::create_dir_all(parent).map_err(|e| {
423 Error::model(format!(
424 "Failed to create cache directory '{}': {}",
425 parent.display(),
426 e
427 ))
428 })?;
429 }
430
431 if metadata.url.starts_with("file://") {
433 let src_path = metadata.url.strip_prefix("file://").unwrap();
434 std::fs::copy(src_path, dest).map_err(|e| {
435 Error::model(format!(
436 "Failed to copy model from '{}' to '{}': {}",
437 src_path,
438 dest.display(),
439 e
440 ))
441 })?;
442 return Ok(());
443 }
444
445 let response = reqwest::get(&metadata.url).await.map_err(|e| {
447 Error::model(format!(
448 "Failed to download model from '{}': {}",
449 metadata.url, e
450 ))
451 })?;
452
453 if !response.status().is_success() {
454 return Err(Error::model(format!(
455 "HTTP error downloading model: {}",
456 response.status()
457 )));
458 }
459
460 let bytes = response.bytes().await.map_err(|e| {
461 Error::model(format!("Failed to read response body: {}", e))
462 })?;
463
464 std::fs::write(dest, bytes).map_err(|e| {
466 Error::model(format!("Failed to write model to '{}': {}", dest.display(), e))
467 })?;
468
469 Ok(())
470 }
471
472 fn verify_checksum(&self, path: &Path, expected: &str) -> Result<bool> {
474 let bytes = std::fs::read(path).map_err(|e| {
475 Error::model(format!("Failed to read file '{}' for checksum: {}", path.display(), e))
476 })?;
477
478 let mut hasher = Sha256::new();
479 hasher.update(&bytes);
480 let hash = format!("{:x}", hasher.finalize());
481
482 Ok(hash == expected)
483 }
484
485 fn model_key(task: &ModelTask, variant: &ModelVariant) -> String {
487 format!("{:?}/{:?}", task, variant)
488 }
489
490 fn default_cache_dir() -> PathBuf {
492 dirs::cache_dir()
493 .unwrap_or_else(|| PathBuf::from(".cache"))
494 .join("llm-shield")
495 .join("models")
496 }
497}
498
499impl Default for ModelRegistry {
500 fn default() -> Self {
501 Self::new()
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use tempfile::TempDir;
509
510 #[test]
511 fn test_model_key_generation() {
512 let key1 = ModelRegistry::model_key(&ModelTask::PromptInjection, &ModelVariant::FP16);
513 let key2 = ModelRegistry::model_key(&ModelTask::Toxicity, &ModelVariant::FP32);
514
515 assert_eq!(key1, "PromptInjection/FP16");
516 assert_eq!(key2, "Toxicity/FP32");
517 assert_ne!(key1, key2);
518 }
519
520 #[test]
521 fn test_default_cache_dir() {
522 let cache_dir = ModelRegistry::default_cache_dir();
523 assert!(cache_dir.to_string_lossy().contains("llm-shield"));
524 assert!(cache_dir.to_string_lossy().contains("models"));
525 }
526
527 #[test]
528 fn test_registry_creation() {
529 let registry = ModelRegistry::new();
530 assert_eq!(registry.models.len(), 0);
531 assert!(registry.cache_dir.to_string_lossy().contains("llm-shield"));
532 }
533
534 #[test]
535 fn test_registry_from_file() {
536 let temp_dir = TempDir::new().unwrap();
537 let registry_path = temp_dir.path().join("registry.json");
538
539 let content = r#"{
540 "cache_dir": "/tmp/test-cache",
541 "models": [
542 {
543 "id": "test-model",
544 "task": "PromptInjection",
545 "variant": "FP16",
546 "url": "https://example.com/model.onnx",
547 "checksum": "abc123",
548 "size_bytes": 1024
549 }
550 ]
551 }"#;
552
553 std::fs::write(®istry_path, content).unwrap();
554
555 let registry = ModelRegistry::from_file(registry_path.to_str().unwrap()).unwrap();
556 assert_eq!(registry.models.len(), 1);
557
558 let metadata = registry
559 .get_model_metadata(ModelTask::PromptInjection, ModelVariant::FP16)
560 .unwrap();
561 assert_eq!(metadata.id, "test-model");
562 assert_eq!(metadata.url, "https://example.com/model.onnx");
563 }
564
565 #[test]
566 fn test_get_missing_model() {
567 let registry = ModelRegistry::new();
568 let result = registry.get_model_metadata(ModelTask::PromptInjection, ModelVariant::FP16);
569 assert!(result.is_err());
570 }
571
572 #[test]
573 fn test_checksum_verification() {
574 let temp_dir = TempDir::new().unwrap();
575 let test_file = temp_dir.path().join("test.txt");
576 let content = b"Hello, World!";
577 std::fs::write(&test_file, content).unwrap();
578
579 let mut hasher = Sha256::new();
581 hasher.update(content);
582 let correct_checksum = format!("{:x}", hasher.finalize());
583
584 let registry = ModelRegistry::new();
585
586 assert!(registry
588 .verify_checksum(&test_file, &correct_checksum)
589 .unwrap());
590
591 assert!(!registry
593 .verify_checksum(&test_file, "wrong_checksum")
594 .unwrap());
595 }
596
597 #[tokio::test]
598 async fn test_download_local_file() {
599 let temp_dir = TempDir::new().unwrap();
600 let src_file = temp_dir.path().join("source.onnx");
601 let content = b"fake model data";
602 std::fs::write(&src_file, content).unwrap();
603
604 let mut hasher = Sha256::new();
606 hasher.update(content);
607 let checksum = format!("{:x}", hasher.finalize());
608
609 let metadata = ModelMetadata {
610 id: "test".to_string(),
611 task: ModelTask::PromptInjection,
612 variant: ModelVariant::FP16,
613 url: format!("file://{}", src_file.display()),
614 checksum,
615 size_bytes: content.len(),
616 };
617
618 let dest_file = temp_dir.path().join("dest.onnx");
619 let registry = ModelRegistry::new();
620
621 registry.download_model(&metadata, &dest_file).await.unwrap();
622 assert!(dest_file.exists());
623
624 let downloaded = std::fs::read(&dest_file).unwrap();
625 assert_eq!(downloaded, content);
626 }
627
628 #[test]
629 fn test_model_task_serialization() {
630 let task = ModelTask::PromptInjection;
631 let json = serde_json::to_string(&task).unwrap();
632 let deserialized: ModelTask = serde_json::from_str(&json).unwrap();
633 assert_eq!(task, deserialized);
634 }
635
636 #[test]
637 fn test_model_variant_serialization() {
638 let variant = ModelVariant::FP16;
639 let json = serde_json::to_string(&variant).unwrap();
640 let deserialized: ModelVariant = serde_json::from_str(&json).unwrap();
641 assert_eq!(variant, deserialized);
642 }
643}