pub struct InferenceEngine { /* private fields */ }Expand description
Inference engine for running ONNX model inference
§Features
- Synchronous and asynchronous inference
- Binary and multi-label classification
- Automatic post-processing (softmax/sigmoid)
- Batch inference support (optional)
§Example
ⓘ
use llm_shield_models::InferenceEngine;
use std::sync::Arc;
let engine = InferenceEngine::new(session);
// Run inference
let result = engine.infer(
&input_ids,
&attention_mask,
&labels,
PostProcessing::Softmax,
).await?;
println!("Predicted: {}", result.predicted_label().unwrap());
println!("Confidence: {:.2}", result.max_score);Implementations§
Source§impl InferenceEngine
impl InferenceEngine
Sourcepub fn new(session: Arc<Mutex<Session>>) -> Self
pub fn new(session: Arc<Mutex<Session>>) -> Self
Create a new inference engine
§Arguments
session- ONNX Runtime session wrapped in Arc<Mutex<>> for thread-safe mutable access
Sourcepub async fn infer_async(
&self,
input_ids: &[u32],
attention_mask: &[u32],
labels: &[String],
post_processing: PostProcessing,
) -> Result<InferenceResult>
pub async fn infer_async( &self, input_ids: &[u32], attention_mask: &[u32], labels: &[String], post_processing: PostProcessing, ) -> Result<InferenceResult>
Sourcepub fn infer(
&self,
input_ids: &[u32],
attention_mask: &[u32],
labels: &[String],
post_processing: PostProcessing,
) -> Result<InferenceResult>
pub fn infer( &self, input_ids: &[u32], attention_mask: &[u32], labels: &[String], post_processing: PostProcessing, ) -> Result<InferenceResult>
Sourcepub fn softmax_static(logits: &[f32]) -> Vec<f32>
pub fn softmax_static(logits: &[f32]) -> Vec<f32>
Apply softmax to logits (static method)
Softmax converts logits to probabilities that sum to 1.0. Used for single-label classification (mutually exclusive classes).
§Arguments
logits- Raw model output logits
§Returns
Probability distribution (sums to 1.0)
§Example
use llm_shield_models::InferenceEngine;
let logits = vec![1.0, 2.0, 0.5];
let probs = InferenceEngine::softmax_static(&logits);
// Probabilities sum to 1.0
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 0.001);Sourcepub fn sigmoid_static(logits: &[f32]) -> Vec<f32>
pub fn sigmoid_static(logits: &[f32]) -> Vec<f32>
Apply sigmoid to logits (static method)
Sigmoid converts each logit independently to [0, 1]. Used for multi-label classification (non-exclusive classes).
§Arguments
logits- Raw model output logits
§Returns
Independent probabilities (do NOT sum to 1.0)
§Example
use llm_shield_models::InferenceEngine;
let logits = vec![0.0, 2.0, -2.0];
let probs = InferenceEngine::sigmoid_static(&logits);
// sigmoid(0) ≈ 0.5
assert!((probs[0] - 0.5).abs() < 0.01);
// All probabilities in [0, 1]
for p in probs {
assert!(p >= 0.0 && p <= 1.0);
}Sourcepub async fn infer_token_classification(
&self,
input_ids: &[u32],
attention_mask: &[u32],
labels: &[String],
) -> Result<Vec<TokenPrediction>>
pub async fn infer_token_classification( &self, input_ids: &[u32], attention_mask: &[u32], labels: &[String], ) -> Result<Vec<TokenPrediction>>
Run token-level classification inference (for NER/token classification)
§Arguments
input_ids- Token IDs from tokenizerattention_mask- Attention mask (1=real token, 0=padding)labels- BIO tag labels (e.g., [“O”, “B-PERSON”, “I-PERSON”, …])
§Returns
Vector of predictions, one per input token
§Errors
- Model inference failure
- Invalid tensor shapes
- Label count mismatch
§Example
ⓘ
use llm_shield_models::InferenceEngine;
let engine = InferenceEngine::new(session);
let labels = vec!["O", "B-PERSON", "I-PERSON"];
let predictions = engine.infer_token_classification(
&input_ids,
&attention_mask,
&labels
).await?;
for pred in predictions {
println!("{}: {:.2}", pred.predicted_label, pred.confidence);
}Auto Trait Implementations§
impl Freeze for InferenceEngine
impl RefUnwindSafe for InferenceEngine
impl Send for InferenceEngine
impl Sync for InferenceEngine
impl Unpin for InferenceEngine
impl UnwindSafe for InferenceEngine
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more