use std::collections::HashMap;
#[derive(Debug, Clone)]
#[derive(Default)]
pub struct BPETokenizer {
pub vocab: HashMap<String, usize>,
pub merges: Vec<(String, String)>,
pub vocab_size: usize,
}
impl BPETokenizer {
pub fn train(corpus: &[String], target_vocab_size: usize) -> Self {
let mut word_freqs: HashMap<Vec<String>, usize> = HashMap::new();
for word in corpus {
let cleaned = word.to_lowercase().trim().to_string();
if cleaned.is_empty() {
continue;
}
let chars: Vec<String> = cleaned.chars().map(|c| c.to_string()).collect();
let mut tokens = chars;
tokens.push("</w>".to_string());
*word_freqs.entry(tokens).or_insert(0) += 1;
}
let mut vocab: HashMap<String, usize> = HashMap::new();
let mut merges: Vec<(String, String)> = Vec::new();
let mut all_tokens: Vec<String> = Vec::new();
for tokens in word_freqs.keys() {
for token in tokens {
all_tokens.push(token.clone());
}
}
for token in all_tokens {
let next_id = vocab.len();
vocab.entry(token).or_insert(next_id);
}
while vocab.len() < target_vocab_size {
let mut pair_counts: HashMap<(String, String), usize> = HashMap::new();
for (tokens, freq) in &word_freqs {
for pair in tokens.windows(2) {
let key = (pair[0].clone(), pair[1].clone());
*pair_counts.entry(key).or_insert(0) += freq;
}
}
if pair_counts.is_empty() {
break;
}
let best_pair = pair_counts
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(pair, _)| pair)
.unwrap();
let merged = format!("{}{}", best_pair.0, best_pair.1);
let next_id = vocab.len();
vocab.insert(merged.clone(), next_id);
merges.push(best_pair.clone());
let mut new_word_freqs: HashMap<Vec<String>, usize> = HashMap::new();
for (tokens, freq) in word_freqs {
let mut new_tokens = Vec::new();
let mut i = 0;
while i < tokens.len() {
if i + 1 < tokens.len()
&& tokens[i] == best_pair.0
&& tokens[i + 1] == best_pair.1
{
new_tokens.push(merged.clone());
i += 2;
} else {
new_tokens.push(tokens[i].clone());
i += 1;
}
}
*new_word_freqs.entry(new_tokens).or_insert(0) += freq;
}
word_freqs = new_word_freqs;
}
let vocab_size = vocab.len();
Self {
vocab,
merges,
vocab_size,
}
}
pub fn encode(&self, text: &str) -> Vec<String> {
let word = text.to_lowercase();
let mut tokens: Vec<String> = word.chars().map(|c| c.to_string()).collect();
tokens.push("</w>".to_string());
for (a, b) in &self.merges {
let merged = format!("{}{}", a, b);
let mut new_tokens = Vec::new();
let mut i = 0;
while i < tokens.len() {
if i + 1 < tokens.len() && &tokens[i] == a && &tokens[i + 1] == b {
new_tokens.push(merged.clone());
i += 2;
} else {
new_tokens.push(tokens[i].clone());
i += 1;
}
}
tokens = new_tokens;
}
tokens
}
pub fn decode(&self, tokens: &[String]) -> String {
let text = tokens.join("");
text.replace("</w>", " ").trim().to_string()
}
}