Skip to main content

axonml_text/
tokenizer.rs

1//! Tokenizer - Text Tokenization
2//!
3//! # File
4//! `crates/axonml-text/src/tokenizer.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use crate::vocab::Vocab;
19use std::collections::HashMap;
20
21// =============================================================================
22// Tokenizer Trait
23// =============================================================================
24
25/// Trait for text tokenization.
26pub trait Tokenizer: Send + Sync {
27    /// Tokenizes a string into tokens.
28    fn tokenize(&self, text: &str) -> Vec<String>;
29
30    /// Tokenizes and encodes to indices using a vocabulary.
31    fn encode(&self, text: &str, vocab: &Vocab) -> Vec<usize> {
32        let tokens = self.tokenize(text);
33        let token_refs: Vec<&str> = tokens.iter().map(std::string::String::as_str).collect();
34        vocab.encode(&token_refs)
35    }
36}
37
38// =============================================================================
39// WhitespaceTokenizer
40// =============================================================================
41
42/// Simple whitespace-based tokenizer.
43#[derive(Debug, Clone, Default)]
44pub struct WhitespaceTokenizer {
45    lowercase: bool,
46}
47
48impl WhitespaceTokenizer {
49    /// Creates a new `WhitespaceTokenizer`.
50    #[must_use]
51    pub fn new() -> Self {
52        Self { lowercase: false }
53    }
54
55    /// Creates a tokenizer that lowercases all tokens.
56    #[must_use]
57    pub fn lowercase() -> Self {
58        Self { lowercase: true }
59    }
60}
61
62impl Tokenizer for WhitespaceTokenizer {
63    fn tokenize(&self, text: &str) -> Vec<String> {
64        text.split_whitespace()
65            .map(|s| {
66                if self.lowercase {
67                    s.to_lowercase()
68                } else {
69                    s.to_string()
70                }
71            })
72            .collect()
73    }
74}
75
76// =============================================================================
77// CharTokenizer
78// =============================================================================
79
80/// Character-level tokenizer.
81#[derive(Debug, Clone, Default)]
82pub struct CharTokenizer {
83    include_whitespace: bool,
84}
85
86impl CharTokenizer {
87    /// Creates a new `CharTokenizer`.
88    #[must_use]
89    pub fn new() -> Self {
90        Self {
91            include_whitespace: true,
92        }
93    }
94
95    /// Creates a tokenizer that excludes whitespace.
96    #[must_use]
97    pub fn no_whitespace() -> Self {
98        Self {
99            include_whitespace: false,
100        }
101    }
102}
103
104impl Tokenizer for CharTokenizer {
105    fn tokenize(&self, text: &str) -> Vec<String> {
106        if self.include_whitespace {
107            text.chars().map(|c| c.to_string()).collect()
108        } else {
109            text.chars()
110                .filter(|c| !c.is_whitespace())
111                .map(|c| c.to_string())
112                .collect()
113        }
114    }
115}
116
117// =============================================================================
118// WordPunctTokenizer
119// =============================================================================
120
121/// Tokenizer that separates words and punctuation.
122#[derive(Debug, Clone, Default)]
123pub struct WordPunctTokenizer {
124    lowercase: bool,
125}
126
127impl WordPunctTokenizer {
128    /// Creates a new `WordPunctTokenizer`.
129    #[must_use]
130    pub fn new() -> Self {
131        Self { lowercase: false }
132    }
133
134    /// Creates a tokenizer that lowercases all tokens.
135    #[must_use]
136    pub fn lowercase() -> Self {
137        Self { lowercase: true }
138    }
139}
140
141impl Tokenizer for WordPunctTokenizer {
142    fn tokenize(&self, text: &str) -> Vec<String> {
143        let mut tokens = Vec::new();
144        let mut current = String::new();
145
146        for c in text.chars() {
147            if c.is_alphanumeric() {
148                current.push(c);
149            } else {
150                if !current.is_empty() {
151                    tokens.push(if self.lowercase {
152                        current.to_lowercase()
153                    } else {
154                        current.clone()
155                    });
156                    current.clear();
157                }
158                if !c.is_whitespace() {
159                    tokens.push(c.to_string());
160                }
161            }
162        }
163
164        if !current.is_empty() {
165            tokens.push(if self.lowercase {
166                current.to_lowercase()
167            } else {
168                current
169            });
170        }
171
172        tokens
173    }
174}
175
176// =============================================================================
177// NGramTokenizer
178// =============================================================================
179
180/// N-gram tokenizer for subword or character n-grams.
181#[derive(Debug, Clone)]
182pub struct NGramTokenizer {
183    n: usize,
184    char_level: bool,
185}
186
187impl NGramTokenizer {
188    /// Creates a word-level n-gram tokenizer.
189    #[must_use]
190    pub fn word_ngrams(n: usize) -> Self {
191        Self {
192            n: n.max(1),
193            char_level: false,
194        }
195    }
196
197    /// Creates a character-level n-gram tokenizer.
198    #[must_use]
199    pub fn char_ngrams(n: usize) -> Self {
200        Self {
201            n: n.max(1),
202            char_level: true,
203        }
204    }
205}
206
207impl Tokenizer for NGramTokenizer {
208    fn tokenize(&self, text: &str) -> Vec<String> {
209        if self.char_level {
210            // Character n-grams
211            let chars: Vec<char> = text.chars().collect();
212            if chars.len() < self.n {
213                return vec![text.to_string()];
214            }
215
216            chars
217                .windows(self.n)
218                .map(|w| w.iter().collect::<String>())
219                .collect()
220        } else {
221            // Word n-grams
222            let words: Vec<&str> = text.split_whitespace().collect();
223            if words.len() < self.n {
224                return vec![text.to_string()];
225            }
226
227            words.windows(self.n).map(|w| w.join(" ")).collect()
228        }
229    }
230}
231
232// =============================================================================
233// BasicBPETokenizer
234// =============================================================================
235
236/// A basic Byte-Pair Encoding tokenizer.
237#[derive(Debug, Clone)]
238pub struct BasicBPETokenizer {
239    merges: HashMap<(String, String), String>,
240    /// Merge priority: lower value = higher priority (learned first = most frequent).
241    merge_priority: HashMap<(String, String), usize>,
242    vocab: Vec<String>,
243}
244
245impl BasicBPETokenizer {
246    /// Creates a new BPE tokenizer.
247    #[must_use]
248    pub fn new() -> Self {
249        Self {
250            merges: HashMap::new(),
251            merge_priority: HashMap::new(),
252            vocab: Vec::new(),
253        }
254    }
255
256    /// Trains the BPE tokenizer on text.
257    pub fn train(&mut self, text: &str, num_merges: usize) {
258        // Initialize vocabulary with characters
259        let mut vocab: HashMap<String, usize> = HashMap::new();
260
261        // Split text into words and add space markers
262        for word in text.split_whitespace() {
263            let word_with_end = format!("{word}</w>");
264            let chars: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
265            *vocab.entry(chars.join(" ")).or_insert(0) += 1;
266        }
267
268        for merge_idx in 0..num_merges {
269            // Count pairs
270            let mut pairs: HashMap<(String, String), usize> = HashMap::new();
271            for (word, count) in &vocab {
272                let symbols: Vec<&str> = word.split(' ').collect();
273                for i in 0..symbols.len().saturating_sub(1) {
274                    let pair = (symbols[i].to_string(), symbols[i + 1].to_string());
275                    *pairs.entry(pair).or_insert(0) += count;
276                }
277            }
278
279            if pairs.is_empty() {
280                break;
281            }
282
283            // Find most frequent pair
284            let best_pair = pairs
285                .into_iter()
286                .max_by_key(|(_, count)| *count)
287                .map(|(pair, _)| pair);
288
289            if let Some((a, b)) = best_pair {
290                let merged = format!("{a}{b}");
291                let pair_key = (a.clone(), b.clone());
292                self.merges.insert(pair_key.clone(), merged.clone());
293                // Record priority: lower index = higher priority (most frequent)
294                self.merge_priority.insert(pair_key, merge_idx);
295
296                // Update vocabulary
297                let pattern = format!("{a} {b}");
298                let mut new_vocab = HashMap::new();
299                for (word, count) in vocab {
300                    let new_word = word.replace(&pattern, &merged);
301                    *new_vocab.entry(new_word).or_insert(0) += count;
302                }
303                vocab = new_vocab;
304            }
305        }
306
307        // Extract final vocabulary
308        let mut all_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
309        for word in vocab.keys() {
310            for symbol in word.split(' ') {
311                all_symbols.insert(symbol.to_string());
312            }
313        }
314        self.vocab = all_symbols.into_iter().collect();
315        self.vocab.sort();
316    }
317
318    /// Returns the vocabulary.
319    #[must_use]
320    pub fn get_vocab(&self) -> &[String] {
321        &self.vocab
322    }
323
324    /// Applies BPE merges to a word in priority order.
325    ///
326    /// Scans all adjacent pairs, selects the one with highest priority
327    /// (lowest merge index = most frequent), and applies it. Repeats until
328    /// no more merges are applicable.
329    fn apply_bpe(&self, word: &str) -> Vec<String> {
330        let word_with_end = format!("{word}</w>");
331        let mut symbols: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
332
333        loop {
334            // Find the pair with highest priority (lowest priority index)
335            let mut best: Option<(usize, usize, &str)> = None; // (position, priority, merged)
336
337            for i in 0..symbols.len().saturating_sub(1) {
338                let pair = (symbols[i].clone(), symbols[i + 1].clone());
339                if let Some(merged) = self.merges.get(&pair) {
340                    let priority = self
341                        .merge_priority
342                        .get(&pair)
343                        .copied()
344                        .unwrap_or(usize::MAX);
345                    if best.is_none() || priority < best.unwrap().1 {
346                        best = Some((i, priority, merged));
347                    }
348                }
349            }
350
351            match best {
352                Some((i, _, merged)) => {
353                    symbols[i] = merged.to_string();
354                    symbols.remove(i + 1);
355                }
356                None => break,
357            }
358        }
359
360        symbols
361    }
362}
363
364impl Default for BasicBPETokenizer {
365    fn default() -> Self {
366        Self::new()
367    }
368}
369
370impl Tokenizer for BasicBPETokenizer {
371    fn tokenize(&self, text: &str) -> Vec<String> {
372        let mut tokens = Vec::new();
373
374        for word in text.split_whitespace() {
375            let word_tokens = self.apply_bpe(word);
376            tokens.extend(word_tokens);
377        }
378
379        tokens
380    }
381}
382
383// =============================================================================
384// SentencePieceTokenizer (Simplified)
385// =============================================================================
386
387/// A simplified unigram-style tokenizer.
388#[derive(Debug, Clone)]
389pub struct UnigramTokenizer {
390    vocab: HashMap<String, f32>,
391    max_token_length: usize,
392}
393
394impl UnigramTokenizer {
395    /// Creates a new unigram tokenizer from a vocabulary with scores.
396    #[must_use]
397    pub fn new(vocab: HashMap<String, f32>) -> Self {
398        let max_len = vocab
399            .keys()
400            .map(std::string::String::len)
401            .max()
402            .unwrap_or(1);
403        Self {
404            vocab,
405            max_token_length: max_len,
406        }
407    }
408
409    /// Creates a tokenizer from a list of tokens (equal scores).
410    #[must_use]
411    pub fn from_tokens(tokens: &[&str]) -> Self {
412        let vocab: HashMap<String, f32> = tokens.iter().map(|&t| (t.to_string(), 1.0)).collect();
413        Self::new(vocab)
414    }
415
416    /// Tokenizes using the Viterbi algorithm for optimal segmentation.
417    ///
418    /// Finds the segmentation that maximizes the sum of log-probabilities
419    /// (unigram scores). Falls back to single characters for OOV spans.
420    fn viterbi_tokenize(&self, text: &str) -> Vec<String> {
421        let chars: Vec<char> = text.chars().collect();
422        let n = chars.len();
423        if n == 0 {
424            return Vec::new();
425        }
426
427        // dp[i] = best log-prob score to segment chars[0..i]
428        // back[i] = length of the last token in the best segmentation ending at i
429        let mut dp = vec![f32::NEG_INFINITY; n + 1];
430        let mut back = vec![1usize; n + 1];
431        dp[0] = 0.0;
432
433        for i in 1..=n {
434            for len in 1..=self.max_token_length.min(i) {
435                let start = i - len;
436                let candidate: String = chars[start..i].iter().collect();
437                if let Some(&score) = self.vocab.get(&candidate) {
438                    let log_score = score.ln();
439                    let total = dp[start] + log_score;
440                    if total > dp[i] {
441                        dp[i] = total;
442                        back[i] = len;
443                    }
444                }
445            }
446            // If no vocab entry covers this position, use single character fallback
447            if dp[i] == f32::NEG_INFINITY {
448                dp[i] = dp[i - 1] + (-10.0); // penalty for OOV character
449                back[i] = 1;
450            }
451        }
452
453        // Backtrack to recover tokens
454        let mut tokens = Vec::new();
455        let mut pos = n;
456        while pos > 0 {
457            let len = back[pos];
458            let token: String = chars[pos - len..pos].iter().collect();
459            tokens.push(token);
460            pos -= len;
461        }
462        tokens.reverse();
463        tokens
464    }
465}
466
467impl Tokenizer for UnigramTokenizer {
468    fn tokenize(&self, text: &str) -> Vec<String> {
469        // Split by whitespace first, then tokenize each word with Viterbi
470        let mut all_tokens = Vec::new();
471
472        for word in text.split_whitespace() {
473            let word_tokens = self.viterbi_tokenize(word);
474            all_tokens.extend(word_tokens);
475        }
476
477        all_tokens
478    }
479}
480
481// =============================================================================
482// Tests
483// =============================================================================
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_whitespace_tokenizer() {
491        let tokenizer = WhitespaceTokenizer::new();
492        let tokens = tokenizer.tokenize("Hello World");
493
494        assert_eq!(tokens, vec!["Hello", "World"]);
495    }
496
497    #[test]
498    fn test_whitespace_tokenizer_lowercase() {
499        let tokenizer = WhitespaceTokenizer::lowercase();
500        let tokens = tokenizer.tokenize("Hello World");
501
502        assert_eq!(tokens, vec!["hello", "world"]);
503    }
504
505    #[test]
506    fn test_char_tokenizer() {
507        let tokenizer = CharTokenizer::new();
508        let tokens = tokenizer.tokenize("Hi!");
509
510        assert_eq!(tokens, vec!["H", "i", "!"]);
511    }
512
513    #[test]
514    fn test_char_tokenizer_no_whitespace() {
515        let tokenizer = CharTokenizer::no_whitespace();
516        let tokens = tokenizer.tokenize("Hi there!");
517
518        assert_eq!(tokens, vec!["H", "i", "t", "h", "e", "r", "e", "!"]);
519    }
520
521    #[test]
522    fn test_word_punct_tokenizer() {
523        let tokenizer = WordPunctTokenizer::new();
524        let tokens = tokenizer.tokenize("Hello, World!");
525
526        assert_eq!(tokens, vec!["Hello", ",", "World", "!"]);
527    }
528
529    #[test]
530    fn test_word_punct_tokenizer_lowercase() {
531        let tokenizer = WordPunctTokenizer::lowercase();
532        let tokens = tokenizer.tokenize("Hello, World!");
533
534        assert_eq!(tokens, vec!["hello", ",", "world", "!"]);
535    }
536
537    #[test]
538    fn test_ngram_word_tokenizer() {
539        let tokenizer = NGramTokenizer::word_ngrams(2);
540        let tokens = tokenizer.tokenize("one two three");
541
542        assert_eq!(tokens, vec!["one two", "two three"]);
543    }
544
545    #[test]
546    fn test_ngram_char_tokenizer() {
547        let tokenizer = NGramTokenizer::char_ngrams(3);
548        let tokens = tokenizer.tokenize("hello");
549
550        assert_eq!(tokens, vec!["hel", "ell", "llo"]);
551    }
552
553    #[test]
554    fn test_bpe_tokenizer_basic() {
555        let mut tokenizer = BasicBPETokenizer::new();
556        tokenizer.train("low lower lowest", 10);
557
558        // Should have learned some merges
559        assert!(!tokenizer.get_vocab().is_empty());
560    }
561
562    #[test]
563    fn test_bpe_tokenizer_apply() {
564        let mut tokenizer = BasicBPETokenizer::new();
565        tokenizer.train("low low low lower lowest", 5);
566
567        let tokens = tokenizer.tokenize("low");
568        assert!(!tokens.is_empty());
569    }
570
571    #[test]
572    fn test_unigram_tokenizer() {
573        let tokenizer = UnigramTokenizer::from_tokens(&[
574            "hel", "lo", "wor", "ld", "h", "e", "l", "o", "w", "r", "d",
575        ]);
576        let tokens = tokenizer.tokenize("hello world");
577
578        // Should produce some tokens
579        assert!(!tokens.is_empty());
580    }
581
582    #[test]
583    fn test_tokenizer_encode() {
584        let tokenizer = WhitespaceTokenizer::new();
585        let mut vocab = Vocab::new();
586        vocab.add_token("hello");
587        vocab.add_token("world");
588
589        let indices = tokenizer.encode("hello world", &vocab);
590        assert_eq!(indices, vec![0, 1]);
591    }
592
593    #[test]
594    fn test_tokenizer_with_multiple_spaces() {
595        let tokenizer = WhitespaceTokenizer::new();
596        let tokens = tokenizer.tokenize("hello    world");
597
598        assert_eq!(tokens, vec!["hello", "world"]);
599    }
600
601    #[test]
602    fn test_empty_text() {
603        let tokenizer = WhitespaceTokenizer::new();
604        let tokens = tokenizer.tokenize("");
605
606        assert!(tokens.is_empty());
607    }
608}