Skip to main content

axonml_text/
tokenizer.rs

1//! Tokenizer - Text Tokenization
2//!
3//! Provides various tokenization strategies for text processing.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use crate::vocab::Vocab;
9use std::collections::HashMap;
10
11// =============================================================================
12// Tokenizer Trait
13// =============================================================================
14
15/// Trait for text tokenization.
16pub trait Tokenizer: Send + Sync {
17    /// Tokenizes a string into tokens.
18    fn tokenize(&self, text: &str) -> Vec<String>;
19
20    /// Tokenizes and encodes to indices using a vocabulary.
21    fn encode(&self, text: &str, vocab: &Vocab) -> Vec<usize> {
22        let tokens = self.tokenize(text);
23        let token_refs: Vec<&str> = tokens.iter().map(std::string::String::as_str).collect();
24        vocab.encode(&token_refs)
25    }
26}
27
28// =============================================================================
29// WhitespaceTokenizer
30// =============================================================================
31
32/// Simple whitespace-based tokenizer.
33#[derive(Debug, Clone, Default)]
34pub struct WhitespaceTokenizer {
35    lowercase: bool,
36}
37
38impl WhitespaceTokenizer {
39    /// Creates a new `WhitespaceTokenizer`.
40    #[must_use]
41    pub fn new() -> Self {
42        Self { lowercase: false }
43    }
44
45    /// Creates a tokenizer that lowercases all tokens.
46    #[must_use]
47    pub fn lowercase() -> Self {
48        Self { lowercase: true }
49    }
50}
51
52impl Tokenizer for WhitespaceTokenizer {
53    fn tokenize(&self, text: &str) -> Vec<String> {
54        text.split_whitespace()
55            .map(|s| {
56                if self.lowercase {
57                    s.to_lowercase()
58                } else {
59                    s.to_string()
60                }
61            })
62            .collect()
63    }
64}
65
66// =============================================================================
67// CharTokenizer
68// =============================================================================
69
70/// Character-level tokenizer.
71#[derive(Debug, Clone, Default)]
72pub struct CharTokenizer {
73    include_whitespace: bool,
74}
75
76impl CharTokenizer {
77    /// Creates a new `CharTokenizer`.
78    #[must_use]
79    pub fn new() -> Self {
80        Self {
81            include_whitespace: true,
82        }
83    }
84
85    /// Creates a tokenizer that excludes whitespace.
86    #[must_use]
87    pub fn no_whitespace() -> Self {
88        Self {
89            include_whitespace: false,
90        }
91    }
92}
93
94impl Tokenizer for CharTokenizer {
95    fn tokenize(&self, text: &str) -> Vec<String> {
96        if self.include_whitespace {
97            text.chars().map(|c| c.to_string()).collect()
98        } else {
99            text.chars()
100                .filter(|c| !c.is_whitespace())
101                .map(|c| c.to_string())
102                .collect()
103        }
104    }
105}
106
107// =============================================================================
108// WordPunctTokenizer
109// =============================================================================
110
111/// Tokenizer that separates words and punctuation.
112#[derive(Debug, Clone, Default)]
113pub struct WordPunctTokenizer {
114    lowercase: bool,
115}
116
117impl WordPunctTokenizer {
118    /// Creates a new `WordPunctTokenizer`.
119    #[must_use]
120    pub fn new() -> Self {
121        Self { lowercase: false }
122    }
123
124    /// Creates a tokenizer that lowercases all tokens.
125    #[must_use]
126    pub fn lowercase() -> Self {
127        Self { lowercase: true }
128    }
129}
130
131impl Tokenizer for WordPunctTokenizer {
132    fn tokenize(&self, text: &str) -> Vec<String> {
133        let mut tokens = Vec::new();
134        let mut current = String::new();
135
136        for c in text.chars() {
137            if c.is_alphanumeric() {
138                current.push(c);
139            } else {
140                if !current.is_empty() {
141                    tokens.push(if self.lowercase {
142                        current.to_lowercase()
143                    } else {
144                        current.clone()
145                    });
146                    current.clear();
147                }
148                if !c.is_whitespace() {
149                    tokens.push(c.to_string());
150                }
151            }
152        }
153
154        if !current.is_empty() {
155            tokens.push(if self.lowercase {
156                current.to_lowercase()
157            } else {
158                current
159            });
160        }
161
162        tokens
163    }
164}
165
166// =============================================================================
167// NGramTokenizer
168// =============================================================================
169
170/// N-gram tokenizer for subword or character n-grams.
171#[derive(Debug, Clone)]
172pub struct NGramTokenizer {
173    n: usize,
174    char_level: bool,
175}
176
177impl NGramTokenizer {
178    /// Creates a word-level n-gram tokenizer.
179    #[must_use]
180    pub fn word_ngrams(n: usize) -> Self {
181        Self {
182            n: n.max(1),
183            char_level: false,
184        }
185    }
186
187    /// Creates a character-level n-gram tokenizer.
188    #[must_use]
189    pub fn char_ngrams(n: usize) -> Self {
190        Self {
191            n: n.max(1),
192            char_level: true,
193        }
194    }
195}
196
197impl Tokenizer for NGramTokenizer {
198    fn tokenize(&self, text: &str) -> Vec<String> {
199        if self.char_level {
200            // Character n-grams
201            let chars: Vec<char> = text.chars().collect();
202            if chars.len() < self.n {
203                return vec![text.to_string()];
204            }
205
206            chars
207                .windows(self.n)
208                .map(|w| w.iter().collect::<String>())
209                .collect()
210        } else {
211            // Word n-grams
212            let words: Vec<&str> = text.split_whitespace().collect();
213            if words.len() < self.n {
214                return vec![text.to_string()];
215            }
216
217            words.windows(self.n).map(|w| w.join(" ")).collect()
218        }
219    }
220}
221
222// =============================================================================
223// BasicBPETokenizer
224// =============================================================================
225
226/// A basic Byte-Pair Encoding tokenizer.
227#[derive(Debug, Clone)]
228pub struct BasicBPETokenizer {
229    merges: HashMap<(String, String), String>,
230    vocab: Vec<String>,
231}
232
233impl BasicBPETokenizer {
234    /// Creates a new BPE tokenizer.
235    #[must_use]
236    pub fn new() -> Self {
237        Self {
238            merges: HashMap::new(),
239            vocab: Vec::new(),
240        }
241    }
242
243    /// Trains the BPE tokenizer on text.
244    pub fn train(&mut self, text: &str, num_merges: usize) {
245        // Initialize vocabulary with characters
246        let mut vocab: HashMap<String, usize> = HashMap::new();
247
248        // Split text into words and add space markers
249        for word in text.split_whitespace() {
250            let word_with_end = format!("{word}</w>");
251            let chars: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
252            *vocab.entry(chars.join(" ")).or_insert(0) += 1;
253        }
254
255        for _ in 0..num_merges {
256            // Count pairs
257            let mut pairs: HashMap<(String, String), usize> = HashMap::new();
258            for (word, count) in &vocab {
259                let symbols: Vec<&str> = word.split(' ').collect();
260                for i in 0..symbols.len().saturating_sub(1) {
261                    let pair = (symbols[i].to_string(), symbols[i + 1].to_string());
262                    *pairs.entry(pair).or_insert(0) += count;
263                }
264            }
265
266            if pairs.is_empty() {
267                break;
268            }
269
270            // Find most frequent pair
271            let best_pair = pairs
272                .into_iter()
273                .max_by_key(|(_, count)| *count)
274                .map(|(pair, _)| pair);
275
276            if let Some((a, b)) = best_pair {
277                let merged = format!("{a}{b}");
278                self.merges.insert((a.clone(), b.clone()), merged.clone());
279
280                // Update vocabulary
281                let pattern = format!("{a} {b}");
282                let mut new_vocab = HashMap::new();
283                for (word, count) in vocab {
284                    let new_word = word.replace(&pattern, &merged);
285                    *new_vocab.entry(new_word).or_insert(0) += count;
286                }
287                vocab = new_vocab;
288            }
289        }
290
291        // Extract final vocabulary
292        let mut all_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
293        for word in vocab.keys() {
294            for symbol in word.split(' ') {
295                all_symbols.insert(symbol.to_string());
296            }
297        }
298        self.vocab = all_symbols.into_iter().collect();
299        self.vocab.sort();
300    }
301
302    /// Returns the vocabulary.
303    #[must_use]
304    pub fn get_vocab(&self) -> &[String] {
305        &self.vocab
306    }
307
308    /// Applies BPE merges to a word.
309    fn apply_bpe(&self, word: &str) -> Vec<String> {
310        let word_with_end = format!("{word}</w>");
311        let mut symbols: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
312
313        loop {
314            let mut best_pair: Option<(usize, &str)> = None;
315
316            for i in 0..symbols.len().saturating_sub(1) {
317                let pair = (symbols[i].clone(), symbols[i + 1].clone());
318                if let Some(merged) = self.merges.get(&pair) {
319                    best_pair = Some((i, merged));
320                    break;
321                }
322            }
323
324            match best_pair {
325                Some((i, merged)) => {
326                    symbols[i] = merged.to_string();
327                    symbols.remove(i + 1);
328                }
329                None => break,
330            }
331        }
332
333        symbols
334    }
335}
336
337impl Default for BasicBPETokenizer {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343impl Tokenizer for BasicBPETokenizer {
344    fn tokenize(&self, text: &str) -> Vec<String> {
345        let mut tokens = Vec::new();
346
347        for word in text.split_whitespace() {
348            let word_tokens = self.apply_bpe(word);
349            tokens.extend(word_tokens);
350        }
351
352        tokens
353    }
354}
355
356// =============================================================================
357// SentencePieceTokenizer (Simplified)
358// =============================================================================
359
360/// A simplified unigram-style tokenizer.
361#[derive(Debug, Clone)]
362pub struct UnigramTokenizer {
363    vocab: HashMap<String, f32>,
364    max_token_length: usize,
365}
366
367impl UnigramTokenizer {
368    /// Creates a new unigram tokenizer from a vocabulary with scores.
369    #[must_use]
370    pub fn new(vocab: HashMap<String, f32>) -> Self {
371        let max_len = vocab
372            .keys()
373            .map(std::string::String::len)
374            .max()
375            .unwrap_or(1);
376        Self {
377            vocab,
378            max_token_length: max_len,
379        }
380    }
381
382    /// Creates a tokenizer from a list of tokens (equal scores).
383    #[must_use]
384    pub fn from_tokens(tokens: &[&str]) -> Self {
385        let vocab: HashMap<String, f32> = tokens.iter().map(|&t| (t.to_string(), 1.0)).collect();
386        Self::new(vocab)
387    }
388
389    /// Tokenizes using a greedy longest-match approach.
390    fn greedy_tokenize(&self, text: &str) -> Vec<String> {
391        let mut tokens = Vec::new();
392        let chars: Vec<char> = text.chars().collect();
393        let mut i = 0;
394
395        while i < chars.len() {
396            let mut best_len = 1;
397            let mut best_token = chars[i].to_string();
398
399            // Try to find the longest matching token
400            for len in 1..=self.max_token_length.min(chars.len() - i) {
401                let candidate: String = chars[i..i + len].iter().collect();
402                if self.vocab.contains_key(&candidate) {
403                    best_len = len;
404                    best_token = candidate;
405                }
406            }
407
408            tokens.push(best_token);
409            i += best_len;
410        }
411
412        tokens
413    }
414}
415
416impl Tokenizer for UnigramTokenizer {
417    fn tokenize(&self, text: &str) -> Vec<String> {
418        // Split by whitespace first, then tokenize each word
419        let mut all_tokens = Vec::new();
420
421        for word in text.split_whitespace() {
422            let word_tokens = self.greedy_tokenize(word);
423            all_tokens.extend(word_tokens);
424        }
425
426        all_tokens
427    }
428}
429
430// =============================================================================
431// Tests
432// =============================================================================
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_whitespace_tokenizer() {
440        let tokenizer = WhitespaceTokenizer::new();
441        let tokens = tokenizer.tokenize("Hello World");
442
443        assert_eq!(tokens, vec!["Hello", "World"]);
444    }
445
446    #[test]
447    fn test_whitespace_tokenizer_lowercase() {
448        let tokenizer = WhitespaceTokenizer::lowercase();
449        let tokens = tokenizer.tokenize("Hello World");
450
451        assert_eq!(tokens, vec!["hello", "world"]);
452    }
453
454    #[test]
455    fn test_char_tokenizer() {
456        let tokenizer = CharTokenizer::new();
457        let tokens = tokenizer.tokenize("Hi!");
458
459        assert_eq!(tokens, vec!["H", "i", "!"]);
460    }
461
462    #[test]
463    fn test_char_tokenizer_no_whitespace() {
464        let tokenizer = CharTokenizer::no_whitespace();
465        let tokens = tokenizer.tokenize("Hi there!");
466
467        assert_eq!(tokens, vec!["H", "i", "t", "h", "e", "r", "e", "!"]);
468    }
469
470    #[test]
471    fn test_word_punct_tokenizer() {
472        let tokenizer = WordPunctTokenizer::new();
473        let tokens = tokenizer.tokenize("Hello, World!");
474
475        assert_eq!(tokens, vec!["Hello", ",", "World", "!"]);
476    }
477
478    #[test]
479    fn test_word_punct_tokenizer_lowercase() {
480        let tokenizer = WordPunctTokenizer::lowercase();
481        let tokens = tokenizer.tokenize("Hello, World!");
482
483        assert_eq!(tokens, vec!["hello", ",", "world", "!"]);
484    }
485
486    #[test]
487    fn test_ngram_word_tokenizer() {
488        let tokenizer = NGramTokenizer::word_ngrams(2);
489        let tokens = tokenizer.tokenize("one two three");
490
491        assert_eq!(tokens, vec!["one two", "two three"]);
492    }
493
494    #[test]
495    fn test_ngram_char_tokenizer() {
496        let tokenizer = NGramTokenizer::char_ngrams(3);
497        let tokens = tokenizer.tokenize("hello");
498
499        assert_eq!(tokens, vec!["hel", "ell", "llo"]);
500    }
501
502    #[test]
503    fn test_bpe_tokenizer_basic() {
504        let mut tokenizer = BasicBPETokenizer::new();
505        tokenizer.train("low lower lowest", 10);
506
507        // Should have learned some merges
508        assert!(!tokenizer.get_vocab().is_empty());
509    }
510
511    #[test]
512    fn test_bpe_tokenizer_apply() {
513        let mut tokenizer = BasicBPETokenizer::new();
514        tokenizer.train("low low low lower lowest", 5);
515
516        let tokens = tokenizer.tokenize("low");
517        assert!(!tokens.is_empty());
518    }
519
520    #[test]
521    fn test_unigram_tokenizer() {
522        let tokenizer = UnigramTokenizer::from_tokens(&[
523            "hel", "lo", "wor", "ld", "h", "e", "l", "o", "w", "r", "d",
524        ]);
525        let tokens = tokenizer.tokenize("hello world");
526
527        // Should produce some tokens
528        assert!(!tokens.is_empty());
529    }
530
531    #[test]
532    fn test_tokenizer_encode() {
533        let tokenizer = WhitespaceTokenizer::new();
534        let mut vocab = Vocab::new();
535        vocab.add_token("hello");
536        vocab.add_token("world");
537
538        let indices = tokenizer.encode("hello world", &vocab);
539        assert_eq!(indices, vec![0, 1]);
540    }
541
542    #[test]
543    fn test_tokenizer_with_multiple_spaces() {
544        let tokenizer = WhitespaceTokenizer::new();
545        let tokens = tokenizer.tokenize("hello    world");
546
547        assert_eq!(tokens, vec!["hello", "world"]);
548    }
549
550    #[test]
551    fn test_empty_text() {
552        let tokenizer = WhitespaceTokenizer::new();
553        let tokens = tokenizer.tokenize("");
554
555        assert!(tokens.is_empty());
556    }
557}