llm_shield_models/model_loader.rs
1//! Model Loader with ONNX Runtime Integration
2//!
3//! ## SPARC Phase 3: Implementation
4//!
5//! This module provides lazy loading, caching, and thread-safe access to ONNX models.
6//!
7//! ## Features
8//!
9//! - **Lazy Loading**: Models are only loaded when first requested
10//! - **Caching**: Loaded models are cached for reuse
11//! - **Thread-Safe**: Uses Arc + RwLock for concurrent access
12//! - **Registry Integration**: Uses ModelRegistry for model discovery
13//! - **Graceful Error Handling**: Comprehensive error messages
14//!
15//! ## Usage Example
16//!
17//! ```no_run
18//! use llm_shield_models::{ModelLoader, ModelRegistry, ModelType, ModelVariant};
19//! use std::sync::Arc;
20//!
21//! # async fn example() -> Result<(), llm_shield_core::Error> {
22//! // Create registry
23//! let registry = ModelRegistry::from_file("models/registry.json")?;
24//!
25//! // Create loader
26//! let loader = ModelLoader::new(Arc::new(registry));
27//!
28//! // Load model (lazy - only loads once)
29//! let session = loader.load(ModelType::PromptInjection, ModelVariant::FP16).await?;
30//!
31//! // Use session for inference...
32//! # Ok(())
33//! # }
34//! ```
35
36use crate::registry::{ModelRegistry, ModelTask, ModelVariant};
37use llm_shield_core::Error;
38use ort::session::Session;
39use std::collections::HashMap;
40use std::path::PathBuf;
41use std::sync::{Arc, RwLock, Mutex};
42
43/// Result type alias
44pub type Result<T> = std::result::Result<T, Error>;
45
46/// Model type identifier
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48pub enum ModelType {
49 /// Prompt injection detection
50 PromptInjection,
51 /// Toxicity classification
52 Toxicity,
53 /// Sentiment analysis
54 Sentiment,
55 /// Named Entity Recognition (PII detection)
56 NamedEntityRecognition,
57}
58
59/// Conversion from ModelTask to ModelType
60impl From<ModelTask> for ModelType {
61 fn from(task: ModelTask) -> Self {
62 match task {
63 ModelTask::PromptInjection => ModelType::PromptInjection,
64 ModelTask::Toxicity => ModelType::Toxicity,
65 ModelTask::Sentiment => ModelType::Sentiment,
66 ModelTask::NamedEntityRecognition => ModelType::NamedEntityRecognition,
67 }
68 }
69}
70
71/// Conversion from ModelType to ModelTask
72impl From<ModelType> for ModelTask {
73 fn from(model_type: ModelType) -> Self {
74 match model_type {
75 ModelType::PromptInjection => ModelTask::PromptInjection,
76 ModelType::Toxicity => ModelTask::Toxicity,
77 ModelType::Sentiment => ModelTask::Sentiment,
78 ModelType::NamedEntityRecognition => ModelTask::NamedEntityRecognition,
79 }
80 }
81}
82
83/// Configuration for loading a model
84#[derive(Debug, Clone)]
85pub struct ModelConfig {
86 /// Type of model
87 pub model_type: ModelType,
88 /// Model variant (precision)
89 pub variant: ModelVariant,
90 /// Path to ONNX model file
91 pub model_path: PathBuf,
92 /// Number of threads for inference
93 pub thread_pool_size: usize,
94 /// ONNX graph optimization level (0-3)
95 pub optimization_level: u8,
96}
97
98impl ModelConfig {
99 /// Create a new model configuration
100 ///
101 /// # Arguments
102 ///
103 /// * `model_type` - Type of model (PromptInjection, Toxicity, Sentiment)
104 /// * `variant` - Model variant (FP32, FP16, INT8)
105 /// * `model_path` - Path to ONNX model file
106 ///
107 /// # Example
108 ///
109 /// ```
110 /// use llm_shield_models::{ModelConfig, ModelType, ModelVariant};
111 /// use std::path::PathBuf;
112 ///
113 /// let config = ModelConfig::new(
114 /// ModelType::PromptInjection,
115 /// ModelVariant::FP16,
116 /// PathBuf::from("/path/to/model.onnx")
117 /// );
118 /// ```
119 pub fn new(model_type: ModelType, variant: ModelVariant, model_path: PathBuf) -> Self {
120 Self {
121 model_type,
122 variant,
123 model_path,
124 thread_pool_size: num_cpus::get().max(1),
125 optimization_level: 3, // Max optimization
126 }
127 }
128
129 /// Set the thread pool size
130 pub fn with_thread_pool_size(mut self, size: usize) -> Self {
131 self.thread_pool_size = size;
132 self
133 }
134
135 /// Set the optimization level (0-3)
136 pub fn with_optimization_level(mut self, level: u8) -> Self {
137 self.optimization_level = level.min(3);
138 self
139 }
140}
141
142/// Statistics about loaded models
143#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
144pub struct LoaderStats {
145 /// Number of models currently loaded
146 pub total_loaded: usize,
147 /// Total number of load operations
148 pub total_loads: u64,
149 /// Total number of cache hits
150 pub cache_hits: u64,
151}
152
153/// Model loader with lazy loading and caching
154///
155/// ## Thread Safety
156///
157/// ModelLoader uses Arc + RwLock internally for thread-safe access.
158/// Multiple threads can safely load and access models concurrently.
159///
160/// ## Caching
161///
162/// Once a model is loaded, it stays in memory until explicitly unloaded.
163/// Subsequent calls to `load()` with the same model type/variant return
164/// the cached session.
165pub struct ModelLoader {
166 /// Model registry for metadata
167 registry: Arc<ModelRegistry>,
168 /// Loaded ONNX sessions cache
169 cache: Arc<RwLock<HashMap<(ModelType, ModelVariant), Arc<Mutex<Session>>>>>,
170 /// Statistics
171 stats: Arc<RwLock<LoaderStats>>,
172}
173
174impl ModelLoader {
175 /// Create a new model loader
176 ///
177 /// # Arguments
178 ///
179 /// * `registry` - Model registry for metadata and downloads
180 ///
181 /// # Example
182 ///
183 /// ```no_run
184 /// use llm_shield_models::{ModelLoader, ModelRegistry};
185 /// use std::sync::Arc;
186 ///
187 /// # fn example() -> Result<(), llm_shield_core::Error> {
188 /// let registry = ModelRegistry::from_file("models/registry.json")?;
189 /// let loader = ModelLoader::new(Arc::new(registry));
190 /// # Ok(())
191 /// # }
192 /// ```
193 pub fn new(registry: Arc<ModelRegistry>) -> Self {
194 Self {
195 registry,
196 cache: Arc::new(RwLock::new(HashMap::new())),
197 stats: Arc::new(RwLock::new(LoaderStats::default())),
198 }
199 }
200
201 /// Create a new model loader (alias for `new`)
202 pub fn with_registry(registry: Arc<ModelRegistry>) -> Self {
203 Self::new(registry)
204 }
205
206 /// Load a model (lazily, with caching)
207 ///
208 /// If the model is already loaded, returns the cached session.
209 /// Otherwise, loads the model from disk using the registry.
210 ///
211 /// # Arguments
212 ///
213 /// * `model_type` - Type of model to load
214 /// * `variant` - Model variant (precision)
215 ///
216 /// # Returns
217 ///
218 /// Arc to ONNX Runtime session
219 ///
220 /// # Example
221 ///
222 /// ```no_run
223 /// use llm_shield_models::{ModelLoader, ModelRegistry, ModelType, ModelVariant};
224 /// use std::sync::Arc;
225 ///
226 /// # async fn example() -> Result<(), llm_shield_core::Error> {
227 /// # let registry = ModelRegistry::new();
228 /// # let loader = ModelLoader::new(Arc::new(registry));
229 /// let session = loader.load(ModelType::PromptInjection, ModelVariant::FP16).await?;
230 /// # Ok(())
231 /// # }
232 /// ```
233 pub async fn load(
234 &self,
235 model_type: ModelType,
236 variant: ModelVariant,
237 ) -> Result<Arc<Mutex<Session>>> {
238 // Check cache first (read lock)
239 {
240 let cache = self.cache.read().unwrap();
241 if let Some(session) = cache.get(&(model_type, variant)) {
242 tracing::debug!(
243 "Model cache hit: {:?}/{:?}",
244 model_type,
245 variant
246 );
247 let mut stats = self.stats.write().unwrap();
248 stats.cache_hits += 1;
249 return Ok(Arc::clone(session));
250 }
251 }
252
253 // Not in cache - load it (write lock)
254 tracing::info!(
255 "Loading model: {:?}/{:?}",
256 model_type,
257 variant
258 );
259
260 // Convert to task and get metadata
261 let task = ModelTask::from(model_type);
262 let metadata = self.registry.get_model_metadata(task, variant)?;
263
264 // Ensure model is downloaded
265 let model_path = self.registry.ensure_model_available(task, variant).await?;
266
267 tracing::debug!("Model path: {:?}", model_path);
268
269 // Create ONNX session
270 let session = Self::create_session(&model_path, num_cpus::get().max(1), 3)?;
271
272 // Cache the session (wrapped in Mutex for ORT 2.0 API)
273 {
274 let mut cache = self.cache.write().unwrap();
275 let session_arc = Arc::new(Mutex::new(session));
276 cache.insert((model_type, variant), Arc::clone(&session_arc));
277
278 // Update stats
279 let mut stats = self.stats.write().unwrap();
280 stats.total_loaded = cache.len();
281 stats.total_loads += 1;
282
283 tracing::info!(
284 "Model loaded successfully: {} ({:?}/{:?})",
285 metadata.id,
286 model_type,
287 variant
288 );
289
290 Ok(session_arc)
291 }
292 }
293
294 /// Load a model with custom configuration
295 ///
296 /// # Arguments
297 ///
298 /// * `config` - Model configuration
299 ///
300 /// # Returns
301 ///
302 /// Arc to ONNX Runtime session
303 pub async fn load_with_config(&self, config: ModelConfig) -> Result<Arc<Mutex<Session>>> {
304 let model_type = config.model_type;
305 let variant = config.variant;
306
307 // Check cache first
308 {
309 let cache = self.cache.read().unwrap();
310 if let Some(session) = cache.get(&(model_type, variant)) {
311 let mut stats = self.stats.write().unwrap();
312 stats.cache_hits += 1;
313 return Ok(Arc::clone(session));
314 }
315 }
316
317 // Get model path from registry
318 let task = ModelTask::from(model_type);
319 let model_path = self.registry.ensure_model_available(task, variant).await?;
320
321 // Create session with custom config
322 let session = Self::create_session(
323 &model_path,
324 config.thread_pool_size,
325 config.optimization_level,
326 )?;
327
328 // Cache it (wrapped in Mutex for ORT 2.0 API)
329 let mut cache = self.cache.write().unwrap();
330 let session_arc = Arc::new(Mutex::new(session));
331 cache.insert((model_type, variant), Arc::clone(&session_arc));
332
333 let mut stats = self.stats.write().unwrap();
334 stats.total_loaded = cache.len();
335 stats.total_loads += 1;
336
337 Ok(session_arc)
338 }
339
340 /// Preload multiple models
341 ///
342 /// Useful for warming up the cache before first use.
343 ///
344 /// # Arguments
345 ///
346 /// * `models` - List of (ModelType, ModelVariant) tuples to preload
347 ///
348 /// # Example
349 ///
350 /// ```no_run
351 /// use llm_shield_models::{ModelLoader, ModelRegistry, ModelType, ModelVariant};
352 /// use std::sync::Arc;
353 ///
354 /// # async fn example() -> Result<(), llm_shield_core::Error> {
355 /// # let registry = ModelRegistry::new();
356 /// # let loader = ModelLoader::new(Arc::new(registry));
357 /// let models = vec![
358 /// (ModelType::PromptInjection, ModelVariant::FP16),
359 /// (ModelType::Toxicity, ModelVariant::FP16),
360 /// ];
361 /// loader.preload(models).await?;
362 /// # Ok(())
363 /// # }
364 /// ```
365 pub async fn preload(&self, models: Vec<(ModelType, ModelVariant)>) -> Result<()> {
366 tracing::info!("Preloading {} models", models.len());
367
368 for (model_type, variant) in models {
369 match self.load(model_type, variant).await {
370 Ok(_) => {
371 tracing::debug!("Preloaded: {:?}/{:?}", model_type, variant);
372 }
373 Err(e) => {
374 tracing::warn!(
375 "Failed to preload {:?}/{:?}: {}",
376 model_type,
377 variant,
378 e
379 );
380 return Err(e);
381 }
382 }
383 }
384
385 Ok(())
386 }
387
388 /// Check if a model is loaded
389 pub fn is_loaded(&self, model_type: ModelType, variant: ModelVariant) -> bool {
390 let cache = self.cache.read().unwrap();
391 cache.contains_key(&(model_type, variant))
392 }
393
394 /// Unload a specific model
395 ///
396 /// Removes the model from cache, freeing memory.
397 pub fn unload(&self, model_type: ModelType, variant: ModelVariant) {
398 let mut cache = self.cache.write().unwrap();
399 if cache.remove(&(model_type, variant)).is_some() {
400 tracing::info!("Unloaded model: {:?}/{:?}", model_type, variant);
401 let mut stats = self.stats.write().unwrap();
402 stats.total_loaded = cache.len();
403 }
404 }
405
406 /// Unload all models
407 ///
408 /// Clears the entire cache, freeing all memory.
409 pub fn unload_all(&self) {
410 let mut cache = self.cache.write().unwrap();
411 let count = cache.len();
412 cache.clear();
413 tracing::info!("Unloaded all {} models", count);
414 let mut stats = self.stats.write().unwrap();
415 stats.total_loaded = 0;
416 }
417
418 /// Get the number of loaded models
419 pub fn len(&self) -> usize {
420 self.cache.read().unwrap().len()
421 }
422
423 /// Check if no models are loaded
424 pub fn is_empty(&self) -> bool {
425 self.len() == 0
426 }
427
428 /// Get list of loaded models
429 pub fn loaded_models(&self) -> Vec<(ModelType, ModelVariant)> {
430 let cache = self.cache.read().unwrap();
431 cache.keys().copied().collect()
432 }
433
434 /// Get information about a loaded model
435 ///
436 /// Returns None if model is not loaded.
437 pub fn model_info(&self, model_type: ModelType, variant: ModelVariant) -> Option<String> {
438 let cache = self.cache.read().unwrap();
439 if cache.contains_key(&(model_type, variant)) {
440 Some(format!(
441 "Model: {:?}, Variant: {:?}, Status: loaded",
442 model_type, variant
443 ))
444 } else {
445 None
446 }
447 }
448
449 /// Get loader statistics
450 pub fn stats(&self) -> LoaderStats {
451 self.stats.read().unwrap().clone()
452 }
453
454 /// Create an ONNX Runtime session
455 fn create_session(
456 model_path: &PathBuf,
457 _thread_pool_size: usize,
458 _optimization_level: u8,
459 ) -> Result<Session> {
460 // Create session with default settings
461 // Note: with_optimization_level and with_intra_threads APIs vary by ort version
462 let session = Session::builder()
463 .map_err(|e| Error::model(format!("Failed to create session builder: {}", e)))?
464 .commit_from_file(model_path)
465 .map_err(|e| {
466 Error::model(format!(
467 "Failed to load model from '{}': {}",
468 model_path.display(),
469 e
470 ))
471 })?;
472
473 Ok(session)
474 }
475}
476
477impl Clone for ModelLoader {
478 /// Clone creates a new reference to the same underlying cache
479 ///
480 /// All clones share the same loaded models and statistics.
481 fn clone(&self) -> Self {
482 Self {
483 registry: Arc::clone(&self.registry),
484 cache: Arc::clone(&self.cache),
485 stats: Arc::clone(&self.stats),
486 }
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use crate::registry::ModelTask;
494
495 #[test]
496 fn test_model_type_conversions() {
497 // ModelTask -> ModelType
498 assert!(matches!(
499 ModelType::from(ModelTask::PromptInjection),
500 ModelType::PromptInjection
501 ));
502 assert!(matches!(
503 ModelType::from(ModelTask::Toxicity),
504 ModelType::Toxicity
505 ));
506 assert!(matches!(
507 ModelType::from(ModelTask::Sentiment),
508 ModelType::Sentiment
509 ));
510 assert!(matches!(
511 ModelType::from(ModelTask::NamedEntityRecognition),
512 ModelType::NamedEntityRecognition
513 ));
514
515 // ModelType -> ModelTask
516 assert!(matches!(
517 ModelTask::from(ModelType::PromptInjection),
518 ModelTask::PromptInjection
519 ));
520 assert!(matches!(
521 ModelTask::from(ModelType::Toxicity),
522 ModelTask::Toxicity
523 ));
524 assert!(matches!(
525 ModelTask::from(ModelType::Sentiment),
526 ModelTask::Sentiment
527 ));
528 assert!(matches!(
529 ModelTask::from(ModelType::NamedEntityRecognition),
530 ModelTask::NamedEntityRecognition
531 ));
532 }
533
534 #[test]
535 fn test_model_config_defaults() {
536 let config = ModelConfig::new(
537 ModelType::PromptInjection,
538 ModelVariant::FP16,
539 PathBuf::from("/test/model.onnx"),
540 );
541
542 assert!(config.thread_pool_size > 0);
543 assert_eq!(config.optimization_level, 3);
544 }
545
546 #[test]
547 fn test_model_config_builder_pattern() {
548 let config = ModelConfig::new(
549 ModelType::Toxicity,
550 ModelVariant::INT8,
551 PathBuf::from("/test/model.onnx"),
552 )
553 .with_thread_pool_size(4)
554 .with_optimization_level(2);
555
556 assert_eq!(config.thread_pool_size, 4);
557 assert_eq!(config.optimization_level, 2);
558 }
559
560 #[test]
561 fn test_loader_stats_default() {
562 let stats = LoaderStats::default();
563 assert_eq!(stats.total_loaded, 0);
564 assert_eq!(stats.total_loads, 0);
565 assert_eq!(stats.cache_hits, 0);
566 }
567}