use llm_shield_core::Error;
use std::sync::Arc;
use tokenizers::{
Tokenizer,
PaddingParams, PaddingStrategy, PaddingDirection,
TruncationParams, TruncationStrategy,
};
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone)]
pub struct TokenizerConfig {
pub max_length: usize,
pub padding: bool,
pub truncation: bool,
pub add_special_tokens: bool,
}
impl Default for TokenizerConfig {
fn default() -> Self {
Self {
max_length: 512,
padding: true,
truncation: true,
add_special_tokens: true,
}
}
}
#[derive(Debug, Clone)]
pub struct Encoding {
pub input_ids: Vec<u32>,
pub attention_mask: Vec<u32>,
pub offsets: Vec<(usize, usize)>,
}
impl Encoding {
pub fn new(input_ids: Vec<u32>, attention_mask: Vec<u32>) -> Self {
let offsets = vec![(0, 0); input_ids.len()];
Self {
input_ids,
attention_mask,
offsets,
}
}
pub fn with_offsets(
input_ids: Vec<u32>,
attention_mask: Vec<u32>,
offsets: Vec<(usize, usize)>,
) -> Self {
Self {
input_ids,
attention_mask,
offsets,
}
}
pub fn len(&self) -> usize {
self.input_ids.len()
}
pub fn is_empty(&self) -> bool {
self.input_ids.is_empty()
}
pub fn to_arrays(&self) -> (Vec<i64>, Vec<i64>) {
let input_ids = self.input_ids.iter().map(|&x| x as i64).collect();
let attention_mask = self.attention_mask.iter().map(|&x| x as i64).collect();
(input_ids, attention_mask)
}
}
pub struct TokenizerWrapper {
tokenizer: Arc<Tokenizer>,
config: TokenizerConfig,
}
impl TokenizerWrapper {
pub fn from_pretrained(model_name: &str, config: TokenizerConfig) -> Result<Self> {
tracing::info!("Loading tokenizer from: {}", model_name);
let tokenizer_path = format!("models/{}/tokenizer.json", model_name);
let mut tokenizer = if std::path::Path::new(&tokenizer_path).exists() {
Tokenizer::from_file(&tokenizer_path)
.map_err(|e| {
Error::model(format!(
"Failed to load tokenizer from '{}': {}",
tokenizer_path, e
))
})?
} else {
return Err(Error::model(format!(
"Tokenizer not found at '{}'. Please download tokenizer files first.",
tokenizer_path
)));
};
if config.padding {
let padding = PaddingParams {
strategy: PaddingStrategy::Fixed(config.max_length),
direction: PaddingDirection::Right,
pad_id: 0, pad_type_id: 0,
pad_token: String::from("[PAD]"), pad_to_multiple_of: None,
};
tokenizer.with_padding(Some(padding));
}
if config.truncation {
let truncation = TruncationParams {
max_length: config.max_length,
strategy: TruncationStrategy::LongestFirst,
stride: 0,
direction: tokenizers::TruncationDirection::Right,
};
tokenizer.with_truncation(Some(truncation))
.map_err(|e| {
Error::model(format!("Failed to configure truncation: {}", e))
})?;
}
tracing::debug!(
"Tokenizer loaded: max_length={}, padding={}, truncation={}",
config.max_length,
config.padding,
config.truncation
);
Ok(Self {
tokenizer: Arc::new(tokenizer),
config,
})
}
pub fn encode(&self, text: &str) -> Result<Encoding> {
let encoding = self.tokenizer
.encode(text, self.config.add_special_tokens)
.map_err(|e| {
Error::model(format!("Failed to encode text: {}", e))
})?;
let input_ids = encoding.get_ids().to_vec();
let attention_mask = encoding.get_attention_mask().to_vec();
let offsets: Vec<(usize, usize)> = encoding
.get_offsets()
.iter()
.map(|offset| (offset.0, offset.1))
.collect();
Ok(Encoding::with_offsets(input_ids, attention_mask, offsets))
}
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<Encoding>> {
if texts.is_empty() {
return Ok(vec![]);
}
let encodings = self.tokenizer
.encode_batch(texts.to_vec(), self.config.add_special_tokens)
.map_err(|e| {
Error::model(format!("Failed to encode batch: {}", e))
})?;
let results = encodings
.into_iter()
.map(|enc| {
let input_ids = enc.get_ids().to_vec();
let attention_mask = enc.get_attention_mask().to_vec();
let offsets: Vec<(usize, usize)> = enc
.get_offsets()
.iter()
.map(|offset| (offset.0, offset.1))
.collect();
Encoding::with_offsets(input_ids, attention_mask, offsets)
})
.collect();
Ok(results)
}
pub fn config(&self) -> &TokenizerConfig {
&self.config
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(self.config.add_special_tokens)
}
}
impl Clone for TokenizerWrapper {
fn clone(&self) -> Self {
Self {
tokenizer: Arc::clone(&self.tokenizer),
config: self.config.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenizer_config_default() {
let config = TokenizerConfig::default();
assert_eq!(config.max_length, 512);
assert!(config.padding);
assert!(config.truncation);
assert!(config.add_special_tokens);
}
#[test]
fn test_encoding_creation() {
let encoding = Encoding::new(
vec![101, 2023, 2003, 102],
vec![1, 1, 1, 1],
);
assert_eq!(encoding.len(), 4);
assert!(!encoding.is_empty());
}
#[test]
fn test_encoding_to_arrays() {
let encoding = Encoding::new(
vec![101, 2023, 102],
vec![1, 1, 1],
);
let (input_ids, attention_mask) = encoding.to_arrays();
assert_eq!(input_ids, vec![101i64, 2023, 102]);
assert_eq!(attention_mask, vec![1i64, 1, 1]);
}
#[test]
fn test_encoding_empty() {
let encoding = Encoding::new(vec![], vec![]);
assert!(encoding.is_empty());
assert_eq!(encoding.len(), 0);
}
}