llm_shield_models/
model_loader.rs

1//! Model Loader with ONNX Runtime Integration
2//!
3//! ## SPARC Phase 3: Implementation
4//!
5//! This module provides lazy loading, caching, and thread-safe access to ONNX models.
6//!
7//! ## Features
8//!
9//! - **Lazy Loading**: Models are only loaded when first requested
10//! - **Caching**: Loaded models are cached for reuse
11//! - **Thread-Safe**: Uses Arc + RwLock for concurrent access
12//! - **Registry Integration**: Uses ModelRegistry for model discovery
13//! - **Graceful Error Handling**: Comprehensive error messages
14//!
15//! ## Usage Example
16//!
17//! ```no_run
18//! use llm_shield_models::{ModelLoader, ModelRegistry, ModelType, ModelVariant};
19//! use std::sync::Arc;
20//!
21//! # async fn example() -> Result<(), llm_shield_core::Error> {
22//! // Create registry
23//! let registry = ModelRegistry::from_file("models/registry.json")?;
24//!
25//! // Create loader
26//! let loader = ModelLoader::new(Arc::new(registry));
27//!
28//! // Load model (lazy - only loads once)
29//! let session = loader.load(ModelType::PromptInjection, ModelVariant::FP16).await?;
30//!
31//! // Use session for inference...
32//! # Ok(())
33//! # }
34//! ```
35
36use crate::registry::{ModelRegistry, ModelTask, ModelVariant};
37use llm_shield_core::Error;
38use ort::session::Session;
39use std::collections::HashMap;
40use std::path::PathBuf;
41use std::sync::{Arc, RwLock, Mutex};
42
43/// Result type alias
44pub type Result<T> = std::result::Result<T, Error>;
45
46/// Model type identifier
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48pub enum ModelType {
49    /// Prompt injection detection
50    PromptInjection,
51    /// Toxicity classification
52    Toxicity,
53    /// Sentiment analysis
54    Sentiment,
55    /// Named Entity Recognition (PII detection)
56    NamedEntityRecognition,
57}
58
59/// Conversion from ModelTask to ModelType
60impl From<ModelTask> for ModelType {
61    fn from(task: ModelTask) -> Self {
62        match task {
63            ModelTask::PromptInjection => ModelType::PromptInjection,
64            ModelTask::Toxicity => ModelType::Toxicity,
65            ModelTask::Sentiment => ModelType::Sentiment,
66            ModelTask::NamedEntityRecognition => ModelType::NamedEntityRecognition,
67        }
68    }
69}
70
71/// Conversion from ModelType to ModelTask
72impl From<ModelType> for ModelTask {
73    fn from(model_type: ModelType) -> Self {
74        match model_type {
75            ModelType::PromptInjection => ModelTask::PromptInjection,
76            ModelType::Toxicity => ModelTask::Toxicity,
77            ModelType::Sentiment => ModelTask::Sentiment,
78            ModelType::NamedEntityRecognition => ModelTask::NamedEntityRecognition,
79        }
80    }
81}
82
83/// Configuration for loading a model
84#[derive(Debug, Clone)]
85pub struct ModelConfig {
86    /// Type of model
87    pub model_type: ModelType,
88    /// Model variant (precision)
89    pub variant: ModelVariant,
90    /// Path to ONNX model file
91    pub model_path: PathBuf,
92    /// Number of threads for inference
93    pub thread_pool_size: usize,
94    /// ONNX graph optimization level (0-3)
95    pub optimization_level: u8,
96}
97
98impl ModelConfig {
99    /// Create a new model configuration
100    ///
101    /// # Arguments
102    ///
103    /// * `model_type` - Type of model (PromptInjection, Toxicity, Sentiment)
104    /// * `variant` - Model variant (FP32, FP16, INT8)
105    /// * `model_path` - Path to ONNX model file
106    ///
107    /// # Example
108    ///
109    /// ```
110    /// use llm_shield_models::{ModelConfig, ModelType, ModelVariant};
111    /// use std::path::PathBuf;
112    ///
113    /// let config = ModelConfig::new(
114    ///     ModelType::PromptInjection,
115    ///     ModelVariant::FP16,
116    ///     PathBuf::from("/path/to/model.onnx")
117    /// );
118    /// ```
119    pub fn new(model_type: ModelType, variant: ModelVariant, model_path: PathBuf) -> Self {
120        Self {
121            model_type,
122            variant,
123            model_path,
124            thread_pool_size: num_cpus::get().max(1),
125            optimization_level: 3, // Max optimization
126        }
127    }
128
129    /// Set the thread pool size
130    pub fn with_thread_pool_size(mut self, size: usize) -> Self {
131        self.thread_pool_size = size;
132        self
133    }
134
135    /// Set the optimization level (0-3)
136    pub fn with_optimization_level(mut self, level: u8) -> Self {
137        self.optimization_level = level.min(3);
138        self
139    }
140}
141
142/// Statistics about loaded models
143#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
144pub struct LoaderStats {
145    /// Number of models currently loaded
146    pub total_loaded: usize,
147    /// Total number of load operations
148    pub total_loads: u64,
149    /// Total number of cache hits
150    pub cache_hits: u64,
151}
152
153/// Model loader with lazy loading and caching
154///
155/// ## Thread Safety
156///
157/// ModelLoader uses Arc + RwLock internally for thread-safe access.
158/// Multiple threads can safely load and access models concurrently.
159///
160/// ## Caching
161///
162/// Once a model is loaded, it stays in memory until explicitly unloaded.
163/// Subsequent calls to `load()` with the same model type/variant return
164/// the cached session.
165pub struct ModelLoader {
166    /// Model registry for metadata
167    registry: Arc<ModelRegistry>,
168    /// Loaded ONNX sessions cache
169    cache: Arc<RwLock<HashMap<(ModelType, ModelVariant), Arc<Mutex<Session>>>>>,
170    /// Statistics
171    stats: Arc<RwLock<LoaderStats>>,
172}
173
174impl ModelLoader {
175    /// Create a new model loader
176    ///
177    /// # Arguments
178    ///
179    /// * `registry` - Model registry for metadata and downloads
180    ///
181    /// # Example
182    ///
183    /// ```no_run
184    /// use llm_shield_models::{ModelLoader, ModelRegistry};
185    /// use std::sync::Arc;
186    ///
187    /// # fn example() -> Result<(), llm_shield_core::Error> {
188    /// let registry = ModelRegistry::from_file("models/registry.json")?;
189    /// let loader = ModelLoader::new(Arc::new(registry));
190    /// # Ok(())
191    /// # }
192    /// ```
193    pub fn new(registry: Arc<ModelRegistry>) -> Self {
194        Self {
195            registry,
196            cache: Arc::new(RwLock::new(HashMap::new())),
197            stats: Arc::new(RwLock::new(LoaderStats::default())),
198        }
199    }
200
201    /// Create a new model loader (alias for `new`)
202    pub fn with_registry(registry: Arc<ModelRegistry>) -> Self {
203        Self::new(registry)
204    }
205
206    /// Load a model (lazily, with caching)
207    ///
208    /// If the model is already loaded, returns the cached session.
209    /// Otherwise, loads the model from disk using the registry.
210    ///
211    /// # Arguments
212    ///
213    /// * `model_type` - Type of model to load
214    /// * `variant` - Model variant (precision)
215    ///
216    /// # Returns
217    ///
218    /// Arc to ONNX Runtime session
219    ///
220    /// # Example
221    ///
222    /// ```no_run
223    /// use llm_shield_models::{ModelLoader, ModelRegistry, ModelType, ModelVariant};
224    /// use std::sync::Arc;
225    ///
226    /// # async fn example() -> Result<(), llm_shield_core::Error> {
227    /// # let registry = ModelRegistry::new();
228    /// # let loader = ModelLoader::new(Arc::new(registry));
229    /// let session = loader.load(ModelType::PromptInjection, ModelVariant::FP16).await?;
230    /// # Ok(())
231    /// # }
232    /// ```
233    pub async fn load(
234        &self,
235        model_type: ModelType,
236        variant: ModelVariant,
237    ) -> Result<Arc<Mutex<Session>>> {
238        // Check cache first (read lock)
239        {
240            let cache = self.cache.read().unwrap();
241            if let Some(session) = cache.get(&(model_type, variant)) {
242                tracing::debug!(
243                    "Model cache hit: {:?}/{:?}",
244                    model_type,
245                    variant
246                );
247                let mut stats = self.stats.write().unwrap();
248                stats.cache_hits += 1;
249                return Ok(Arc::clone(session));
250            }
251        }
252
253        // Not in cache - load it (write lock)
254        tracing::info!(
255            "Loading model: {:?}/{:?}",
256            model_type,
257            variant
258        );
259
260        // Convert to task and get metadata
261        let task = ModelTask::from(model_type);
262        let metadata = self.registry.get_model_metadata(task, variant)?;
263
264        // Ensure model is downloaded
265        let model_path = self.registry.ensure_model_available(task, variant).await?;
266
267        tracing::debug!("Model path: {:?}", model_path);
268
269        // Create ONNX session
270        let session = Self::create_session(&model_path, num_cpus::get().max(1), 3)?;
271
272        // Cache the session (wrapped in Mutex for ORT 2.0 API)
273        {
274            let mut cache = self.cache.write().unwrap();
275            let session_arc = Arc::new(Mutex::new(session));
276            cache.insert((model_type, variant), Arc::clone(&session_arc));
277
278            // Update stats
279            let mut stats = self.stats.write().unwrap();
280            stats.total_loaded = cache.len();
281            stats.total_loads += 1;
282
283            tracing::info!(
284                "Model loaded successfully: {} ({:?}/{:?})",
285                metadata.id,
286                model_type,
287                variant
288            );
289
290            Ok(session_arc)
291        }
292    }
293
294    /// Load a model with custom configuration
295    ///
296    /// # Arguments
297    ///
298    /// * `config` - Model configuration
299    ///
300    /// # Returns
301    ///
302    /// Arc to ONNX Runtime session
303    pub async fn load_with_config(&self, config: ModelConfig) -> Result<Arc<Mutex<Session>>> {
304        let model_type = config.model_type;
305        let variant = config.variant;
306
307        // Check cache first
308        {
309            let cache = self.cache.read().unwrap();
310            if let Some(session) = cache.get(&(model_type, variant)) {
311                let mut stats = self.stats.write().unwrap();
312                stats.cache_hits += 1;
313                return Ok(Arc::clone(session));
314            }
315        }
316
317        // Get model path from registry
318        let task = ModelTask::from(model_type);
319        let model_path = self.registry.ensure_model_available(task, variant).await?;
320
321        // Create session with custom config
322        let session = Self::create_session(
323            &model_path,
324            config.thread_pool_size,
325            config.optimization_level,
326        )?;
327
328        // Cache it (wrapped in Mutex for ORT 2.0 API)
329        let mut cache = self.cache.write().unwrap();
330        let session_arc = Arc::new(Mutex::new(session));
331        cache.insert((model_type, variant), Arc::clone(&session_arc));
332
333        let mut stats = self.stats.write().unwrap();
334        stats.total_loaded = cache.len();
335        stats.total_loads += 1;
336
337        Ok(session_arc)
338    }
339
340    /// Preload multiple models
341    ///
342    /// Useful for warming up the cache before first use.
343    ///
344    /// # Arguments
345    ///
346    /// * `models` - List of (ModelType, ModelVariant) tuples to preload
347    ///
348    /// # Example
349    ///
350    /// ```no_run
351    /// use llm_shield_models::{ModelLoader, ModelRegistry, ModelType, ModelVariant};
352    /// use std::sync::Arc;
353    ///
354    /// # async fn example() -> Result<(), llm_shield_core::Error> {
355    /// # let registry = ModelRegistry::new();
356    /// # let loader = ModelLoader::new(Arc::new(registry));
357    /// let models = vec![
358    ///     (ModelType::PromptInjection, ModelVariant::FP16),
359    ///     (ModelType::Toxicity, ModelVariant::FP16),
360    /// ];
361    /// loader.preload(models).await?;
362    /// # Ok(())
363    /// # }
364    /// ```
365    pub async fn preload(&self, models: Vec<(ModelType, ModelVariant)>) -> Result<()> {
366        tracing::info!("Preloading {} models", models.len());
367
368        for (model_type, variant) in models {
369            match self.load(model_type, variant).await {
370                Ok(_) => {
371                    tracing::debug!("Preloaded: {:?}/{:?}", model_type, variant);
372                }
373                Err(e) => {
374                    tracing::warn!(
375                        "Failed to preload {:?}/{:?}: {}",
376                        model_type,
377                        variant,
378                        e
379                    );
380                    return Err(e);
381                }
382            }
383        }
384
385        Ok(())
386    }
387
388    /// Check if a model is loaded
389    pub fn is_loaded(&self, model_type: ModelType, variant: ModelVariant) -> bool {
390        let cache = self.cache.read().unwrap();
391        cache.contains_key(&(model_type, variant))
392    }
393
394    /// Unload a specific model
395    ///
396    /// Removes the model from cache, freeing memory.
397    pub fn unload(&self, model_type: ModelType, variant: ModelVariant) {
398        let mut cache = self.cache.write().unwrap();
399        if cache.remove(&(model_type, variant)).is_some() {
400            tracing::info!("Unloaded model: {:?}/{:?}", model_type, variant);
401            let mut stats = self.stats.write().unwrap();
402            stats.total_loaded = cache.len();
403        }
404    }
405
406    /// Unload all models
407    ///
408    /// Clears the entire cache, freeing all memory.
409    pub fn unload_all(&self) {
410        let mut cache = self.cache.write().unwrap();
411        let count = cache.len();
412        cache.clear();
413        tracing::info!("Unloaded all {} models", count);
414        let mut stats = self.stats.write().unwrap();
415        stats.total_loaded = 0;
416    }
417
418    /// Get the number of loaded models
419    pub fn len(&self) -> usize {
420        self.cache.read().unwrap().len()
421    }
422
423    /// Check if no models are loaded
424    pub fn is_empty(&self) -> bool {
425        self.len() == 0
426    }
427
428    /// Get list of loaded models
429    pub fn loaded_models(&self) -> Vec<(ModelType, ModelVariant)> {
430        let cache = self.cache.read().unwrap();
431        cache.keys().copied().collect()
432    }
433
434    /// Get information about a loaded model
435    ///
436    /// Returns None if model is not loaded.
437    pub fn model_info(&self, model_type: ModelType, variant: ModelVariant) -> Option<String> {
438        let cache = self.cache.read().unwrap();
439        if cache.contains_key(&(model_type, variant)) {
440            Some(format!(
441                "Model: {:?}, Variant: {:?}, Status: loaded",
442                model_type, variant
443            ))
444        } else {
445            None
446        }
447    }
448
449    /// Get loader statistics
450    pub fn stats(&self) -> LoaderStats {
451        self.stats.read().unwrap().clone()
452    }
453
454    /// Create an ONNX Runtime session
455    fn create_session(
456        model_path: &PathBuf,
457        _thread_pool_size: usize,
458        _optimization_level: u8,
459    ) -> Result<Session> {
460        // Create session with default settings
461        // Note: with_optimization_level and with_intra_threads APIs vary by ort version
462        let session = Session::builder()
463            .map_err(|e| Error::model(format!("Failed to create session builder: {}", e)))?
464            .commit_from_file(model_path)
465            .map_err(|e| {
466                Error::model(format!(
467                    "Failed to load model from '{}': {}",
468                    model_path.display(),
469                    e
470                ))
471            })?;
472
473        Ok(session)
474    }
475}
476
477impl Clone for ModelLoader {
478    /// Clone creates a new reference to the same underlying cache
479    ///
480    /// All clones share the same loaded models and statistics.
481    fn clone(&self) -> Self {
482        Self {
483            registry: Arc::clone(&self.registry),
484            cache: Arc::clone(&self.cache),
485            stats: Arc::clone(&self.stats),
486        }
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use crate::registry::ModelTask;
494
495    #[test]
496    fn test_model_type_conversions() {
497        // ModelTask -> ModelType
498        assert!(matches!(
499            ModelType::from(ModelTask::PromptInjection),
500            ModelType::PromptInjection
501        ));
502        assert!(matches!(
503            ModelType::from(ModelTask::Toxicity),
504            ModelType::Toxicity
505        ));
506        assert!(matches!(
507            ModelType::from(ModelTask::Sentiment),
508            ModelType::Sentiment
509        ));
510        assert!(matches!(
511            ModelType::from(ModelTask::NamedEntityRecognition),
512            ModelType::NamedEntityRecognition
513        ));
514
515        // ModelType -> ModelTask
516        assert!(matches!(
517            ModelTask::from(ModelType::PromptInjection),
518            ModelTask::PromptInjection
519        ));
520        assert!(matches!(
521            ModelTask::from(ModelType::Toxicity),
522            ModelTask::Toxicity
523        ));
524        assert!(matches!(
525            ModelTask::from(ModelType::Sentiment),
526            ModelTask::Sentiment
527        ));
528        assert!(matches!(
529            ModelTask::from(ModelType::NamedEntityRecognition),
530            ModelTask::NamedEntityRecognition
531        ));
532    }
533
534    #[test]
535    fn test_model_config_defaults() {
536        let config = ModelConfig::new(
537            ModelType::PromptInjection,
538            ModelVariant::FP16,
539            PathBuf::from("/test/model.onnx"),
540        );
541
542        assert!(config.thread_pool_size > 0);
543        assert_eq!(config.optimization_level, 3);
544    }
545
546    #[test]
547    fn test_model_config_builder_pattern() {
548        let config = ModelConfig::new(
549            ModelType::Toxicity,
550            ModelVariant::INT8,
551            PathBuf::from("/test/model.onnx"),
552        )
553        .with_thread_pool_size(4)
554        .with_optimization_level(2);
555
556        assert_eq!(config.thread_pool_size, 4);
557        assert_eq!(config.optimization_level, 2);
558    }
559
560    #[test]
561    fn test_loader_stats_default() {
562        let stats = LoaderStats::default();
563        assert_eq!(stats.total_loaded, 0);
564        assert_eq!(stats.total_loads, 0);
565        assert_eq!(stats.cache_hits, 0);
566    }
567}