llm_shield_models/inference.rs
1//! Inference Engine
2//!
3//! Handles model inference and result processing.
4//!
5//! ## Features
6//!
7//! - Binary and multi-label classification
8//! - Softmax and sigmoid post-processing
9//! - Threshold-based decision making
10//! - Async inference API
11//! - Support for different model tasks
12//!
13//! ## Example
14//!
15//! ```rust,ignore
16//! use llm_shield_models::InferenceEngine;
17//!
18//! let engine = InferenceEngine::new(session);
19//! let result = engine.infer(&input_ids, &attention_mask, &labels).await?;
20//! ```
21
22use llm_shield_core::Error;
23use ndarray::Array2;
24use ort::session::Session;
25use serde::{Deserialize, Serialize};
26use std::sync::{Arc, Mutex};
27
28/// Post-processing method for model outputs
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum PostProcessing {
31 /// Softmax (for single-label classification)
32 /// Outputs sum to 1.0
33 Softmax,
34
35 /// Sigmoid (for multi-label classification)
36 /// Each output is independent [0, 1]
37 Sigmoid,
38}
39
40/// Prediction for a single token in token classification
41#[derive(Debug, Clone, PartialEq)]
42pub struct TokenPrediction {
43 /// Token ID from vocabulary
44 pub token_id: u32,
45
46 /// Predicted label (e.g., "B-PERSON", "I-EMAIL", "O")
47 pub predicted_label: String,
48
49 /// Index of predicted class
50 pub predicted_class: usize,
51
52 /// Confidence score for predicted class (0.0-1.0)
53 pub confidence: f32,
54
55 /// Probability distribution over all classes (after softmax)
56 pub all_scores: Vec<f32>,
57}
58
59impl TokenPrediction {
60 /// Create a new token prediction
61 pub fn new(
62 token_id: u32,
63 predicted_label: String,
64 predicted_class: usize,
65 confidence: f32,
66 all_scores: Vec<f32>,
67 ) -> Self {
68 Self {
69 token_id,
70 predicted_label,
71 predicted_class,
72 confidence,
73 all_scores,
74 }
75 }
76
77 /// Validate invariants
78 pub fn validate(&self) -> Result<(), String> {
79 // Confidence in valid range
80 if self.confidence < 0.0 || self.confidence > 1.0 {
81 return Err(format!("Invalid confidence: {}", self.confidence));
82 }
83
84 // Predicted class is valid index
85 if self.predicted_class >= self.all_scores.len() {
86 return Err(format!(
87 "Invalid predicted_class {} for {} scores",
88 self.predicted_class,
89 self.all_scores.len()
90 ));
91 }
92
93 // Confidence matches all_scores
94 let expected_confidence = self.all_scores[self.predicted_class];
95 if (self.confidence - expected_confidence).abs() > 0.001 {
96 return Err(format!(
97 "Confidence mismatch: {} != {}",
98 self.confidence, expected_confidence
99 ));
100 }
101
102 // All scores in valid range
103 for (i, &score) in self.all_scores.iter().enumerate() {
104 if score < 0.0 || score > 1.0 {
105 return Err(format!("Invalid score at index {}: {}", i, score));
106 }
107 }
108
109 // Sum of scores is approximately 1.0 (softmax invariant)
110 let sum: f32 = self.all_scores.iter().sum();
111 if (sum - 1.0).abs() > 0.01 {
112 return Err(format!("Scores don't sum to 1.0: {}", sum));
113 }
114
115 Ok(())
116 }
117}
118
119/// Inference result with classification predictions
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
121pub struct InferenceResult {
122 /// Predicted class labels
123 pub labels: Vec<String>,
124
125 /// Confidence scores for each class (after post-processing)
126 pub scores: Vec<f32>,
127
128 /// Predicted class index (highest score)
129 pub predicted_class: usize,
130
131 /// Maximum confidence score
132 pub max_score: f32,
133}
134
135impl InferenceResult {
136 /// Get the predicted label
137 ///
138 /// # Example
139 ///
140 /// ```
141 /// use llm_shield_models::InferenceResult;
142 ///
143 /// let result = InferenceResult {
144 /// labels: vec!["SAFE".to_string(), "INJECTION".to_string()],
145 /// scores: vec![0.8, 0.2],
146 /// predicted_class: 0,
147 /// max_score: 0.8,
148 /// };
149 ///
150 /// assert_eq!(result.predicted_label(), Some("SAFE"));
151 /// ```
152 pub fn predicted_label(&self) -> Option<&str> {
153 self.labels.get(self.predicted_class).map(|s| s.as_str())
154 }
155
156 /// Check if prediction confidence exceeds threshold
157 ///
158 /// # Arguments
159 ///
160 /// * `threshold` - Minimum confidence threshold (0.0 to 1.0)
161 ///
162 /// # Example
163 ///
164 /// ```
165 /// use llm_shield_models::InferenceResult;
166 ///
167 /// let result = InferenceResult {
168 /// labels: vec!["SAFE".to_string(), "INJECTION".to_string()],
169 /// scores: vec![0.3, 0.7],
170 /// predicted_class: 1,
171 /// max_score: 0.7,
172 /// };
173 ///
174 /// assert!(result.exceeds_threshold(0.5));
175 /// assert!(!result.exceeds_threshold(0.8));
176 /// ```
177 pub fn exceeds_threshold(&self, threshold: f32) -> bool {
178 self.max_score >= threshold
179 }
180
181 /// Get score for a specific label
182 ///
183 /// # Arguments
184 ///
185 /// * `label` - The label to get the score for
186 ///
187 /// # Returns
188 ///
189 /// The confidence score for the label, or None if not found
190 pub fn get_score_for_label(&self, label: &str) -> Option<f32> {
191 self.labels
192 .iter()
193 .position(|l| l == label)
194 .and_then(|idx| self.scores.get(idx).copied())
195 }
196
197 /// Check if this is a binary classification result
198 pub fn is_binary(&self) -> bool {
199 self.labels.len() == 2
200 }
201
202 /// Get indices of labels that exceed their respective thresholds
203 ///
204 /// Used for multi-label classification where each class has its own threshold.
205 ///
206 /// # Arguments
207 ///
208 /// * `thresholds` - Per-class thresholds (must match number of labels)
209 ///
210 /// # Returns
211 ///
212 /// Vector of class indices that exceed their thresholds
213 pub fn get_threshold_violations(&self, thresholds: &[f32]) -> Vec<usize> {
214 if thresholds.len() != self.scores.len() {
215 tracing::warn!(
216 "Threshold count mismatch: {} thresholds for {} classes",
217 thresholds.len(),
218 self.scores.len()
219 );
220 return vec![];
221 }
222
223 self.scores
224 .iter()
225 .enumerate()
226 .filter_map(|(idx, &score)| {
227 if score >= thresholds[idx] {
228 Some(idx)
229 } else {
230 None
231 }
232 })
233 .collect()
234 }
235
236 /// Create InferenceResult from logits using softmax (single-label)
237 ///
238 /// # Arguments
239 ///
240 /// * `logits` - Raw model output logits
241 /// * `labels` - Class labels
242 ///
243 /// # Example
244 ///
245 /// ```
246 /// use llm_shield_models::InferenceResult;
247 ///
248 /// let logits = vec![1.0, 2.0, 0.5];
249 /// let labels = vec!["A".to_string(), "B".to_string(), "C".to_string()];
250 /// let result = InferenceResult::from_binary_logits(logits, labels);
251 ///
252 /// // B should have highest probability
253 /// assert_eq!(result.predicted_class, 1);
254 /// ```
255 pub fn from_binary_logits(logits: Vec<f32>, labels: Vec<String>) -> Self {
256 let scores = InferenceEngine::softmax_static(&logits);
257 let (predicted_class, max_score) = scores
258 .iter()
259 .enumerate()
260 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
261 .map(|(idx, &score)| (idx, score))
262 .unwrap_or((0, 0.0));
263
264 Self {
265 labels,
266 scores,
267 predicted_class,
268 max_score,
269 }
270 }
271
272 /// Create InferenceResult from logits using sigmoid (multi-label)
273 ///
274 /// # Arguments
275 ///
276 /// * `logits` - Raw model output logits
277 /// * `labels` - Class labels
278 ///
279 /// # Example
280 ///
281 /// ```
282 /// use llm_shield_models::InferenceResult;
283 ///
284 /// let logits = vec![2.0, -1.0, 1.0];
285 /// let labels = vec!["toxic".to_string(), "threat".to_string(), "insult".to_string()];
286 /// let result = InferenceResult::from_multilabel_logits(logits, labels);
287 ///
288 /// // All scores should be in [0, 1]
289 /// for score in &result.scores {
290 /// assert!(*score >= 0.0 && *score <= 1.0);
291 /// }
292 /// ```
293 pub fn from_multilabel_logits(logits: Vec<f32>, labels: Vec<String>) -> Self {
294 let scores = InferenceEngine::sigmoid_static(&logits);
295 let (predicted_class, max_score) = scores
296 .iter()
297 .enumerate()
298 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
299 .map(|(idx, &score)| (idx, score))
300 .unwrap_or((0, 0.0));
301
302 Self {
303 labels,
304 scores,
305 predicted_class,
306 max_score,
307 }
308 }
309}
310
311/// Inference engine for running ONNX model inference
312///
313/// ## Features
314///
315/// - Synchronous and asynchronous inference
316/// - Binary and multi-label classification
317/// - Automatic post-processing (softmax/sigmoid)
318/// - Batch inference support (optional)
319///
320/// ## Example
321///
322/// ```rust,ignore
323/// use llm_shield_models::InferenceEngine;
324/// use std::sync::Arc;
325///
326/// let engine = InferenceEngine::new(session);
327///
328/// // Run inference
329/// let result = engine.infer(
330/// &input_ids,
331/// &attention_mask,
332/// &labels,
333/// PostProcessing::Softmax,
334/// ).await?;
335///
336/// println!("Predicted: {}", result.predicted_label().unwrap());
337/// println!("Confidence: {:.2}", result.max_score);
338/// ```
339pub struct InferenceEngine {
340 session: Arc<Mutex<Session>>,
341}
342
343impl InferenceEngine {
344 /// Create a new inference engine
345 ///
346 /// # Arguments
347 ///
348 /// * `session` - ONNX Runtime session wrapped in Arc<Mutex<>> for thread-safe mutable access
349 pub fn new(session: Arc<Mutex<Session>>) -> Self {
350 Self { session }
351 }
352
353 /// Run inference on input IDs (async)
354 ///
355 /// # Arguments
356 ///
357 /// * `input_ids` - Tokenized input IDs
358 /// * `attention_mask` - Attention mask (1 for real tokens, 0 for padding)
359 /// * `labels` - Class labels
360 /// * `post_processing` - Post-processing method (Softmax or Sigmoid)
361 ///
362 /// # Returns
363 ///
364 /// InferenceResult with predictions and confidence scores
365 pub async fn infer_async(
366 &self,
367 input_ids: &[u32],
368 attention_mask: &[u32],
369 labels: &[String],
370 post_processing: PostProcessing,
371 ) -> crate::Result<InferenceResult> {
372 // Run inference in blocking thread pool to avoid blocking async runtime
373 let session = Arc::clone(&self.session);
374 let input_ids = input_ids.to_vec();
375 let attention_mask = attention_mask.to_vec();
376 let labels = labels.to_vec();
377
378 tokio::task::spawn_blocking(move || {
379 let mut session_guard = session.lock()
380 .map_err(|e| Error::model(format!("Failed to lock session: {}", e)))?;
381 Self::infer_sync(&mut *session_guard, &input_ids, &attention_mask, &labels, post_processing)
382 })
383 .await
384 .map_err(|e| Error::model(format!("Async inference task failed: {}", e)))?
385 }
386
387 /// Run inference on input IDs (synchronous)
388 ///
389 /// # Arguments
390 ///
391 /// * `input_ids` - Tokenized input IDs
392 /// * `attention_mask` - Attention mask (1 for real tokens, 0 for padding)
393 /// * `labels` - Class labels
394 /// * `post_processing` - Post-processing method (Softmax or Sigmoid)
395 ///
396 /// # Returns
397 ///
398 /// InferenceResult with predictions and confidence scores
399 pub fn infer(
400 &self,
401 input_ids: &[u32],
402 attention_mask: &[u32],
403 labels: &[String],
404 post_processing: PostProcessing,
405 ) -> crate::Result<InferenceResult> {
406 let mut session_guard = self.session.lock()
407 .map_err(|e| Error::model(format!("Failed to lock session: {}", e)))?;
408 Self::infer_sync(&mut *session_guard, input_ids, attention_mask, labels, post_processing)
409 }
410
411 /// Internal synchronous inference implementation
412 fn infer_sync(
413 session: &mut Session,
414 input_ids: &[u32],
415 attention_mask: &[u32],
416 labels: &[String],
417 post_processing: PostProcessing,
418 ) -> crate::Result<InferenceResult> {
419 // Convert to i64 for ONNX
420 let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
421 let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
422
423 let batch_size = 1;
424 let seq_length = input_ids.len();
425
426 // Create input arrays
427 let input_ids_array =
428 Array2::from_shape_vec((batch_size, seq_length), input_ids_i64)
429 .map_err(|e| Error::model(format!("Failed to create input array: {}", e)))?;
430
431 let attention_mask_array =
432 Array2::from_shape_vec((batch_size, seq_length), attention_mask_i64)
433 .map_err(|e| Error::model(format!("Failed to create attention mask array: {}", e)))?;
434
435 // Create ONNX values
436 let input_ids_value = ort::value::Value::from_array(input_ids_array)
437 .map_err(|e| Error::model(format!("Failed to create input_ids value: {}", e)))?;
438 let attention_mask_value = ort::value::Value::from_array(attention_mask_array)
439 .map_err(|e| Error::model(format!("Failed to create attention_mask value: {}", e)))?;
440
441 // Run inference
442 let outputs = session
443 .run(ort::inputs![
444 "input_ids" => input_ids_value,
445 "attention_mask" => attention_mask_value,
446 ])
447 .map_err(|e| Error::model(format!("Inference failed: {}", e)))?;
448
449 // Extract logits
450 let logits = outputs["logits"]
451 .try_extract_tensor::<f32>()
452 .map_err(|e| Error::model(format!("Failed to extract logits: {}", e)))?;
453
454 // Extract logits as Vec<f32> - logits is now (shape, data)
455 let (_shape, data) = logits;
456 let logits_vec: Vec<f32> = data.to_vec();
457
458 // Apply post-processing
459 let scores = match post_processing {
460 PostProcessing::Softmax => Self::softmax_static(&logits_vec),
461 PostProcessing::Sigmoid => Self::sigmoid_static(&logits_vec),
462 };
463
464 // Find predicted class
465 let (predicted_class, max_score) = scores
466 .iter()
467 .enumerate()
468 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
469 .map(|(idx, &score)| (idx, score))
470 .unwrap_or((0, 0.0));
471
472 Ok(InferenceResult {
473 labels: labels.to_vec(),
474 scores,
475 predicted_class,
476 max_score,
477 })
478 }
479
480 /// Apply softmax to logits (static method)
481 ///
482 /// Softmax converts logits to probabilities that sum to 1.0.
483 /// Used for single-label classification (mutually exclusive classes).
484 ///
485 /// # Arguments
486 ///
487 /// * `logits` - Raw model output logits
488 ///
489 /// # Returns
490 ///
491 /// Probability distribution (sums to 1.0)
492 ///
493 /// # Example
494 ///
495 /// ```
496 /// use llm_shield_models::InferenceEngine;
497 ///
498 /// let logits = vec![1.0, 2.0, 0.5];
499 /// let probs = InferenceEngine::softmax_static(&logits);
500 ///
501 /// // Probabilities sum to 1.0
502 /// let sum: f32 = probs.iter().sum();
503 /// assert!((sum - 1.0).abs() < 0.001);
504 /// ```
505 pub fn softmax_static(logits: &[f32]) -> Vec<f32> {
506 if logits.is_empty() {
507 return vec![];
508 }
509
510 // Find max for numerical stability
511 let max_logit = logits
512 .iter()
513 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
514
515 // Compute exp(logit - max)
516 let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
517
518 // Sum of exponentials
519 let sum_exp: f32 = exp_logits.iter().sum();
520
521 // Normalize
522 if sum_exp == 0.0 {
523 // Edge case: all logits are very negative
524 vec![1.0 / logits.len() as f32; logits.len()]
525 } else {
526 exp_logits.iter().map(|&x| x / sum_exp).collect()
527 }
528 }
529
530 /// Apply sigmoid to logits (static method)
531 ///
532 /// Sigmoid converts each logit independently to [0, 1].
533 /// Used for multi-label classification (non-exclusive classes).
534 ///
535 /// # Arguments
536 ///
537 /// * `logits` - Raw model output logits
538 ///
539 /// # Returns
540 ///
541 /// Independent probabilities (do NOT sum to 1.0)
542 ///
543 /// # Example
544 ///
545 /// ```
546 /// use llm_shield_models::InferenceEngine;
547 ///
548 /// let logits = vec![0.0, 2.0, -2.0];
549 /// let probs = InferenceEngine::sigmoid_static(&logits);
550 ///
551 /// // sigmoid(0) ≈ 0.5
552 /// assert!((probs[0] - 0.5).abs() < 0.01);
553 ///
554 /// // All probabilities in [0, 1]
555 /// for p in probs {
556 /// assert!(p >= 0.0 && p <= 1.0);
557 /// }
558 /// ```
559 pub fn sigmoid_static(logits: &[f32]) -> Vec<f32> {
560 logits
561 .iter()
562 .map(|&x| 1.0 / (1.0 + (-x).exp()))
563 .collect()
564 }
565
566 /// Apply softmax to logits (instance method)
567 #[allow(dead_code)]
568 fn softmax(&self, logits: &[f32]) -> Vec<f32> {
569 Self::softmax_static(logits)
570 }
571
572 /// Apply sigmoid to logits (instance method)
573 #[allow(dead_code)]
574 fn sigmoid(&self, logits: &[f32]) -> Vec<f32> {
575 Self::sigmoid_static(logits)
576 }
577
578 /// Run token-level classification inference (for NER/token classification)
579 ///
580 /// # Arguments
581 ///
582 /// * `input_ids` - Token IDs from tokenizer
583 /// * `attention_mask` - Attention mask (1=real token, 0=padding)
584 /// * `labels` - BIO tag labels (e.g., ["O", "B-PERSON", "I-PERSON", ...])
585 ///
586 /// # Returns
587 ///
588 /// Vector of predictions, one per input token
589 ///
590 /// # Errors
591 ///
592 /// - Model inference failure
593 /// - Invalid tensor shapes
594 /// - Label count mismatch
595 ///
596 /// # Example
597 ///
598 /// ```rust,ignore
599 /// use llm_shield_models::InferenceEngine;
600 ///
601 /// let engine = InferenceEngine::new(session);
602 /// let labels = vec!["O", "B-PERSON", "I-PERSON"];
603 ///
604 /// let predictions = engine.infer_token_classification(
605 /// &input_ids,
606 /// &attention_mask,
607 /// &labels
608 /// ).await?;
609 ///
610 /// for pred in predictions {
611 /// println!("{}: {:.2}", pred.predicted_label, pred.confidence);
612 /// }
613 /// ```
614 pub async fn infer_token_classification(
615 &self,
616 input_ids: &[u32],
617 attention_mask: &[u32],
618 labels: &[String],
619 ) -> crate::Result<Vec<TokenPrediction>> {
620 // Validation
621 if input_ids.is_empty() {
622 return Err(Error::model("input_ids cannot be empty"));
623 }
624 if input_ids.len() != attention_mask.len() {
625 return Err(Error::model(format!(
626 "input_ids length ({}) != attention_mask length ({})",
627 input_ids.len(),
628 attention_mask.len()
629 )));
630 }
631 if labels.is_empty() {
632 return Err(Error::model("labels cannot be empty"));
633 }
634
635 // Run inference in blocking thread pool to avoid blocking async runtime
636 let session = Arc::clone(&self.session);
637 let input_ids = input_ids.to_vec();
638 let attention_mask = attention_mask.to_vec();
639 let labels = labels.to_vec();
640
641 tokio::task::spawn_blocking(move || {
642 let mut session_guard = session.lock()
643 .map_err(|e| Error::model(format!("Failed to lock session: {}", e)))?;
644 Self::infer_token_classification_sync(
645 &mut *session_guard,
646 &input_ids,
647 &attention_mask,
648 &labels
649 )
650 })
651 .await
652 .map_err(|e| Error::model(format!("Async inference task failed: {}", e)))?
653 }
654
655 /// Internal synchronous token classification implementation
656 fn infer_token_classification_sync(
657 session: &mut Session,
658 input_ids: &[u32],
659 attention_mask: &[u32],
660 labels: &[String],
661 ) -> crate::Result<Vec<TokenPrediction>> {
662 // Convert to i64 for ONNX
663 let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
664 let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
665
666 let batch_size = 1;
667 let seq_length = input_ids.len();
668
669 // Create input arrays [batch_size, seq_length]
670 let input_ids_array =
671 Array2::from_shape_vec((batch_size, seq_length), input_ids_i64)
672 .map_err(|e| Error::model(format!("Failed to create input array: {}", e)))?;
673
674 let attention_mask_array =
675 Array2::from_shape_vec((batch_size, seq_length), attention_mask_i64)
676 .map_err(|e| Error::model(format!("Failed to create attention mask array: {}", e)))?;
677
678 // Create ONNX values
679 let input_ids_value = ort::value::Value::from_array(input_ids_array)
680 .map_err(|e| Error::model(format!("Failed to create input_ids value: {}", e)))?;
681 let attention_mask_value = ort::value::Value::from_array(attention_mask_array)
682 .map_err(|e| Error::model(format!("Failed to create attention_mask value: {}", e)))?;
683
684 // Run inference
685 let outputs = session
686 .run(ort::inputs![
687 "input_ids" => input_ids_value,
688 "attention_mask" => attention_mask_value,
689 ])
690 .map_err(|e| Error::model(format!("Inference failed: {}", e)))?;
691
692 // Extract logits [batch_size, seq_length, num_labels]
693 let logits = outputs["logits"]
694 .try_extract_tensor::<f32>()
695 .map_err(|e| Error::model(format!("Failed to extract logits: {}", e)))?;
696
697 // logits is (shape, data)
698 let (shape, data) = logits;
699
700 // Validate shape: should be [batch_size, seq_length, num_labels]
701 if shape.len() != 3 {
702 return Err(Error::model(format!(
703 "Expected 3D logits tensor, got shape with {} dimensions",
704 shape.len()
705 )));
706 }
707
708 let actual_batch = shape[0] as usize;
709 let actual_seq_len = shape[1] as usize;
710 let num_labels = shape[2] as usize;
711
712 if actual_batch != batch_size {
713 return Err(Error::model(format!(
714 "Batch size mismatch: expected {}, got {}",
715 batch_size, actual_batch
716 )));
717 }
718
719 if actual_seq_len != seq_length {
720 return Err(Error::model(format!(
721 "Sequence length mismatch: expected {}, got {}",
722 seq_length, actual_seq_len
723 )));
724 }
725
726 if num_labels != labels.len() {
727 return Err(Error::model(format!(
728 "Label count mismatch: model has {} labels, provided {}",
729 num_labels,
730 labels.len()
731 )));
732 }
733
734 // Process each token
735 let mut predictions = Vec::with_capacity(seq_length);
736
737 for token_idx in 0..seq_length {
738 // Extract logits for this token
739 let start_idx = token_idx * num_labels;
740 let end_idx = start_idx + num_labels;
741 let token_logits: Vec<f32> = data[start_idx..end_idx].to_vec();
742
743 // Apply softmax
744 let scores = Self::softmax_static(&token_logits);
745
746 // Find predicted class
747 let (predicted_class, max_score) = scores
748 .iter()
749 .enumerate()
750 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
751 .map(|(idx, &score)| (idx, score))
752 .unwrap_or((0, 0.0));
753
754 let predicted_label = labels[predicted_class].clone();
755
756 predictions.push(TokenPrediction::new(
757 input_ids[token_idx],
758 predicted_label,
759 predicted_class,
760 max_score,
761 scores,
762 ));
763 }
764
765 Ok(predictions)
766 }
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772
773 #[test]
774 fn test_inference_result_predicted_label() {
775 let result = InferenceResult {
776 labels: vec!["safe".to_string(), "unsafe".to_string()],
777 scores: vec![0.8, 0.2],
778 predicted_class: 0,
779 max_score: 0.8,
780 };
781
782 assert_eq!(result.predicted_label(), Some("safe"));
783 assert!(result.exceeds_threshold(0.7));
784 assert!(!result.exceeds_threshold(0.9));
785 }
786
787 #[test]
788 fn test_softmax_values() {
789 // Manual softmax verification would require creating a session
790 // This test verifies the structure compiles
791 assert!(true);
792 }
793}