Skip to main content

oxibonsai_tokenizer/
trainer.rs

1//! BPE tokenizer trainer: learn merge rules from a text corpus.
2//!
3//! Algorithm (Sennrich et al. 2016):
4//! 1. Initialize vocabulary with byte-level characters (0–255).
5//! 2. Encode corpus as sequences of byte token IDs.
6//! 3. Repeat for `num_merges` iterations:
7//!    a. Count all adjacent symbol-pair frequencies.
8//!    b. Find the most frequent pair.
9//!    c. Merge that pair everywhere in the corpus.
10//!    d. Add the merged token to vocabulary.
11//! 4. Return trained vocabulary + merge rules.
12
13use std::collections::HashMap;
14
15use thiserror::Error;
16
17use crate::{
18    bpe::BpeMerges,
19    tokenizer::{OxiTokenizer, TokenizerConfig},
20    vocab::Vocabulary,
21};
22
23// ── TrainerConfig ─────────────────────────────────────────────────────────────
24
25/// Configuration for the BPE trainer.
26///
27/// Marked `#[non_exhaustive]` so that new training knobs can be added in
28/// future minor releases without a breaking change.  Downstream callers must
29/// construct it via [`TrainerConfig::new`] or [`TrainerConfig::default`].
30#[derive(Debug, Clone)]
31#[non_exhaustive]
32pub struct TrainerConfig {
33    /// Target vocabulary size (base 256 byte tokens + num_merges merged tokens).
34    pub vocab_size: usize,
35    /// Minimum pair frequency required to perform a merge.
36    pub min_frequency: usize,
37    /// Whether to add special tokens (BOS=0, EOS=1, PAD=2, UNK=3) at IDs 0–3.
38    /// When `true`, byte tokens start at ID 4 instead of ID 0.
39    pub add_special_tokens: bool,
40    /// When `true`, pre-tokenize on whitespace boundaries (GPT-2 style) before BPE.
41    pub byte_level: bool,
42    /// If `Some(n)`, log progress every `n` merges.
43    pub progress_interval: Option<usize>,
44}
45
46impl Default for TrainerConfig {
47    fn default() -> Self {
48        Self {
49            vocab_size: 1000,
50            min_frequency: 2,
51            add_special_tokens: true,
52            byte_level: true,
53            progress_interval: None,
54        }
55    }
56}
57
58impl TrainerConfig {
59    /// Create a config targeting `vocab_size` tokens with all other fields at
60    /// their defaults.
61    pub fn new(vocab_size: usize) -> Self {
62        Self {
63            vocab_size,
64            ..Default::default()
65        }
66    }
67
68    /// Override the minimum pair frequency threshold.
69    pub fn with_min_frequency(mut self, freq: usize) -> Self {
70        self.min_frequency = freq;
71        self
72    }
73
74    /// Enable or disable automatic special-token insertion.
75    pub fn with_special_tokens(mut self, add: bool) -> Self {
76        self.add_special_tokens = add;
77        self
78    }
79}
80
81// ── SymbolPair ────────────────────────────────────────────────────────────────
82
83/// A pair of adjacent symbol IDs (left, right).
84#[derive(Debug, Clone, PartialEq, Eq, Hash)]
85pub struct SymbolPair(pub u32, pub u32);
86
87impl SymbolPair {
88    /// Construct a pair from two token IDs.
89    pub fn new(a: u32, b: u32) -> Self {
90        Self(a, b)
91    }
92
93    /// Produce the [`MergeRule`] that results from merging this pair into `new_id`.
94    pub fn merged_symbol(&self, new_id: u32, merged_text: String) -> MergeRule {
95        MergeRule {
96            left: self.0,
97            right: self.1,
98            merged: new_id,
99            merged_text,
100        }
101    }
102}
103
104// ── MergeRule ─────────────────────────────────────────────────────────────────
105
106/// A single BPE merge rule: (left, right) → merged token.
107#[derive(Debug, Clone)]
108pub struct MergeRule {
109    /// ID of the left symbol in the pair.
110    pub left: u32,
111    /// ID of the right symbol in the pair.
112    pub right: u32,
113    /// ID assigned to the merged token.
114    pub merged: u32,
115    /// String representation of the merged token.
116    pub merged_text: String,
117}
118
119// ── Word ──────────────────────────────────────────────────────────────────────
120
121/// A word (pre-token) in the training corpus represented as an ordered sequence
122/// of symbol IDs together with its frequency.
123#[derive(Debug, Clone)]
124struct Word {
125    /// Current symbol sequence (may shrink as merges are applied).
126    symbols: Vec<u32>,
127    /// Number of times this word appears in the corpus.
128    freq: usize,
129}
130
131impl Word {
132    fn new(symbols: Vec<u32>, freq: usize) -> Self {
133        Self { symbols, freq }
134    }
135}
136
137// ── TrainingStats ─────────────────────────────────────────────────────────────
138
139/// Statistics gathered during a training run.
140#[derive(Debug, Clone)]
141pub struct TrainingStats {
142    /// Vocabulary size before any merges (256 byte tokens + optional specials).
143    pub initial_vocab_size: usize,
144    /// Vocabulary size at the end of training.
145    pub final_vocab_size: usize,
146    /// Number of merge operations successfully applied.
147    pub num_merges_performed: usize,
148    /// Number of candidate pairs rejected because they fell below `min_frequency`.
149    pub num_merges_skipped: usize,
150    /// Total character count across the entire corpus (sum of `str::len()`).
151    pub corpus_size_chars: usize,
152    /// Number of distinct pre-tokenized word types.
153    pub unique_words: usize,
154}
155
156impl TrainingStats {
157    /// Human-readable one-line summary of the training run.
158    pub fn summary(&self) -> String {
159        format!(
160            "BPE training: {init} → {fin} tokens | \
161             {merges} merges applied, {skipped} skipped | \
162             corpus {chars} bytes, {words} unique words",
163            init = self.initial_vocab_size,
164            fin = self.final_vocab_size,
165            merges = self.num_merges_performed,
166            skipped = self.num_merges_skipped,
167            chars = self.corpus_size_chars,
168            words = self.unique_words,
169        )
170    }
171}
172
173// ── TrainedTokenizer ──────────────────────────────────────────────────────────
174
175/// The result returned by [`BpeTrainer::train`].
176#[derive(Debug)]
177pub struct TrainedTokenizer {
178    /// Full ID → token-string mapping (byte tokens + merged tokens + specials).
179    pub vocab: HashMap<u32, String>,
180    /// Merge rules in the order they were learned (first learned = highest priority).
181    pub merges: Vec<MergeRule>,
182    /// Diagnostic information about the training run.
183    pub stats: TrainingStats,
184}
185
186impl TrainedTokenizer {
187    /// Convert this trained result into a ready-to-use [`OxiTokenizer`].
188    ///
189    /// The [`TokenizerConfig`] is set to defaults; callers may rebuild from the
190    /// raw `vocab` / `merges` fields if a custom config is needed.
191    pub fn to_oxi_tokenizer(&self) -> OxiTokenizer {
192        let mut vocabulary = Vocabulary::new();
193        // Determine whether special-token slots are present by checking IDs 0-3.
194        // Special tokens are identified by their angle-bracket names.
195        for (&id, token) in &self.vocab {
196            if token.starts_with('<') && token.ends_with('>') {
197                vocabulary.add_special(token, id);
198            } else {
199                vocabulary.insert(token, id);
200            }
201        }
202
203        let mut bpe_merges = BpeMerges::new();
204        for rule in &self.merges {
205            // Reconstruct the left and right token strings from the vocab map.
206            let left_str = self.vocab.get(&rule.left).map(|s| s.as_str()).unwrap_or("");
207            let right_str = self
208                .vocab
209                .get(&rule.right)
210                .map(|s| s.as_str())
211                .unwrap_or("");
212            bpe_merges.add_merge(left_str, right_str, rule.merged);
213        }
214
215        let config = TokenizerConfig::default();
216        OxiTokenizer::new(vocabulary, bpe_merges, config)
217    }
218
219    /// Serialize merge rules as plain text (one rule per line).
220    ///
221    /// Format: `<left_token> <right_token>`
222    /// (matching the HuggingFace `merges.txt` convention).
223    pub fn merges_to_text(&self) -> String {
224        let mut out = String::new();
225        for rule in &self.merges {
226            let left = self.vocab.get(&rule.left).map(|s| s.as_str()).unwrap_or("");
227            let right = self
228                .vocab
229                .get(&rule.right)
230                .map(|s| s.as_str())
231                .unwrap_or("");
232            out.push_str(left);
233            out.push(' ');
234            out.push_str(right);
235            out.push('\n');
236        }
237        out
238    }
239
240    /// Total number of tokens in the trained vocabulary.
241    pub fn vocab_size(&self) -> usize {
242        self.vocab.len()
243    }
244}
245
246// ── TrainerError ──────────────────────────────────────────────────────────────
247
248/// Errors that can occur during BPE training.
249#[derive(Debug, Error)]
250pub enum TrainerError {
251    /// The corpus slice was empty.
252    #[error("empty corpus")]
253    EmptyCorpus,
254    /// Requested `vocab_size` is too small to hold even the base byte vocabulary.
255    #[error("vocab_size {0} must be > 256 (base byte vocabulary)")]
256    VocabSizeTooSmall(usize),
257    /// Pre-tokenization produced no usable words.
258    #[error("corpus has no valid words after pre-tokenization")]
259    NoValidWords,
260}
261
262// ── BpeTrainer ────────────────────────────────────────────────────────────────
263
264/// BPE trainer that learns merge rules from a raw text corpus.
265///
266/// # Example
267///
268/// ```rust
269/// use oxibonsai_tokenizer::trainer::{BpeTrainer, TrainerConfig};
270///
271/// let mut trainer = BpeTrainer::new(TrainerConfig::new(512));
272/// let corpus = ["the quick brown fox", "the fox jumped"];
273/// let trained = trainer.train(&corpus).expect("training should succeed");
274/// println!("{}", trained.stats.summary());
275/// ```
276pub struct BpeTrainer {
277    config: TrainerConfig,
278    /// Byte value → initial token ID (256 entries when `add_special_tokens` is
279    /// false; otherwise IDs are offset by 4 to leave room for specials).
280    char_vocab: HashMap<u8, u32>,
281    /// The next token ID to assign to a newly merged token.
282    next_id: u32,
283}
284
285impl BpeTrainer {
286    /// Create a new trainer with the supplied configuration.
287    pub fn new(config: TrainerConfig) -> Self {
288        let char_vocab = HashMap::new(); // populated lazily in `train`
289        let next_id = 0;
290        Self {
291            config,
292            char_vocab,
293            next_id,
294        }
295    }
296
297    /// Convenience constructor with default configuration.
298    pub fn default_config() -> Self {
299        Self::new(TrainerConfig::default())
300    }
301
302    // ── Public entry point ────────────────────────────────────────────────
303
304    /// Train a BPE tokenizer on the supplied corpus.
305    ///
306    /// Each element of `corpus` is treated as an independent document.
307    /// The function is deterministic: given the same corpus and config it always
308    /// produces the same output.
309    pub fn train(&mut self, corpus: &[&str]) -> Result<TrainedTokenizer, TrainerError> {
310        // ── Validate inputs ───────────────────────────────────────────────
311        if corpus.is_empty() {
312            return Err(TrainerError::EmptyCorpus);
313        }
314
315        // We always need room for at least 256 byte tokens.
316        let min_size: usize = if self.config.add_special_tokens {
317            256 + 4
318        } else {
319            256
320        };
321        if self.config.vocab_size <= min_size.saturating_sub(1) {
322            return Err(TrainerError::VocabSizeTooSmall(self.config.vocab_size));
323        }
324
325        // ── Build initial byte vocabulary ─────────────────────────────────
326        let mut id_to_token: HashMap<u32, String> = HashMap::new();
327
328        // Reserve IDs 0-3 for special tokens when requested.
329        let byte_id_offset: u32 = if self.config.add_special_tokens { 4 } else { 0 };
330
331        if self.config.add_special_tokens {
332            id_to_token.insert(0, "<unk>".to_owned());
333            id_to_token.insert(1, "<bos>".to_owned());
334            id_to_token.insert(2, "<eos>".to_owned());
335            id_to_token.insert(3, "<pad>".to_owned());
336        }
337
338        self.char_vocab.clear();
339        for byte in 0u8..=255u8 {
340            let id = byte as u32 + byte_id_offset;
341            // Token string for a byte is the raw UTF-8 character if it is
342            // printable ASCII; otherwise use the `<0xHH>` byte-fallback form.
343            let token = byte_token_string(byte);
344            self.char_vocab.insert(byte, id);
345            id_to_token.insert(id, token);
346        }
347
348        self.next_id = 256 + byte_id_offset;
349
350        let initial_vocab_size = id_to_token.len();
351
352        // ── Pre-tokenize corpus ───────────────────────────────────────────
353        let corpus_size_chars: usize = corpus.iter().map(|s| s.len()).sum();
354        let word_freqs = self.pretokenize(corpus);
355
356        if word_freqs.is_empty() {
357            return Err(TrainerError::NoValidWords);
358        }
359
360        let unique_words = word_freqs.len();
361
362        // Convert word-frequency map to a Vec<Word> of symbol sequences.
363        let mut words: Vec<Word> = word_freqs
364            .into_iter()
365            .map(|(text, freq)| {
366                let symbols = self.encode_word(&text);
367                Word::new(symbols, freq)
368            })
369            .collect();
370
371        // ── BPE training loop ─────────────────────────────────────────────
372        let num_merges = self.config.vocab_size.saturating_sub(self.next_id as usize);
373        let mut merge_rules: Vec<MergeRule> = Vec::with_capacity(num_merges);
374        let mut num_merges_skipped: usize = 0;
375
376        for merge_idx in 0..num_merges {
377            // Log progress if requested.
378            if let Some(interval) = self.config.progress_interval {
379                if interval > 0 && merge_idx % interval == 0 {
380                    tracing::debug!(
381                        merge = merge_idx,
382                        total = num_merges,
383                        vocab = self.next_id,
384                        "BPE training progress",
385                    );
386                }
387            }
388
389            // Count pair frequencies.
390            let pair_counts = self.count_pairs(&words);
391            if pair_counts.is_empty() {
392                // No more pairs — corpus has been fully merged.
393                break;
394            }
395
396            // Select the best pair.
397            let best = match self.best_pair(&pair_counts) {
398                Some(b) => b,
399                None => {
400                    // All remaining pairs are below min_frequency.
401                    num_merges_skipped += num_merges - merge_idx;
402                    break;
403                }
404            };
405
406            let (pair, _freq) = best;
407
408            // Build the merged token string.
409            let left_str = id_to_token.get(&pair.0).cloned().unwrap_or_default();
410            let right_str = id_to_token.get(&pair.1).cloned().unwrap_or_default();
411            let merged_text = format!("{left_str}{right_str}");
412
413            // Assign a new ID to the merged token.
414            let new_id = self.next_id;
415            self.next_id += 1;
416            id_to_token.insert(new_id, merged_text.clone());
417
418            // Record the merge rule.
419            let rule = pair.merged_symbol(new_id, merged_text);
420            merge_rules.push(rule);
421
422            // Apply the merge throughout the corpus.
423            self.apply_merge(&mut words, &pair, new_id);
424        }
425
426        let final_vocab_size = id_to_token.len();
427        let num_merges_performed = merge_rules.len();
428
429        let stats = TrainingStats {
430            initial_vocab_size,
431            final_vocab_size,
432            num_merges_performed,
433            num_merges_skipped,
434            corpus_size_chars,
435            unique_words,
436        };
437
438        Ok(TrainedTokenizer {
439            vocab: id_to_token,
440            merges: merge_rules,
441            stats,
442        })
443    }
444
445    // ── Private helpers ───────────────────────────────────────────────────
446
447    /// Count the frequency of every adjacent symbol pair across all words.
448    ///
449    /// Each pair's count is weighted by the frequency of the word it appears in.
450    fn count_pairs(&self, words: &[Word]) -> HashMap<SymbolPair, usize> {
451        let mut counts: HashMap<SymbolPair, usize> = HashMap::new();
452        for word in words {
453            if word.symbols.len() < 2 {
454                continue;
455            }
456            for window in word.symbols.windows(2) {
457                let pair = SymbolPair::new(window[0], window[1]);
458                *counts.entry(pair).or_insert(0) += word.freq;
459            }
460        }
461        counts
462    }
463
464    /// Find the most frequent pair whose count meets or exceeds `min_frequency`.
465    ///
466    /// Ties are broken deterministically by preferring the pair with the smallest
467    /// (left, right) ID values so that training is fully reproducible.
468    fn best_pair(&self, pair_counts: &HashMap<SymbolPair, usize>) -> Option<(SymbolPair, usize)> {
469        pair_counts
470            .iter()
471            .filter(|(_, &count)| count >= self.config.min_frequency)
472            .max_by(|(pair_a, &cnt_a), (pair_b, &cnt_b)| {
473                // Primary: higher frequency wins.
474                // Secondary (tiebreak): lower IDs win (deterministic).
475                cnt_a
476                    .cmp(&cnt_b)
477                    .then_with(|| pair_b.0.cmp(&pair_a.0))
478                    .then_with(|| pair_b.1.cmp(&pair_a.1))
479            })
480            .map(|(pair, &count)| (pair.clone(), count))
481    }
482
483    /// Apply a merge rule to every occurrence of `pair` in all words in-place.
484    ///
485    /// When a match is found at position `i`, `symbols[i]` is replaced with
486    /// `new_id` and `symbols[i+1]` is removed.  The scan continues from
487    /// position `i` (not `i+1`) to handle non-overlapping matches correctly.
488    fn apply_merge(&self, words: &mut [Word], pair: &SymbolPair, new_id: u32) {
489        for word in words.iter_mut() {
490            if word.symbols.len() < 2 {
491                continue;
492            }
493            let mut i = 0;
494            while i + 1 < word.symbols.len() {
495                if word.symbols[i] == pair.0 && word.symbols[i + 1] == pair.1 {
496                    word.symbols[i] = new_id;
497                    word.symbols.remove(i + 1);
498                    // Do NOT advance `i`: the newly merged token at position `i`
499                    // may form another valid pair with the next symbol.
500                } else {
501                    i += 1;
502                }
503            }
504        }
505    }
506
507    /// Pre-tokenize the corpus into a map from word-string → frequency.
508    ///
509    /// When `byte_level` is set, text is split on whitespace boundaries so that
510    /// BPE operates on words rather than the full document.  Otherwise the
511    /// entire document is treated as one unit.
512    fn pretokenize(&self, corpus: &[&str]) -> HashMap<String, usize> {
513        let mut freq_map: HashMap<String, usize> = HashMap::new();
514        for &doc in corpus {
515            if self.config.byte_level {
516                // Split on whitespace; keep non-empty parts only.
517                for word in doc.split_whitespace() {
518                    if !word.is_empty() {
519                        *freq_map.entry(word.to_owned()).or_insert(0) += 1;
520                    }
521                }
522            } else {
523                // Treat the entire document as a single unit.
524                if !doc.is_empty() {
525                    *freq_map.entry(doc.to_owned()).or_insert(0) += 1;
526                }
527            }
528        }
529        freq_map
530    }
531
532    /// Encode a word string into its initial byte-level token ID sequence.
533    ///
534    /// Each byte of the UTF-8 representation becomes one symbol ID.
535    fn encode_word(&self, word: &str) -> Vec<u32> {
536        word.as_bytes()
537            .iter()
538            .filter_map(|b| self.char_vocab.get(b).copied())
539            .collect()
540    }
541}
542
543// ── Helpers ───────────────────────────────────────────────────────────────────
544
545/// Return the canonical string representation for a byte token.
546///
547/// - Printable ASCII (0x20–0x7E): the character itself.
548/// - Everything else: `<0xHH>` byte-fallback form.
549fn byte_token_string(byte: u8) -> String {
550    if byte.is_ascii() && !byte.is_ascii_control() {
551        // Printable ASCII.
552        (byte as char).to_string()
553    } else {
554        format!("<0x{byte:02X}>")
555    }
556}
557
558// ── Tests (inline sanity checks) ──────────────────────────────────────────────
559
560#[cfg(test)]
561mod inline_tests {
562    use super::*;
563
564    #[test]
565    fn byte_token_string_printable() {
566        assert_eq!(byte_token_string(b'a'), "a");
567        assert_eq!(byte_token_string(b' '), " ");
568        assert_eq!(byte_token_string(b'~'), "~");
569    }
570
571    #[test]
572    fn byte_token_string_control() {
573        assert_eq!(byte_token_string(0x00), "<0x00>");
574        assert_eq!(byte_token_string(0x0A), "<0x0A>");
575        assert_eq!(byte_token_string(0xFF), "<0xFF>");
576    }
577
578    #[test]
579    fn count_pairs_basic() {
580        let mut trainer = BpeTrainer::new(TrainerConfig::new(300));
581        trainer.char_vocab.insert(b'a', 0);
582        trainer.char_vocab.insert(b'b', 1);
583        let words = vec![Word::new(vec![0, 1, 0, 1], 3)];
584        let counts = trainer.count_pairs(&words);
585        assert_eq!(counts.get(&SymbolPair::new(0, 1)), Some(&6)); // appears twice × freq 3
586    }
587
588    #[test]
589    fn apply_merge_replaces_pair() {
590        let trainer = BpeTrainer::new(TrainerConfig::new(300));
591        let mut words = vec![Word::new(vec![0, 1, 0, 1], 1)];
592        trainer.apply_merge(&mut words, &SymbolPair::new(0, 1), 99);
593        assert_eq!(words[0].symbols, vec![99, 99]);
594    }
595}