llm_shield_models/
inference.rs

1//! Inference Engine
2//!
3//! Handles model inference and result processing.
4//!
5//! ## Features
6//!
7//! - Binary and multi-label classification
8//! - Softmax and sigmoid post-processing
9//! - Threshold-based decision making
10//! - Async inference API
11//! - Support for different model tasks
12//!
13//! ## Example
14//!
15//! ```rust,ignore
16//! use llm_shield_models::InferenceEngine;
17//!
18//! let engine = InferenceEngine::new(session);
19//! let result = engine.infer(&input_ids, &attention_mask, &labels).await?;
20//! ```
21
22use llm_shield_core::Error;
23use ndarray::Array2;
24use ort::session::Session;
25use serde::{Deserialize, Serialize};
26use std::sync::{Arc, Mutex};
27
28/// Post-processing method for model outputs
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum PostProcessing {
31    /// Softmax (for single-label classification)
32    /// Outputs sum to 1.0
33    Softmax,
34
35    /// Sigmoid (for multi-label classification)
36    /// Each output is independent [0, 1]
37    Sigmoid,
38}
39
40/// Prediction for a single token in token classification
41#[derive(Debug, Clone, PartialEq)]
42pub struct TokenPrediction {
43    /// Token ID from vocabulary
44    pub token_id: u32,
45
46    /// Predicted label (e.g., "B-PERSON", "I-EMAIL", "O")
47    pub predicted_label: String,
48
49    /// Index of predicted class
50    pub predicted_class: usize,
51
52    /// Confidence score for predicted class (0.0-1.0)
53    pub confidence: f32,
54
55    /// Probability distribution over all classes (after softmax)
56    pub all_scores: Vec<f32>,
57}
58
59impl TokenPrediction {
60    /// Create a new token prediction
61    pub fn new(
62        token_id: u32,
63        predicted_label: String,
64        predicted_class: usize,
65        confidence: f32,
66        all_scores: Vec<f32>,
67    ) -> Self {
68        Self {
69            token_id,
70            predicted_label,
71            predicted_class,
72            confidence,
73            all_scores,
74        }
75    }
76
77    /// Validate invariants
78    pub fn validate(&self) -> Result<(), String> {
79        // Confidence in valid range
80        if self.confidence < 0.0 || self.confidence > 1.0 {
81            return Err(format!("Invalid confidence: {}", self.confidence));
82        }
83
84        // Predicted class is valid index
85        if self.predicted_class >= self.all_scores.len() {
86            return Err(format!(
87                "Invalid predicted_class {} for {} scores",
88                self.predicted_class,
89                self.all_scores.len()
90            ));
91        }
92
93        // Confidence matches all_scores
94        let expected_confidence = self.all_scores[self.predicted_class];
95        if (self.confidence - expected_confidence).abs() > 0.001 {
96            return Err(format!(
97                "Confidence mismatch: {} != {}",
98                self.confidence, expected_confidence
99            ));
100        }
101
102        // All scores in valid range
103        for (i, &score) in self.all_scores.iter().enumerate() {
104            if score < 0.0 || score > 1.0 {
105                return Err(format!("Invalid score at index {}: {}", i, score));
106            }
107        }
108
109        // Sum of scores is approximately 1.0 (softmax invariant)
110        let sum: f32 = self.all_scores.iter().sum();
111        if (sum - 1.0).abs() > 0.01 {
112            return Err(format!("Scores don't sum to 1.0: {}", sum));
113        }
114
115        Ok(())
116    }
117}
118
119/// Inference result with classification predictions
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
121pub struct InferenceResult {
122    /// Predicted class labels
123    pub labels: Vec<String>,
124
125    /// Confidence scores for each class (after post-processing)
126    pub scores: Vec<f32>,
127
128    /// Predicted class index (highest score)
129    pub predicted_class: usize,
130
131    /// Maximum confidence score
132    pub max_score: f32,
133}
134
135impl InferenceResult {
136    /// Get the predicted label
137    ///
138    /// # Example
139    ///
140    /// ```
141    /// use llm_shield_models::InferenceResult;
142    ///
143    /// let result = InferenceResult {
144    ///     labels: vec!["SAFE".to_string(), "INJECTION".to_string()],
145    ///     scores: vec![0.8, 0.2],
146    ///     predicted_class: 0,
147    ///     max_score: 0.8,
148    /// };
149    ///
150    /// assert_eq!(result.predicted_label(), Some("SAFE"));
151    /// ```
152    pub fn predicted_label(&self) -> Option<&str> {
153        self.labels.get(self.predicted_class).map(|s| s.as_str())
154    }
155
156    /// Check if prediction confidence exceeds threshold
157    ///
158    /// # Arguments
159    ///
160    /// * `threshold` - Minimum confidence threshold (0.0 to 1.0)
161    ///
162    /// # Example
163    ///
164    /// ```
165    /// use llm_shield_models::InferenceResult;
166    ///
167    /// let result = InferenceResult {
168    ///     labels: vec!["SAFE".to_string(), "INJECTION".to_string()],
169    ///     scores: vec![0.3, 0.7],
170    ///     predicted_class: 1,
171    ///     max_score: 0.7,
172    /// };
173    ///
174    /// assert!(result.exceeds_threshold(0.5));
175    /// assert!(!result.exceeds_threshold(0.8));
176    /// ```
177    pub fn exceeds_threshold(&self, threshold: f32) -> bool {
178        self.max_score >= threshold
179    }
180
181    /// Get score for a specific label
182    ///
183    /// # Arguments
184    ///
185    /// * `label` - The label to get the score for
186    ///
187    /// # Returns
188    ///
189    /// The confidence score for the label, or None if not found
190    pub fn get_score_for_label(&self, label: &str) -> Option<f32> {
191        self.labels
192            .iter()
193            .position(|l| l == label)
194            .and_then(|idx| self.scores.get(idx).copied())
195    }
196
197    /// Check if this is a binary classification result
198    pub fn is_binary(&self) -> bool {
199        self.labels.len() == 2
200    }
201
202    /// Get indices of labels that exceed their respective thresholds
203    ///
204    /// Used for multi-label classification where each class has its own threshold.
205    ///
206    /// # Arguments
207    ///
208    /// * `thresholds` - Per-class thresholds (must match number of labels)
209    ///
210    /// # Returns
211    ///
212    /// Vector of class indices that exceed their thresholds
213    pub fn get_threshold_violations(&self, thresholds: &[f32]) -> Vec<usize> {
214        if thresholds.len() != self.scores.len() {
215            tracing::warn!(
216                "Threshold count mismatch: {} thresholds for {} classes",
217                thresholds.len(),
218                self.scores.len()
219            );
220            return vec![];
221        }
222
223        self.scores
224            .iter()
225            .enumerate()
226            .filter_map(|(idx, &score)| {
227                if score >= thresholds[idx] {
228                    Some(idx)
229                } else {
230                    None
231                }
232            })
233            .collect()
234    }
235
236    /// Create InferenceResult from logits using softmax (single-label)
237    ///
238    /// # Arguments
239    ///
240    /// * `logits` - Raw model output logits
241    /// * `labels` - Class labels
242    ///
243    /// # Example
244    ///
245    /// ```
246    /// use llm_shield_models::InferenceResult;
247    ///
248    /// let logits = vec![1.0, 2.0, 0.5];
249    /// let labels = vec!["A".to_string(), "B".to_string(), "C".to_string()];
250    /// let result = InferenceResult::from_binary_logits(logits, labels);
251    ///
252    /// // B should have highest probability
253    /// assert_eq!(result.predicted_class, 1);
254    /// ```
255    pub fn from_binary_logits(logits: Vec<f32>, labels: Vec<String>) -> Self {
256        let scores = InferenceEngine::softmax_static(&logits);
257        let (predicted_class, max_score) = scores
258            .iter()
259            .enumerate()
260            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
261            .map(|(idx, &score)| (idx, score))
262            .unwrap_or((0, 0.0));
263
264        Self {
265            labels,
266            scores,
267            predicted_class,
268            max_score,
269        }
270    }
271
272    /// Create InferenceResult from logits using sigmoid (multi-label)
273    ///
274    /// # Arguments
275    ///
276    /// * `logits` - Raw model output logits
277    /// * `labels` - Class labels
278    ///
279    /// # Example
280    ///
281    /// ```
282    /// use llm_shield_models::InferenceResult;
283    ///
284    /// let logits = vec![2.0, -1.0, 1.0];
285    /// let labels = vec!["toxic".to_string(), "threat".to_string(), "insult".to_string()];
286    /// let result = InferenceResult::from_multilabel_logits(logits, labels);
287    ///
288    /// // All scores should be in [0, 1]
289    /// for score in &result.scores {
290    ///     assert!(*score >= 0.0 && *score <= 1.0);
291    /// }
292    /// ```
293    pub fn from_multilabel_logits(logits: Vec<f32>, labels: Vec<String>) -> Self {
294        let scores = InferenceEngine::sigmoid_static(&logits);
295        let (predicted_class, max_score) = scores
296            .iter()
297            .enumerate()
298            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
299            .map(|(idx, &score)| (idx, score))
300            .unwrap_or((0, 0.0));
301
302        Self {
303            labels,
304            scores,
305            predicted_class,
306            max_score,
307        }
308    }
309}
310
311/// Inference engine for running ONNX model inference
312///
313/// ## Features
314///
315/// - Synchronous and asynchronous inference
316/// - Binary and multi-label classification
317/// - Automatic post-processing (softmax/sigmoid)
318/// - Batch inference support (optional)
319///
320/// ## Example
321///
322/// ```rust,ignore
323/// use llm_shield_models::InferenceEngine;
324/// use std::sync::Arc;
325///
326/// let engine = InferenceEngine::new(session);
327///
328/// // Run inference
329/// let result = engine.infer(
330///     &input_ids,
331///     &attention_mask,
332///     &labels,
333///     PostProcessing::Softmax,
334/// ).await?;
335///
336/// println!("Predicted: {}", result.predicted_label().unwrap());
337/// println!("Confidence: {:.2}", result.max_score);
338/// ```
339pub struct InferenceEngine {
340    session: Arc<Mutex<Session>>,
341}
342
343impl InferenceEngine {
344    /// Create a new inference engine
345    ///
346    /// # Arguments
347    ///
348    /// * `session` - ONNX Runtime session wrapped in Arc<Mutex<>> for thread-safe mutable access
349    pub fn new(session: Arc<Mutex<Session>>) -> Self {
350        Self { session }
351    }
352
353    /// Run inference on input IDs (async)
354    ///
355    /// # Arguments
356    ///
357    /// * `input_ids` - Tokenized input IDs
358    /// * `attention_mask` - Attention mask (1 for real tokens, 0 for padding)
359    /// * `labels` - Class labels
360    /// * `post_processing` - Post-processing method (Softmax or Sigmoid)
361    ///
362    /// # Returns
363    ///
364    /// InferenceResult with predictions and confidence scores
365    pub async fn infer_async(
366        &self,
367        input_ids: &[u32],
368        attention_mask: &[u32],
369        labels: &[String],
370        post_processing: PostProcessing,
371    ) -> crate::Result<InferenceResult> {
372        // Run inference in blocking thread pool to avoid blocking async runtime
373        let session = Arc::clone(&self.session);
374        let input_ids = input_ids.to_vec();
375        let attention_mask = attention_mask.to_vec();
376        let labels = labels.to_vec();
377
378        tokio::task::spawn_blocking(move || {
379            let mut session_guard = session.lock()
380                .map_err(|e| Error::model(format!("Failed to lock session: {}", e)))?;
381            Self::infer_sync(&mut *session_guard, &input_ids, &attention_mask, &labels, post_processing)
382        })
383        .await
384        .map_err(|e| Error::model(format!("Async inference task failed: {}", e)))?
385    }
386
387    /// Run inference on input IDs (synchronous)
388    ///
389    /// # Arguments
390    ///
391    /// * `input_ids` - Tokenized input IDs
392    /// * `attention_mask` - Attention mask (1 for real tokens, 0 for padding)
393    /// * `labels` - Class labels
394    /// * `post_processing` - Post-processing method (Softmax or Sigmoid)
395    ///
396    /// # Returns
397    ///
398    /// InferenceResult with predictions and confidence scores
399    pub fn infer(
400        &self,
401        input_ids: &[u32],
402        attention_mask: &[u32],
403        labels: &[String],
404        post_processing: PostProcessing,
405    ) -> crate::Result<InferenceResult> {
406        let mut session_guard = self.session.lock()
407            .map_err(|e| Error::model(format!("Failed to lock session: {}", e)))?;
408        Self::infer_sync(&mut *session_guard, input_ids, attention_mask, labels, post_processing)
409    }
410
411    /// Internal synchronous inference implementation
412    fn infer_sync(
413        session: &mut Session,
414        input_ids: &[u32],
415        attention_mask: &[u32],
416        labels: &[String],
417        post_processing: PostProcessing,
418    ) -> crate::Result<InferenceResult> {
419        // Convert to i64 for ONNX
420        let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
421        let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
422
423        let batch_size = 1;
424        let seq_length = input_ids.len();
425
426        // Create input arrays
427        let input_ids_array =
428            Array2::from_shape_vec((batch_size, seq_length), input_ids_i64)
429                .map_err(|e| Error::model(format!("Failed to create input array: {}", e)))?;
430
431        let attention_mask_array =
432            Array2::from_shape_vec((batch_size, seq_length), attention_mask_i64)
433                .map_err(|e| Error::model(format!("Failed to create attention mask array: {}", e)))?;
434
435        // Create ONNX values
436        let input_ids_value = ort::value::Value::from_array(input_ids_array)
437            .map_err(|e| Error::model(format!("Failed to create input_ids value: {}", e)))?;
438        let attention_mask_value = ort::value::Value::from_array(attention_mask_array)
439            .map_err(|e| Error::model(format!("Failed to create attention_mask value: {}", e)))?;
440
441        // Run inference
442        let outputs = session
443            .run(ort::inputs![
444                "input_ids" => input_ids_value,
445                "attention_mask" => attention_mask_value,
446            ])
447            .map_err(|e| Error::model(format!("Inference failed: {}", e)))?;
448
449        // Extract logits
450        let logits = outputs["logits"]
451            .try_extract_tensor::<f32>()
452            .map_err(|e| Error::model(format!("Failed to extract logits: {}", e)))?;
453
454        // Extract logits as Vec<f32> - logits is now (shape, data)
455        let (_shape, data) = logits;
456        let logits_vec: Vec<f32> = data.to_vec();
457
458        // Apply post-processing
459        let scores = match post_processing {
460            PostProcessing::Softmax => Self::softmax_static(&logits_vec),
461            PostProcessing::Sigmoid => Self::sigmoid_static(&logits_vec),
462        };
463
464        // Find predicted class
465        let (predicted_class, max_score) = scores
466            .iter()
467            .enumerate()
468            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
469            .map(|(idx, &score)| (idx, score))
470            .unwrap_or((0, 0.0));
471
472        Ok(InferenceResult {
473            labels: labels.to_vec(),
474            scores,
475            predicted_class,
476            max_score,
477        })
478    }
479
480    /// Apply softmax to logits (static method)
481    ///
482    /// Softmax converts logits to probabilities that sum to 1.0.
483    /// Used for single-label classification (mutually exclusive classes).
484    ///
485    /// # Arguments
486    ///
487    /// * `logits` - Raw model output logits
488    ///
489    /// # Returns
490    ///
491    /// Probability distribution (sums to 1.0)
492    ///
493    /// # Example
494    ///
495    /// ```
496    /// use llm_shield_models::InferenceEngine;
497    ///
498    /// let logits = vec![1.0, 2.0, 0.5];
499    /// let probs = InferenceEngine::softmax_static(&logits);
500    ///
501    /// // Probabilities sum to 1.0
502    /// let sum: f32 = probs.iter().sum();
503    /// assert!((sum - 1.0).abs() < 0.001);
504    /// ```
505    pub fn softmax_static(logits: &[f32]) -> Vec<f32> {
506        if logits.is_empty() {
507            return vec![];
508        }
509
510        // Find max for numerical stability
511        let max_logit = logits
512            .iter()
513            .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
514
515        // Compute exp(logit - max)
516        let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
517
518        // Sum of exponentials
519        let sum_exp: f32 = exp_logits.iter().sum();
520
521        // Normalize
522        if sum_exp == 0.0 {
523            // Edge case: all logits are very negative
524            vec![1.0 / logits.len() as f32; logits.len()]
525        } else {
526            exp_logits.iter().map(|&x| x / sum_exp).collect()
527        }
528    }
529
530    /// Apply sigmoid to logits (static method)
531    ///
532    /// Sigmoid converts each logit independently to [0, 1].
533    /// Used for multi-label classification (non-exclusive classes).
534    ///
535    /// # Arguments
536    ///
537    /// * `logits` - Raw model output logits
538    ///
539    /// # Returns
540    ///
541    /// Independent probabilities (do NOT sum to 1.0)
542    ///
543    /// # Example
544    ///
545    /// ```
546    /// use llm_shield_models::InferenceEngine;
547    ///
548    /// let logits = vec![0.0, 2.0, -2.0];
549    /// let probs = InferenceEngine::sigmoid_static(&logits);
550    ///
551    /// // sigmoid(0) ≈ 0.5
552    /// assert!((probs[0] - 0.5).abs() < 0.01);
553    ///
554    /// // All probabilities in [0, 1]
555    /// for p in probs {
556    ///     assert!(p >= 0.0 && p <= 1.0);
557    /// }
558    /// ```
559    pub fn sigmoid_static(logits: &[f32]) -> Vec<f32> {
560        logits
561            .iter()
562            .map(|&x| 1.0 / (1.0 + (-x).exp()))
563            .collect()
564    }
565
566    /// Apply softmax to logits (instance method)
567    #[allow(dead_code)]
568    fn softmax(&self, logits: &[f32]) -> Vec<f32> {
569        Self::softmax_static(logits)
570    }
571
572    /// Apply sigmoid to logits (instance method)
573    #[allow(dead_code)]
574    fn sigmoid(&self, logits: &[f32]) -> Vec<f32> {
575        Self::sigmoid_static(logits)
576    }
577
578    /// Run token-level classification inference (for NER/token classification)
579    ///
580    /// # Arguments
581    ///
582    /// * `input_ids` - Token IDs from tokenizer
583    /// * `attention_mask` - Attention mask (1=real token, 0=padding)
584    /// * `labels` - BIO tag labels (e.g., ["O", "B-PERSON", "I-PERSON", ...])
585    ///
586    /// # Returns
587    ///
588    /// Vector of predictions, one per input token
589    ///
590    /// # Errors
591    ///
592    /// - Model inference failure
593    /// - Invalid tensor shapes
594    /// - Label count mismatch
595    ///
596    /// # Example
597    ///
598    /// ```rust,ignore
599    /// use llm_shield_models::InferenceEngine;
600    ///
601    /// let engine = InferenceEngine::new(session);
602    /// let labels = vec!["O", "B-PERSON", "I-PERSON"];
603    ///
604    /// let predictions = engine.infer_token_classification(
605    ///     &input_ids,
606    ///     &attention_mask,
607    ///     &labels
608    /// ).await?;
609    ///
610    /// for pred in predictions {
611    ///     println!("{}: {:.2}", pred.predicted_label, pred.confidence);
612    /// }
613    /// ```
614    pub async fn infer_token_classification(
615        &self,
616        input_ids: &[u32],
617        attention_mask: &[u32],
618        labels: &[String],
619    ) -> crate::Result<Vec<TokenPrediction>> {
620        // Validation
621        if input_ids.is_empty() {
622            return Err(Error::model("input_ids cannot be empty"));
623        }
624        if input_ids.len() != attention_mask.len() {
625            return Err(Error::model(format!(
626                "input_ids length ({}) != attention_mask length ({})",
627                input_ids.len(),
628                attention_mask.len()
629            )));
630        }
631        if labels.is_empty() {
632            return Err(Error::model("labels cannot be empty"));
633        }
634
635        // Run inference in blocking thread pool to avoid blocking async runtime
636        let session = Arc::clone(&self.session);
637        let input_ids = input_ids.to_vec();
638        let attention_mask = attention_mask.to_vec();
639        let labels = labels.to_vec();
640
641        tokio::task::spawn_blocking(move || {
642            let mut session_guard = session.lock()
643                .map_err(|e| Error::model(format!("Failed to lock session: {}", e)))?;
644            Self::infer_token_classification_sync(
645                &mut *session_guard,
646                &input_ids,
647                &attention_mask,
648                &labels
649            )
650        })
651        .await
652        .map_err(|e| Error::model(format!("Async inference task failed: {}", e)))?
653    }
654
655    /// Internal synchronous token classification implementation
656    fn infer_token_classification_sync(
657        session: &mut Session,
658        input_ids: &[u32],
659        attention_mask: &[u32],
660        labels: &[String],
661    ) -> crate::Result<Vec<TokenPrediction>> {
662        // Convert to i64 for ONNX
663        let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
664        let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
665
666        let batch_size = 1;
667        let seq_length = input_ids.len();
668
669        // Create input arrays [batch_size, seq_length]
670        let input_ids_array =
671            Array2::from_shape_vec((batch_size, seq_length), input_ids_i64)
672                .map_err(|e| Error::model(format!("Failed to create input array: {}", e)))?;
673
674        let attention_mask_array =
675            Array2::from_shape_vec((batch_size, seq_length), attention_mask_i64)
676                .map_err(|e| Error::model(format!("Failed to create attention mask array: {}", e)))?;
677
678        // Create ONNX values
679        let input_ids_value = ort::value::Value::from_array(input_ids_array)
680            .map_err(|e| Error::model(format!("Failed to create input_ids value: {}", e)))?;
681        let attention_mask_value = ort::value::Value::from_array(attention_mask_array)
682            .map_err(|e| Error::model(format!("Failed to create attention_mask value: {}", e)))?;
683
684        // Run inference
685        let outputs = session
686            .run(ort::inputs![
687                "input_ids" => input_ids_value,
688                "attention_mask" => attention_mask_value,
689            ])
690            .map_err(|e| Error::model(format!("Inference failed: {}", e)))?;
691
692        // Extract logits [batch_size, seq_length, num_labels]
693        let logits = outputs["logits"]
694            .try_extract_tensor::<f32>()
695            .map_err(|e| Error::model(format!("Failed to extract logits: {}", e)))?;
696
697        // logits is (shape, data)
698        let (shape, data) = logits;
699
700        // Validate shape: should be [batch_size, seq_length, num_labels]
701        if shape.len() != 3 {
702            return Err(Error::model(format!(
703                "Expected 3D logits tensor, got shape with {} dimensions",
704                shape.len()
705            )));
706        }
707
708        let actual_batch = shape[0] as usize;
709        let actual_seq_len = shape[1] as usize;
710        let num_labels = shape[2] as usize;
711
712        if actual_batch != batch_size {
713            return Err(Error::model(format!(
714                "Batch size mismatch: expected {}, got {}",
715                batch_size, actual_batch
716            )));
717        }
718
719        if actual_seq_len != seq_length {
720            return Err(Error::model(format!(
721                "Sequence length mismatch: expected {}, got {}",
722                seq_length, actual_seq_len
723            )));
724        }
725
726        if num_labels != labels.len() {
727            return Err(Error::model(format!(
728                "Label count mismatch: model has {} labels, provided {}",
729                num_labels,
730                labels.len()
731            )));
732        }
733
734        // Process each token
735        let mut predictions = Vec::with_capacity(seq_length);
736
737        for token_idx in 0..seq_length {
738            // Extract logits for this token
739            let start_idx = token_idx * num_labels;
740            let end_idx = start_idx + num_labels;
741            let token_logits: Vec<f32> = data[start_idx..end_idx].to_vec();
742
743            // Apply softmax
744            let scores = Self::softmax_static(&token_logits);
745
746            // Find predicted class
747            let (predicted_class, max_score) = scores
748                .iter()
749                .enumerate()
750                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
751                .map(|(idx, &score)| (idx, score))
752                .unwrap_or((0, 0.0));
753
754            let predicted_label = labels[predicted_class].clone();
755
756            predictions.push(TokenPrediction::new(
757                input_ids[token_idx],
758                predicted_label,
759                predicted_class,
760                max_score,
761                scores,
762            ));
763        }
764
765        Ok(predictions)
766    }
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772
773    #[test]
774    fn test_inference_result_predicted_label() {
775        let result = InferenceResult {
776            labels: vec!["safe".to_string(), "unsafe".to_string()],
777            scores: vec![0.8, 0.2],
778            predicted_class: 0,
779            max_score: 0.8,
780        };
781
782        assert_eq!(result.predicted_label(), Some("safe"));
783        assert!(result.exceeds_threshold(0.7));
784        assert!(!result.exceeds_threshold(0.9));
785    }
786
787    #[test]
788    fn test_softmax_values() {
789        // Manual softmax verification would require creating a session
790        // This test verifies the structure compiles
791        assert!(true);
792    }
793}