llm_shield_models/
registry.rs

1//! Model Registry for LLM Shield
2//!
3//! Manages model metadata, downloads, caching, and verification.
4//!
5//! ## Features
6//!
7//! - Model catalog management
8//! - Automatic downloading with caching
9//! - Checksum verification
10//! - Support for multiple model tasks and variants
11//!
12//! ## Example
13//!
14//! ```no_run
15//! use llm_shield_models::{ModelRegistry, ModelTask, ModelVariant};
16//!
17//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
18//! let registry = ModelRegistry::from_file("models/registry.json")?;
19//! let model_path = registry.ensure_model_available(
20//!     ModelTask::PromptInjection,
21//!     ModelVariant::FP16
22//! ).await?;
23//! # Ok(())
24//! # }
25//! ```
26
27use llm_shield_core::Error;
28use serde::{Deserialize, Serialize};
29use sha2::{Digest, Sha256};
30use std::collections::HashMap;
31use std::path::{Path, PathBuf};
32use std::sync::Arc;
33
34/// Result type alias
35pub type Result<T> = std::result::Result<T, Error>;
36
37/// Model task type
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
39pub enum ModelTask {
40    /// Prompt injection detection
41    PromptInjection,
42    /// Toxicity classification
43    Toxicity,
44    /// Sentiment analysis
45    Sentiment,
46    /// Named Entity Recognition (PII detection)
47    NamedEntityRecognition,
48}
49
50/// Model variant (precision/quantization)
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
52pub enum ModelVariant {
53    /// 16-bit floating point
54    FP16,
55    /// 32-bit floating point
56    FP32,
57    /// 8-bit integer quantization
58    INT8,
59}
60
61/// Model metadata
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ModelMetadata {
64    /// Unique model identifier
65    pub id: String,
66    /// Task this model performs
67    pub task: ModelTask,
68    /// Model variant (precision)
69    pub variant: ModelVariant,
70    /// Download URL
71    pub url: String,
72    /// SHA-256 checksum
73    pub checksum: String,
74    /// Model size in bytes
75    pub size_bytes: usize,
76}
77
78/// Registry data structure (for deserialization)
79#[derive(Debug, Serialize, Deserialize)]
80struct RegistryData {
81    /// Cache directory path
82    cache_dir: Option<String>,
83    /// List of available models
84    models: Vec<ModelMetadata>,
85}
86
87/// Model registry for managing model lifecycle
88///
89/// ## Thread Safety
90///
91/// ModelRegistry uses Arc internally for efficient cloning and sharing
92/// across threads. The internal HashMap is immutable after creation,
93/// making concurrent reads safe without locks.
94#[derive(Debug, Clone)]
95pub struct ModelRegistry {
96    /// Model metadata by key (task/variant) - immutable after creation
97    models: Arc<HashMap<String, ModelMetadata>>,
98    /// Local cache directory
99    cache_dir: Arc<PathBuf>,
100}
101
102impl ModelRegistry {
103    /// Create a new registry with default cache directory
104    pub fn new() -> Self {
105        let cache_dir = Self::default_cache_dir();
106
107        Self {
108            models: Arc::new(HashMap::new()),
109            cache_dir: Arc::new(cache_dir),
110        }
111    }
112
113    /// Create a registry from a JSON file
114    ///
115    /// # Arguments
116    ///
117    /// * `path` - Path to registry.json file
118    ///
119    /// # Example
120    ///
121    /// ```no_run
122    /// # use llm_shield_models::ModelRegistry;
123    /// let registry = ModelRegistry::from_file("models/registry.json")?;
124    /// # Ok::<(), llm_shield_core::Error>(())
125    /// ```
126    pub fn from_file(path: &str) -> Result<Self> {
127        tracing::info!("Loading model registry from: {}", path);
128
129        let json = std::fs::read_to_string(path).map_err(|e| {
130            Error::model(format!("Failed to read registry file '{}': {}", path, e))
131        })?;
132
133        let data: RegistryData = serde_json::from_str(&json).map_err(|e| {
134            Error::model(format!("Failed to parse registry JSON: {}", e))
135        })?;
136
137        let mut models = HashMap::new();
138        for model in data.models {
139            let key = Self::model_key(&model.task, &model.variant);
140            tracing::debug!(
141                "Registered model: {} ({:?}/{:?})",
142                model.id,
143                model.task,
144                model.variant
145            );
146            models.insert(key, model);
147        }
148
149        let cache_dir = if let Some(dir) = data.cache_dir {
150            PathBuf::from(shellexpand::tilde(&dir).to_string())
151        } else {
152            Self::default_cache_dir()
153        };
154
155        tracing::info!(
156            "Registry loaded with {} models, cache_dir: {}",
157            models.len(),
158            cache_dir.display()
159        );
160
161        Ok(Self {
162            models: Arc::new(models),
163            cache_dir: Arc::new(cache_dir)
164        })
165    }
166
167    /// Get metadata for a specific model
168    ///
169    /// # Arguments
170    ///
171    /// * `task` - The model task
172    /// * `variant` - The model variant
173    ///
174    /// # Returns
175    ///
176    /// Reference to model metadata, or Error if not found
177    ///
178    /// # Example
179    ///
180    /// ```no_run
181    /// # use llm_shield_models::{ModelRegistry, ModelTask, ModelVariant};
182    /// # fn example() -> Result<(), llm_shield_core::Error> {
183    /// # let registry = ModelRegistry::new();
184    /// let metadata = registry.get_model_metadata(
185    ///     ModelTask::PromptInjection,
186    ///     ModelVariant::FP16
187    /// )?;
188    /// println!("Model ID: {}", metadata.id);
189    /// # Ok(())
190    /// # }
191    /// ```
192    pub fn get_model_metadata(
193        &self,
194        task: ModelTask,
195        variant: ModelVariant,
196    ) -> Result<&ModelMetadata> {
197        let key = Self::model_key(&task, &variant);
198        self.models.get(&key).ok_or_else(|| {
199            Error::model(format!(
200                "Model not found in registry: {:?}/{:?}",
201                task, variant
202            ))
203        })
204    }
205
206    /// List all available models in the registry
207    ///
208    /// # Returns
209    ///
210    /// Vector of references to all model metadata
211    ///
212    /// # Example
213    ///
214    /// ```no_run
215    /// # use llm_shield_models::ModelRegistry;
216    /// # fn example() -> Result<(), llm_shield_core::Error> {
217    /// # let registry = ModelRegistry::new();
218    /// let all_models = registry.list_models();
219    /// for model in all_models {
220    ///     println!("Model: {} ({:?}/{:?})", model.id, model.task, model.variant);
221    /// }
222    /// # Ok(())
223    /// # }
224    /// ```
225    pub fn list_models(&self) -> Vec<&ModelMetadata> {
226        self.models.values().collect()
227    }
228
229    /// List all models for a specific task
230    ///
231    /// # Arguments
232    ///
233    /// * `task` - The model task to filter by
234    ///
235    /// # Returns
236    ///
237    /// Vector of references to model metadata matching the task
238    ///
239    /// # Example
240    ///
241    /// ```no_run
242    /// # use llm_shield_models::{ModelRegistry, ModelTask};
243    /// # fn example() -> Result<(), llm_shield_core::Error> {
244    /// # let registry = ModelRegistry::new();
245    /// let models = registry.list_models_for_task(ModelTask::PromptInjection);
246    /// println!("Found {} PromptInjection models", models.len());
247    /// # Ok(())
248    /// # }
249    /// ```
250    pub fn list_models_for_task(&self, task: ModelTask) -> Vec<&ModelMetadata> {
251        self.models
252            .values()
253            .filter(|m| m.task == task)
254            .collect()
255    }
256
257    /// Get all available variants for a specific task
258    ///
259    /// # Arguments
260    ///
261    /// * `task` - The model task
262    ///
263    /// # Returns
264    ///
265    /// Vector of all available model variants for the task
266    ///
267    /// # Example
268    ///
269    /// ```no_run
270    /// # use llm_shield_models::{ModelRegistry, ModelTask};
271    /// # fn example() -> Result<(), llm_shield_core::Error> {
272    /// # let registry = ModelRegistry::new();
273    /// let variants = registry.get_available_variants(ModelTask::PromptInjection);
274    /// for variant in variants {
275    ///     println!("Available variant: {:?}", variant);
276    /// }
277    /// # Ok(())
278    /// # }
279    /// ```
280    pub fn get_available_variants(&self, task: ModelTask) -> Vec<ModelVariant> {
281        self.models
282            .values()
283            .filter(|m| m.task == task)
284            .map(|m| m.variant)
285            .collect()
286    }
287
288    /// Check if a specific model is available in the registry
289    ///
290    /// # Arguments
291    ///
292    /// * `task` - The model task
293    /// * `variant` - The model variant
294    ///
295    /// # Returns
296    ///
297    /// `true` if the model is registered, `false` otherwise
298    ///
299    /// # Example
300    ///
301    /// ```no_run
302    /// # use llm_shield_models::{ModelRegistry, ModelTask, ModelVariant};
303    /// # fn example() -> Result<(), llm_shield_core::Error> {
304    /// # let registry = ModelRegistry::new();
305    /// if registry.has_model(ModelTask::PromptInjection, ModelVariant::FP16) {
306    ///     println!("Model is available!");
307    /// } else {
308    ///     println!("Model not found");
309    /// }
310    /// # Ok(())
311    /// # }
312    /// ```
313    pub fn has_model(&self, task: ModelTask, variant: ModelVariant) -> bool {
314        let key = Self::model_key(&task, &variant);
315        self.models.contains_key(&key)
316    }
317
318    /// Get the total number of registered models
319    ///
320    /// # Returns
321    ///
322    /// Count of all models in the registry
323    ///
324    /// # Example
325    ///
326    /// ```no_run
327    /// # use llm_shield_models::ModelRegistry;
328    /// # fn example() -> Result<(), llm_shield_core::Error> {
329    /// # let registry = ModelRegistry::new();
330    /// println!("Registry contains {} models", registry.model_count());
331    /// # Ok(())
332    /// # }
333    /// ```
334    pub fn model_count(&self) -> usize {
335        self.models.len()
336    }
337
338    /// Check if the registry is empty
339    ///
340    /// # Returns
341    ///
342    /// `true` if no models are registered, `false` otherwise
343    ///
344    /// # Example
345    ///
346    /// ```no_run
347    /// # use llm_shield_models::ModelRegistry;
348    /// # fn example() -> Result<(), llm_shield_core::Error> {
349    /// # let registry = ModelRegistry::new();
350    /// if registry.is_empty() {
351    ///     println!("No models registered");
352    /// }
353    /// # Ok(())
354    /// # }
355    /// ```
356    pub fn is_empty(&self) -> bool {
357        self.models.is_empty()
358    }
359
360    /// Ensure a model is available locally (download if needed)
361    ///
362    /// This method:
363    /// 1. Checks if model is already cached
364    /// 2. Verifies checksum if cached
365    /// 3. Downloads if not cached or verification fails
366    /// 4. Verifies checksum after download
367    ///
368    /// # Arguments
369    ///
370    /// * `task` - The model task
371    /// * `variant` - The model variant
372    ///
373    /// # Returns
374    ///
375    /// Path to the local model file
376    pub async fn ensure_model_available(
377        &self,
378        task: ModelTask,
379        variant: ModelVariant,
380    ) -> Result<PathBuf> {
381        let metadata = self.get_model_metadata(task, variant)?;
382        let model_path = self.cache_dir.join(&metadata.id).join("model.onnx");
383
384        // Check if already cached and valid
385        if model_path.exists() {
386            tracing::debug!("Model found in cache: {:?}", model_path);
387
388            if self.verify_checksum(&model_path, &metadata.checksum)? {
389                tracing::debug!("Checksum verified, using cached model");
390                return Ok(model_path);
391            } else {
392                tracing::warn!("Cached model checksum mismatch, re-downloading");
393            }
394        }
395
396        // Download model
397        tracing::info!(
398            "Downloading model: {} from {}",
399            metadata.id,
400            metadata.url
401        );
402        self.download_model(metadata, &model_path).await?;
403
404        // Verify checksum
405        if !self.verify_checksum(&model_path, &metadata.checksum)? {
406            // Clean up failed download
407            let _ = std::fs::remove_file(&model_path);
408            return Err(Error::model(format!(
409                "Checksum verification failed for model: {}",
410                metadata.id
411            )));
412        }
413
414        tracing::info!("Model downloaded and verified: {:?}", model_path);
415        Ok(model_path)
416    }
417
418    /// Download a model from URL to local path
419    async fn download_model(&self, metadata: &ModelMetadata, dest: &Path) -> Result<()> {
420        // Create parent directory
421        if let Some(parent) = dest.parent() {
422            std::fs::create_dir_all(parent).map_err(|e| {
423                Error::model(format!(
424                    "Failed to create cache directory '{}': {}",
425                    parent.display(),
426                    e
427                ))
428            })?;
429        }
430
431        // Handle file:// URLs for testing
432        if metadata.url.starts_with("file://") {
433            let src_path = metadata.url.strip_prefix("file://").unwrap();
434            std::fs::copy(src_path, dest).map_err(|e| {
435                Error::model(format!(
436                    "Failed to copy model from '{}' to '{}': {}",
437                    src_path,
438                    dest.display(),
439                    e
440                ))
441            })?;
442            return Ok(());
443        }
444
445        // Download using reqwest for HTTP(S) URLs
446        let response = reqwest::get(&metadata.url).await.map_err(|e| {
447            Error::model(format!(
448                "Failed to download model from '{}': {}",
449                metadata.url, e
450            ))
451        })?;
452
453        if !response.status().is_success() {
454            return Err(Error::model(format!(
455                "HTTP error downloading model: {}",
456                response.status()
457            )));
458        }
459
460        let bytes = response.bytes().await.map_err(|e| {
461            Error::model(format!("Failed to read response body: {}", e))
462        })?;
463
464        // Write to file
465        std::fs::write(dest, bytes).map_err(|e| {
466            Error::model(format!("Failed to write model to '{}': {}", dest.display(), e))
467        })?;
468
469        Ok(())
470    }
471
472    /// Verify SHA-256 checksum of a file
473    fn verify_checksum(&self, path: &Path, expected: &str) -> Result<bool> {
474        let bytes = std::fs::read(path).map_err(|e| {
475            Error::model(format!("Failed to read file '{}' for checksum: {}", path.display(), e))
476        })?;
477
478        let mut hasher = Sha256::new();
479        hasher.update(&bytes);
480        let hash = format!("{:x}", hasher.finalize());
481
482        Ok(hash == expected)
483    }
484
485    /// Generate a key for model lookup
486    fn model_key(task: &ModelTask, variant: &ModelVariant) -> String {
487        format!("{:?}/{:?}", task, variant)
488    }
489
490    /// Get the default cache directory
491    fn default_cache_dir() -> PathBuf {
492        dirs::cache_dir()
493            .unwrap_or_else(|| PathBuf::from(".cache"))
494            .join("llm-shield")
495            .join("models")
496    }
497}
498
499impl Default for ModelRegistry {
500    fn default() -> Self {
501        Self::new()
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508    use tempfile::TempDir;
509
510    #[test]
511    fn test_model_key_generation() {
512        let key1 = ModelRegistry::model_key(&ModelTask::PromptInjection, &ModelVariant::FP16);
513        let key2 = ModelRegistry::model_key(&ModelTask::Toxicity, &ModelVariant::FP32);
514
515        assert_eq!(key1, "PromptInjection/FP16");
516        assert_eq!(key2, "Toxicity/FP32");
517        assert_ne!(key1, key2);
518    }
519
520    #[test]
521    fn test_default_cache_dir() {
522        let cache_dir = ModelRegistry::default_cache_dir();
523        assert!(cache_dir.to_string_lossy().contains("llm-shield"));
524        assert!(cache_dir.to_string_lossy().contains("models"));
525    }
526
527    #[test]
528    fn test_registry_creation() {
529        let registry = ModelRegistry::new();
530        assert_eq!(registry.models.len(), 0);
531        assert!(registry.cache_dir.to_string_lossy().contains("llm-shield"));
532    }
533
534    #[test]
535    fn test_registry_from_file() {
536        let temp_dir = TempDir::new().unwrap();
537        let registry_path = temp_dir.path().join("registry.json");
538
539        let content = r#"{
540            "cache_dir": "/tmp/test-cache",
541            "models": [
542                {
543                    "id": "test-model",
544                    "task": "PromptInjection",
545                    "variant": "FP16",
546                    "url": "https://example.com/model.onnx",
547                    "checksum": "abc123",
548                    "size_bytes": 1024
549                }
550            ]
551        }"#;
552
553        std::fs::write(&registry_path, content).unwrap();
554
555        let registry = ModelRegistry::from_file(registry_path.to_str().unwrap()).unwrap();
556        assert_eq!(registry.models.len(), 1);
557
558        let metadata = registry
559            .get_model_metadata(ModelTask::PromptInjection, ModelVariant::FP16)
560            .unwrap();
561        assert_eq!(metadata.id, "test-model");
562        assert_eq!(metadata.url, "https://example.com/model.onnx");
563    }
564
565    #[test]
566    fn test_get_missing_model() {
567        let registry = ModelRegistry::new();
568        let result = registry.get_model_metadata(ModelTask::PromptInjection, ModelVariant::FP16);
569        assert!(result.is_err());
570    }
571
572    #[test]
573    fn test_checksum_verification() {
574        let temp_dir = TempDir::new().unwrap();
575        let test_file = temp_dir.path().join("test.txt");
576        let content = b"Hello, World!";
577        std::fs::write(&test_file, content).unwrap();
578
579        // Calculate correct checksum
580        let mut hasher = Sha256::new();
581        hasher.update(content);
582        let correct_checksum = format!("{:x}", hasher.finalize());
583
584        let registry = ModelRegistry::new();
585
586        // Test correct checksum
587        assert!(registry
588            .verify_checksum(&test_file, &correct_checksum)
589            .unwrap());
590
591        // Test incorrect checksum
592        assert!(!registry
593            .verify_checksum(&test_file, "wrong_checksum")
594            .unwrap());
595    }
596
597    #[tokio::test]
598    async fn test_download_local_file() {
599        let temp_dir = TempDir::new().unwrap();
600        let src_file = temp_dir.path().join("source.onnx");
601        let content = b"fake model data";
602        std::fs::write(&src_file, content).unwrap();
603
604        // Calculate checksum
605        let mut hasher = Sha256::new();
606        hasher.update(content);
607        let checksum = format!("{:x}", hasher.finalize());
608
609        let metadata = ModelMetadata {
610            id: "test".to_string(),
611            task: ModelTask::PromptInjection,
612            variant: ModelVariant::FP16,
613            url: format!("file://{}", src_file.display()),
614            checksum,
615            size_bytes: content.len(),
616        };
617
618        let dest_file = temp_dir.path().join("dest.onnx");
619        let registry = ModelRegistry::new();
620
621        registry.download_model(&metadata, &dest_file).await.unwrap();
622        assert!(dest_file.exists());
623
624        let downloaded = std::fs::read(&dest_file).unwrap();
625        assert_eq!(downloaded, content);
626    }
627
628    #[test]
629    fn test_model_task_serialization() {
630        let task = ModelTask::PromptInjection;
631        let json = serde_json::to_string(&task).unwrap();
632        let deserialized: ModelTask = serde_json::from_str(&json).unwrap();
633        assert_eq!(task, deserialized);
634    }
635
636    #[test]
637    fn test_model_variant_serialization() {
638        let variant = ModelVariant::FP16;
639        let json = serde_json::to_string(&variant).unwrap();
640        let deserialized: ModelVariant = serde_json::from_str(&json).unwrap();
641        assert_eq!(variant, deserialized);
642    }
643}