use crate::error::{MemvidError, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
use tokenizers::Tokenizer;
use unicode_normalization::UnicodeNormalization;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextConfig {
pub max_length: usize,
pub truncate: bool,
pub add_special_tokens: bool,
pub normalize_unicode: bool,
pub lowercase: bool,
}
impl Default for TextConfig {
fn default() -> Self {
Self {
max_length: 384,
truncate: true,
add_special_tokens: true,
normalize_unicode: true,
lowercase: false, }
}
}
#[derive(Debug, Clone)]
pub struct TokenizedText {
pub input_ids: Vec<u32>,
pub attention_mask: Vec<u32>,
pub token_type_ids: Vec<u32>,
pub original_length: usize,
}
pub struct TextProcessor {
tokenizer: Option<Tokenizer>,
config: TextConfig,
}
impl TextProcessor {
pub fn new(config: TextConfig) -> Self {
Self {
tokenizer: None,
config,
}
}
pub fn load_tokenizer<P: AsRef<Path>>(&mut self, model_dir: P) -> Result<()> {
let tokenizer_path = model_dir.as_ref().join("tokenizer.json");
if tokenizer_path.exists() {
match Tokenizer::from_file(&tokenizer_path) {
Ok(tokenizer) => {
self.tokenizer = Some(tokenizer);
log::info!("Loaded tokenizer from {:?}", tokenizer_path);
Ok(())
}
Err(e) => {
log::warn!("Failed to load tokenizer from {:?}: {}", tokenizer_path, e);
Err(MemvidError::MachineLearning(format!(
"Failed to load tokenizer: {}",
e
)))
}
}
} else {
log::warn!("Tokenizer file not found at {:?}", tokenizer_path);
Err(MemvidError::MachineLearning(
"Tokenizer file not found".to_string(),
))
}
}
pub fn preprocess_text(&self, text: &str) -> String {
let mut processed = text.to_string();
if self.config.normalize_unicode {
processed = processed.nfc().collect::<String>();
}
if self.config.lowercase {
processed = processed.to_lowercase();
}
processed = processed.trim().to_string();
processed = processed
.split_whitespace()
.collect::<Vec<&str>>()
.join(" ");
processed
}
pub fn tokenize(&self, text: &str) -> Result<TokenizedText> {
let preprocessed = self.preprocess_text(text);
let original_length = text.len();
if let Some(ref tokenizer) = self.tokenizer {
let encoding = tokenizer
.encode(preprocessed.clone(), self.config.add_special_tokens)
.map_err(|e| MemvidError::MachineLearning(format!("Tokenization failed: {}", e)))?;
let input_ids = encoding.get_ids().to_vec();
let attention_mask = encoding.get_attention_mask().to_vec();
let token_type_ids = encoding.get_type_ids().to_vec();
let (input_ids, attention_mask, token_type_ids) =
self.pad_or_truncate(input_ids, attention_mask, token_type_ids);
Ok(TokenizedText {
input_ids,
attention_mask,
token_type_ids,
original_length,
})
} else {
log::warn!("No tokenizer loaded, using fallback tokenization");
self.fallback_tokenize(&preprocessed, original_length)
}
}
pub fn tokenize_batch(&self, texts: &[String]) -> Result<Vec<TokenizedText>> {
let mut results = Vec::new();
if let Some(ref tokenizer) = self.tokenizer {
let preprocessed: Vec<String> = texts
.iter()
.map(|text| self.preprocess_text(text))
.collect();
let encodings = tokenizer
.encode_batch(preprocessed.clone(), self.config.add_special_tokens)
.map_err(|e| {
MemvidError::MachineLearning(format!("Batch tokenization failed: {}", e))
})?;
for (encoding, original_text) in encodings.iter().zip(texts.iter()) {
let input_ids = encoding.get_ids().to_vec();
let attention_mask = encoding.get_attention_mask().to_vec();
let token_type_ids = encoding.get_type_ids().to_vec();
let (input_ids, attention_mask, token_type_ids) =
self.pad_or_truncate(input_ids, attention_mask, token_type_ids);
results.push(TokenizedText {
input_ids,
attention_mask,
token_type_ids,
original_length: original_text.len(),
});
}
} else {
for text in texts {
results.push(self.tokenize(text)?);
}
}
Ok(results)
}
fn pad_or_truncate(
&self,
mut input_ids: Vec<u32>,
mut attention_mask: Vec<u32>,
mut token_type_ids: Vec<u32>,
) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
let max_len = self.config.max_length;
if input_ids.len() > max_len && self.config.truncate {
input_ids.truncate(max_len);
attention_mask.truncate(max_len);
token_type_ids.truncate(max_len);
} else if input_ids.len() < max_len {
let pad_len = max_len - input_ids.len();
input_ids.extend(vec![0; pad_len]); attention_mask.extend(vec![0; pad_len]); token_type_ids.extend(vec![0; pad_len]); }
(input_ids, attention_mask, token_type_ids)
}
fn fallback_tokenize(&self, text: &str, original_length: usize) -> Result<TokenizedText> {
let words: Vec<&str> = text.split_whitespace().collect();
let mut input_ids = Vec::new();
if self.config.add_special_tokens {
input_ids.push(101); }
for word in words.iter().take(self.config.max_length - 2) {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
use std::hash::{Hash, Hasher};
word.hash(&mut hasher);
let token_id = (hasher.finish() % 30000 + 1000) as u32; input_ids.push(token_id);
}
if self.config.add_special_tokens {
input_ids.push(102); }
let seq_len = input_ids.len();
let attention_mask = vec![1u32; seq_len];
let token_type_ids = vec![0u32; seq_len];
let (input_ids, attention_mask, token_type_ids) =
self.pad_or_truncate(input_ids, attention_mask, token_type_ids);
log::debug!(
"Fallback tokenization: {} words -> {} tokens",
words.len(),
seq_len
);
Ok(TokenizedText {
input_ids,
attention_mask,
token_type_ids,
original_length,
})
}
pub fn vocab_size(&self) -> Option<usize> {
self.tokenizer.as_ref().map(|t| t.get_vocab_size(false))
}
pub fn config(&self) -> &TextConfig {
&self.config
}
pub fn has_tokenizer(&self) -> bool {
self.tokenizer.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_text_config_default() {
let config = TextConfig::default();
assert_eq!(config.max_length, 384);
assert!(config.truncate);
assert!(config.add_special_tokens);
}
#[test]
fn test_text_preprocessing() {
let config = TextConfig {
normalize_unicode: true,
lowercase: true,
..Default::default()
};
let processor = TextProcessor::new(config);
let text = " Hello WORLD! ";
let processed = processor.preprocess_text(text);
assert_eq!(processed, "hello world!");
}
#[test]
fn test_fallback_tokenization() {
let config = TextConfig::default();
let max_length = config.max_length;
let processor = TextProcessor::new(config);
let text = "Hello world test";
let tokenized = processor.tokenize(text).unwrap();
assert!(!tokenized.input_ids.is_empty());
assert_eq!(tokenized.input_ids.len(), max_length);
assert_eq!(tokenized.attention_mask.len(), max_length);
assert_eq!(tokenized.original_length, text.len());
}
#[test]
fn test_batch_tokenization_fallback() {
let config = TextConfig::default();
let max_length = config.max_length;
let processor = TextProcessor::new(config);
let texts = vec![
"First sentence".to_string(),
"Second sentence".to_string(),
"Third sentence".to_string(),
];
let tokenized = processor.tokenize_batch(&texts).unwrap();
assert_eq!(tokenized.len(), 3);
for tokens in &tokenized {
assert_eq!(tokens.input_ids.len(), max_length);
assert_eq!(tokens.attention_mask.len(), max_length);
}
}
#[test]
fn test_padding_truncation() {
let config = TextConfig {
max_length: 10,
truncate: true,
..Default::default()
};
let processor = TextProcessor::new(config);
let long_text = "This is a very long sentence that should be truncated";
let tokenized = processor.tokenize(long_text).unwrap();
assert_eq!(tokenized.input_ids.len(), 10);
let short_text = "Short";
let tokenized = processor.tokenize(short_text).unwrap();
assert_eq!(tokenized.input_ids.len(), 10);
let padding_start = tokenized.attention_mask.iter().position(|&x| x == 0);
assert!(padding_start.is_some());
}
}