oxibonsai_tokenizer/trainer.rs
1//! BPE tokenizer trainer: learn merge rules from a text corpus.
2//!
3//! Algorithm (Sennrich et al. 2016):
4//! 1. Initialize vocabulary with byte-level characters (0–255).
5//! 2. Encode corpus as sequences of byte token IDs.
6//! 3. Repeat for `num_merges` iterations:
7//! a. Count all adjacent symbol-pair frequencies.
8//! b. Find the most frequent pair.
9//! c. Merge that pair everywhere in the corpus.
10//! d. Add the merged token to vocabulary.
11//! 4. Return trained vocabulary + merge rules.
12
13use std::collections::HashMap;
14
15use thiserror::Error;
16
17use crate::{
18 bpe::BpeMerges,
19 tokenizer::{OxiTokenizer, TokenizerConfig},
20 vocab::Vocabulary,
21};
22
23// ── TrainerConfig ─────────────────────────────────────────────────────────────
24
25/// Configuration for the BPE trainer.
26///
27/// Marked `#[non_exhaustive]` so that new training knobs can be added in
28/// future minor releases without a breaking change. Downstream callers must
29/// construct it via [`TrainerConfig::new`] or [`TrainerConfig::default`].
30#[derive(Debug, Clone)]
31#[non_exhaustive]
32pub struct TrainerConfig {
33 /// Target vocabulary size (base 256 byte tokens + num_merges merged tokens).
34 pub vocab_size: usize,
35 /// Minimum pair frequency required to perform a merge.
36 pub min_frequency: usize,
37 /// Whether to add special tokens (BOS=0, EOS=1, PAD=2, UNK=3) at IDs 0–3.
38 /// When `true`, byte tokens start at ID 4 instead of ID 0.
39 pub add_special_tokens: bool,
40 /// When `true`, pre-tokenize on whitespace boundaries (GPT-2 style) before BPE.
41 pub byte_level: bool,
42 /// If `Some(n)`, log progress every `n` merges.
43 pub progress_interval: Option<usize>,
44}
45
46impl Default for TrainerConfig {
47 fn default() -> Self {
48 Self {
49 vocab_size: 1000,
50 min_frequency: 2,
51 add_special_tokens: true,
52 byte_level: true,
53 progress_interval: None,
54 }
55 }
56}
57
58impl TrainerConfig {
59 /// Create a config targeting `vocab_size` tokens with all other fields at
60 /// their defaults.
61 pub fn new(vocab_size: usize) -> Self {
62 Self {
63 vocab_size,
64 ..Default::default()
65 }
66 }
67
68 /// Override the minimum pair frequency threshold.
69 pub fn with_min_frequency(mut self, freq: usize) -> Self {
70 self.min_frequency = freq;
71 self
72 }
73
74 /// Enable or disable automatic special-token insertion.
75 pub fn with_special_tokens(mut self, add: bool) -> Self {
76 self.add_special_tokens = add;
77 self
78 }
79}
80
81// ── SymbolPair ────────────────────────────────────────────────────────────────
82
83/// A pair of adjacent symbol IDs (left, right).
84#[derive(Debug, Clone, PartialEq, Eq, Hash)]
85pub struct SymbolPair(pub u32, pub u32);
86
87impl SymbolPair {
88 /// Construct a pair from two token IDs.
89 pub fn new(a: u32, b: u32) -> Self {
90 Self(a, b)
91 }
92
93 /// Produce the [`MergeRule`] that results from merging this pair into `new_id`.
94 pub fn merged_symbol(&self, new_id: u32, merged_text: String) -> MergeRule {
95 MergeRule {
96 left: self.0,
97 right: self.1,
98 merged: new_id,
99 merged_text,
100 }
101 }
102}
103
104// ── MergeRule ─────────────────────────────────────────────────────────────────
105
106/// A single BPE merge rule: (left, right) → merged token.
107#[derive(Debug, Clone)]
108pub struct MergeRule {
109 /// ID of the left symbol in the pair.
110 pub left: u32,
111 /// ID of the right symbol in the pair.
112 pub right: u32,
113 /// ID assigned to the merged token.
114 pub merged: u32,
115 /// String representation of the merged token.
116 pub merged_text: String,
117}
118
119// ── Word ──────────────────────────────────────────────────────────────────────
120
121/// A word (pre-token) in the training corpus represented as an ordered sequence
122/// of symbol IDs together with its frequency.
123#[derive(Debug, Clone)]
124struct Word {
125 /// Current symbol sequence (may shrink as merges are applied).
126 symbols: Vec<u32>,
127 /// Number of times this word appears in the corpus.
128 freq: usize,
129}
130
131impl Word {
132 fn new(symbols: Vec<u32>, freq: usize) -> Self {
133 Self { symbols, freq }
134 }
135}
136
137// ── TrainingStats ─────────────────────────────────────────────────────────────
138
139/// Statistics gathered during a training run.
140#[derive(Debug, Clone)]
141pub struct TrainingStats {
142 /// Vocabulary size before any merges (256 byte tokens + optional specials).
143 pub initial_vocab_size: usize,
144 /// Vocabulary size at the end of training.
145 pub final_vocab_size: usize,
146 /// Number of merge operations successfully applied.
147 pub num_merges_performed: usize,
148 /// Number of candidate pairs rejected because they fell below `min_frequency`.
149 pub num_merges_skipped: usize,
150 /// Total character count across the entire corpus (sum of `str::len()`).
151 pub corpus_size_chars: usize,
152 /// Number of distinct pre-tokenized word types.
153 pub unique_words: usize,
154}
155
156impl TrainingStats {
157 /// Human-readable one-line summary of the training run.
158 pub fn summary(&self) -> String {
159 format!(
160 "BPE training: {init} → {fin} tokens | \
161 {merges} merges applied, {skipped} skipped | \
162 corpus {chars} bytes, {words} unique words",
163 init = self.initial_vocab_size,
164 fin = self.final_vocab_size,
165 merges = self.num_merges_performed,
166 skipped = self.num_merges_skipped,
167 chars = self.corpus_size_chars,
168 words = self.unique_words,
169 )
170 }
171}
172
173// ── TrainedTokenizer ──────────────────────────────────────────────────────────
174
175/// The result returned by [`BpeTrainer::train`].
176#[derive(Debug)]
177pub struct TrainedTokenizer {
178 /// Full ID → token-string mapping (byte tokens + merged tokens + specials).
179 pub vocab: HashMap<u32, String>,
180 /// Merge rules in the order they were learned (first learned = highest priority).
181 pub merges: Vec<MergeRule>,
182 /// Diagnostic information about the training run.
183 pub stats: TrainingStats,
184}
185
186impl TrainedTokenizer {
187 /// Convert this trained result into a ready-to-use [`OxiTokenizer`].
188 ///
189 /// The [`TokenizerConfig`] is set to defaults; callers may rebuild from the
190 /// raw `vocab` / `merges` fields if a custom config is needed.
191 pub fn to_oxi_tokenizer(&self) -> OxiTokenizer {
192 let mut vocabulary = Vocabulary::new();
193 // Determine whether special-token slots are present by checking IDs 0-3.
194 // Special tokens are identified by their angle-bracket names.
195 for (&id, token) in &self.vocab {
196 if token.starts_with('<') && token.ends_with('>') {
197 vocabulary.add_special(token, id);
198 } else {
199 vocabulary.insert(token, id);
200 }
201 }
202
203 let mut bpe_merges = BpeMerges::new();
204 for rule in &self.merges {
205 // Reconstruct the left and right token strings from the vocab map.
206 let left_str = self.vocab.get(&rule.left).map(|s| s.as_str()).unwrap_or("");
207 let right_str = self
208 .vocab
209 .get(&rule.right)
210 .map(|s| s.as_str())
211 .unwrap_or("");
212 bpe_merges.add_merge(left_str, right_str, rule.merged);
213 }
214
215 let config = TokenizerConfig::default();
216 OxiTokenizer::new(vocabulary, bpe_merges, config)
217 }
218
219 /// Serialize merge rules as plain text (one rule per line).
220 ///
221 /// Format: `<left_token> <right_token>`
222 /// (matching the HuggingFace `merges.txt` convention).
223 pub fn merges_to_text(&self) -> String {
224 let mut out = String::new();
225 for rule in &self.merges {
226 let left = self.vocab.get(&rule.left).map(|s| s.as_str()).unwrap_or("");
227 let right = self
228 .vocab
229 .get(&rule.right)
230 .map(|s| s.as_str())
231 .unwrap_or("");
232 out.push_str(left);
233 out.push(' ');
234 out.push_str(right);
235 out.push('\n');
236 }
237 out
238 }
239
240 /// Total number of tokens in the trained vocabulary.
241 pub fn vocab_size(&self) -> usize {
242 self.vocab.len()
243 }
244}
245
246// ── TrainerError ──────────────────────────────────────────────────────────────
247
248/// Errors that can occur during BPE training.
249#[derive(Debug, Error)]
250pub enum TrainerError {
251 /// The corpus slice was empty.
252 #[error("empty corpus")]
253 EmptyCorpus,
254 /// Requested `vocab_size` is too small to hold even the base byte vocabulary.
255 #[error("vocab_size {0} must be > 256 (base byte vocabulary)")]
256 VocabSizeTooSmall(usize),
257 /// Pre-tokenization produced no usable words.
258 #[error("corpus has no valid words after pre-tokenization")]
259 NoValidWords,
260}
261
262// ── BpeTrainer ────────────────────────────────────────────────────────────────
263
264/// BPE trainer that learns merge rules from a raw text corpus.
265///
266/// # Example
267///
268/// ```rust
269/// use oxibonsai_tokenizer::trainer::{BpeTrainer, TrainerConfig};
270///
271/// let mut trainer = BpeTrainer::new(TrainerConfig::new(512));
272/// let corpus = ["the quick brown fox", "the fox jumped"];
273/// let trained = trainer.train(&corpus).expect("training should succeed");
274/// println!("{}", trained.stats.summary());
275/// ```
276pub struct BpeTrainer {
277 config: TrainerConfig,
278 /// Byte value → initial token ID (256 entries when `add_special_tokens` is
279 /// false; otherwise IDs are offset by 4 to leave room for specials).
280 char_vocab: HashMap<u8, u32>,
281 /// The next token ID to assign to a newly merged token.
282 next_id: u32,
283}
284
285impl BpeTrainer {
286 /// Create a new trainer with the supplied configuration.
287 pub fn new(config: TrainerConfig) -> Self {
288 let char_vocab = HashMap::new(); // populated lazily in `train`
289 let next_id = 0;
290 Self {
291 config,
292 char_vocab,
293 next_id,
294 }
295 }
296
297 /// Convenience constructor with default configuration.
298 pub fn default_config() -> Self {
299 Self::new(TrainerConfig::default())
300 }
301
302 // ── Public entry point ────────────────────────────────────────────────
303
304 /// Train a BPE tokenizer on the supplied corpus.
305 ///
306 /// Each element of `corpus` is treated as an independent document.
307 /// The function is deterministic: given the same corpus and config it always
308 /// produces the same output.
309 pub fn train(&mut self, corpus: &[&str]) -> Result<TrainedTokenizer, TrainerError> {
310 // ── Validate inputs ───────────────────────────────────────────────
311 if corpus.is_empty() {
312 return Err(TrainerError::EmptyCorpus);
313 }
314
315 // We always need room for at least 256 byte tokens.
316 let min_size: usize = if self.config.add_special_tokens {
317 256 + 4
318 } else {
319 256
320 };
321 if self.config.vocab_size <= min_size.saturating_sub(1) {
322 return Err(TrainerError::VocabSizeTooSmall(self.config.vocab_size));
323 }
324
325 // ── Build initial byte vocabulary ─────────────────────────────────
326 let mut id_to_token: HashMap<u32, String> = HashMap::new();
327
328 // Reserve IDs 0-3 for special tokens when requested.
329 let byte_id_offset: u32 = if self.config.add_special_tokens { 4 } else { 0 };
330
331 if self.config.add_special_tokens {
332 id_to_token.insert(0, "<unk>".to_owned());
333 id_to_token.insert(1, "<bos>".to_owned());
334 id_to_token.insert(2, "<eos>".to_owned());
335 id_to_token.insert(3, "<pad>".to_owned());
336 }
337
338 self.char_vocab.clear();
339 for byte in 0u8..=255u8 {
340 let id = byte as u32 + byte_id_offset;
341 // Token string for a byte is the raw UTF-8 character if it is
342 // printable ASCII; otherwise use the `<0xHH>` byte-fallback form.
343 let token = byte_token_string(byte);
344 self.char_vocab.insert(byte, id);
345 id_to_token.insert(id, token);
346 }
347
348 self.next_id = 256 + byte_id_offset;
349
350 let initial_vocab_size = id_to_token.len();
351
352 // ── Pre-tokenize corpus ───────────────────────────────────────────
353 let corpus_size_chars: usize = corpus.iter().map(|s| s.len()).sum();
354 let word_freqs = self.pretokenize(corpus);
355
356 if word_freqs.is_empty() {
357 return Err(TrainerError::NoValidWords);
358 }
359
360 let unique_words = word_freqs.len();
361
362 // Convert word-frequency map to a Vec<Word> of symbol sequences.
363 let mut words: Vec<Word> = word_freqs
364 .into_iter()
365 .map(|(text, freq)| {
366 let symbols = self.encode_word(&text);
367 Word::new(symbols, freq)
368 })
369 .collect();
370
371 // ── BPE training loop ─────────────────────────────────────────────
372 let num_merges = self.config.vocab_size.saturating_sub(self.next_id as usize);
373 let mut merge_rules: Vec<MergeRule> = Vec::with_capacity(num_merges);
374 let mut num_merges_skipped: usize = 0;
375
376 for merge_idx in 0..num_merges {
377 // Log progress if requested.
378 if let Some(interval) = self.config.progress_interval {
379 if interval > 0 && merge_idx % interval == 0 {
380 tracing::debug!(
381 merge = merge_idx,
382 total = num_merges,
383 vocab = self.next_id,
384 "BPE training progress",
385 );
386 }
387 }
388
389 // Count pair frequencies.
390 let pair_counts = self.count_pairs(&words);
391 if pair_counts.is_empty() {
392 // No more pairs — corpus has been fully merged.
393 break;
394 }
395
396 // Select the best pair.
397 let best = match self.best_pair(&pair_counts) {
398 Some(b) => b,
399 None => {
400 // All remaining pairs are below min_frequency.
401 num_merges_skipped += num_merges - merge_idx;
402 break;
403 }
404 };
405
406 let (pair, _freq) = best;
407
408 // Build the merged token string.
409 let left_str = id_to_token.get(&pair.0).cloned().unwrap_or_default();
410 let right_str = id_to_token.get(&pair.1).cloned().unwrap_or_default();
411 let merged_text = format!("{left_str}{right_str}");
412
413 // Assign a new ID to the merged token.
414 let new_id = self.next_id;
415 self.next_id += 1;
416 id_to_token.insert(new_id, merged_text.clone());
417
418 // Record the merge rule.
419 let rule = pair.merged_symbol(new_id, merged_text);
420 merge_rules.push(rule);
421
422 // Apply the merge throughout the corpus.
423 self.apply_merge(&mut words, &pair, new_id);
424 }
425
426 let final_vocab_size = id_to_token.len();
427 let num_merges_performed = merge_rules.len();
428
429 let stats = TrainingStats {
430 initial_vocab_size,
431 final_vocab_size,
432 num_merges_performed,
433 num_merges_skipped,
434 corpus_size_chars,
435 unique_words,
436 };
437
438 Ok(TrainedTokenizer {
439 vocab: id_to_token,
440 merges: merge_rules,
441 stats,
442 })
443 }
444
445 // ── Private helpers ───────────────────────────────────────────────────
446
447 /// Count the frequency of every adjacent symbol pair across all words.
448 ///
449 /// Each pair's count is weighted by the frequency of the word it appears in.
450 fn count_pairs(&self, words: &[Word]) -> HashMap<SymbolPair, usize> {
451 let mut counts: HashMap<SymbolPair, usize> = HashMap::new();
452 for word in words {
453 if word.symbols.len() < 2 {
454 continue;
455 }
456 for window in word.symbols.windows(2) {
457 let pair = SymbolPair::new(window[0], window[1]);
458 *counts.entry(pair).or_insert(0) += word.freq;
459 }
460 }
461 counts
462 }
463
464 /// Find the most frequent pair whose count meets or exceeds `min_frequency`.
465 ///
466 /// Ties are broken deterministically by preferring the pair with the smallest
467 /// (left, right) ID values so that training is fully reproducible.
468 fn best_pair(&self, pair_counts: &HashMap<SymbolPair, usize>) -> Option<(SymbolPair, usize)> {
469 pair_counts
470 .iter()
471 .filter(|(_, &count)| count >= self.config.min_frequency)
472 .max_by(|(pair_a, &cnt_a), (pair_b, &cnt_b)| {
473 // Primary: higher frequency wins.
474 // Secondary (tiebreak): lower IDs win (deterministic).
475 cnt_a
476 .cmp(&cnt_b)
477 .then_with(|| pair_b.0.cmp(&pair_a.0))
478 .then_with(|| pair_b.1.cmp(&pair_a.1))
479 })
480 .map(|(pair, &count)| (pair.clone(), count))
481 }
482
483 /// Apply a merge rule to every occurrence of `pair` in all words in-place.
484 ///
485 /// When a match is found at position `i`, `symbols[i]` is replaced with
486 /// `new_id` and `symbols[i+1]` is removed. The scan continues from
487 /// position `i` (not `i+1`) to handle non-overlapping matches correctly.
488 fn apply_merge(&self, words: &mut [Word], pair: &SymbolPair, new_id: u32) {
489 for word in words.iter_mut() {
490 if word.symbols.len() < 2 {
491 continue;
492 }
493 let mut i = 0;
494 while i + 1 < word.symbols.len() {
495 if word.symbols[i] == pair.0 && word.symbols[i + 1] == pair.1 {
496 word.symbols[i] = new_id;
497 word.symbols.remove(i + 1);
498 // Do NOT advance `i`: the newly merged token at position `i`
499 // may form another valid pair with the next symbol.
500 } else {
501 i += 1;
502 }
503 }
504 }
505 }
506
507 /// Pre-tokenize the corpus into a map from word-string → frequency.
508 ///
509 /// When `byte_level` is set, text is split on whitespace boundaries so that
510 /// BPE operates on words rather than the full document. Otherwise the
511 /// entire document is treated as one unit.
512 fn pretokenize(&self, corpus: &[&str]) -> HashMap<String, usize> {
513 let mut freq_map: HashMap<String, usize> = HashMap::new();
514 for &doc in corpus {
515 if self.config.byte_level {
516 // Split on whitespace; keep non-empty parts only.
517 for word in doc.split_whitespace() {
518 if !word.is_empty() {
519 *freq_map.entry(word.to_owned()).or_insert(0) += 1;
520 }
521 }
522 } else {
523 // Treat the entire document as a single unit.
524 if !doc.is_empty() {
525 *freq_map.entry(doc.to_owned()).or_insert(0) += 1;
526 }
527 }
528 }
529 freq_map
530 }
531
532 /// Encode a word string into its initial byte-level token ID sequence.
533 ///
534 /// Each byte of the UTF-8 representation becomes one symbol ID.
535 fn encode_word(&self, word: &str) -> Vec<u32> {
536 word.as_bytes()
537 .iter()
538 .filter_map(|b| self.char_vocab.get(b).copied())
539 .collect()
540 }
541}
542
543// ── Helpers ───────────────────────────────────────────────────────────────────
544
545/// Return the canonical string representation for a byte token.
546///
547/// - Printable ASCII (0x20–0x7E): the character itself.
548/// - Everything else: `<0xHH>` byte-fallback form.
549fn byte_token_string(byte: u8) -> String {
550 if byte.is_ascii() && !byte.is_ascii_control() {
551 // Printable ASCII.
552 (byte as char).to_string()
553 } else {
554 format!("<0x{byte:02X}>")
555 }
556}
557
558// ── Tests (inline sanity checks) ──────────────────────────────────────────────
559
560#[cfg(test)]
561mod inline_tests {
562 use super::*;
563
564 #[test]
565 fn byte_token_string_printable() {
566 assert_eq!(byte_token_string(b'a'), "a");
567 assert_eq!(byte_token_string(b' '), " ");
568 assert_eq!(byte_token_string(b'~'), "~");
569 }
570
571 #[test]
572 fn byte_token_string_control() {
573 assert_eq!(byte_token_string(0x00), "<0x00>");
574 assert_eq!(byte_token_string(0x0A), "<0x0A>");
575 assert_eq!(byte_token_string(0xFF), "<0xFF>");
576 }
577
578 #[test]
579 fn count_pairs_basic() {
580 let mut trainer = BpeTrainer::new(TrainerConfig::new(300));
581 trainer.char_vocab.insert(b'a', 0);
582 trainer.char_vocab.insert(b'b', 1);
583 let words = vec![Word::new(vec![0, 1, 0, 1], 3)];
584 let counts = trainer.count_pairs(&words);
585 assert_eq!(counts.get(&SymbolPair::new(0, 1)), Some(&6)); // appears twice × freq 3
586 }
587
588 #[test]
589 fn apply_merge_replaces_pair() {
590 let trainer = BpeTrainer::new(TrainerConfig::new(300));
591 let mut words = vec![Word::new(vec![0, 1, 0, 1], 1)];
592 trainer.apply_merge(&mut words, &SymbolPair::new(0, 1), 99);
593 assert_eq!(words[0].symbols, vec![99, 99]);
594 }
595}