use crate::registry::{ModelRegistry, ModelTask, ModelVariant};
use llm_shield_core::Error;
use ort::session::Session;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, RwLock, Mutex};
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ModelType {
PromptInjection,
Toxicity,
Sentiment,
NamedEntityRecognition,
}
impl From<ModelTask> for ModelType {
fn from(task: ModelTask) -> Self {
match task {
ModelTask::PromptInjection => ModelType::PromptInjection,
ModelTask::Toxicity => ModelType::Toxicity,
ModelTask::Sentiment => ModelType::Sentiment,
ModelTask::NamedEntityRecognition => ModelType::NamedEntityRecognition,
}
}
}
impl From<ModelType> for ModelTask {
fn from(model_type: ModelType) -> Self {
match model_type {
ModelType::PromptInjection => ModelTask::PromptInjection,
ModelType::Toxicity => ModelTask::Toxicity,
ModelType::Sentiment => ModelTask::Sentiment,
ModelType::NamedEntityRecognition => ModelTask::NamedEntityRecognition,
}
}
}
#[derive(Debug, Clone)]
pub struct ModelConfig {
pub model_type: ModelType,
pub variant: ModelVariant,
pub model_path: PathBuf,
pub thread_pool_size: usize,
pub optimization_level: u8,
}
impl ModelConfig {
pub fn new(model_type: ModelType, variant: ModelVariant, model_path: PathBuf) -> Self {
Self {
model_type,
variant,
model_path,
thread_pool_size: num_cpus::get().max(1),
optimization_level: 3, }
}
pub fn with_thread_pool_size(mut self, size: usize) -> Self {
self.thread_pool_size = size;
self
}
pub fn with_optimization_level(mut self, level: u8) -> Self {
self.optimization_level = level.min(3);
self
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct LoaderStats {
pub total_loaded: usize,
pub total_loads: u64,
pub cache_hits: u64,
}
pub struct ModelLoader {
registry: Arc<ModelRegistry>,
cache: Arc<RwLock<HashMap<(ModelType, ModelVariant), Arc<Mutex<Session>>>>>,
stats: Arc<RwLock<LoaderStats>>,
}
impl ModelLoader {
pub fn new(registry: Arc<ModelRegistry>) -> Self {
Self {
registry,
cache: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(LoaderStats::default())),
}
}
pub fn with_registry(registry: Arc<ModelRegistry>) -> Self {
Self::new(registry)
}
pub async fn load(
&self,
model_type: ModelType,
variant: ModelVariant,
) -> Result<Arc<Mutex<Session>>> {
{
let cache = self.cache.read().unwrap();
if let Some(session) = cache.get(&(model_type, variant)) {
tracing::debug!(
"Model cache hit: {:?}/{:?}",
model_type,
variant
);
let mut stats = self.stats.write().unwrap();
stats.cache_hits += 1;
return Ok(Arc::clone(session));
}
}
tracing::info!(
"Loading model: {:?}/{:?}",
model_type,
variant
);
let task = ModelTask::from(model_type);
let metadata = self.registry.get_model_metadata(task, variant)?;
let model_path = self.registry.ensure_model_available(task, variant).await?;
tracing::debug!("Model path: {:?}", model_path);
let session = Self::create_session(&model_path, num_cpus::get().max(1), 3)?;
{
let mut cache = self.cache.write().unwrap();
let session_arc = Arc::new(Mutex::new(session));
cache.insert((model_type, variant), Arc::clone(&session_arc));
let mut stats = self.stats.write().unwrap();
stats.total_loaded = cache.len();
stats.total_loads += 1;
tracing::info!(
"Model loaded successfully: {} ({:?}/{:?})",
metadata.id,
model_type,
variant
);
Ok(session_arc)
}
}
pub async fn load_with_config(&self, config: ModelConfig) -> Result<Arc<Mutex<Session>>> {
let model_type = config.model_type;
let variant = config.variant;
{
let cache = self.cache.read().unwrap();
if let Some(session) = cache.get(&(model_type, variant)) {
let mut stats = self.stats.write().unwrap();
stats.cache_hits += 1;
return Ok(Arc::clone(session));
}
}
let task = ModelTask::from(model_type);
let model_path = self.registry.ensure_model_available(task, variant).await?;
let session = Self::create_session(
&model_path,
config.thread_pool_size,
config.optimization_level,
)?;
let mut cache = self.cache.write().unwrap();
let session_arc = Arc::new(Mutex::new(session));
cache.insert((model_type, variant), Arc::clone(&session_arc));
let mut stats = self.stats.write().unwrap();
stats.total_loaded = cache.len();
stats.total_loads += 1;
Ok(session_arc)
}
pub async fn preload(&self, models: Vec<(ModelType, ModelVariant)>) -> Result<()> {
tracing::info!("Preloading {} models", models.len());
for (model_type, variant) in models {
match self.load(model_type, variant).await {
Ok(_) => {
tracing::debug!("Preloaded: {:?}/{:?}", model_type, variant);
}
Err(e) => {
tracing::warn!(
"Failed to preload {:?}/{:?}: {}",
model_type,
variant,
e
);
return Err(e);
}
}
}
Ok(())
}
pub fn is_loaded(&self, model_type: ModelType, variant: ModelVariant) -> bool {
let cache = self.cache.read().unwrap();
cache.contains_key(&(model_type, variant))
}
pub fn unload(&self, model_type: ModelType, variant: ModelVariant) {
let mut cache = self.cache.write().unwrap();
if cache.remove(&(model_type, variant)).is_some() {
tracing::info!("Unloaded model: {:?}/{:?}", model_type, variant);
let mut stats = self.stats.write().unwrap();
stats.total_loaded = cache.len();
}
}
pub fn unload_all(&self) {
let mut cache = self.cache.write().unwrap();
let count = cache.len();
cache.clear();
tracing::info!("Unloaded all {} models", count);
let mut stats = self.stats.write().unwrap();
stats.total_loaded = 0;
}
pub fn len(&self) -> usize {
self.cache.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn loaded_models(&self) -> Vec<(ModelType, ModelVariant)> {
let cache = self.cache.read().unwrap();
cache.keys().copied().collect()
}
pub fn model_info(&self, model_type: ModelType, variant: ModelVariant) -> Option<String> {
let cache = self.cache.read().unwrap();
if cache.contains_key(&(model_type, variant)) {
Some(format!(
"Model: {:?}, Variant: {:?}, Status: loaded",
model_type, variant
))
} else {
None
}
}
pub fn stats(&self) -> LoaderStats {
self.stats.read().unwrap().clone()
}
fn create_session(
model_path: &PathBuf,
_thread_pool_size: usize,
_optimization_level: u8,
) -> Result<Session> {
let session = Session::builder()
.map_err(|e| Error::model(format!("Failed to create session builder: {}", e)))?
.commit_from_file(model_path)
.map_err(|e| {
Error::model(format!(
"Failed to load model from '{}': {}",
model_path.display(),
e
))
})?;
Ok(session)
}
}
impl Clone for ModelLoader {
fn clone(&self) -> Self {
Self {
registry: Arc::clone(&self.registry),
cache: Arc::clone(&self.cache),
stats: Arc::clone(&self.stats),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::registry::ModelTask;
#[test]
fn test_model_type_conversions() {
assert!(matches!(
ModelType::from(ModelTask::PromptInjection),
ModelType::PromptInjection
));
assert!(matches!(
ModelType::from(ModelTask::Toxicity),
ModelType::Toxicity
));
assert!(matches!(
ModelType::from(ModelTask::Sentiment),
ModelType::Sentiment
));
assert!(matches!(
ModelType::from(ModelTask::NamedEntityRecognition),
ModelType::NamedEntityRecognition
));
assert!(matches!(
ModelTask::from(ModelType::PromptInjection),
ModelTask::PromptInjection
));
assert!(matches!(
ModelTask::from(ModelType::Toxicity),
ModelTask::Toxicity
));
assert!(matches!(
ModelTask::from(ModelType::Sentiment),
ModelTask::Sentiment
));
assert!(matches!(
ModelTask::from(ModelType::NamedEntityRecognition),
ModelTask::NamedEntityRecognition
));
}
#[test]
fn test_model_config_defaults() {
let config = ModelConfig::new(
ModelType::PromptInjection,
ModelVariant::FP16,
PathBuf::from("/test/model.onnx"),
);
assert!(config.thread_pool_size > 0);
assert_eq!(config.optimization_level, 3);
}
#[test]
fn test_model_config_builder_pattern() {
let config = ModelConfig::new(
ModelType::Toxicity,
ModelVariant::INT8,
PathBuf::from("/test/model.onnx"),
)
.with_thread_pool_size(4)
.with_optimization_level(2);
assert_eq!(config.thread_pool_size, 4);
assert_eq!(config.optimization_level, 2);
}
#[test]
fn test_loader_stats_default() {
let stats = LoaderStats::default();
assert_eq!(stats.total_loaded, 0);
assert_eq!(stats.total_loads, 0);
assert_eq!(stats.cache_hits, 0);
}
}