use std::collections::HashMap;
use thiserror::Error;
use crate::{
bpe::BpeMerges,
tokenizer::{OxiTokenizer, TokenizerConfig},
vocab::Vocabulary,
};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TrainerConfig {
pub vocab_size: usize,
pub min_frequency: usize,
pub add_special_tokens: bool,
pub byte_level: bool,
pub progress_interval: Option<usize>,
}
impl Default for TrainerConfig {
fn default() -> Self {
Self {
vocab_size: 1000,
min_frequency: 2,
add_special_tokens: true,
byte_level: true,
progress_interval: None,
}
}
}
impl TrainerConfig {
pub fn new(vocab_size: usize) -> Self {
Self {
vocab_size,
..Default::default()
}
}
pub fn with_min_frequency(mut self, freq: usize) -> Self {
self.min_frequency = freq;
self
}
pub fn with_special_tokens(mut self, add: bool) -> Self {
self.add_special_tokens = add;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SymbolPair(pub u32, pub u32);
impl SymbolPair {
pub fn new(a: u32, b: u32) -> Self {
Self(a, b)
}
pub fn merged_symbol(&self, new_id: u32, merged_text: String) -> MergeRule {
MergeRule {
left: self.0,
right: self.1,
merged: new_id,
merged_text,
}
}
}
#[derive(Debug, Clone)]
pub struct MergeRule {
pub left: u32,
pub right: u32,
pub merged: u32,
pub merged_text: String,
}
#[derive(Debug, Clone)]
struct Word {
symbols: Vec<u32>,
freq: usize,
}
impl Word {
fn new(symbols: Vec<u32>, freq: usize) -> Self {
Self { symbols, freq }
}
}
#[derive(Debug, Clone)]
pub struct TrainingStats {
pub initial_vocab_size: usize,
pub final_vocab_size: usize,
pub num_merges_performed: usize,
pub num_merges_skipped: usize,
pub corpus_size_chars: usize,
pub unique_words: usize,
}
impl TrainingStats {
pub fn summary(&self) -> String {
format!(
"BPE training: {init} → {fin} tokens | \
{merges} merges applied, {skipped} skipped | \
corpus {chars} bytes, {words} unique words",
init = self.initial_vocab_size,
fin = self.final_vocab_size,
merges = self.num_merges_performed,
skipped = self.num_merges_skipped,
chars = self.corpus_size_chars,
words = self.unique_words,
)
}
}
#[derive(Debug)]
pub struct TrainedTokenizer {
pub vocab: HashMap<u32, String>,
pub merges: Vec<MergeRule>,
pub stats: TrainingStats,
}
impl TrainedTokenizer {
pub fn to_oxi_tokenizer(&self) -> OxiTokenizer {
let mut vocabulary = Vocabulary::new();
for (&id, token) in &self.vocab {
if token.starts_with('<') && token.ends_with('>') {
vocabulary.add_special(token, id);
} else {
vocabulary.insert(token, id);
}
}
let mut bpe_merges = BpeMerges::new();
for rule in &self.merges {
let left_str = self.vocab.get(&rule.left).map(|s| s.as_str()).unwrap_or("");
let right_str = self
.vocab
.get(&rule.right)
.map(|s| s.as_str())
.unwrap_or("");
bpe_merges.add_merge(left_str, right_str, rule.merged);
}
let config = TokenizerConfig::default();
OxiTokenizer::new(vocabulary, bpe_merges, config)
}
pub fn merges_to_text(&self) -> String {
let mut out = String::new();
for rule in &self.merges {
let left = self.vocab.get(&rule.left).map(|s| s.as_str()).unwrap_or("");
let right = self
.vocab
.get(&rule.right)
.map(|s| s.as_str())
.unwrap_or("");
out.push_str(left);
out.push(' ');
out.push_str(right);
out.push('\n');
}
out
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
}
#[derive(Debug, Error)]
pub enum TrainerError {
#[error("empty corpus")]
EmptyCorpus,
#[error("vocab_size {0} must be > 256 (base byte vocabulary)")]
VocabSizeTooSmall(usize),
#[error("corpus has no valid words after pre-tokenization")]
NoValidWords,
}
pub struct BpeTrainer {
config: TrainerConfig,
char_vocab: HashMap<u8, u32>,
next_id: u32,
}
impl BpeTrainer {
pub fn new(config: TrainerConfig) -> Self {
let char_vocab = HashMap::new(); let next_id = 0;
Self {
config,
char_vocab,
next_id,
}
}
pub fn default_config() -> Self {
Self::new(TrainerConfig::default())
}
pub fn train(&mut self, corpus: &[&str]) -> Result<TrainedTokenizer, TrainerError> {
if corpus.is_empty() {
return Err(TrainerError::EmptyCorpus);
}
let min_size: usize = if self.config.add_special_tokens {
256 + 4
} else {
256
};
if self.config.vocab_size <= min_size.saturating_sub(1) {
return Err(TrainerError::VocabSizeTooSmall(self.config.vocab_size));
}
let mut id_to_token: HashMap<u32, String> = HashMap::new();
let byte_id_offset: u32 = if self.config.add_special_tokens { 4 } else { 0 };
if self.config.add_special_tokens {
id_to_token.insert(0, "<unk>".to_owned());
id_to_token.insert(1, "<bos>".to_owned());
id_to_token.insert(2, "<eos>".to_owned());
id_to_token.insert(3, "<pad>".to_owned());
}
self.char_vocab.clear();
for byte in 0u8..=255u8 {
let id = byte as u32 + byte_id_offset;
let token = byte_token_string(byte);
self.char_vocab.insert(byte, id);
id_to_token.insert(id, token);
}
self.next_id = 256 + byte_id_offset;
let initial_vocab_size = id_to_token.len();
let corpus_size_chars: usize = corpus.iter().map(|s| s.len()).sum();
let word_freqs = self.pretokenize(corpus);
if word_freqs.is_empty() {
return Err(TrainerError::NoValidWords);
}
let unique_words = word_freqs.len();
let mut words: Vec<Word> = word_freqs
.into_iter()
.map(|(text, freq)| {
let symbols = self.encode_word(&text);
Word::new(symbols, freq)
})
.collect();
let num_merges = self.config.vocab_size.saturating_sub(self.next_id as usize);
let mut merge_rules: Vec<MergeRule> = Vec::with_capacity(num_merges);
let mut num_merges_skipped: usize = 0;
for merge_idx in 0..num_merges {
if let Some(interval) = self.config.progress_interval {
if interval > 0 && merge_idx % interval == 0 {
tracing::debug!(
merge = merge_idx,
total = num_merges,
vocab = self.next_id,
"BPE training progress",
);
}
}
let pair_counts = self.count_pairs(&words);
if pair_counts.is_empty() {
break;
}
let best = match self.best_pair(&pair_counts) {
Some(b) => b,
None => {
num_merges_skipped += num_merges - merge_idx;
break;
}
};
let (pair, _freq) = best;
let left_str = id_to_token.get(&pair.0).cloned().unwrap_or_default();
let right_str = id_to_token.get(&pair.1).cloned().unwrap_or_default();
let merged_text = format!("{left_str}{right_str}");
let new_id = self.next_id;
self.next_id += 1;
id_to_token.insert(new_id, merged_text.clone());
let rule = pair.merged_symbol(new_id, merged_text);
merge_rules.push(rule);
self.apply_merge(&mut words, &pair, new_id);
}
let final_vocab_size = id_to_token.len();
let num_merges_performed = merge_rules.len();
let stats = TrainingStats {
initial_vocab_size,
final_vocab_size,
num_merges_performed,
num_merges_skipped,
corpus_size_chars,
unique_words,
};
Ok(TrainedTokenizer {
vocab: id_to_token,
merges: merge_rules,
stats,
})
}
fn count_pairs(&self, words: &[Word]) -> HashMap<SymbolPair, usize> {
let mut counts: HashMap<SymbolPair, usize> = HashMap::new();
for word in words {
if word.symbols.len() < 2 {
continue;
}
for window in word.symbols.windows(2) {
let pair = SymbolPair::new(window[0], window[1]);
*counts.entry(pair).or_insert(0) += word.freq;
}
}
counts
}
fn best_pair(&self, pair_counts: &HashMap<SymbolPair, usize>) -> Option<(SymbolPair, usize)> {
pair_counts
.iter()
.filter(|(_, &count)| count >= self.config.min_frequency)
.max_by(|(pair_a, &cnt_a), (pair_b, &cnt_b)| {
cnt_a
.cmp(&cnt_b)
.then_with(|| pair_b.0.cmp(&pair_a.0))
.then_with(|| pair_b.1.cmp(&pair_a.1))
})
.map(|(pair, &count)| (pair.clone(), count))
}
fn apply_merge(&self, words: &mut [Word], pair: &SymbolPair, new_id: u32) {
for word in words.iter_mut() {
if word.symbols.len() < 2 {
continue;
}
let mut i = 0;
while i + 1 < word.symbols.len() {
if word.symbols[i] == pair.0 && word.symbols[i + 1] == pair.1 {
word.symbols[i] = new_id;
word.symbols.remove(i + 1);
} else {
i += 1;
}
}
}
}
fn pretokenize(&self, corpus: &[&str]) -> HashMap<String, usize> {
let mut freq_map: HashMap<String, usize> = HashMap::new();
for &doc in corpus {
if self.config.byte_level {
for word in doc.split_whitespace() {
if !word.is_empty() {
*freq_map.entry(word.to_owned()).or_insert(0) += 1;
}
}
} else {
if !doc.is_empty() {
*freq_map.entry(doc.to_owned()).or_insert(0) += 1;
}
}
}
freq_map
}
fn encode_word(&self, word: &str) -> Vec<u32> {
word.as_bytes()
.iter()
.filter_map(|b| self.char_vocab.get(b).copied())
.collect()
}
}
fn byte_token_string(byte: u8) -> String {
if byte.is_ascii() && !byte.is_ascii_control() {
(byte as char).to_string()
} else {
format!("<0x{byte:02X}>")
}
}
#[cfg(test)]
mod inline_tests {
use super::*;
#[test]
fn byte_token_string_printable() {
assert_eq!(byte_token_string(b'a'), "a");
assert_eq!(byte_token_string(b' '), " ");
assert_eq!(byte_token_string(b'~'), "~");
}
#[test]
fn byte_token_string_control() {
assert_eq!(byte_token_string(0x00), "<0x00>");
assert_eq!(byte_token_string(0x0A), "<0x0A>");
assert_eq!(byte_token_string(0xFF), "<0xFF>");
}
#[test]
fn count_pairs_basic() {
let mut trainer = BpeTrainer::new(TrainerConfig::new(300));
trainer.char_vocab.insert(b'a', 0);
trainer.char_vocab.insert(b'b', 1);
let words = vec![Word::new(vec![0, 1, 0, 1], 3)];
let counts = trainer.count_pairs(&words);
assert_eq!(counts.get(&SymbolPair::new(0, 1)), Some(&6)); }
#[test]
fn apply_merge_replaces_pair() {
let trainer = BpeTrainer::new(TrainerConfig::new(300));
let mut words = vec![Word::new(vec![0, 1, 0, 1], 1)];
trainer.apply_merge(&mut words, &SymbolPair::new(0, 1), 99);
assert_eq!(words[0].symbols, vec![99, 99]);
}
}