#![allow(clippy::map_entry)]
use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::parallelism::*;
use crate::tokenizer::{AddedToken, Result, Trainer};
use crate::utils::progress::{ProgressBar, ProgressStyle};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Eq)]
struct Merge {
pair: Pair,
count: u32,
pos: HashSet<usize>,
}
impl PartialEq for Merge {
fn eq(&self, other: &Self) -> bool {
self.count == other.count && self.pair == other.pair
}
}
impl PartialOrd for Merge {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.count != other.count {
Some(self.count.cmp(&other.count))
} else {
Some(other.pair.cmp(&self.pair))
}
}
}
impl Ord for Merge {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
struct Config {
min_frequency: u32,
vocab_size: usize,
show_progress: bool,
special_tokens: Vec<AddedToken>,
limit_alphabet: Option<usize>,
initial_alphabet: HashSet<char>,
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
}
pub struct BpeTrainerBuilder {
config: Config,
}
impl Default for BpeTrainerBuilder {
fn default() -> Self {
Self {
config: Config {
min_frequency: 0,
vocab_size: 30000,
show_progress: true,
special_tokens: vec![],
limit_alphabet: None,
initial_alphabet: HashSet::new(),
continuing_subword_prefix: None,
end_of_word_suffix: None,
},
}
}
}
impl BpeTrainerBuilder {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn min_frequency(mut self, frequency: u32) -> Self {
self.config.min_frequency = frequency;
self
}
#[must_use]
pub fn vocab_size(mut self, size: usize) -> Self {
self.config.vocab_size = size;
self
}
#[must_use]
pub fn show_progress(mut self, show: bool) -> Self {
self.config.show_progress = show;
self
}
#[must_use]
pub fn special_tokens(mut self, tokens: Vec<AddedToken>) -> Self {
self.config.special_tokens = tokens;
self
}
#[must_use]
pub fn limit_alphabet(mut self, limit: usize) -> Self {
self.config.limit_alphabet = Some(limit);
self
}
#[must_use]
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
self.config.initial_alphabet = alphabet;
self
}
#[must_use]
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
self.config.continuing_subword_prefix = Some(prefix);
self
}
#[must_use]
pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
self.config.end_of_word_suffix = Some(suffix);
self
}
pub fn build(self) -> BpeTrainer {
BpeTrainer {
min_frequency: self.config.min_frequency,
vocab_size: self.config.vocab_size,
show_progress: self.config.show_progress,
special_tokens: self.config.special_tokens,
limit_alphabet: self.config.limit_alphabet,
initial_alphabet: self.config.initial_alphabet,
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
words: HashMap::new(),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
pub struct BpeTrainer {
pub min_frequency: u32,
pub vocab_size: usize,
pub show_progress: bool,
pub special_tokens: Vec<AddedToken>,
pub limit_alphabet: Option<usize>,
pub initial_alphabet: HashSet<char>,
pub continuing_subword_prefix: Option<String>,
pub end_of_word_suffix: Option<String>,
words: HashMap<String, u32>,
}
impl Default for BpeTrainer {
fn default() -> Self {
Self::builder().build()
}
}
impl BpeTrainer {
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
Self {
min_frequency,
vocab_size,
..Default::default()
}
}
pub fn builder() -> BpeTrainerBuilder {
BpeTrainerBuilder::new()
}
fn setup_progress(&self) -> Option<ProgressBar> {
if self.show_progress {
let p = ProgressBar::new(0);
p.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
);
Some(p)
} else {
None
}
}
fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
if let Some(p) = p {
p.set_length(final_len as u64);
p.finish();
println!();
}
}
fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &str) {
if let Some(p) = p {
p.set_message(message);
p.set_length(len as u64);
p.set_draw_delta(len as u64 / 100);
p.reset();
}
}
fn add_special_tokens(&self, w2id: &mut HashMap<String, u32>, id2w: &mut Vec<String>) {
for token in &self.special_tokens {
if !w2id.contains_key(&token.content) {
id2w.push(token.content.to_owned());
w2id.insert(token.content.to_owned(), (id2w.len() - 1) as u32);
}
}
}
fn compute_alphabet(
&self,
wc: &HashMap<String, u32>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
) {
let mut alphabet: HashMap<char, usize> = HashMap::new();
for (word, count) in wc {
for c in word.chars() {
alphabet
.entry(c)
.and_modify(|cnt| *cnt += *count as usize)
.or_insert(*count as usize);
}
}
for c in &self.initial_alphabet {
alphabet
.entry(*c)
.and_modify(|cnt| *cnt = std::usize::MAX)
.or_insert(std::usize::MAX);
}
let mut kept = alphabet.iter().collect::<Vec<_>>();
let to_remove = self
.limit_alphabet
.map(|limit| {
if alphabet.len() > limit {
alphabet.len() - limit
} else {
0
}
})
.unwrap_or(0);
if to_remove > 0 {
kept.sort_unstable_by_key(|k| *k.1);
kept.drain(..to_remove);
}
kept.sort_unstable_by_key(|k| (*k.0) as u32);
kept.into_iter().for_each(|(c, _)| {
let s = c.to_string();
if !w2id.contains_key(&s) {
id2w.push(s.clone());
w2id.insert(s, (id2w.len() - 1) as u32);
}
});
}
fn tokenize_words(
&self,
wc: &HashMap<String, u32>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
p: &Option<ProgressBar>,
) -> (Vec<Word>, Vec<u32>) {
let mut words: Vec<Word> = Vec::with_capacity(wc.len());
let mut counts: Vec<u32> = Vec::with_capacity(wc.len());
for (word, count) in wc {
let mut current_word = Word::new();
counts.push(*count);
for (is_first, is_last, c) in word.chars().with_first_and_last() {
let mut s = c.to_string();
if w2id.contains_key(&s) {
if !is_first {
if let Some(prefix) = &self.continuing_subword_prefix {
s = format!("{}{}", prefix, s);
}
}
if is_last {
if let Some(suffix) = &self.end_of_word_suffix {
s = format!("{}{}", s, suffix);
}
}
if !w2id.contains_key(&s) {
id2w.push(s.clone());
w2id.insert(s.clone(), (id2w.len() - 1) as u32);
}
current_word.add(w2id[&s], 1); }
}
words.push(current_word);
if let Some(p) = p {
p.inc(1);
}
}
(words, counts)
}
fn count_pairs(
&self,
words: &[Word],
counts: &[u32],
p: &Option<ProgressBar>,
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
words
.maybe_par_iter()
.enumerate()
.map(|(i, word)| {
let mut pair_counts = HashMap::new();
let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new();
for window in word.get_chars().windows(2) {
let cur_pair: Pair = (window[0], window[1]);
if !pair_counts.contains_key(&cur_pair) {
pair_counts.insert(cur_pair, 0);
}
let count = counts[i];
where_to_update
.entry(cur_pair)
.and_modify(|h| {
h.insert(i);
})
.or_insert_with(|| {
let mut h = HashSet::new();
h.insert(i);
h
});
*pair_counts.get_mut(&cur_pair).unwrap() += count as i32;
}
if let Some(p) = &p {
p.inc(1);
}
(pair_counts, where_to_update)
})
.reduce(
|| (HashMap::new(), HashMap::new()),
|(mut pair_counts, mut where_to_update), (pc, wtu)| {
for (k, v) in pc {
pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v);
}
for (k, v) in wtu {
where_to_update
.entry(k)
.and_modify(|set| *set = set.union(&v).copied().collect())
.or_insert(v);
}
(pair_counts, where_to_update)
},
)
}
pub fn do_train(
&self,
word_counts: &HashMap<String, u32>,
model: &mut BPE,
) -> Result<Vec<AddedToken>> {
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
let mut id_to_word: Vec<String> = Vec::with_capacity(self.vocab_size);
let progress = self.setup_progress();
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
self.compute_alphabet(word_counts, &mut word_to_id, &mut id_to_word);
self.update_progress(&progress, word_counts.len(), "Tokenize words");
let (words, counts) =
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.finalize_progress(&progress, words.len());
self.update_progress(&progress, words.len(), "Count pairs");
let (mut pair_counts, mut where_to_update) = self.count_pairs(&words, &counts, &progress);
let mut queue = BinaryHeap::with_capacity(pair_counts.len());
where_to_update.drain().for_each(|(pair, pos)| {
let count = pair_counts[&pair];
if count > 0 {
queue.push(Merge {
pair,
count: count as u32,
pos,
});
}
});
self.finalize_progress(&progress, words.len());
self.update_progress(&progress, self.vocab_size, "Compute merges");
let mut merges: Vec<(Pair, u32)> = vec![];
loop {
if word_to_id.len() >= self.vocab_size {
break;
}
if queue.is_empty() {
break;
}
let mut top = queue.pop().unwrap();
if top.count != pair_counts[&top.pair] as u32 {
top.count = pair_counts[&top.pair] as u32;
queue.push(top);
continue;
}
if top.count < 1 || self.min_frequency > top.count {
break;
}
let part_a = &id_to_word[top.pair.0 as usize];
let mut part_b = id_to_word[top.pair.1 as usize].to_owned();
if let Some(prefix) = &self.continuing_subword_prefix {
if part_b.starts_with(prefix) {
let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum();
part_b = part_b[prefix_byte_len..].to_string();
}
}
let new_token = format!("{}{}", part_a, part_b);
let new_token_id = word_to_id
.get(&new_token)
.copied()
.unwrap_or(id_to_word.len() as u32);
if word_to_id.get(&new_token).is_none() {
id_to_word.push(new_token.clone());
word_to_id.insert(new_token.clone(), new_token_id);
}
merges.push((top.pair, new_token_id));
let changes = top
.pos
.maybe_par_iter()
.flat_map(|i| {
let w = &words[*i] as *const _ as *mut _;
unsafe {
let word: &mut Word = &mut (*w);
word.merge(top.pair.0, top.pair.1, new_token_id)
.into_iter()
.map(|c| (c, *i))
.collect::<Vec<_>>()
}
})
.collect::<Vec<_>>();
for ((pair, change), iw) in changes {
let count = change * counts[iw] as i32;
pair_counts
.entry(pair)
.and_modify(|c| *c += count)
.or_insert(count);
if change > 0 {
where_to_update
.entry(pair)
.and_modify(|h| {
h.insert(iw);
})
.or_insert_with(|| {
let mut h = HashSet::new();
h.insert(iw);
h
});
}
}
where_to_update.drain().for_each(|(pair, pos)| {
let count = pair_counts[&pair];
if count > 0 {
queue.push(Merge {
pair,
count: count as u32,
pos,
});
}
});
if let Some(p) = &progress {
p.inc(1);
}
}
self.finalize_progress(&progress, merges.len());
model.vocab = word_to_id;
model.vocab_r = model
.vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
model.merges = merges
.into_iter()
.enumerate()
.map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id)))
.collect();
if let Some(prefix) = &self.continuing_subword_prefix {
model.continuing_subword_prefix = Some(prefix.to_owned());
} else {
model.continuing_subword_prefix = None;
}
if let Some(suffix) = &self.end_of_word_suffix {
model.end_of_word_suffix = Some(suffix.to_owned());
} else {
model.end_of_word_suffix = None;
}
Ok(self.special_tokens.clone())
}
}
impl Trainer for BpeTrainer {
type Model = BPE;
fn train(&self, model: &mut BPE) -> Result<Vec<AddedToken>> {
self.do_train(&self.words, model)
}
fn should_show_progress(&self) -> bool {
self.show_progress
}
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
let words: Result<HashMap<String, u32>> = iterator
.maybe_par_bridge()
.map(|sequence| {
let words = process(sequence.as_ref())?;
let mut map = HashMap::new();
for word in words {
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
}
Ok(map)
})
.reduce(
|| Ok(HashMap::new()),
|acc, ws| {
let mut acc = acc?;
for (k, v) in ws? {
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
}
Ok(acc)
},
);
self.words = words?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{BpeTrainer, Pair, BPE};
use std::collections::HashMap;
#[test]
fn test_train() {
let word_counts: HashMap<String, u32> = [
("roses".into(), 1),
("are".into(), 2),
("red".into(), 1),
("voilets".into(), 1),
("blue".into(), 1),
("BERT".into(), 1),
("is".into(), 2),
("big".into(), 1),
("and".into(), 1),
("so".into(), 1),
("GPT-2".into(), 1),
]
.iter()
.cloned()
.collect();
let trainer = BpeTrainer::builder()
.show_progress(false)
.min_frequency(2)
.build();
let mut model = BPE::default();
trainer.do_train(&word_counts, &mut model).unwrap();
let expected_vocab: HashMap<String, u32> = [
("-".into(), 0),
("2".into(), 1),
("B".into(), 2),
("E".into(), 3),
("G".into(), 4),
("P".into(), 5),
("R".into(), 6),
("T".into(), 7),
("a".into(), 8),
("b".into(), 9),
("d".into(), 10),
("e".into(), 11),
("g".into(), 12),
("i".into(), 13),
("l".into(), 14),
("n".into(), 15),
("o".into(), 16),
("r".into(), 17),
("s".into(), 18),
("t".into(), 19),
("u".into(), 20),
("v".into(), 21),
("re".into(), 22),
("are".into(), 23),
("is".into(), 24),
]
.iter()
.cloned()
.collect();
assert_eq!(model.vocab, expected_vocab);
let expected_merges: HashMap<Pair, (u32, u32)> = [
((17, 11), (0, 22)), ((8, 22), (1, 23)), ((13, 18), (2, 24)), ]
.iter()
.cloned()
.collect();
assert_eq!(model.merges, expected_merges);
}
}