Skip to main content

axonml_text/
tokenizer.rs

1//! Tokenizer - Text Tokenization
2//!
3//! Provides various tokenization strategies for text processing.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use crate::vocab::Vocab;
9use std::collections::HashMap;
10
11// =============================================================================
12// Tokenizer Trait
13// =============================================================================
14
15/// Trait for text tokenization.
16pub trait Tokenizer: Send + Sync {
17    /// Tokenizes a string into tokens.
18    fn tokenize(&self, text: &str) -> Vec<String>;
19
20    /// Tokenizes and encodes to indices using a vocabulary.
21    fn encode(&self, text: &str, vocab: &Vocab) -> Vec<usize> {
22        let tokens = self.tokenize(text);
23        let token_refs: Vec<&str> = tokens.iter().map(std::string::String::as_str).collect();
24        vocab.encode(&token_refs)
25    }
26}
27
28// =============================================================================
29// WhitespaceTokenizer
30// =============================================================================
31
32/// Simple whitespace-based tokenizer.
33#[derive(Debug, Clone, Default)]
34pub struct WhitespaceTokenizer {
35    lowercase: bool,
36}
37
38impl WhitespaceTokenizer {
39    /// Creates a new `WhitespaceTokenizer`.
40    #[must_use] pub fn new() -> Self {
41        Self { lowercase: false }
42    }
43
44    /// Creates a tokenizer that lowercases all tokens.
45    #[must_use] pub fn lowercase() -> Self {
46        Self { lowercase: true }
47    }
48}
49
50impl Tokenizer for WhitespaceTokenizer {
51    fn tokenize(&self, text: &str) -> Vec<String> {
52        text.split_whitespace()
53            .map(|s| {
54                if self.lowercase {
55                    s.to_lowercase()
56                } else {
57                    s.to_string()
58                }
59            })
60            .collect()
61    }
62}
63
64// =============================================================================
65// CharTokenizer
66// =============================================================================
67
68/// Character-level tokenizer.
69#[derive(Debug, Clone, Default)]
70pub struct CharTokenizer {
71    include_whitespace: bool,
72}
73
74impl CharTokenizer {
75    /// Creates a new `CharTokenizer`.
76    #[must_use] pub fn new() -> Self {
77        Self {
78            include_whitespace: true,
79        }
80    }
81
82    /// Creates a tokenizer that excludes whitespace.
83    #[must_use] pub fn no_whitespace() -> Self {
84        Self {
85            include_whitespace: false,
86        }
87    }
88}
89
90impl Tokenizer for CharTokenizer {
91    fn tokenize(&self, text: &str) -> Vec<String> {
92        if self.include_whitespace {
93            text.chars().map(|c| c.to_string()).collect()
94        } else {
95            text.chars()
96                .filter(|c| !c.is_whitespace())
97                .map(|c| c.to_string())
98                .collect()
99        }
100    }
101}
102
103// =============================================================================
104// WordPunctTokenizer
105// =============================================================================
106
107/// Tokenizer that separates words and punctuation.
108#[derive(Debug, Clone, Default)]
109pub struct WordPunctTokenizer {
110    lowercase: bool,
111}
112
113impl WordPunctTokenizer {
114    /// Creates a new `WordPunctTokenizer`.
115    #[must_use] pub fn new() -> Self {
116        Self { lowercase: false }
117    }
118
119    /// Creates a tokenizer that lowercases all tokens.
120    #[must_use] pub fn lowercase() -> Self {
121        Self { lowercase: true }
122    }
123}
124
125impl Tokenizer for WordPunctTokenizer {
126    fn tokenize(&self, text: &str) -> Vec<String> {
127        let mut tokens = Vec::new();
128        let mut current = String::new();
129
130        for c in text.chars() {
131            if c.is_alphanumeric() {
132                current.push(c);
133            } else {
134                if !current.is_empty() {
135                    tokens.push(if self.lowercase {
136                        current.to_lowercase()
137                    } else {
138                        current.clone()
139                    });
140                    current.clear();
141                }
142                if !c.is_whitespace() {
143                    tokens.push(c.to_string());
144                }
145            }
146        }
147
148        if !current.is_empty() {
149            tokens.push(if self.lowercase {
150                current.to_lowercase()
151            } else {
152                current
153            });
154        }
155
156        tokens
157    }
158}
159
160// =============================================================================
161// NGramTokenizer
162// =============================================================================
163
164/// N-gram tokenizer for subword or character n-grams.
165#[derive(Debug, Clone)]
166pub struct NGramTokenizer {
167    n: usize,
168    char_level: bool,
169}
170
171impl NGramTokenizer {
172    /// Creates a word-level n-gram tokenizer.
173    #[must_use] pub fn word_ngrams(n: usize) -> Self {
174        Self {
175            n: n.max(1),
176            char_level: false,
177        }
178    }
179
180    /// Creates a character-level n-gram tokenizer.
181    #[must_use] pub fn char_ngrams(n: usize) -> Self {
182        Self {
183            n: n.max(1),
184            char_level: true,
185        }
186    }
187}
188
189impl Tokenizer for NGramTokenizer {
190    fn tokenize(&self, text: &str) -> Vec<String> {
191        if self.char_level {
192            // Character n-grams
193            let chars: Vec<char> = text.chars().collect();
194            if chars.len() < self.n {
195                return vec![text.to_string()];
196            }
197
198            chars
199                .windows(self.n)
200                .map(|w| w.iter().collect::<String>())
201                .collect()
202        } else {
203            // Word n-grams
204            let words: Vec<&str> = text.split_whitespace().collect();
205            if words.len() < self.n {
206                return vec![text.to_string()];
207            }
208
209            words.windows(self.n).map(|w| w.join(" ")).collect()
210        }
211    }
212}
213
214// =============================================================================
215// BasicBPETokenizer
216// =============================================================================
217
218/// A basic Byte-Pair Encoding tokenizer.
219#[derive(Debug, Clone)]
220pub struct BasicBPETokenizer {
221    merges: HashMap<(String, String), String>,
222    vocab: Vec<String>,
223}
224
225impl BasicBPETokenizer {
226    /// Creates a new BPE tokenizer.
227    #[must_use] pub fn new() -> Self {
228        Self {
229            merges: HashMap::new(),
230            vocab: Vec::new(),
231        }
232    }
233
234    /// Trains the BPE tokenizer on text.
235    pub fn train(&mut self, text: &str, num_merges: usize) {
236        // Initialize vocabulary with characters
237        let mut vocab: HashMap<String, usize> = HashMap::new();
238
239        // Split text into words and add space markers
240        for word in text.split_whitespace() {
241            let word_with_end = format!("{word}</w>");
242            let chars: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
243            *vocab.entry(chars.join(" ")).or_insert(0) += 1;
244        }
245
246        for _ in 0..num_merges {
247            // Count pairs
248            let mut pairs: HashMap<(String, String), usize> = HashMap::new();
249            for (word, count) in &vocab {
250                let symbols: Vec<&str> = word.split(' ').collect();
251                for i in 0..symbols.len().saturating_sub(1) {
252                    let pair = (symbols[i].to_string(), symbols[i + 1].to_string());
253                    *pairs.entry(pair).or_insert(0) += count;
254                }
255            }
256
257            if pairs.is_empty() {
258                break;
259            }
260
261            // Find most frequent pair
262            let best_pair = pairs
263                .into_iter()
264                .max_by_key(|(_, count)| *count)
265                .map(|(pair, _)| pair);
266
267            if let Some((a, b)) = best_pair {
268                let merged = format!("{a}{b}");
269                self.merges.insert((a.clone(), b.clone()), merged.clone());
270
271                // Update vocabulary
272                let pattern = format!("{a} {b}");
273                let mut new_vocab = HashMap::new();
274                for (word, count) in vocab {
275                    let new_word = word.replace(&pattern, &merged);
276                    *new_vocab.entry(new_word).or_insert(0) += count;
277                }
278                vocab = new_vocab;
279            }
280        }
281
282        // Extract final vocabulary
283        let mut all_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
284        for word in vocab.keys() {
285            for symbol in word.split(' ') {
286                all_symbols.insert(symbol.to_string());
287            }
288        }
289        self.vocab = all_symbols.into_iter().collect();
290        self.vocab.sort();
291    }
292
293    /// Returns the vocabulary.
294    #[must_use] pub fn get_vocab(&self) -> &[String] {
295        &self.vocab
296    }
297
298    /// Applies BPE merges to a word.
299    fn apply_bpe(&self, word: &str) -> Vec<String> {
300        let word_with_end = format!("{word}</w>");
301        let mut symbols: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
302
303        loop {
304            let mut best_pair: Option<(usize, &str)> = None;
305
306            for i in 0..symbols.len().saturating_sub(1) {
307                let pair = (symbols[i].clone(), symbols[i + 1].clone());
308                if let Some(merged) = self.merges.get(&pair) {
309                    best_pair = Some((i, merged));
310                    break;
311                }
312            }
313
314            match best_pair {
315                Some((i, merged)) => {
316                    symbols[i] = merged.to_string();
317                    symbols.remove(i + 1);
318                }
319                None => break,
320            }
321        }
322
323        symbols
324    }
325}
326
327impl Default for BasicBPETokenizer {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333impl Tokenizer for BasicBPETokenizer {
334    fn tokenize(&self, text: &str) -> Vec<String> {
335        let mut tokens = Vec::new();
336
337        for word in text.split_whitespace() {
338            let word_tokens = self.apply_bpe(word);
339            tokens.extend(word_tokens);
340        }
341
342        tokens
343    }
344}
345
346// =============================================================================
347// SentencePieceTokenizer (Simplified)
348// =============================================================================
349
350/// A simplified unigram-style tokenizer.
351#[derive(Debug, Clone)]
352pub struct UnigramTokenizer {
353    vocab: HashMap<String, f32>,
354    max_token_length: usize,
355}
356
357impl UnigramTokenizer {
358    /// Creates a new unigram tokenizer from a vocabulary with scores.
359    #[must_use] pub fn new(vocab: HashMap<String, f32>) -> Self {
360        let max_len = vocab.keys().map(std::string::String::len).max().unwrap_or(1);
361        Self {
362            vocab,
363            max_token_length: max_len,
364        }
365    }
366
367    /// Creates a tokenizer from a list of tokens (equal scores).
368    #[must_use] pub fn from_tokens(tokens: &[&str]) -> Self {
369        let vocab: HashMap<String, f32> = tokens.iter().map(|&t| (t.to_string(), 1.0)).collect();
370        Self::new(vocab)
371    }
372
373    /// Tokenizes using a greedy longest-match approach.
374    fn greedy_tokenize(&self, text: &str) -> Vec<String> {
375        let mut tokens = Vec::new();
376        let chars: Vec<char> = text.chars().collect();
377        let mut i = 0;
378
379        while i < chars.len() {
380            let mut best_len = 1;
381            let mut best_token = chars[i].to_string();
382
383            // Try to find the longest matching token
384            for len in 1..=self.max_token_length.min(chars.len() - i) {
385                let candidate: String = chars[i..i + len].iter().collect();
386                if self.vocab.contains_key(&candidate) {
387                    best_len = len;
388                    best_token = candidate;
389                }
390            }
391
392            tokens.push(best_token);
393            i += best_len;
394        }
395
396        tokens
397    }
398}
399
400impl Tokenizer for UnigramTokenizer {
401    fn tokenize(&self, text: &str) -> Vec<String> {
402        // Split by whitespace first, then tokenize each word
403        let mut all_tokens = Vec::new();
404
405        for word in text.split_whitespace() {
406            let word_tokens = self.greedy_tokenize(word);
407            all_tokens.extend(word_tokens);
408        }
409
410        all_tokens
411    }
412}
413
414// =============================================================================
415// Tests
416// =============================================================================
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_whitespace_tokenizer() {
424        let tokenizer = WhitespaceTokenizer::new();
425        let tokens = tokenizer.tokenize("Hello World");
426
427        assert_eq!(tokens, vec!["Hello", "World"]);
428    }
429
430    #[test]
431    fn test_whitespace_tokenizer_lowercase() {
432        let tokenizer = WhitespaceTokenizer::lowercase();
433        let tokens = tokenizer.tokenize("Hello World");
434
435        assert_eq!(tokens, vec!["hello", "world"]);
436    }
437
438    #[test]
439    fn test_char_tokenizer() {
440        let tokenizer = CharTokenizer::new();
441        let tokens = tokenizer.tokenize("Hi!");
442
443        assert_eq!(tokens, vec!["H", "i", "!"]);
444    }
445
446    #[test]
447    fn test_char_tokenizer_no_whitespace() {
448        let tokenizer = CharTokenizer::no_whitespace();
449        let tokens = tokenizer.tokenize("Hi there!");
450
451        assert_eq!(tokens, vec!["H", "i", "t", "h", "e", "r", "e", "!"]);
452    }
453
454    #[test]
455    fn test_word_punct_tokenizer() {
456        let tokenizer = WordPunctTokenizer::new();
457        let tokens = tokenizer.tokenize("Hello, World!");
458
459        assert_eq!(tokens, vec!["Hello", ",", "World", "!"]);
460    }
461
462    #[test]
463    fn test_word_punct_tokenizer_lowercase() {
464        let tokenizer = WordPunctTokenizer::lowercase();
465        let tokens = tokenizer.tokenize("Hello, World!");
466
467        assert_eq!(tokens, vec!["hello", ",", "world", "!"]);
468    }
469
470    #[test]
471    fn test_ngram_word_tokenizer() {
472        let tokenizer = NGramTokenizer::word_ngrams(2);
473        let tokens = tokenizer.tokenize("one two three");
474
475        assert_eq!(tokens, vec!["one two", "two three"]);
476    }
477
478    #[test]
479    fn test_ngram_char_tokenizer() {
480        let tokenizer = NGramTokenizer::char_ngrams(3);
481        let tokens = tokenizer.tokenize("hello");
482
483        assert_eq!(tokens, vec!["hel", "ell", "llo"]);
484    }
485
486    #[test]
487    fn test_bpe_tokenizer_basic() {
488        let mut tokenizer = BasicBPETokenizer::new();
489        tokenizer.train("low lower lowest", 10);
490
491        // Should have learned some merges
492        assert!(!tokenizer.get_vocab().is_empty());
493    }
494
495    #[test]
496    fn test_bpe_tokenizer_apply() {
497        let mut tokenizer = BasicBPETokenizer::new();
498        tokenizer.train("low low low lower lowest", 5);
499
500        let tokens = tokenizer.tokenize("low");
501        assert!(!tokens.is_empty());
502    }
503
504    #[test]
505    fn test_unigram_tokenizer() {
506        let tokenizer = UnigramTokenizer::from_tokens(&[
507            "hel", "lo", "wor", "ld", "h", "e", "l", "o", "w", "r", "d",
508        ]);
509        let tokens = tokenizer.tokenize("hello world");
510
511        // Should produce some tokens
512        assert!(!tokens.is_empty());
513    }
514
515    #[test]
516    fn test_tokenizer_encode() {
517        let tokenizer = WhitespaceTokenizer::new();
518        let mut vocab = Vocab::new();
519        vocab.add_token("hello");
520        vocab.add_token("world");
521
522        let indices = tokenizer.encode("hello world", &vocab);
523        assert_eq!(indices, vec![0, 1]);
524    }
525
526    #[test]
527    fn test_tokenizer_with_multiple_spaces() {
528        let tokenizer = WhitespaceTokenizer::new();
529        let tokens = tokenizer.tokenize("hello    world");
530
531        assert_eq!(tokens, vec!["hello", "world"]);
532    }
533
534    #[test]
535    fn test_empty_text() {
536        let tokenizer = WhitespaceTokenizer::new();
537        let tokens = tokenizer.tokenize("");
538
539        assert!(tokens.is_empty());
540    }
541}