Skip to main content

memvid_rs/ml/
text.rs

1//! Text preprocessing and tokenization for ML models
2//!
3//! This module provides text preprocessing capabilities including tokenization,
4//! normalization, and preparation for ML model inference.
5
6use crate::error::{MemvidError, Result};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9use tokenizers::Tokenizer;
10use unicode_normalization::UnicodeNormalization;
11
12/// Text preprocessing configuration
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TextConfig {
15    /// Maximum sequence length
16    pub max_length: usize,
17    /// Whether to truncate long sequences
18    pub truncate: bool,
19    /// Whether to add special tokens (CLS, SEP)
20    pub add_special_tokens: bool,
21    /// Whether to normalize unicode
22    pub normalize_unicode: bool,
23    /// Whether to lowercase text
24    pub lowercase: bool,
25}
26
27impl Default for TextConfig {
28    fn default() -> Self {
29        Self {
30            max_length: 384,
31            truncate: true,
32            add_special_tokens: true,
33            normalize_unicode: true,
34            lowercase: false, // SentenceTransformers typically preserve case
35        }
36    }
37}
38
39/// Tokenized text ready for model inference
40#[derive(Debug, Clone)]
41pub struct TokenizedText {
42    /// Token IDs
43    pub input_ids: Vec<u32>,
44    /// Attention mask (1 for real tokens, 0 for padding)
45    pub attention_mask: Vec<u32>,
46    /// Token type IDs (for BERT-style models)
47    pub token_type_ids: Vec<u32>,
48    /// Original text length before processing
49    pub original_length: usize,
50}
51
52/// Text preprocessor and tokenizer
53pub struct TextProcessor {
54    /// Tokenizer instance
55    tokenizer: Option<Tokenizer>,
56    /// Configuration
57    config: TextConfig,
58}
59
60impl TextProcessor {
61    /// Create new text processor
62    pub fn new(config: TextConfig) -> Self {
63        Self {
64            tokenizer: None,
65            config,
66        }
67    }
68
69    /// Load tokenizer from model directory
70    pub fn load_tokenizer<P: AsRef<Path>>(&mut self, model_dir: P) -> Result<()> {
71        let tokenizer_path = model_dir.as_ref().join("tokenizer.json");
72
73        if tokenizer_path.exists() {
74            match Tokenizer::from_file(&tokenizer_path) {
75                Ok(tokenizer) => {
76                    self.tokenizer = Some(tokenizer);
77                    log::info!("Loaded tokenizer from {:?}", tokenizer_path);
78                    Ok(())
79                }
80                Err(e) => {
81                    log::warn!("Failed to load tokenizer from {:?}: {}", tokenizer_path, e);
82                    Err(MemvidError::MachineLearning(format!(
83                        "Failed to load tokenizer: {}",
84                        e
85                    )))
86                }
87            }
88        } else {
89            log::warn!("Tokenizer file not found at {:?}", tokenizer_path);
90            Err(MemvidError::MachineLearning(
91                "Tokenizer file not found".to_string(),
92            ))
93        }
94    }
95
96    /// Preprocess text (normalize, clean, etc.)
97    pub fn preprocess_text(&self, text: &str) -> String {
98        let mut processed = text.to_string();
99
100        // Unicode normalization
101        if self.config.normalize_unicode {
102            processed = processed.nfc().collect::<String>();
103        }
104
105        // Lowercase if configured
106        if self.config.lowercase {
107            processed = processed.to_lowercase();
108        }
109
110        // Basic cleaning
111        processed = processed.trim().to_string();
112
113        // Remove excessive whitespace
114        processed = processed
115            .split_whitespace()
116            .collect::<Vec<&str>>()
117            .join(" ");
118
119        processed
120    }
121
122    /// Tokenize text for model inference
123    pub fn tokenize(&self, text: &str) -> Result<TokenizedText> {
124        let preprocessed = self.preprocess_text(text);
125        let original_length = text.len();
126
127        if let Some(ref tokenizer) = self.tokenizer {
128            // Use real tokenizer
129            let encoding = tokenizer
130                .encode(preprocessed.clone(), self.config.add_special_tokens)
131                .map_err(|e| MemvidError::MachineLearning(format!("Tokenization failed: {}", e)))?;
132
133            let input_ids = encoding.get_ids().to_vec();
134            let attention_mask = encoding.get_attention_mask().to_vec();
135            let token_type_ids = encoding.get_type_ids().to_vec();
136
137            // Truncate or pad to max_length
138            let (input_ids, attention_mask, token_type_ids) =
139                self.pad_or_truncate(input_ids, attention_mask, token_type_ids);
140
141            Ok(TokenizedText {
142                input_ids,
143                attention_mask,
144                token_type_ids,
145                original_length,
146            })
147        } else {
148            // Fallback to simple word-based tokenization
149            log::warn!("No tokenizer loaded, using fallback tokenization");
150            self.fallback_tokenize(&preprocessed, original_length)
151        }
152    }
153
154    /// Tokenize multiple texts in batch
155    pub fn tokenize_batch(&self, texts: &[String]) -> Result<Vec<TokenizedText>> {
156        let mut results = Vec::new();
157
158        if let Some(ref tokenizer) = self.tokenizer {
159            // Batch tokenization for efficiency
160            let preprocessed: Vec<String> = texts
161                .iter()
162                .map(|text| self.preprocess_text(text))
163                .collect();
164
165            let encodings = tokenizer
166                .encode_batch(preprocessed.clone(), self.config.add_special_tokens)
167                .map_err(|e| {
168                    MemvidError::MachineLearning(format!("Batch tokenization failed: {}", e))
169                })?;
170
171            for (encoding, original_text) in encodings.iter().zip(texts.iter()) {
172                let input_ids = encoding.get_ids().to_vec();
173                let attention_mask = encoding.get_attention_mask().to_vec();
174                let token_type_ids = encoding.get_type_ids().to_vec();
175
176                let (input_ids, attention_mask, token_type_ids) =
177                    self.pad_or_truncate(input_ids, attention_mask, token_type_ids);
178
179                results.push(TokenizedText {
180                    input_ids,
181                    attention_mask,
182                    token_type_ids,
183                    original_length: original_text.len(),
184                });
185            }
186        } else {
187            // Fallback to individual tokenization
188            for text in texts {
189                results.push(self.tokenize(text)?);
190            }
191        }
192
193        Ok(results)
194    }
195
196    /// Pad or truncate sequences to max_length
197    fn pad_or_truncate(
198        &self,
199        mut input_ids: Vec<u32>,
200        mut attention_mask: Vec<u32>,
201        mut token_type_ids: Vec<u32>,
202    ) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
203        let max_len = self.config.max_length;
204
205        if input_ids.len() > max_len && self.config.truncate {
206            // Truncate
207            input_ids.truncate(max_len);
208            attention_mask.truncate(max_len);
209            token_type_ids.truncate(max_len);
210        } else if input_ids.len() < max_len {
211            // Pad with zeros (or appropriate padding tokens)
212            let pad_len = max_len - input_ids.len();
213            input_ids.extend(vec![0; pad_len]); // 0 is typically PAD token
214            attention_mask.extend(vec![0; pad_len]); // 0 for padding
215            token_type_ids.extend(vec![0; pad_len]); // 0 for padding
216        }
217
218        (input_ids, attention_mask, token_type_ids)
219    }
220
221    /// Fallback tokenization when no real tokenizer is available
222    fn fallback_tokenize(&self, text: &str, original_length: usize) -> Result<TokenizedText> {
223        // Simple word-based tokenization for fallback
224        let words: Vec<&str> = text.split_whitespace().collect();
225        let mut input_ids = Vec::new();
226
227        // Add CLS token if configured
228        if self.config.add_special_tokens {
229            input_ids.push(101); // [CLS] token ID
230        }
231
232        // Convert words to simple hash-based IDs (for testing)
233        for word in words.iter().take(self.config.max_length - 2) {
234            // Leave space for special tokens
235            let mut hasher = std::collections::hash_map::DefaultHasher::new();
236            use std::hash::{Hash, Hasher};
237            word.hash(&mut hasher);
238            let token_id = (hasher.finish() % 30000 + 1000) as u32; // Keep in reasonable range
239            input_ids.push(token_id);
240        }
241
242        // Add SEP token if configured
243        if self.config.add_special_tokens {
244            input_ids.push(102); // [SEP] token ID
245        }
246
247        // Create attention mask and token type IDs
248        let seq_len = input_ids.len();
249        let attention_mask = vec![1u32; seq_len];
250        let token_type_ids = vec![0u32; seq_len];
251
252        // Pad to max_length
253        let (input_ids, attention_mask, token_type_ids) =
254            self.pad_or_truncate(input_ids, attention_mask, token_type_ids);
255
256        log::debug!(
257            "Fallback tokenization: {} words -> {} tokens",
258            words.len(),
259            seq_len
260        );
261
262        Ok(TokenizedText {
263            input_ids,
264            attention_mask,
265            token_type_ids,
266            original_length,
267        })
268    }
269
270    /// Get tokenizer vocabulary size
271    pub fn vocab_size(&self) -> Option<usize> {
272        self.tokenizer.as_ref().map(|t| t.get_vocab_size(false))
273    }
274
275    /// Get configuration
276    pub fn config(&self) -> &TextConfig {
277        &self.config
278    }
279
280    /// Check if real tokenizer is loaded
281    pub fn has_tokenizer(&self) -> bool {
282        self.tokenizer.is_some()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_text_config_default() {
292        let config = TextConfig::default();
293        assert_eq!(config.max_length, 384);
294        assert!(config.truncate);
295        assert!(config.add_special_tokens);
296    }
297
298    #[test]
299    fn test_text_preprocessing() {
300        let config = TextConfig {
301            normalize_unicode: true,
302            lowercase: true,
303            ..Default::default()
304        };
305        let processor = TextProcessor::new(config);
306
307        let text = "  Hello    WORLD!  ";
308        let processed = processor.preprocess_text(text);
309        assert_eq!(processed, "hello world!");
310    }
311
312    #[test]
313    fn test_fallback_tokenization() {
314        let config = TextConfig::default();
315        let max_length = config.max_length;
316        let processor = TextProcessor::new(config);
317
318        let text = "Hello world test";
319        let tokenized = processor.tokenize(text).unwrap();
320
321        assert!(!tokenized.input_ids.is_empty());
322        assert_eq!(tokenized.input_ids.len(), max_length);
323        assert_eq!(tokenized.attention_mask.len(), max_length);
324        assert_eq!(tokenized.original_length, text.len());
325    }
326
327    #[test]
328    fn test_batch_tokenization_fallback() {
329        let config = TextConfig::default();
330        let max_length = config.max_length;
331        let processor = TextProcessor::new(config);
332
333        let texts = vec![
334            "First sentence".to_string(),
335            "Second sentence".to_string(),
336            "Third sentence".to_string(),
337        ];
338
339        let tokenized = processor.tokenize_batch(&texts).unwrap();
340        assert_eq!(tokenized.len(), 3);
341
342        for tokens in &tokenized {
343            assert_eq!(tokens.input_ids.len(), max_length);
344            assert_eq!(tokens.attention_mask.len(), max_length);
345        }
346    }
347
348    #[test]
349    fn test_padding_truncation() {
350        let config = TextConfig {
351            max_length: 10,
352            truncate: true,
353            ..Default::default()
354        };
355        let processor = TextProcessor::new(config);
356
357        // Test truncation
358        let long_text = "This is a very long sentence that should be truncated";
359        let tokenized = processor.tokenize(long_text).unwrap();
360        assert_eq!(tokenized.input_ids.len(), 10);
361
362        // Test padding
363        let short_text = "Short";
364        let tokenized = processor.tokenize(short_text).unwrap();
365        assert_eq!(tokenized.input_ids.len(), 10);
366
367        // Check that padding tokens are 0
368        let padding_start = tokenized.attention_mask.iter().position(|&x| x == 0);
369        assert!(padding_start.is_some());
370    }
371}