use llm_shield_core::Error;
use ndarray::Array2;
use ort::session::Session;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PostProcessing {
Softmax,
Sigmoid,
}
#[derive(Debug, Clone, PartialEq)]
pub struct TokenPrediction {
pub token_id: u32,
pub predicted_label: String,
pub predicted_class: usize,
pub confidence: f32,
pub all_scores: Vec<f32>,
}
impl TokenPrediction {
pub fn new(
token_id: u32,
predicted_label: String,
predicted_class: usize,
confidence: f32,
all_scores: Vec<f32>,
) -> Self {
Self {
token_id,
predicted_label,
predicted_class,
confidence,
all_scores,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.confidence < 0.0 || self.confidence > 1.0 {
return Err(format!("Invalid confidence: {}", self.confidence));
}
if self.predicted_class >= self.all_scores.len() {
return Err(format!(
"Invalid predicted_class {} for {} scores",
self.predicted_class,
self.all_scores.len()
));
}
let expected_confidence = self.all_scores[self.predicted_class];
if (self.confidence - expected_confidence).abs() > 0.001 {
return Err(format!(
"Confidence mismatch: {} != {}",
self.confidence, expected_confidence
));
}
for (i, &score) in self.all_scores.iter().enumerate() {
if score < 0.0 || score > 1.0 {
return Err(format!("Invalid score at index {}: {}", i, score));
}
}
let sum: f32 = self.all_scores.iter().sum();
if (sum - 1.0).abs() > 0.01 {
return Err(format!("Scores don't sum to 1.0: {}", sum));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct InferenceResult {
pub labels: Vec<String>,
pub scores: Vec<f32>,
pub predicted_class: usize,
pub max_score: f32,
}
impl InferenceResult {
pub fn predicted_label(&self) -> Option<&str> {
self.labels.get(self.predicted_class).map(|s| s.as_str())
}
pub fn exceeds_threshold(&self, threshold: f32) -> bool {
self.max_score >= threshold
}
pub fn get_score_for_label(&self, label: &str) -> Option<f32> {
self.labels
.iter()
.position(|l| l == label)
.and_then(|idx| self.scores.get(idx).copied())
}
pub fn is_binary(&self) -> bool {
self.labels.len() == 2
}
pub fn get_threshold_violations(&self, thresholds: &[f32]) -> Vec<usize> {
if thresholds.len() != self.scores.len() {
tracing::warn!(
"Threshold count mismatch: {} thresholds for {} classes",
thresholds.len(),
self.scores.len()
);
return vec![];
}
self.scores
.iter()
.enumerate()
.filter_map(|(idx, &score)| {
if score >= thresholds[idx] {
Some(idx)
} else {
None
}
})
.collect()
}
pub fn from_binary_logits(logits: Vec<f32>, labels: Vec<String>) -> Self {
let scores = InferenceEngine::softmax_static(&logits);
let (predicted_class, max_score) = scores
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, &score)| (idx, score))
.unwrap_or((0, 0.0));
Self {
labels,
scores,
predicted_class,
max_score,
}
}
pub fn from_multilabel_logits(logits: Vec<f32>, labels: Vec<String>) -> Self {
let scores = InferenceEngine::sigmoid_static(&logits);
let (predicted_class, max_score) = scores
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, &score)| (idx, score))
.unwrap_or((0, 0.0));
Self {
labels,
scores,
predicted_class,
max_score,
}
}
}
pub struct InferenceEngine {
session: Arc<Mutex<Session>>,
}
impl InferenceEngine {
pub fn new(session: Arc<Mutex<Session>>) -> Self {
Self { session }
}
pub async fn infer_async(
&self,
input_ids: &[u32],
attention_mask: &[u32],
labels: &[String],
post_processing: PostProcessing,
) -> crate::Result<InferenceResult> {
let session = Arc::clone(&self.session);
let input_ids = input_ids.to_vec();
let attention_mask = attention_mask.to_vec();
let labels = labels.to_vec();
tokio::task::spawn_blocking(move || {
let mut session_guard = session.lock()
.map_err(|e| Error::model(format!("Failed to lock session: {}", e)))?;
Self::infer_sync(&mut *session_guard, &input_ids, &attention_mask, &labels, post_processing)
})
.await
.map_err(|e| Error::model(format!("Async inference task failed: {}", e)))?
}
pub fn infer(
&self,
input_ids: &[u32],
attention_mask: &[u32],
labels: &[String],
post_processing: PostProcessing,
) -> crate::Result<InferenceResult> {
let mut session_guard = self.session.lock()
.map_err(|e| Error::model(format!("Failed to lock session: {}", e)))?;
Self::infer_sync(&mut *session_guard, input_ids, attention_mask, labels, post_processing)
}
fn infer_sync(
session: &mut Session,
input_ids: &[u32],
attention_mask: &[u32],
labels: &[String],
post_processing: PostProcessing,
) -> crate::Result<InferenceResult> {
let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
let batch_size = 1;
let seq_length = input_ids.len();
let input_ids_array =
Array2::from_shape_vec((batch_size, seq_length), input_ids_i64)
.map_err(|e| Error::model(format!("Failed to create input array: {}", e)))?;
let attention_mask_array =
Array2::from_shape_vec((batch_size, seq_length), attention_mask_i64)
.map_err(|e| Error::model(format!("Failed to create attention mask array: {}", e)))?;
let input_ids_value = ort::value::Value::from_array(input_ids_array)
.map_err(|e| Error::model(format!("Failed to create input_ids value: {}", e)))?;
let attention_mask_value = ort::value::Value::from_array(attention_mask_array)
.map_err(|e| Error::model(format!("Failed to create attention_mask value: {}", e)))?;
let outputs = session
.run(ort::inputs![
"input_ids" => input_ids_value,
"attention_mask" => attention_mask_value,
])
.map_err(|e| Error::model(format!("Inference failed: {}", e)))?;
let logits = outputs["logits"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::model(format!("Failed to extract logits: {}", e)))?;
let (_shape, data) = logits;
let logits_vec: Vec<f32> = data.to_vec();
let scores = match post_processing {
PostProcessing::Softmax => Self::softmax_static(&logits_vec),
PostProcessing::Sigmoid => Self::sigmoid_static(&logits_vec),
};
let (predicted_class, max_score) = scores
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, &score)| (idx, score))
.unwrap_or((0, 0.0));
Ok(InferenceResult {
labels: labels.to_vec(),
scores,
predicted_class,
max_score,
})
}
pub fn softmax_static(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return vec![];
}
let max_logit = logits
.iter()
.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
if sum_exp == 0.0 {
vec![1.0 / logits.len() as f32; logits.len()]
} else {
exp_logits.iter().map(|&x| x / sum_exp).collect()
}
}
pub fn sigmoid_static(logits: &[f32]) -> Vec<f32> {
logits
.iter()
.map(|&x| 1.0 / (1.0 + (-x).exp()))
.collect()
}
#[allow(dead_code)]
fn softmax(&self, logits: &[f32]) -> Vec<f32> {
Self::softmax_static(logits)
}
#[allow(dead_code)]
fn sigmoid(&self, logits: &[f32]) -> Vec<f32> {
Self::sigmoid_static(logits)
}
pub async fn infer_token_classification(
&self,
input_ids: &[u32],
attention_mask: &[u32],
labels: &[String],
) -> crate::Result<Vec<TokenPrediction>> {
if input_ids.is_empty() {
return Err(Error::model("input_ids cannot be empty"));
}
if input_ids.len() != attention_mask.len() {
return Err(Error::model(format!(
"input_ids length ({}) != attention_mask length ({})",
input_ids.len(),
attention_mask.len()
)));
}
if labels.is_empty() {
return Err(Error::model("labels cannot be empty"));
}
let session = Arc::clone(&self.session);
let input_ids = input_ids.to_vec();
let attention_mask = attention_mask.to_vec();
let labels = labels.to_vec();
tokio::task::spawn_blocking(move || {
let mut session_guard = session.lock()
.map_err(|e| Error::model(format!("Failed to lock session: {}", e)))?;
Self::infer_token_classification_sync(
&mut *session_guard,
&input_ids,
&attention_mask,
&labels
)
})
.await
.map_err(|e| Error::model(format!("Async inference task failed: {}", e)))?
}
fn infer_token_classification_sync(
session: &mut Session,
input_ids: &[u32],
attention_mask: &[u32],
labels: &[String],
) -> crate::Result<Vec<TokenPrediction>> {
let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
let batch_size = 1;
let seq_length = input_ids.len();
let input_ids_array =
Array2::from_shape_vec((batch_size, seq_length), input_ids_i64)
.map_err(|e| Error::model(format!("Failed to create input array: {}", e)))?;
let attention_mask_array =
Array2::from_shape_vec((batch_size, seq_length), attention_mask_i64)
.map_err(|e| Error::model(format!("Failed to create attention mask array: {}", e)))?;
let input_ids_value = ort::value::Value::from_array(input_ids_array)
.map_err(|e| Error::model(format!("Failed to create input_ids value: {}", e)))?;
let attention_mask_value = ort::value::Value::from_array(attention_mask_array)
.map_err(|e| Error::model(format!("Failed to create attention_mask value: {}", e)))?;
let outputs = session
.run(ort::inputs![
"input_ids" => input_ids_value,
"attention_mask" => attention_mask_value,
])
.map_err(|e| Error::model(format!("Inference failed: {}", e)))?;
let logits = outputs["logits"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::model(format!("Failed to extract logits: {}", e)))?;
let (shape, data) = logits;
if shape.len() != 3 {
return Err(Error::model(format!(
"Expected 3D logits tensor, got shape with {} dimensions",
shape.len()
)));
}
let actual_batch = shape[0] as usize;
let actual_seq_len = shape[1] as usize;
let num_labels = shape[2] as usize;
if actual_batch != batch_size {
return Err(Error::model(format!(
"Batch size mismatch: expected {}, got {}",
batch_size, actual_batch
)));
}
if actual_seq_len != seq_length {
return Err(Error::model(format!(
"Sequence length mismatch: expected {}, got {}",
seq_length, actual_seq_len
)));
}
if num_labels != labels.len() {
return Err(Error::model(format!(
"Label count mismatch: model has {} labels, provided {}",
num_labels,
labels.len()
)));
}
let mut predictions = Vec::with_capacity(seq_length);
for token_idx in 0..seq_length {
let start_idx = token_idx * num_labels;
let end_idx = start_idx + num_labels;
let token_logits: Vec<f32> = data[start_idx..end_idx].to_vec();
let scores = Self::softmax_static(&token_logits);
let (predicted_class, max_score) = scores
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, &score)| (idx, score))
.unwrap_or((0, 0.0));
let predicted_label = labels[predicted_class].clone();
predictions.push(TokenPrediction::new(
input_ids[token_idx],
predicted_label,
predicted_class,
max_score,
scores,
));
}
Ok(predictions)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inference_result_predicted_label() {
let result = InferenceResult {
labels: vec!["safe".to_string(), "unsafe".to_string()],
scores: vec![0.8, 0.2],
predicted_class: 0,
max_score: 0.8,
};
assert_eq!(result.predicted_label(), Some("safe"));
assert!(result.exceeds_threshold(0.7));
assert!(!result.exceeds_threshold(0.9));
}
#[test]
fn test_softmax_values() {
assert!(true);
}
}