#![allow(clippy::map_entry)]
use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::tokenizer::{AddedToken, Model, Result, Trainer};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Eq)]
struct Merge {
pair: Pair,
count: u32,
pos: HashSet<usize>,
}
impl PartialEq for Merge {
fn eq(&self, other: &Self) -> bool {
self.count == other.count && self.pair == other.pair
}
}
impl PartialOrd for Merge {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.count != other.count {
Some(self.count.cmp(&other.count))
} else {
Some(other.pair.cmp(&self.pair))
}
}
}
impl Ord for Merge {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
struct Config {
min_frequency: u32,
vocab_size: usize,
show_progress: bool,
special_tokens: Vec<AddedToken>,
limit_alphabet: Option<usize>,
initial_alphabet: HashSet<char>,
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
}
pub struct BpeTrainerBuilder {
config: Config,
}
impl Default for BpeTrainerBuilder {
fn default() -> Self {
Self {
config: Config {
min_frequency: 0,
vocab_size: 30000,
show_progress: true,
special_tokens: vec![],
limit_alphabet: None,
initial_alphabet: HashSet::new(),
continuing_subword_prefix: None,
end_of_word_suffix: None,
},
}
}
}
impl BpeTrainerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn min_frequency(mut self, frequency: u32) -> Self {
self.config.min_frequency = frequency;
self
}
pub fn vocab_size(mut self, size: usize) -> Self {
self.config.vocab_size = size;
self
}
pub fn show_progress(mut self, show: bool) -> Self {
self.config.show_progress = show;
self
}
pub fn special_tokens(mut self, tokens: Vec<AddedToken>) -> Self {
self.config.special_tokens = tokens;
self
}
pub fn limit_alphabet(mut self, limit: usize) -> Self {
self.config.limit_alphabet = Some(limit);
self
}
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
self.config.initial_alphabet = alphabet;
self
}
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
self.config.continuing_subword_prefix = Some(prefix);
self
}
pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
self.config.end_of_word_suffix = Some(suffix);
self
}
pub fn build(self) -> BpeTrainer {
BpeTrainer {
min_frequency: self.config.min_frequency,
vocab_size: self.config.vocab_size,
show_progress: self.config.show_progress,
special_tokens: self.config.special_tokens,
limit_alphabet: self.config.limit_alphabet,
initial_alphabet: self.config.initial_alphabet,
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
}
}
}
pub struct BpeTrainer {
min_frequency: u32,
vocab_size: usize,
show_progress: bool,
special_tokens: Vec<AddedToken>,
limit_alphabet: Option<usize>,
initial_alphabet: HashSet<char>,
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
}
impl Default for BpeTrainer {
fn default() -> Self {
Self::builder().build()
}
}
impl BpeTrainer {
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
Self {
min_frequency,
vocab_size,
..Default::default()
}
}
pub fn builder() -> BpeTrainerBuilder {
BpeTrainerBuilder::new()
}
fn setup_progress(&self) -> Option<ProgressBar> {
if self.show_progress {
let p = ProgressBar::new(0);
p.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
);
Some(p)
} else {
None
}
}
fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
if let Some(p) = p {
p.set_length(final_len as u64);
p.finish();
println!();
}
}
fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &str) {
if let Some(p) = p {
p.set_message(message);
p.set_length(len as u64);
p.reset();
}
}
fn add_special_tokens(&self, w2id: &mut HashMap<String, u32>, id2w: &mut Vec<String>) {
for token in &self.special_tokens {
if !w2id.contains_key(&token.content) {
id2w.push(token.content.to_owned());
w2id.insert(token.content.to_owned(), (id2w.len() - 1) as u32);
}
}
}
fn compute_alphabet(
&self,
wc: &HashMap<String, u32>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
) {
let mut alphabet: HashMap<char, usize> = HashMap::new();
for (word, count) in wc {
for c in word.chars() {
alphabet
.entry(c)
.and_modify(|cnt| *cnt += *count as usize)
.or_insert(*count as usize);
}
}
for c in &self.initial_alphabet {
alphabet
.entry(*c)
.and_modify(|cnt| *cnt = std::usize::MAX)
.or_insert(std::usize::MAX);
}
let mut kept = alphabet.iter().collect::<Vec<_>>();
let to_remove = self
.limit_alphabet
.map(|limit| {
if alphabet.len() > limit {
alphabet.len() - limit
} else {
0
}
})
.unwrap_or(0);
if to_remove > 0 {
kept.sort_unstable_by_key(|k| *k.1);
kept.drain(..to_remove);
}
kept.sort_unstable_by_key(|k| (*k.0) as u32);
kept.into_iter().for_each(|(c, _)| {
let s = c.to_string();
if !w2id.contains_key(&s) {
id2w.push(s.clone());
w2id.insert(s, (id2w.len() - 1) as u32);
}
});
}
fn tokenize_words(
&self,
wc: &HashMap<String, u32>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
p: &Option<ProgressBar>,
) -> (Vec<Word>, Vec<u32>) {
let mut words: Vec<Word> = Vec::with_capacity(wc.len());
let mut counts: Vec<u32> = Vec::with_capacity(wc.len());
for (word, count) in wc {
let mut current_word = Word::new();
counts.push(*count);
for (is_first, is_last, c) in word.chars().with_first_and_last() {
let mut s = c.to_string();
if w2id.contains_key(&s) {
if !is_first {
if let Some(prefix) = &self.continuing_subword_prefix {
s = format!("{}{}", prefix, s);
}
}
if is_last {
if let Some(suffix) = &self.end_of_word_suffix {
s = format!("{}{}", s, suffix);
}
}
if !w2id.contains_key(&s) {
id2w.push(s.clone());
w2id.insert(s.clone(), (id2w.len() - 1) as u32);
}
current_word.add(w2id[&s]);
}
}
words.push(current_word);
if let Some(p) = p {
p.inc(1);
}
}
(words, counts)
}
fn count_pairs(
&self,
words: &[Word],
counts: &[u32],
p: &Option<ProgressBar>,
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
let mut pair_counts: HashMap<Pair, i32> = HashMap::with_capacity(self.vocab_size * 2);
let mut where_to_update: HashMap<Pair, HashSet<usize>> =
HashMap::with_capacity(self.vocab_size * 2);
let n_threads = if words.len() > rayon::current_num_threads() * 5 {
rayon::current_num_threads()
} else {
1
};
let batch = if words.len() % n_threads > 0 {
(words.len() + (n_threads - words.len() % n_threads)) / n_threads
} else {
words.len() / n_threads
};
let results = (0..n_threads)
.into_par_iter()
.map(|n| {
let mut pair_counts = HashMap::new();
let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new();
let mut done = 0;
for (i, word) in words.chunks(batch).nth(n).unwrap().iter().enumerate() {
let index = i + (n * batch);
for window in word.get_chars().windows(2) {
let cur_pair: Pair = (window[0], window[1]);
if !pair_counts.contains_key(&cur_pair) {
pair_counts.insert(cur_pair, 0);
}
let count = counts[index];
where_to_update
.entry(cur_pair)
.and_modify(|h| {
h.insert(index);
})
.or_insert_with(|| {
let mut h = HashSet::new();
h.insert(index);
h
});
*pair_counts.get_mut(&cur_pair).unwrap() += count as i32;
}
done += 1;
if done % 1000 == 0 {
if let Some(p) = &p {
p.inc(1000);
}
}
}
if let Some(p) = &p {
p.inc(done % 1000);
}
(pair_counts, where_to_update)
})
.collect::<Vec<_>>();
results.into_iter().for_each(|(pc, wt)| {
pc.into_iter().for_each(|(p, c)| {
pair_counts
.entry(p)
.and_modify(|count| *count += c)
.or_insert(c);
});
wt.into_iter().for_each(|(p, s)| {
where_to_update
.entry(p)
.and_modify(|set| *set = set.union(&s).copied().collect())
.or_insert(s);
});
});
(pair_counts, where_to_update)
}
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<(BPE, Vec<AddedToken>)> {
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
let mut id_to_word: Vec<String> = Vec::with_capacity(self.vocab_size);
let progress = self.setup_progress();
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
self.compute_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
self.update_progress(&progress, word_counts.len(), "Tokenize words");
let (words, counts) =
self.tokenize_words(&word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.finalize_progress(&progress, words.len());
self.update_progress(&progress, words.len(), "Count pairs");
let (mut pair_counts, mut where_to_update) = self.count_pairs(&words, &counts, &progress);
let mut queue = BinaryHeap::with_capacity(pair_counts.len());
where_to_update.drain().for_each(|(pair, pos)| {
let count = pair_counts[&pair];
if count > 0 {
queue.push(Merge {
pair,
count: count as u32,
pos,
});
}
});
self.finalize_progress(&progress, words.len());
self.update_progress(&progress, self.vocab_size, "Compute merges");
let mut merges: Vec<(Pair, u32)> = vec![];
loop {
if word_to_id.len() >= self.vocab_size {
break;
}
if queue.is_empty() {
break;
}
let mut top = queue.pop().unwrap();
if top.count != pair_counts[&top.pair] as u32 {
top.count = pair_counts[&top.pair] as u32;
queue.push(top);
continue;
}
if top.count < 1 || self.min_frequency > top.count {
break;
}
let part_a = &id_to_word[top.pair.0 as usize];
let mut part_b = id_to_word[top.pair.1 as usize].to_owned();
if let Some(prefix) = &self.continuing_subword_prefix {
if part_b.starts_with(prefix) {
let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum();
part_b = part_b[prefix_byte_len..].to_string();
}
}
let new_token = format!("{}{}", part_a, part_b);
let new_token_id = id_to_word.len() as u32;
id_to_word.push(new_token.clone());
word_to_id.insert(new_token.clone(), new_token_id);
merges.push((top.pair, new_token_id));
let changes = top
.pos
.par_iter()
.flat_map(|i| {
let w = &words[*i] as *const _ as *mut _;
unsafe {
let word: &mut Word = &mut (*w);
word.merge(top.pair.0, top.pair.1, new_token_id)
.into_iter()
.map(|c| (c, *i))
.collect::<Vec<_>>()
}
})
.collect::<Vec<_>>();
for ((pair, change), iw) in changes {
let count = change * counts[iw] as i32;
pair_counts
.entry(pair)
.and_modify(|c| *c += count)
.or_insert(count);
if change > 0 {
where_to_update
.entry(pair)
.and_modify(|h| {
h.insert(iw);
})
.or_insert_with(|| {
let mut h = HashSet::new();
h.insert(iw);
h
});
}
}
where_to_update.drain().for_each(|(pair, pos)| {
let count = pair_counts[&pair];
if count > 0 {
queue.push(Merge {
pair,
count: count as u32,
pos,
});
}
});
if let Some(p) = &progress {
p.inc(1);
}
}
self.finalize_progress(&progress, merges.len());
let mut builder = BPE::builder().vocab_and_merges(
word_to_id,
merges
.into_iter()
.enumerate()
.map(|(index, (pair, new_id))| (pair, (index as u32, new_id)))
.collect(),
);
if let Some(prefix) = &self.continuing_subword_prefix {
builder = builder.continuing_subword_prefix(prefix.to_owned());
}
if let Some(suffix) = &self.end_of_word_suffix {
builder = builder.end_of_word_suffix(suffix.to_owned());
}
Ok((
builder
.build()
.expect("Trainer should know how to build BPE"),
self.special_tokens.clone(),
))
}
}
impl Trainer for BpeTrainer {
fn train(
&self,
word_counts: HashMap<String, u32>,
) -> Result<(Box<dyn Model>, Vec<AddedToken>)> {
let (bpe, tokens) = self.train(word_counts)?;
Ok((Box::new(bpe), tokens))
}
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
for token in tokens {
words
.entry(token.clone())
.and_modify(|c| *c += 1)
.or_insert(1);
}
}
fn should_show_progress(&self) -> bool {
self.show_progress
}
}
#[cfg(test)]
mod tests {
use super::{BpeTrainer, Pair};
use std::collections::HashMap;
#[test]
fn test_train() {
let word_counts: HashMap<String, u32> = [
("roses".into(), 1),
("are".into(), 2),
("red".into(), 1),
("voilets".into(), 1),
("blue".into(), 1),
("BERT".into(), 1),
("is".into(), 2),
("big".into(), 1),
("and".into(), 1),
("so".into(), 1),
("GPT-2".into(), 1),
]
.iter()
.cloned()
.collect();
let trainer = BpeTrainer::builder()
.show_progress(false)
.min_frequency(2)
.build();
let (model, _) = trainer.train(word_counts).unwrap();
let expected_vocab: HashMap<String, u32> = [
("-".into(), 0),
("2".into(), 1),
("B".into(), 2),
("E".into(), 3),
("G".into(), 4),
("P".into(), 5),
("R".into(), 6),
("T".into(), 7),
("a".into(), 8),
("b".into(), 9),
("d".into(), 10),
("e".into(), 11),
("g".into(), 12),
("i".into(), 13),
("l".into(), 14),
("n".into(), 15),
("o".into(), 16),
("r".into(), 17),
("s".into(), 18),
("t".into(), 19),
("u".into(), 20),
("v".into(), 21),
("re".into(), 22),
("are".into(), 23),
("is".into(), 24),
]
.iter()
.cloned()
.collect();
assert_eq!(model.vocab, expected_vocab);
let expected_merges: HashMap<Pair, (u32, u32)> = [
((17, 11), (0, 22)),
((8, 22), (1, 23)),
((13, 18), (2, 24)),
]
.iter()
.cloned()
.collect();
assert_eq!(model.merges, expected_merges);
}
}