use llm_shield_models::{
InferenceEngine, TokenPrediction, TokenizerWrapper, TokenizerConfig,
};
use std::sync::{Arc, Mutex};
#[tokio::test]
#[ignore] async fn test_token_classification_basic() {
let text = "John Smith works at Microsoft";
}
#[test]
fn test_token_prediction_invariants() {
let prediction = TokenPrediction::new(
101, "B-PERSON".to_string(),
1, 0.95,
vec![0.03, 0.95, 0.02], );
assert!(prediction.validate().is_ok());
}
#[test]
fn test_token_prediction_invalid_confidence() {
let prediction = TokenPrediction::new(
101,
"B-PERSON".to_string(),
1,
1.5, vec![0.03, 0.95, 0.02],
);
assert!(prediction.validate().is_err());
}
#[test]
fn test_token_prediction_confidence_mismatch() {
let prediction = TokenPrediction::new(
101,
"B-PERSON".to_string(),
1,
0.80, vec![0.03, 0.95, 0.02],
);
assert!(prediction.validate().is_err());
}
#[test]
fn test_token_prediction_scores_dont_sum_to_one() {
let prediction = TokenPrediction::new(
101,
"B-PERSON".to_string(),
1,
0.95,
vec![0.05, 0.95, 0.05], );
assert!(prediction.validate().is_err());
}
#[test]
fn test_token_prediction_invalid_class_index() {
let prediction = TokenPrediction::new(
101,
"B-PERSON".to_string(),
10, 0.95,
vec![0.03, 0.95, 0.02],
);
assert!(prediction.validate().is_err());
}