use crate::bpe::BPETokenizer;
use crate::normalizer::Normalizer;
use crate::unigram::UnigramTokenizer;
use crate::wordpiece::WordPieceTokenizer;
use std::collections::HashMap;
use trustformers_core::errors::Result;
use super::config::TrainingConfig;
pub struct BPETrainer {
config: TrainingConfig,
normalizer: Option<Box<dyn Normalizer>>,
}
impl BPETrainer {
pub fn new(config: TrainingConfig) -> Self {
Self {
config,
normalizer: None,
}
}
pub fn with_normalizer(mut self, normalizer: Box<dyn Normalizer>) -> Self {
self.normalizer = Some(normalizer);
self
}
pub fn train(&self, texts: &[String]) -> Result<BPETokenizer> {
let mut word_freqs = HashMap::new();
for text in texts {
let processed_text = if let Some(ref normalizer) = self.normalizer {
normalizer.normalize(text)
} else {
text.clone()
};
for word in processed_text.split_whitespace() {
*word_freqs.entry(word.to_string()).or_insert(0) += 1;
}
}
let mut vocab = HashMap::new();
let mut merge_rules = Vec::new();
for (i, token) in self.config.special_tokens.iter().enumerate() {
vocab.insert(token.clone(), i as u32);
}
let mut next_id = self.config.special_tokens.len() as u32;
let mut char_freqs = HashMap::new();
for (word, freq) in &word_freqs {
for ch in word.chars() {
*char_freqs.entry(ch.to_string()).or_insert(0) += freq;
}
}
for (ch, freq) in char_freqs {
if freq >= self.config.min_frequency {
vocab.insert(ch, next_id);
next_id += 1;
}
}
let mut splits = HashMap::new();
for (word, freq) in word_freqs {
if word.chars().count() <= self.config.max_input_chars_per_word {
let split: Vec<String> = word.chars().map(|c| c.to_string()).collect();
splits.insert(word, (split, freq));
}
}
while vocab.len() < self.config.vocab_size {
let mut pair_freqs = HashMap::new();
for (split, freq) in splits.values() {
for i in 0..split.len().saturating_sub(1) {
let pair = (split[i].clone(), split[i + 1].clone());
*pair_freqs.entry(pair).or_insert(0) += freq;
}
}
if pair_freqs.is_empty() {
break;
}
let best_pair = pair_freqs
.iter()
.max_by_key(|(_, &freq)| freq)
.map(|(pair, _)| pair.clone())
.expect("pair_freqs should be non-empty in training loop");
let merged_token = format!("{}{}", best_pair.0, best_pair.1);
vocab.insert(merged_token, next_id);
next_id += 1;
merge_rules.push(best_pair.clone());
let mut new_splits = HashMap::new();
for (word, (split, freq)) in splits {
let new_split = self.merge_word(&split, &best_pair);
new_splits.insert(word, (new_split, freq));
}
splits = new_splits;
}
Ok(BPETokenizer::new(vocab, merge_rules))
}
fn merge_word(&self, word: &[String], pair: &(String, String)) -> Vec<String> {
let mut new_word = Vec::new();
let mut i = 0;
while i < word.len() {
if i < word.len() - 1 && word[i] == pair.0 && word[i + 1] == pair.1 {
new_word.push(format!("{}{}", pair.0, pair.1));
i += 2;
} else {
new_word.push(word[i].clone());
i += 1;
}
}
new_word
}
}
pub struct WordPieceTrainer {
config: TrainingConfig,
normalizer: Option<Box<dyn Normalizer>>,
}
impl WordPieceTrainer {
pub fn new(config: TrainingConfig) -> Self {
Self {
config,
normalizer: None,
}
}
pub fn with_normalizer(mut self, normalizer: Box<dyn Normalizer>) -> Self {
self.normalizer = Some(normalizer);
self
}
pub fn train(&self, texts: &[String]) -> Result<WordPieceTokenizer> {
let mut word_freqs = HashMap::new();
for text in texts {
let processed_text = if let Some(ref normalizer) = self.normalizer {
normalizer.normalize(text)
} else {
text.clone()
};
for word in processed_text.split_whitespace() {
*word_freqs.entry(word.to_string()).or_insert(0) += 1;
}
}
let mut vocab = HashMap::new();
for (i, token) in self.config.special_tokens.iter().enumerate() {
vocab.insert(token.clone(), i as u32);
}
let mut next_id = self.config.special_tokens.len() as u32;
let mut char_set = std::collections::HashSet::new();
for word in word_freqs.keys() {
for ch in word.chars() {
char_set.insert(ch);
}
}
for ch in char_set {
vocab.insert(ch.to_string(), next_id);
next_id += 1;
}
while vocab.len() < self.config.vocab_size {
let mut subword_scores = HashMap::new();
for (word, freq) in &word_freqs {
let subwords = self.generate_subwords(word, &vocab);
for subword in subwords {
if !vocab.contains_key(&subword) {
let score = self.score_subword(&subword, &word_freqs, &vocab);
*subword_scores.entry(subword).or_insert(0.0) += score * (*freq as f64);
}
}
}
if subword_scores.is_empty() {
break;
}
let best_subword = subword_scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(subword, _)| subword.clone())
.expect("subword_scores should be non-empty");
vocab.insert(best_subword, next_id);
next_id += 1;
}
Ok(WordPieceTokenizer::new(vocab, false))
}
fn generate_subwords(&self, word: &str, vocab: &HashMap<String, u32>) -> Vec<String> {
let mut subwords = Vec::new();
let chars: Vec<char> = word.chars().collect();
for start in 0..chars.len() {
for end in (start + 1)..=chars.len() {
let subword = if start > 0 {
format!(
"{}{}",
self.config.end_of_word_suffix,
chars[start..end].iter().collect::<String>()
)
} else {
chars[start..end].iter().collect::<String>()
};
if subword.len() > 1 && subword.len() <= 10 && !vocab.contains_key(&subword) {
subwords.push(subword);
}
}
}
subwords
}
fn score_subword(
&self,
subword: &str,
word_freqs: &HashMap<String, usize>,
_vocab: &HashMap<String, u32>,
) -> f64 {
let mut score = 0.0;
for word in word_freqs.keys() {
if word.contains(subword.trim_start_matches("##")) {
score += 1.0;
}
}
score * (subword.len() as f64).sqrt()
}
}
pub struct UnigramTrainer {
config: TrainingConfig,
normalizer: Option<Box<dyn Normalizer>>,
shrinking_factor: f64,
num_iterations: usize,
}
impl UnigramTrainer {
pub fn new(config: TrainingConfig) -> Self {
Self {
config,
normalizer: None,
shrinking_factor: 0.75, num_iterations: 8,
}
}
pub fn with_normalizer(mut self, normalizer: Box<dyn Normalizer>) -> Self {
self.normalizer = Some(normalizer);
self
}
pub fn with_shrinking_factor(mut self, factor: f64) -> Self {
self.shrinking_factor = factor;
self
}
pub fn with_iterations(mut self, iterations: usize) -> Self {
self.num_iterations = iterations;
self
}
pub fn train(&self, texts: &[String]) -> Result<UnigramTokenizer> {
let mut word_freqs = HashMap::new();
for text in texts {
let processed_text = if let Some(ref normalizer) = self.normalizer {
normalizer.normalize(text)
} else {
text.clone()
};
for word in processed_text.split_whitespace() {
*word_freqs.entry(word.to_string()).or_insert(0) += 1;
}
}
let mut vocab = self.create_initial_vocabulary(&word_freqs)?;
for _ in 0..self.num_iterations {
vocab = self.prune_vocabulary(vocab, &word_freqs)?;
if vocab.len() <= self.config.vocab_size {
break;
}
}
while vocab.len() > self.config.vocab_size {
vocab = self.prune_vocabulary(vocab, &word_freqs)?;
}
let mut vocab_map = HashMap::new();
let mut scores_map = HashMap::new();
for (i, (token, score)) in vocab.iter().enumerate() {
vocab_map.insert(token.clone(), i as u32);
scores_map.insert(token.clone(), *score as f32);
}
UnigramTokenizer::new(vocab_map, scores_map)
}
fn create_initial_vocabulary(
&self,
word_freqs: &HashMap<String, usize>,
) -> Result<HashMap<String, f64>> {
let mut vocab = HashMap::new();
for token in &self.config.special_tokens {
vocab.insert(token.clone(), 0.0);
}
let mut char_freqs = HashMap::new();
for (word, freq) in word_freqs {
for ch in word.chars() {
*char_freqs.entry(ch.to_string()).or_insert(0) += freq;
}
}
for (ch, freq) in char_freqs {
if freq >= self.config.min_frequency {
vocab.insert(ch, (freq as f64).ln());
}
}
let subword_candidates = self.generate_subword_candidates(word_freqs);
for (subword, score) in subword_candidates {
if vocab.len() >= self.config.vocab_size * 4 {
break; }
vocab.insert(subword, score);
}
Ok(vocab)
}
fn generate_subword_candidates(
&self,
word_freqs: &HashMap<String, usize>,
) -> Vec<(String, f64)> {
let mut subword_counts = HashMap::new();
for (word, freq) in word_freqs {
let chars: Vec<char> = word.chars().collect();
for start in 0..chars.len() {
for end in (start + 1)..=chars.len() {
if end - start > 1 && end - start <= 10 {
let subword = chars[start..end].iter().collect::<String>();
*subword_counts.entry(subword).or_insert(0) += freq;
}
}
}
}
let mut scored_subwords: Vec<_> = subword_counts
.into_iter()
.filter(|(_, freq)| *freq >= self.config.min_frequency)
.map(|(subword, freq)| {
let score = (freq as f64).ln() - (subword.len() as f64) * 0.1;
(subword, score)
})
.collect();
scored_subwords.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored_subwords
}
fn prune_vocabulary(
&self,
mut vocab: HashMap<String, f64>,
word_freqs: &HashMap<String, usize>,
) -> Result<HashMap<String, f64>> {
if vocab.len() <= self.config.vocab_size {
return Ok(vocab);
}
let mut loss_scores = Vec::new();
for token in vocab.keys() {
if self.config.special_tokens.contains(token) {
continue;
}
let loss = self.calculate_removal_loss(token, &vocab, word_freqs);
loss_scores.push((token.clone(), loss));
}
loss_scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let target_size = ((vocab.len() as f64) * self.shrinking_factor)
.max(self.config.vocab_size as f64) as usize;
let tokens_to_remove = vocab.len() - target_size;
for (token, _) in loss_scores.iter().take(tokens_to_remove) {
vocab.remove(token);
}
Ok(vocab)
}
fn calculate_removal_loss(
&self,
token: &str,
vocab: &HashMap<String, f64>,
word_freqs: &HashMap<String, usize>,
) -> f64 {
let mut total_loss = 0.0;
for (word, freq) in word_freqs {
if word.contains(token) {
let token_benefit = vocab.get(token).unwrap_or(&0.0) * (*freq as f64);
total_loss += token_benefit;
}
}
total_loss * (1.0 / (token.len() as f64 + 1.0))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::normalizer::LowercaseNormalizer;
use trustformers_core::traits::Tokenizer;
#[test]
fn test_bpe_training() {
let config = TrainingConfig {
vocab_size: 100,
min_frequency: 1,
special_tokens: vec!["[UNK]".to_string()],
..Default::default()
};
let trainer = BPETrainer::new(config).with_normalizer(Box::new(LowercaseNormalizer));
let texts = vec![
"hello world".to_string(),
"hello there".to_string(),
"world peace".to_string(),
];
let tokenizer = trainer.train(&texts).expect("Operation failed in test");
assert!(tokenizer.vocab_size() > 0);
assert!(tokenizer.vocab_size() <= 100);
let encoded = tokenizer.encode("hello world").expect("Encoding failed");
assert!(!encoded.input_ids.is_empty());
}
#[test]
fn test_wordpiece_training() {
let config = TrainingConfig {
vocab_size: 100,
min_frequency: 1,
special_tokens: vec![
"[UNK]".to_string(),
"[CLS]".to_string(),
"[SEP]".to_string(),
],
..Default::default()
};
let trainer = WordPieceTrainer::new(config);
let texts = vec!["hello world".to_string(), "hello there".to_string()];
let tokenizer = trainer.train(&texts).expect("Operation failed in test");
assert!(tokenizer.vocab_size() > 0);
assert!(tokenizer.vocab_size() <= 100);
}
#[test]
fn test_unigram_training() {
let config = TrainingConfig {
vocab_size: 50,
min_frequency: 1,
special_tokens: vec![
"<unk>".to_string(),
"<s>".to_string(),
"</s>".to_string(),
"<pad>".to_string(),
],
..Default::default()
};
let trainer = UnigramTrainer::new(config).with_shrinking_factor(0.8).with_iterations(5);
let texts = vec![
"hello world".to_string(),
"hello there".to_string(),
"world peace".to_string(),
"hello hello world".to_string(),
];
let tokenizer = trainer.train(&texts).expect("Operation failed in test");
assert!(tokenizer.vocab_size() > 0);
assert!(tokenizer.vocab_size() <= 50);
let encoded = tokenizer.encode("hello world").expect("Encoding failed");
assert!(!encoded.input_ids.is_empty());
}
#[test]
fn test_bpe_merge_word() {
let config = TrainingConfig::default();
let trainer = BPETrainer::new(config);
let word = vec![
"h".to_string(),
"e".to_string(),
"l".to_string(),
"l".to_string(),
"o".to_string(),
];
let pair = ("l".to_string(), "l".to_string());
let merged = trainer.merge_word(&word, &pair);
assert_eq!(merged, vec!["h", "e", "ll", "o"]);
}
#[test]
fn test_wordpiece_subword_generation() {
let config = TrainingConfig::default();
let trainer = WordPieceTrainer::new(config);
let vocab = HashMap::new();
let subwords = trainer.generate_subwords("hello", &vocab);
assert!(!subwords.is_empty());
assert!(subwords.iter().any(|s| s == "he" || s == "##ell" || s == "hello"));
}
#[test]
fn test_trainer_with_normalizer() {
let config = TrainingConfig {
vocab_size: 50,
min_frequency: 1,
..Default::default()
};
let trainer = BPETrainer::new(config).with_normalizer(Box::new(LowercaseNormalizer));
let texts = vec!["Hello World".to_string(), "HELLO WORLD".to_string()];
let tokenizer = trainer.train(&texts).expect("Operation failed in test");
let encoded1 = tokenizer.encode("Hello World").expect("Encoding failed");
let encoded2 = tokenizer.encode("hello world").expect("Encoding failed");
assert!(!encoded1.input_ids.is_empty());
assert!(!encoded2.input_ids.is_empty());
}
}