Skip to main content

axonml_text/
tokenizer.rs

1//! Tokenizer - Text Tokenization
2//!
3//! # File
4//! `crates/axonml-text/src/tokenizer.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::vocab::Vocab;
18use std::collections::HashMap;
19
20// =============================================================================
21// Tokenizer Trait
22// =============================================================================
23
24/// Trait for text tokenization.
25pub trait Tokenizer: Send + Sync {
26    /// Tokenizes a string into tokens.
27    fn tokenize(&self, text: &str) -> Vec<String>;
28
29    /// Tokenizes and encodes to indices using a vocabulary.
30    fn encode(&self, text: &str, vocab: &Vocab) -> Vec<usize> {
31        let tokens = self.tokenize(text);
32        let token_refs: Vec<&str> = tokens.iter().map(std::string::String::as_str).collect();
33        vocab.encode(&token_refs)
34    }
35}
36
37// =============================================================================
38// WhitespaceTokenizer
39// =============================================================================
40
41/// Simple whitespace-based tokenizer.
42#[derive(Debug, Clone, Default)]
43pub struct WhitespaceTokenizer {
44    lowercase: bool,
45}
46
47impl WhitespaceTokenizer {
48    /// Creates a new `WhitespaceTokenizer`.
49    #[must_use]
50    pub fn new() -> Self {
51        Self { lowercase: false }
52    }
53
54    /// Creates a tokenizer that lowercases all tokens.
55    #[must_use]
56    pub fn lowercase() -> Self {
57        Self { lowercase: true }
58    }
59}
60
61impl Tokenizer for WhitespaceTokenizer {
62    fn tokenize(&self, text: &str) -> Vec<String> {
63        text.split_whitespace()
64            .map(|s| {
65                if self.lowercase {
66                    s.to_lowercase()
67                } else {
68                    s.to_string()
69                }
70            })
71            .collect()
72    }
73}
74
75// =============================================================================
76// CharTokenizer
77// =============================================================================
78
79/// Character-level tokenizer.
80#[derive(Debug, Clone, Default)]
81pub struct CharTokenizer {
82    include_whitespace: bool,
83}
84
85impl CharTokenizer {
86    /// Creates a new `CharTokenizer`.
87    #[must_use]
88    pub fn new() -> Self {
89        Self {
90            include_whitespace: true,
91        }
92    }
93
94    /// Creates a tokenizer that excludes whitespace.
95    #[must_use]
96    pub fn no_whitespace() -> Self {
97        Self {
98            include_whitespace: false,
99        }
100    }
101}
102
103impl Tokenizer for CharTokenizer {
104    fn tokenize(&self, text: &str) -> Vec<String> {
105        if self.include_whitespace {
106            text.chars().map(|c| c.to_string()).collect()
107        } else {
108            text.chars()
109                .filter(|c| !c.is_whitespace())
110                .map(|c| c.to_string())
111                .collect()
112        }
113    }
114}
115
116// =============================================================================
117// WordPunctTokenizer
118// =============================================================================
119
120/// Tokenizer that separates words and punctuation.
121#[derive(Debug, Clone, Default)]
122pub struct WordPunctTokenizer {
123    lowercase: bool,
124}
125
126impl WordPunctTokenizer {
127    /// Creates a new `WordPunctTokenizer`.
128    #[must_use]
129    pub fn new() -> Self {
130        Self { lowercase: false }
131    }
132
133    /// Creates a tokenizer that lowercases all tokens.
134    #[must_use]
135    pub fn lowercase() -> Self {
136        Self { lowercase: true }
137    }
138}
139
140impl Tokenizer for WordPunctTokenizer {
141    fn tokenize(&self, text: &str) -> Vec<String> {
142        let mut tokens = Vec::new();
143        let mut current = String::new();
144
145        for c in text.chars() {
146            if c.is_alphanumeric() {
147                current.push(c);
148            } else {
149                if !current.is_empty() {
150                    tokens.push(if self.lowercase {
151                        current.to_lowercase()
152                    } else {
153                        current.clone()
154                    });
155                    current.clear();
156                }
157                if !c.is_whitespace() {
158                    tokens.push(c.to_string());
159                }
160            }
161        }
162
163        if !current.is_empty() {
164            tokens.push(if self.lowercase {
165                current.to_lowercase()
166            } else {
167                current
168            });
169        }
170
171        tokens
172    }
173}
174
175// =============================================================================
176// NGramTokenizer
177// =============================================================================
178
179/// N-gram tokenizer for subword or character n-grams.
180#[derive(Debug, Clone)]
181pub struct NGramTokenizer {
182    n: usize,
183    char_level: bool,
184}
185
186impl NGramTokenizer {
187    /// Creates a word-level n-gram tokenizer.
188    #[must_use]
189    pub fn word_ngrams(n: usize) -> Self {
190        Self {
191            n: n.max(1),
192            char_level: false,
193        }
194    }
195
196    /// Creates a character-level n-gram tokenizer.
197    #[must_use]
198    pub fn char_ngrams(n: usize) -> Self {
199        Self {
200            n: n.max(1),
201            char_level: true,
202        }
203    }
204}
205
206impl Tokenizer for NGramTokenizer {
207    fn tokenize(&self, text: &str) -> Vec<String> {
208        if self.char_level {
209            // Character n-grams
210            let chars: Vec<char> = text.chars().collect();
211            if chars.len() < self.n {
212                return vec![text.to_string()];
213            }
214
215            chars
216                .windows(self.n)
217                .map(|w| w.iter().collect::<String>())
218                .collect()
219        } else {
220            // Word n-grams
221            let words: Vec<&str> = text.split_whitespace().collect();
222            if words.len() < self.n {
223                return vec![text.to_string()];
224            }
225
226            words.windows(self.n).map(|w| w.join(" ")).collect()
227        }
228    }
229}
230
231// =============================================================================
232// BasicBPETokenizer
233// =============================================================================
234
235/// A basic Byte-Pair Encoding tokenizer.
236#[derive(Debug, Clone)]
237pub struct BasicBPETokenizer {
238    merges: HashMap<(String, String), String>,
239    vocab: Vec<String>,
240}
241
242impl BasicBPETokenizer {
243    /// Creates a new BPE tokenizer.
244    #[must_use]
245    pub fn new() -> Self {
246        Self {
247            merges: HashMap::new(),
248            vocab: Vec::new(),
249        }
250    }
251
252    /// Trains the BPE tokenizer on text.
253    pub fn train(&mut self, text: &str, num_merges: usize) {
254        // Initialize vocabulary with characters
255        let mut vocab: HashMap<String, usize> = HashMap::new();
256
257        // Split text into words and add space markers
258        for word in text.split_whitespace() {
259            let word_with_end = format!("{word}</w>");
260            let chars: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
261            *vocab.entry(chars.join(" ")).or_insert(0) += 1;
262        }
263
264        for _ in 0..num_merges {
265            // Count pairs
266            let mut pairs: HashMap<(String, String), usize> = HashMap::new();
267            for (word, count) in &vocab {
268                let symbols: Vec<&str> = word.split(' ').collect();
269                for i in 0..symbols.len().saturating_sub(1) {
270                    let pair = (symbols[i].to_string(), symbols[i + 1].to_string());
271                    *pairs.entry(pair).or_insert(0) += count;
272                }
273            }
274
275            if pairs.is_empty() {
276                break;
277            }
278
279            // Find most frequent pair
280            let best_pair = pairs
281                .into_iter()
282                .max_by_key(|(_, count)| *count)
283                .map(|(pair, _)| pair);
284
285            if let Some((a, b)) = best_pair {
286                let merged = format!("{a}{b}");
287                self.merges.insert((a.clone(), b.clone()), merged.clone());
288
289                // Update vocabulary
290                let pattern = format!("{a} {b}");
291                let mut new_vocab = HashMap::new();
292                for (word, count) in vocab {
293                    let new_word = word.replace(&pattern, &merged);
294                    *new_vocab.entry(new_word).or_insert(0) += count;
295                }
296                vocab = new_vocab;
297            }
298        }
299
300        // Extract final vocabulary
301        let mut all_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
302        for word in vocab.keys() {
303            for symbol in word.split(' ') {
304                all_symbols.insert(symbol.to_string());
305            }
306        }
307        self.vocab = all_symbols.into_iter().collect();
308        self.vocab.sort();
309    }
310
311    /// Returns the vocabulary.
312    #[must_use]
313    pub fn get_vocab(&self) -> &[String] {
314        &self.vocab
315    }
316
317    /// Applies BPE merges to a word.
318    fn apply_bpe(&self, word: &str) -> Vec<String> {
319        let word_with_end = format!("{word}</w>");
320        let mut symbols: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
321
322        loop {
323            let mut best_pair: Option<(usize, &str)> = None;
324
325            for i in 0..symbols.len().saturating_sub(1) {
326                let pair = (symbols[i].clone(), symbols[i + 1].clone());
327                if let Some(merged) = self.merges.get(&pair) {
328                    best_pair = Some((i, merged));
329                    break;
330                }
331            }
332
333            match best_pair {
334                Some((i, merged)) => {
335                    symbols[i] = merged.to_string();
336                    symbols.remove(i + 1);
337                }
338                None => break,
339            }
340        }
341
342        symbols
343    }
344}
345
346impl Default for BasicBPETokenizer {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352impl Tokenizer for BasicBPETokenizer {
353    fn tokenize(&self, text: &str) -> Vec<String> {
354        let mut tokens = Vec::new();
355
356        for word in text.split_whitespace() {
357            let word_tokens = self.apply_bpe(word);
358            tokens.extend(word_tokens);
359        }
360
361        tokens
362    }
363}
364
365// =============================================================================
366// SentencePieceTokenizer (Simplified)
367// =============================================================================
368
369/// A simplified unigram-style tokenizer.
370#[derive(Debug, Clone)]
371pub struct UnigramTokenizer {
372    vocab: HashMap<String, f32>,
373    max_token_length: usize,
374}
375
376impl UnigramTokenizer {
377    /// Creates a new unigram tokenizer from a vocabulary with scores.
378    #[must_use]
379    pub fn new(vocab: HashMap<String, f32>) -> Self {
380        let max_len = vocab
381            .keys()
382            .map(std::string::String::len)
383            .max()
384            .unwrap_or(1);
385        Self {
386            vocab,
387            max_token_length: max_len,
388        }
389    }
390
391    /// Creates a tokenizer from a list of tokens (equal scores).
392    #[must_use]
393    pub fn from_tokens(tokens: &[&str]) -> Self {
394        let vocab: HashMap<String, f32> = tokens.iter().map(|&t| (t.to_string(), 1.0)).collect();
395        Self::new(vocab)
396    }
397
398    /// Tokenizes using a greedy longest-match approach.
399    fn greedy_tokenize(&self, text: &str) -> Vec<String> {
400        let mut tokens = Vec::new();
401        let chars: Vec<char> = text.chars().collect();
402        let mut i = 0;
403
404        while i < chars.len() {
405            let mut best_len = 1;
406            let mut best_token = chars[i].to_string();
407
408            // Try to find the longest matching token
409            for len in 1..=self.max_token_length.min(chars.len() - i) {
410                let candidate: String = chars[i..i + len].iter().collect();
411                if self.vocab.contains_key(&candidate) {
412                    best_len = len;
413                    best_token = candidate;
414                }
415            }
416
417            tokens.push(best_token);
418            i += best_len;
419        }
420
421        tokens
422    }
423}
424
425impl Tokenizer for UnigramTokenizer {
426    fn tokenize(&self, text: &str) -> Vec<String> {
427        // Split by whitespace first, then tokenize each word
428        let mut all_tokens = Vec::new();
429
430        for word in text.split_whitespace() {
431            let word_tokens = self.greedy_tokenize(word);
432            all_tokens.extend(word_tokens);
433        }
434
435        all_tokens
436    }
437}
438
439// =============================================================================
440// Tests
441// =============================================================================
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn test_whitespace_tokenizer() {
449        let tokenizer = WhitespaceTokenizer::new();
450        let tokens = tokenizer.tokenize("Hello World");
451
452        assert_eq!(tokens, vec!["Hello", "World"]);
453    }
454
455    #[test]
456    fn test_whitespace_tokenizer_lowercase() {
457        let tokenizer = WhitespaceTokenizer::lowercase();
458        let tokens = tokenizer.tokenize("Hello World");
459
460        assert_eq!(tokens, vec!["hello", "world"]);
461    }
462
463    #[test]
464    fn test_char_tokenizer() {
465        let tokenizer = CharTokenizer::new();
466        let tokens = tokenizer.tokenize("Hi!");
467
468        assert_eq!(tokens, vec!["H", "i", "!"]);
469    }
470
471    #[test]
472    fn test_char_tokenizer_no_whitespace() {
473        let tokenizer = CharTokenizer::no_whitespace();
474        let tokens = tokenizer.tokenize("Hi there!");
475
476        assert_eq!(tokens, vec!["H", "i", "t", "h", "e", "r", "e", "!"]);
477    }
478
479    #[test]
480    fn test_word_punct_tokenizer() {
481        let tokenizer = WordPunctTokenizer::new();
482        let tokens = tokenizer.tokenize("Hello, World!");
483
484        assert_eq!(tokens, vec!["Hello", ",", "World", "!"]);
485    }
486
487    #[test]
488    fn test_word_punct_tokenizer_lowercase() {
489        let tokenizer = WordPunctTokenizer::lowercase();
490        let tokens = tokenizer.tokenize("Hello, World!");
491
492        assert_eq!(tokens, vec!["hello", ",", "world", "!"]);
493    }
494
495    #[test]
496    fn test_ngram_word_tokenizer() {
497        let tokenizer = NGramTokenizer::word_ngrams(2);
498        let tokens = tokenizer.tokenize("one two three");
499
500        assert_eq!(tokens, vec!["one two", "two three"]);
501    }
502
503    #[test]
504    fn test_ngram_char_tokenizer() {
505        let tokenizer = NGramTokenizer::char_ngrams(3);
506        let tokens = tokenizer.tokenize("hello");
507
508        assert_eq!(tokens, vec!["hel", "ell", "llo"]);
509    }
510
511    #[test]
512    fn test_bpe_tokenizer_basic() {
513        let mut tokenizer = BasicBPETokenizer::new();
514        tokenizer.train("low lower lowest", 10);
515
516        // Should have learned some merges
517        assert!(!tokenizer.get_vocab().is_empty());
518    }
519
520    #[test]
521    fn test_bpe_tokenizer_apply() {
522        let mut tokenizer = BasicBPETokenizer::new();
523        tokenizer.train("low low low lower lowest", 5);
524
525        let tokens = tokenizer.tokenize("low");
526        assert!(!tokens.is_empty());
527    }
528
529    #[test]
530    fn test_unigram_tokenizer() {
531        let tokenizer = UnigramTokenizer::from_tokens(&[
532            "hel", "lo", "wor", "ld", "h", "e", "l", "o", "w", "r", "d",
533        ]);
534        let tokens = tokenizer.tokenize("hello world");
535
536        // Should produce some tokens
537        assert!(!tokens.is_empty());
538    }
539
540    #[test]
541    fn test_tokenizer_encode() {
542        let tokenizer = WhitespaceTokenizer::new();
543        let mut vocab = Vocab::new();
544        vocab.add_token("hello");
545        vocab.add_token("world");
546
547        let indices = tokenizer.encode("hello world", &vocab);
548        assert_eq!(indices, vec![0, 1]);
549    }
550
551    #[test]
552    fn test_tokenizer_with_multiple_spaces() {
553        let tokenizer = WhitespaceTokenizer::new();
554        let tokens = tokenizer.tokenize("hello    world");
555
556        assert_eq!(tokens, vec!["hello", "world"]);
557    }
558
559    #[test]
560    fn test_empty_text() {
561        let tokenizer = WhitespaceTokenizer::new();
562        let tokens = tokenizer.tokenize("");
563
564        assert!(tokens.is_empty());
565    }
566}