use llm_shield_models::{TokenizerWrapper, TokenizerConfig};
#[test]
fn test_tokenizer_config_default() {
let config = TokenizerConfig::default();
assert_eq!(config.max_length, 512);
assert!(config.padding);
assert!(config.truncation);
assert!(config.add_special_tokens);
}
#[test]
fn test_tokenizer_config_custom() {
let config = TokenizerConfig {
max_length: 256,
padding: false,
truncation: false,
add_special_tokens: false,
};
assert_eq!(config.max_length, 256);
assert!(!config.padding);
assert!(!config.truncation);
assert!(!config.add_special_tokens);
}
#[test]
fn test_tokenizer_from_pretrained_deberta() {
let result = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
);
assert!(result.is_ok(), "Failed to load DeBERTa tokenizer: {:?}", result.err());
}
#[test]
fn test_tokenizer_from_pretrained_roberta() {
let result = TokenizerWrapper::from_pretrained(
"roberta-base",
TokenizerConfig::default(),
);
assert!(result.is_ok(), "Failed to load RoBERTa tokenizer: {:?}", result.err());
}
#[test]
fn test_tokenizer_from_invalid_model() {
let result = TokenizerWrapper::from_pretrained(
"invalid/nonexistent-model",
TokenizerConfig::default(),
);
assert!(result.is_err());
}
#[test]
fn test_encode_simple_text() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let text = "Hello, world!";
let encoding = tokenizer.encode(text).unwrap();
assert!(!encoding.input_ids.is_empty());
assert_eq!(encoding.input_ids.len(), encoding.attention_mask.len());
assert!(encoding.input_ids.len() < 20);
}
#[test]
fn test_encode_empty_string() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let encoding = tokenizer.encode("").unwrap();
assert!(!encoding.input_ids.is_empty());
}
#[test]
fn test_truncation_at_max_length() {
let config = TokenizerConfig {
max_length: 512,
truncation: true,
..Default::default()
};
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
config,
).unwrap();
let long_text = "This is a test sentence. ".repeat(1000);
let encoding = tokenizer.encode(&long_text).unwrap();
assert!(encoding.input_ids.len() <= 512);
}
#[test]
fn test_truncation_disabled() {
let config = TokenizerConfig {
max_length: 512,
truncation: false,
..Default::default()
};
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
config,
).unwrap();
let text = "This is a test sentence. ".repeat(100);
let encoding = tokenizer.encode(&text).unwrap();
assert!(!encoding.input_ids.is_empty());
}
#[test]
fn test_padding_to_max_length() {
let config = TokenizerConfig {
max_length: 512,
padding: true,
truncation: true,
..Default::default()
};
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
config,
).unwrap();
let short_text = "Hello";
let encoding = tokenizer.encode(short_text).unwrap();
assert_eq!(encoding.input_ids.len(), 512);
assert_eq!(encoding.attention_mask.len(), 512);
let padding_count = encoding.attention_mask.iter().filter(|&&x| x == 0).count();
assert!(padding_count > 0, "Expected padding tokens");
}
#[test]
fn test_padding_disabled() {
let config = TokenizerConfig {
max_length: 512,
padding: false,
truncation: true,
..Default::default()
};
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
config,
).unwrap();
let short_text = "Hello";
let encoding = tokenizer.encode(short_text).unwrap();
assert!(encoding.input_ids.len() < 512);
}
#[test]
fn test_special_tokens_added() {
let config = TokenizerConfig {
add_special_tokens: true,
padding: false,
..Default::default()
};
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
config,
).unwrap();
let encoding = tokenizer.encode("test").unwrap();
assert!(encoding.input_ids.len() >= 3);
}
#[test]
fn test_special_tokens_disabled() {
let config = TokenizerConfig {
add_special_tokens: false,
padding: false,
truncation: true,
max_length: 512,
};
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
config,
).unwrap();
let encoding = tokenizer.encode("test").unwrap();
assert!(!encoding.input_ids.is_empty());
}
#[test]
fn test_encode_batch() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let texts = vec![
"First sentence",
"Second sentence is longer",
"Third",
];
let encodings = tokenizer.encode_batch(&texts).unwrap();
assert_eq!(encodings.len(), 3);
let first_len = encodings[0].input_ids.len();
for encoding in &encodings {
assert_eq!(encoding.input_ids.len(), first_len);
}
}
#[test]
fn test_encode_batch_empty() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let texts: Vec<&str> = vec![];
let encodings = tokenizer.encode_batch(&texts).unwrap();
assert_eq!(encodings.len(), 0);
}
#[test]
fn test_tokenizer_thread_safety() {
use std::sync::Arc;
use std::thread;
let tokenizer = Arc::new(
TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap()
);
let mut handles = vec![];
for i in 0..4 {
let tokenizer_clone = Arc::clone(&tokenizer);
let handle = thread::spawn(move || {
let text = format!("Test sentence number {}", i);
tokenizer_clone.encode(&text).unwrap()
});
handles.push(handle);
}
for handle in handles {
let encoding = handle.join().unwrap();
assert!(!encoding.input_ids.is_empty());
}
}
#[test]
fn test_encode_unicode() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let texts = vec![
"Hello 世界", "Привет мир", "Ù…Ø±ØØ¨Ø§", "🚀 Emoji test 🎉",
];
for text in texts {
let encoding = tokenizer.encode(text).unwrap();
assert!(!encoding.input_ids.is_empty(), "Failed to encode: {}", text);
}
}
#[test]
fn test_encode_special_characters() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let text = "Test with\nnewlines\tand\ttabs";
let encoding = tokenizer.encode(text).unwrap();
assert!(!encoding.input_ids.is_empty());
}
#[test]
fn test_encoding_properties() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let text = "This is a test sentence";
let encoding = tokenizer.encode(text).unwrap();
assert!(!encoding.input_ids.is_empty());
assert!(!encoding.attention_mask.is_empty());
assert_eq!(encoding.input_ids.len(), encoding.attention_mask.len());
}
#[test]
fn test_different_tokenizer_types() {
let models = vec![
"microsoft/deberta-v3-base",
"roberta-base",
];
for model_name in models {
let result = TokenizerWrapper::from_pretrained(
model_name,
TokenizerConfig::default(),
);
assert!(result.is_ok(), "Failed to load tokenizer: {}", model_name);
let tokenizer = result.unwrap();
let encoding = tokenizer.encode("test").unwrap();
assert!(!encoding.input_ids.is_empty());
}
}
#[test]
fn test_config_validation() {
let configs = vec![
TokenizerConfig {
max_length: 128,
padding: true,
truncation: true,
add_special_tokens: true,
},
TokenizerConfig {
max_length: 256,
padding: false,
truncation: true,
add_special_tokens: true,
},
TokenizerConfig {
max_length: 512,
padding: true,
truncation: false,
add_special_tokens: false,
},
];
for config in configs {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
config,
).unwrap();
let encoding = tokenizer.encode("test").unwrap();
assert!(!encoding.input_ids.is_empty());
}
}
#[test]
fn test_very_long_text() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let long_text = "word ".repeat(5000);
let encoding = tokenizer.encode(&long_text).unwrap();
assert_eq!(encoding.input_ids.len(), 512);
}
#[test]
fn test_encoding_to_vec() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let encoding = tokenizer.encode("test").unwrap();
let ids_vec = encoding.input_ids.clone();
let mask_vec = encoding.attention_mask.clone();
assert!(!ids_vec.is_empty());
assert_eq!(ids_vec.len(), mask_vec.len());
}
#[test]
fn test_edge_cases() {
let tokenizer = TokenizerWrapper::from_pretrained(
"microsoft/deberta-v3-base",
TokenizerConfig::default(),
).unwrap();
let encoding = tokenizer.encode("a").unwrap();
assert!(!encoding.input_ids.is_empty());
let encoding = tokenizer.encode(" ").unwrap();
assert!(!encoding.input_ids.is_empty());
let encoding = tokenizer.encode("!@#$%^&*()").unwrap();
assert!(!encoding.input_ids.is_empty());
let encoding = tokenizer.encode("1234567890").unwrap();
assert!(!encoding.input_ids.is_empty());
}