use std::collections::HashMap;
#[derive(Clone)]
pub struct CharTokenizer {
pub char_to_idx: Vec<(char, usize)>,
pub idx_to_char: Vec<char>,
pub vocab_size: usize,
}
impl CharTokenizer {
pub fn from_text(text: &str) -> Self {
let mut chars: Vec<char> = text
.chars()
.collect::<std::collections::BTreeSet<_>>()
.into_iter()
.collect();
chars.sort();
let char_to_idx: Vec<(char, usize)> = chars.iter().enumerate().map(|(i, &c)| (c, i)).collect();
let idx_to_char = chars;
let vocab_size = idx_to_char.len();
Self {
char_to_idx,
idx_to_char,
vocab_size,
}
}
pub fn encode(&self, text: &str) -> Vec<usize> {
text.chars()
.map(|c| {
self.char_to_idx
.iter()
.find(|&&(ch, _)| ch == c)
.map(|&(_, idx)| idx)
.unwrap_or(0)
})
.collect()
}
pub fn decode(&self, tokens: &[usize]) -> String {
tokens
.iter()
.map(|&idx| self.idx_to_char.get(idx).copied().unwrap_or('?'))
.collect()
}
}
#[derive(Clone)]
pub struct BpeTokenizer {
merges: Vec<(String, String)>,
token_to_idx: HashMap<String, usize>,
idx_to_token: Vec<String>,
pub vocab_size: usize,
}
impl BpeTokenizer {
pub fn train(text: &str, target_vocab: usize) -> Self {
let mut base_chars: Vec<char> = text
.chars()
.collect::<std::collections::BTreeSet<_>>()
.into_iter()
.collect();
base_chars.sort();
let base_vocab_size = base_chars.len();
let mut token_to_idx: HashMap<String, usize> = HashMap::new();
let mut idx_to_token: Vec<String> = Vec::new();
for (i, &c) in base_chars.iter().enumerate() {
let s = c.to_string();
token_to_idx.insert(s.clone(), i);
idx_to_token.push(s);
}
let mut corpus_tokens: Vec<Vec<String>> = text
.lines()
.map(|line| line.chars().map(|c| c.to_string()).collect())
.collect();
let num_merges = target_vocab.saturating_sub(base_vocab_size);
let mut merges: Vec<(String, String)> = Vec::with_capacity(num_merges);
for _merge_round in 0..num_merges {
let mut pair_counts: HashMap<(String, String), usize> = HashMap::new();
for seq in &corpus_tokens {
for window in seq.windows(2) {
let pair = (window[0].clone(), window[1].clone());
*pair_counts.entry(pair).or_insert(0) += 1;
}
}
let best = pair_counts.into_iter().max_by_key(|&(_, count)| count);
let (best_pair, best_count) = match best {
Some((pair, count)) if count >= 2 => (pair, count),
_ => break, };
let _ = best_count;
let merged = format!("{}{}", best_pair.0, best_pair.1);
let new_idx = idx_to_token.len();
token_to_idx.insert(merged.clone(), new_idx);
idx_to_token.push(merged.clone());
merges.push(best_pair.clone());
for seq in &mut corpus_tokens {
let mut i = 0;
while i + 1 < seq.len() {
if seq[i] == best_pair.0 && seq[i + 1] == best_pair.1 {
seq[i] = merged.clone();
seq.remove(i + 1);
} else {
i += 1;
}
}
}
}
let vocab_size = idx_to_token.len();
println!(
"BPE: {} merges, vocab = {} (base {} + {} merges)",
merges.len(),
vocab_size,
base_vocab_size,
merges.len()
);
Self {
merges,
token_to_idx,
idx_to_token,
vocab_size,
}
}
pub fn encode(&self, text: &str) -> Vec<usize> {
let mut tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
for (a, b) in &self.merges {
let merged = format!("{}{}", a, b);
let mut i = 0;
while i + 1 < tokens.len() {
if tokens[i] == *a && tokens[i + 1] == *b {
tokens[i] = merged.clone();
tokens.remove(i + 1);
} else {
i += 1;
}
}
}
tokens
.iter()
.map(|t| self.token_to_idx.get(t).copied().unwrap_or(0))
.collect()
}
pub fn decode(&self, tokens: &[usize]) -> String {
tokens
.iter()
.map(|&idx| self.idx_to_token.get(idx).map(|s| s.as_str()).unwrap_or("?"))
.collect()
}
}
#[derive(Clone)]
pub struct MiTokenizer {
merges: Vec<(String, String)>,
token_to_idx: HashMap<String, usize>,
idx_to_token: Vec<String>,
pub vocab_size: usize,
}
impl MiTokenizer {
pub fn train(text: &str, target_vocab: usize) -> Self {
let mut base_chars: Vec<char> = text
.chars()
.collect::<std::collections::BTreeSet<_>>()
.into_iter()
.collect();
base_chars.sort();
let mut token_to_idx: HashMap<String, usize> = HashMap::new();
let mut idx_to_token: Vec<String> = Vec::new();
for (i, &c) in base_chars.iter().enumerate() {
let s = c.to_string();
token_to_idx.insert(s.clone(), i);
idx_to_token.push(s);
}
let mut corpus: Vec<Vec<String>> = text
.lines()
.map(|line| line.chars().map(|c| c.to_string()).collect())
.collect();
let mut all_merges: Vec<(String, String)> = Vec::new();
let phi_threshold = (1.618033988_f64).ln();
for round in 0..8 {
let remaining = target_vocab.saturating_sub(idx_to_token.len());
if remaining == 0 {
break;
}
let mut unigram: HashMap<String, usize> = HashMap::new();
let mut bigram: HashMap<(String, String), usize> = HashMap::new();
let mut total: usize = 0;
for seq in &corpus {
total += seq.len();
for tok in seq {
*unigram.entry(tok.clone()).or_default() += 1;
}
for w in seq.windows(2) {
*bigram.entry((w[0].clone(), w[1].clone())).or_default() += 1;
}
}
if total < 2 {
break;
}
let total_f = total as f64;
let mut mi_pairs: Vec<((String, String), f64)> = bigram
.iter()
.filter_map(|((a, b), &count)| {
if count < 2 {
return None;
}
let p_ab = count as f64 / total_f;
let p_a = *unigram.get(a).unwrap_or(&1) as f64 / total_f;
let p_b = *unigram.get(b).unwrap_or(&1) as f64 / total_f;
if p_a == 0.0 || p_b == 0.0 {
return None;
}
let mi = (p_ab / (p_a * p_b)).ln();
if mi > phi_threshold {
Some(((a.clone(), b.clone()), mi))
} else {
None
}
})
.collect();
if mi_pairs.is_empty() {
break;
}
mi_pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let take = mi_pairs.len().min(remaining);
let round_merges: Vec<(String, String)> = mi_pairs[..take].iter().map(|(pair, _)| pair.clone()).collect();
if round_merges.is_empty() {
break;
}
for (a, b) in &round_merges {
let merged = format!("{}{}", a, b);
let new_idx = idx_to_token.len();
token_to_idx.insert(merged.clone(), new_idx);
idx_to_token.push(merged.clone());
all_merges.push((a.clone(), b.clone()));
for seq in &mut corpus {
let mut i = 0;
while i + 1 < seq.len() {
if seq[i] == *a && seq[i + 1] == *b {
seq[i] = merged.clone();
seq.remove(i + 1);
} else {
i += 1;
}
}
}
}
println!(
"MI round {}: {} merges (MI > ln(φ)={:.3}), vocab = {}",
round,
round_merges.len(),
phi_threshold,
idx_to_token.len()
);
}
let vocab_size = idx_to_token.len();
println!(
"MI tokenizer: {} total merges, vocab = {}",
all_merges.len(),
vocab_size
);
Self {
merges: all_merges,
token_to_idx,
idx_to_token,
vocab_size,
}
}
pub fn encode(&self, text: &str) -> Vec<usize> {
let mut tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
for (a, b) in &self.merges {
let merged = format!("{}{}", a, b);
let mut i = 0;
while i + 1 < tokens.len() {
if tokens[i] == *a && tokens[i + 1] == *b {
tokens[i] = merged.clone();
tokens.remove(i + 1);
} else {
i += 1;
}
}
}
tokens
.iter()
.map(|t| self.token_to_idx.get(t).copied().unwrap_or(0))
.collect()
}
pub fn decode(&self, tokens: &[usize]) -> String {
tokens
.iter()
.map(|&idx| self.idx_to_token.get(idx).map(|s| s.as_str()).unwrap_or("?"))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn char_roundtrip() {
let text = "hello world";
let tok = CharTokenizer::from_text(text);
let encoded = tok.encode(text);
let decoded = tok.decode(&encoded);
assert_eq!(decoded, text);
}
#[test]
fn char_vocab_size_correct() {
let tok = CharTokenizer::from_text("abcabc");
assert_eq!(tok.vocab_size, 3);
}
#[test]
fn bpe_trains_and_encodes() {
let text = "abababab cdcdcdcd abababab";
let bpe = BpeTokenizer::train(text, 20);
assert!(bpe.vocab_size > 6, "BPE should have merged some pairs");
let encoded = bpe.encode("abab");
let decoded = bpe.decode(&encoded);
assert_eq!(decoded, "abab");
}
#[test]
fn bpe_roundtrip() {
let text = "the cat sat on the mat the cat sat on the mat";
let bpe = BpeTokenizer::train(text, 30);
let encoded = bpe.encode(text);
let decoded = bpe.decode(&encoded);
assert_eq!(decoded, text);
}
#[test]
fn bpe_compression() {
let text = "aaaa bbbb aaaa bbbb aaaa bbbb";
let bpe = BpeTokenizer::train(text, 20);
let char_len = text.len();
let bpe_len = bpe.encode(text).len();
assert!(bpe_len < char_len, "BPE should compress: {} < {}", bpe_len, char_len);
}
#[test]
fn mi_roundtrip() {
let text = "the cat sat on the mat the cat sat on the mat";
let mi = MiTokenizer::train(text, 30);
let encoded = mi.encode(text);
let decoded = mi.decode(&encoded);
assert_eq!(decoded, text);
}
#[test]
fn mi_compression() {
let text = "aaaa bbbb aaaa bbbb aaaa bbbb cccc dddd cccc dddd";
let mi = MiTokenizer::train(text, 30);
let char_len = text.len();
let mi_len = mi.encode(text).len();
assert!(mi_len < char_len, "MI should compress: {} < {}", mi_len, char_len);
}
#[test]
fn mi_merges_high_mi_pairs() {
let text = "the the the the the the the the the the other this that them then";
let mi = MiTokenizer::train(text, 50);
assert!(
mi.vocab_size > 10,
"MI should have merged pairs, got vocab={}",
mi.vocab_size
);
let encoded = mi.encode("the");
assert!(
encoded.len() < 3,
"\"the\" should be compressed: {} tokens",
encoded.len()
);
}
}