1use crate::error::{MemvidError, Result};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9use tokenizers::Tokenizer;
10use unicode_normalization::UnicodeNormalization;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TextConfig {
15 pub max_length: usize,
17 pub truncate: bool,
19 pub add_special_tokens: bool,
21 pub normalize_unicode: bool,
23 pub lowercase: bool,
25}
26
27impl Default for TextConfig {
28 fn default() -> Self {
29 Self {
30 max_length: 384,
31 truncate: true,
32 add_special_tokens: true,
33 normalize_unicode: true,
34 lowercase: false, }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct TokenizedText {
42 pub input_ids: Vec<u32>,
44 pub attention_mask: Vec<u32>,
46 pub token_type_ids: Vec<u32>,
48 pub original_length: usize,
50}
51
52pub struct TextProcessor {
54 tokenizer: Option<Tokenizer>,
56 config: TextConfig,
58}
59
60impl TextProcessor {
61 pub fn new(config: TextConfig) -> Self {
63 Self {
64 tokenizer: None,
65 config,
66 }
67 }
68
69 pub fn load_tokenizer<P: AsRef<Path>>(&mut self, model_dir: P) -> Result<()> {
71 let tokenizer_path = model_dir.as_ref().join("tokenizer.json");
72
73 if tokenizer_path.exists() {
74 match Tokenizer::from_file(&tokenizer_path) {
75 Ok(tokenizer) => {
76 self.tokenizer = Some(tokenizer);
77 log::info!("Loaded tokenizer from {:?}", tokenizer_path);
78 Ok(())
79 }
80 Err(e) => {
81 log::warn!("Failed to load tokenizer from {:?}: {}", tokenizer_path, e);
82 Err(MemvidError::MachineLearning(format!(
83 "Failed to load tokenizer: {}",
84 e
85 )))
86 }
87 }
88 } else {
89 log::warn!("Tokenizer file not found at {:?}", tokenizer_path);
90 Err(MemvidError::MachineLearning(
91 "Tokenizer file not found".to_string(),
92 ))
93 }
94 }
95
96 pub fn preprocess_text(&self, text: &str) -> String {
98 let mut processed = text.to_string();
99
100 if self.config.normalize_unicode {
102 processed = processed.nfc().collect::<String>();
103 }
104
105 if self.config.lowercase {
107 processed = processed.to_lowercase();
108 }
109
110 processed = processed.trim().to_string();
112
113 processed = processed
115 .split_whitespace()
116 .collect::<Vec<&str>>()
117 .join(" ");
118
119 processed
120 }
121
122 pub fn tokenize(&self, text: &str) -> Result<TokenizedText> {
124 let preprocessed = self.preprocess_text(text);
125 let original_length = text.len();
126
127 if let Some(ref tokenizer) = self.tokenizer {
128 let encoding = tokenizer
130 .encode(preprocessed.clone(), self.config.add_special_tokens)
131 .map_err(|e| MemvidError::MachineLearning(format!("Tokenization failed: {}", e)))?;
132
133 let input_ids = encoding.get_ids().to_vec();
134 let attention_mask = encoding.get_attention_mask().to_vec();
135 let token_type_ids = encoding.get_type_ids().to_vec();
136
137 let (input_ids, attention_mask, token_type_ids) =
139 self.pad_or_truncate(input_ids, attention_mask, token_type_ids);
140
141 Ok(TokenizedText {
142 input_ids,
143 attention_mask,
144 token_type_ids,
145 original_length,
146 })
147 } else {
148 log::warn!("No tokenizer loaded, using fallback tokenization");
150 self.fallback_tokenize(&preprocessed, original_length)
151 }
152 }
153
154 pub fn tokenize_batch(&self, texts: &[String]) -> Result<Vec<TokenizedText>> {
156 let mut results = Vec::new();
157
158 if let Some(ref tokenizer) = self.tokenizer {
159 let preprocessed: Vec<String> = texts
161 .iter()
162 .map(|text| self.preprocess_text(text))
163 .collect();
164
165 let encodings = tokenizer
166 .encode_batch(preprocessed.clone(), self.config.add_special_tokens)
167 .map_err(|e| {
168 MemvidError::MachineLearning(format!("Batch tokenization failed: {}", e))
169 })?;
170
171 for (encoding, original_text) in encodings.iter().zip(texts.iter()) {
172 let input_ids = encoding.get_ids().to_vec();
173 let attention_mask = encoding.get_attention_mask().to_vec();
174 let token_type_ids = encoding.get_type_ids().to_vec();
175
176 let (input_ids, attention_mask, token_type_ids) =
177 self.pad_or_truncate(input_ids, attention_mask, token_type_ids);
178
179 results.push(TokenizedText {
180 input_ids,
181 attention_mask,
182 token_type_ids,
183 original_length: original_text.len(),
184 });
185 }
186 } else {
187 for text in texts {
189 results.push(self.tokenize(text)?);
190 }
191 }
192
193 Ok(results)
194 }
195
196 fn pad_or_truncate(
198 &self,
199 mut input_ids: Vec<u32>,
200 mut attention_mask: Vec<u32>,
201 mut token_type_ids: Vec<u32>,
202 ) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
203 let max_len = self.config.max_length;
204
205 if input_ids.len() > max_len && self.config.truncate {
206 input_ids.truncate(max_len);
208 attention_mask.truncate(max_len);
209 token_type_ids.truncate(max_len);
210 } else if input_ids.len() < max_len {
211 let pad_len = max_len - input_ids.len();
213 input_ids.extend(vec![0; pad_len]); attention_mask.extend(vec![0; pad_len]); token_type_ids.extend(vec![0; pad_len]); }
217
218 (input_ids, attention_mask, token_type_ids)
219 }
220
221 fn fallback_tokenize(&self, text: &str, original_length: usize) -> Result<TokenizedText> {
223 let words: Vec<&str> = text.split_whitespace().collect();
225 let mut input_ids = Vec::new();
226
227 if self.config.add_special_tokens {
229 input_ids.push(101); }
231
232 for word in words.iter().take(self.config.max_length - 2) {
234 let mut hasher = std::collections::hash_map::DefaultHasher::new();
236 use std::hash::{Hash, Hasher};
237 word.hash(&mut hasher);
238 let token_id = (hasher.finish() % 30000 + 1000) as u32; input_ids.push(token_id);
240 }
241
242 if self.config.add_special_tokens {
244 input_ids.push(102); }
246
247 let seq_len = input_ids.len();
249 let attention_mask = vec![1u32; seq_len];
250 let token_type_ids = vec![0u32; seq_len];
251
252 let (input_ids, attention_mask, token_type_ids) =
254 self.pad_or_truncate(input_ids, attention_mask, token_type_ids);
255
256 log::debug!(
257 "Fallback tokenization: {} words -> {} tokens",
258 words.len(),
259 seq_len
260 );
261
262 Ok(TokenizedText {
263 input_ids,
264 attention_mask,
265 token_type_ids,
266 original_length,
267 })
268 }
269
270 pub fn vocab_size(&self) -> Option<usize> {
272 self.tokenizer.as_ref().map(|t| t.get_vocab_size(false))
273 }
274
275 pub fn config(&self) -> &TextConfig {
277 &self.config
278 }
279
280 pub fn has_tokenizer(&self) -> bool {
282 self.tokenizer.is_some()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_text_config_default() {
292 let config = TextConfig::default();
293 assert_eq!(config.max_length, 384);
294 assert!(config.truncate);
295 assert!(config.add_special_tokens);
296 }
297
298 #[test]
299 fn test_text_preprocessing() {
300 let config = TextConfig {
301 normalize_unicode: true,
302 lowercase: true,
303 ..Default::default()
304 };
305 let processor = TextProcessor::new(config);
306
307 let text = " Hello WORLD! ";
308 let processed = processor.preprocess_text(text);
309 assert_eq!(processed, "hello world!");
310 }
311
312 #[test]
313 fn test_fallback_tokenization() {
314 let config = TextConfig::default();
315 let max_length = config.max_length;
316 let processor = TextProcessor::new(config);
317
318 let text = "Hello world test";
319 let tokenized = processor.tokenize(text).unwrap();
320
321 assert!(!tokenized.input_ids.is_empty());
322 assert_eq!(tokenized.input_ids.len(), max_length);
323 assert_eq!(tokenized.attention_mask.len(), max_length);
324 assert_eq!(tokenized.original_length, text.len());
325 }
326
327 #[test]
328 fn test_batch_tokenization_fallback() {
329 let config = TextConfig::default();
330 let max_length = config.max_length;
331 let processor = TextProcessor::new(config);
332
333 let texts = vec![
334 "First sentence".to_string(),
335 "Second sentence".to_string(),
336 "Third sentence".to_string(),
337 ];
338
339 let tokenized = processor.tokenize_batch(&texts).unwrap();
340 assert_eq!(tokenized.len(), 3);
341
342 for tokens in &tokenized {
343 assert_eq!(tokens.input_ids.len(), max_length);
344 assert_eq!(tokens.attention_mask.len(), max_length);
345 }
346 }
347
348 #[test]
349 fn test_padding_truncation() {
350 let config = TextConfig {
351 max_length: 10,
352 truncate: true,
353 ..Default::default()
354 };
355 let processor = TextProcessor::new(config);
356
357 let long_text = "This is a very long sentence that should be truncated";
359 let tokenized = processor.tokenize(long_text).unwrap();
360 assert_eq!(tokenized.input_ids.len(), 10);
361
362 let short_text = "Short";
364 let tokenized = processor.tokenize(short_text).unwrap();
365 assert_eq!(tokenized.input_ids.len(), 10);
366
367 let padding_start = tokenized.attention_mask.iter().position(|&x| x == 0);
369 assert!(padding_start.is_some());
370 }
371}