1#![allow(clippy::map_entry)]
2
3use super::{Pair, WithFirstLastIterator, Word, BPE};
4use crate::parallelism::*;
5use crate::tokenizer::{AddedToken, Result, Trainer};
6use crate::utils::progress::{ProgressBar, ProgressStyle};
7use serde::{Deserialize, Serialize};
8use std::cmp::Ordering;
9use std::collections::{BinaryHeap, HashMap, HashSet};
10
11#[derive(Debug, Eq)]
12struct Merge {
13 pair: Pair,
14 count: u64,
15 pos: HashSet<usize>,
16}
17impl PartialEq for Merge {
18 fn eq(&self, other: &Self) -> bool {
19 self.count == other.count && self.pair == other.pair
20 }
21}
22impl PartialOrd for Merge {
23 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
24 Some(self.cmp(other))
25 }
26}
27impl Ord for Merge {
28 fn cmp(&self, other: &Self) -> Ordering {
29 if self.count != other.count {
30 self.count.cmp(&other.count)
31 } else {
32 other.pair.cmp(&self.pair)
34 }
35 }
36}
37
38struct Config {
39 min_frequency: u64,
40 vocab_size: usize,
41 show_progress: bool,
42 special_tokens: Vec<AddedToken>,
43 limit_alphabet: Option<usize>,
44 initial_alphabet: HashSet<char>,
45 continuing_subword_prefix: Option<String>,
46 end_of_word_suffix: Option<String>,
47 max_token_length: Option<usize>,
48}
49
50pub struct BpeTrainerBuilder {
53 config: Config,
54}
55
56impl Default for BpeTrainerBuilder {
57 fn default() -> Self {
58 Self {
59 config: Config {
60 min_frequency: 0,
61 vocab_size: 30000,
62 show_progress: true,
63 special_tokens: vec![],
64 limit_alphabet: None,
65 initial_alphabet: HashSet::new(),
66 continuing_subword_prefix: None,
67 end_of_word_suffix: None,
68 max_token_length: None,
69 },
70 }
71 }
72}
73
74impl BpeTrainerBuilder {
75 pub fn new() -> Self {
77 Self::default()
78 }
79
80 #[must_use]
82 pub fn min_frequency(mut self, frequency: u64) -> Self {
83 self.config.min_frequency = frequency;
84 self
85 }
86
87 #[must_use]
89 pub fn vocab_size(mut self, size: usize) -> Self {
90 self.config.vocab_size = size;
91 self
92 }
93
94 #[must_use]
96 pub fn show_progress(mut self, show: bool) -> Self {
97 self.config.show_progress = show;
98 self
99 }
100
101 #[must_use]
103 pub fn special_tokens(mut self, tokens: Vec<AddedToken>) -> Self {
104 self.config.special_tokens = tokens;
105 self
106 }
107
108 #[must_use]
110 pub fn limit_alphabet(mut self, limit: usize) -> Self {
111 self.config.limit_alphabet = Some(limit);
112 self
113 }
114
115 #[must_use]
117 pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
118 self.config.initial_alphabet = alphabet;
119 self
120 }
121
122 #[must_use]
124 pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
125 self.config.continuing_subword_prefix = Some(prefix);
126 self
127 }
128
129 #[must_use]
131 pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
132 self.config.end_of_word_suffix = Some(suffix);
133 self
134 }
135 #[must_use]
137 pub fn max_token_length(mut self, max_token_length: Option<usize>) -> Self {
138 self.config.max_token_length = max_token_length;
139 self
140 }
141
142 pub fn build(self) -> BpeTrainer {
144 BpeTrainer {
145 min_frequency: self.config.min_frequency,
146 vocab_size: self.config.vocab_size,
147 show_progress: self.config.show_progress,
148 special_tokens: self.config.special_tokens,
149 limit_alphabet: self.config.limit_alphabet,
150 initial_alphabet: self.config.initial_alphabet,
151 continuing_subword_prefix: self.config.continuing_subword_prefix,
152 end_of_word_suffix: self.config.end_of_word_suffix,
153 max_token_length: self.config.max_token_length,
154 words: HashMap::new(),
155 }
156 }
157}
158
159#[non_exhaustive]
176#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
177pub struct BpeTrainer {
178 pub min_frequency: u64,
180 pub vocab_size: usize,
182 pub show_progress: bool,
184 pub special_tokens: Vec<AddedToken>,
186 pub limit_alphabet: Option<usize>,
188 pub initial_alphabet: HashSet<char>,
191 pub continuing_subword_prefix: Option<String>,
193 pub end_of_word_suffix: Option<String>,
195 pub max_token_length: Option<usize>,
197
198 words: HashMap<String, u64>,
199}
200
201impl Default for BpeTrainer {
202 fn default() -> Self {
203 Self::builder().build()
204 }
205}
206
207impl BpeTrainer {
208 pub fn new(min_frequency: u64, vocab_size: usize) -> Self {
209 Self {
210 min_frequency,
211 vocab_size,
212 ..Default::default()
213 }
214 }
215
216 pub fn builder() -> BpeTrainerBuilder {
217 BpeTrainerBuilder::new()
218 }
219
220 fn setup_progress(&self) -> Option<ProgressBar> {
222 if self.show_progress {
223 let p = ProgressBar::new(0);
224 p.set_style(
225 ProgressStyle::default_bar()
226 .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}")
227 .expect("Invalid progress template"),
228 );
229 Some(p)
230 } else {
231 None
232 }
233 }
234
235 fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
237 if let Some(p) = p {
238 p.set_length(final_len as u64);
239 p.finish();
240 println!();
241 }
242 }
243
244 fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &'static str) {
246 if let Some(p) = p {
247 p.set_message(message);
248 p.set_length(len as u64);
249 p.reset();
250 }
251 }
252
253 fn add_special_tokens(&self, w2id: &mut HashMap<String, u32>, id2w: &mut Vec<String>) {
255 for token in &self.special_tokens {
256 if !w2id.contains_key(&token.content) {
257 id2w.push(token.content.to_owned());
258 w2id.insert(token.content.to_owned(), (id2w.len() - 1) as u32);
259 }
260 }
261 }
262
263 fn compute_alphabet(
265 &self,
266 wc: &HashMap<String, u64>,
267 w2id: &mut HashMap<String, u32>,
268 id2w: &mut Vec<String>,
269 ) {
270 let mut alphabet: HashMap<char, usize> = HashMap::new();
272 for (word, count) in wc {
273 for c in word.chars() {
274 alphabet
275 .entry(c)
276 .and_modify(|cnt| *cnt += *count as usize)
277 .or_insert(*count as usize);
278 }
279 }
280
281 for c in &self.initial_alphabet {
283 alphabet
284 .entry(*c)
285 .and_modify(|cnt| *cnt = usize::MAX)
286 .or_insert(usize::MAX);
287 }
288
289 let mut kept = alphabet.iter().collect::<Vec<_>>();
290
291 let to_remove = self
295 .limit_alphabet
296 .map(|limit| {
297 if alphabet.len() > limit {
298 alphabet.len() - limit
299 } else {
300 0
301 }
302 })
303 .unwrap_or(0);
304
305 if to_remove > 0 {
307 kept.sort_unstable_by_key(|k| *k.1);
308 kept.drain(..to_remove);
309 }
310
311 kept.sort_unstable_by_key(|k| (*k.0) as u32);
313 kept.into_iter().for_each(|(c, _)| {
314 let s = c.to_string();
315 if !w2id.contains_key(&s) {
316 id2w.push(s.clone());
317 w2id.insert(s, (id2w.len() - 1) as u32);
318 }
319 });
320 }
321
322 fn tokenize_words(
324 &self,
325 wc: &HashMap<String, u64>,
326 w2id: &mut HashMap<String, u32>,
327 id2w: &mut Vec<String>,
328 p: &Option<ProgressBar>,
329 ) -> (Vec<Word>, Vec<u64>) {
330 let mut words: Vec<Word> = Vec::with_capacity(wc.len());
331 let mut counts: Vec<u64> = Vec::with_capacity(wc.len());
332
333 for (word, count) in wc {
334 let mut current_word = Word::new();
335 counts.push(*count);
336
337 for (is_first, is_last, c) in word.chars().with_first_and_last() {
338 let mut s = c.to_string();
339 if w2id.contains_key(&s) {
340 if !is_first {
344 if let Some(prefix) = &self.continuing_subword_prefix {
345 s = format!("{prefix}{s}");
346 }
347 }
348 if is_last {
350 if let Some(suffix) = &self.end_of_word_suffix {
351 s = format!("{s}{suffix}");
352 }
353 }
354
355 if !w2id.contains_key(&s) {
357 id2w.push(s.clone());
358 w2id.insert(s.clone(), (id2w.len() - 1) as u32);
359 }
360 current_word.add(w2id[&s], 1); }
362 }
363 words.push(current_word);
364
365 if let Some(p) = p {
366 p.inc(1);
367 }
368 }
369
370 (words, counts)
371 }
372
373 fn count_pairs(
374 &self,
375 words: &[Word],
376 counts: &[u64],
377 p: &Option<ProgressBar>,
378 ) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
379 words
380 .maybe_par_iter()
381 .enumerate()
382 .map(|(i, word)| {
383 let mut pair_counts = HashMap::new();
384 let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new();
385
386 for window in word.get_chars().windows(2) {
387 let cur_pair: Pair = (window[0], window[1]);
388
389 if !pair_counts.contains_key(&cur_pair) {
391 pair_counts.insert(cur_pair, 0);
392 }
393
394 let count = counts[i];
396 where_to_update
397 .entry(cur_pair)
398 .and_modify(|h| {
399 h.insert(i);
400 })
401 .or_insert_with(|| {
402 let mut h = HashSet::new();
403 h.insert(i);
404 h
405 });
406 *pair_counts.get_mut(&cur_pair).unwrap() += count as i32;
407 }
408
409 if let Some(p) = &p {
410 p.inc(1);
411 }
412
413 (pair_counts, where_to_update)
414 })
415 .reduce(
416 || (HashMap::new(), HashMap::new()),
417 |(mut pair_counts, mut where_to_update), (pc, wtu)| {
418 for (k, v) in pc {
419 pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v);
420 }
421 for (k, v) in wtu {
422 where_to_update
423 .entry(k)
424 .and_modify(|set| *set = set.union(&v).copied().collect())
425 .or_insert(v);
426 }
427 (pair_counts, where_to_update)
428 },
429 )
430 }
431
432 pub fn do_train(
433 &self,
434 word_counts: &HashMap<String, u64>,
435 model: &mut BPE,
436 ) -> Result<Vec<AddedToken>> {
437 let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
438 let mut id_to_word: Vec<String> = Vec::with_capacity(self.vocab_size);
439 let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX);
440
441 let progress = self.setup_progress();
442
443 self.add_special_tokens(&mut word_to_id, &mut id_to_word);
447
448 self.compute_alphabet(word_counts, &mut word_to_id, &mut id_to_word);
452
453 self.update_progress(&progress, word_counts.len(), "Tokenize words");
457 let (mut words, counts) =
458 self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
459 self.finalize_progress(&progress, words.len());
460
461 self.update_progress(&progress, words.len(), "Count pairs");
465 let (mut pair_counts, mut where_to_update) = self.count_pairs(&words, &counts, &progress);
466 let mut queue = BinaryHeap::with_capacity(pair_counts.len());
468 where_to_update.drain().for_each(|(pair, pos)| {
469 let count = pair_counts[&pair];
470 if count > 0 {
471 queue.push(Merge {
472 pair,
473 count: count as u64,
474 pos,
475 });
476 }
477 });
478 self.finalize_progress(&progress, words.len());
479
480 self.update_progress(&progress, self.vocab_size, "Compute merges");
484 let mut merges: Vec<(Pair, u32)> = vec![];
485 loop {
486 if word_to_id.len() >= self.vocab_size {
488 break;
489 }
490
491 if queue.is_empty() {
492 break;
493 }
494
495 let mut top = queue.pop().unwrap();
496 if top.count != pair_counts[&top.pair] as u64 {
497 top.count = pair_counts[&top.pair] as u64;
498 queue.push(top);
499 continue;
500 }
501
502 if top.count < 1 || self.min_frequency > top.count {
503 break;
504 }
505
506 let part_a = &id_to_word[top.pair.0 as usize];
507 let mut part_b = id_to_word[top.pair.1 as usize].to_owned();
508
509 if let Some(prefix) = &self.continuing_subword_prefix {
511 if part_b.starts_with(prefix) {
512 let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum();
513 part_b = part_b[prefix_byte_len..].to_string();
514 }
515 }
516 let new_token = format!("{part_a}{part_b}");
517 let new_token_id = word_to_id
523 .get(&new_token)
524 .copied()
525 .unwrap_or(id_to_word.len() as u32);
526 if !word_to_id.contains_key(&new_token) {
527 id_to_word.push(new_token.clone());
528 word_to_id.insert(new_token.clone(), new_token_id);
529 }
530 merges.push((top.pair, new_token_id));
531
532 let pos: &HashSet<usize> = &top.pos;
536
537 let words_len = words.len();
538 struct WordPtr(*mut Word);
539 unsafe impl Sync for WordPtr {}
542 let word_start = WordPtr(words.as_mut_ptr());
543
544 let changes = pos
545 .maybe_par_iter()
546 .flat_map(|&i| {
547 unsafe {
553 assert!(i < words_len);
554 let word = word_start.0.add(i);
556 (*word)
558 .merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
559 .into_iter()
560 .map(|c| (c, i))
561 .collect::<Vec<_>>()
562 }
563 })
564 .collect::<Vec<_>>();
565
566 for ((pair, change), iw) in changes {
568 let count = change * counts[iw] as i32;
569 pair_counts
570 .entry(pair)
571 .and_modify(|c| *c += count)
572 .or_insert(count);
573 if change > 0 {
574 where_to_update
575 .entry(pair)
576 .and_modify(|h| {
577 h.insert(iw);
578 })
579 .or_insert_with(|| {
580 let mut h = HashSet::new();
581 h.insert(iw);
582 h
583 });
584 }
585 }
586 where_to_update.drain().for_each(|(pair, pos)| {
587 let count = pair_counts[&pair];
588 if count > 0 {
589 queue.push(Merge {
590 pair,
591 count: count as u64,
592 pos,
593 });
594 }
595 });
596
597 if let Some(p) = &progress {
598 p.inc(1);
599 }
600 }
601 self.finalize_progress(&progress, merges.len());
602
603 model.vocab = word_to_id;
605 model.vocab_r = model
606 .vocab
607 .iter()
608 .map(|(key, val)| (*val, key.to_owned()))
609 .collect();
610 model.merges = merges
611 .into_iter()
612 .enumerate()
613 .map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id)))
614 .collect();
615
616 if let Some(prefix) = &self.continuing_subword_prefix {
617 model.continuing_subword_prefix = Some(prefix.to_owned());
618 } else {
619 model.continuing_subword_prefix = None;
620 }
621 if let Some(suffix) = &self.end_of_word_suffix {
622 model.end_of_word_suffix = Some(suffix.to_owned());
623 } else {
624 model.end_of_word_suffix = None;
625 }
626
627 Ok(self.special_tokens.clone())
628 }
629}
630
631impl Trainer for BpeTrainer {
632 type Model = BPE;
633
634 fn train(&self, model: &mut BPE) -> Result<Vec<AddedToken>> {
636 self.do_train(&self.words, model)
637 }
638
639 fn should_show_progress(&self) -> bool {
641 self.show_progress
642 }
643
644 fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
645 where
646 I: Iterator<Item = S> + Send,
647 S: AsRef<str> + Send,
648 F: Fn(&str) -> Result<Vec<String>> + Sync,
649 {
650 let words: Result<HashMap<String, u64>> = iterator
651 .maybe_par_bridge()
652 .map(|sequence| {
653 let words = process(sequence.as_ref())?;
654 let mut map = HashMap::new();
655 for word in words {
656 map.entry(word).and_modify(|c| *c += 1).or_insert(1);
657 }
658 Ok(map)
659 })
660 .reduce(
661 || Ok(HashMap::new()),
662 |acc, ws| {
663 let mut acc = acc?;
664 for (k, v) in ws? {
665 acc.entry(k).and_modify(|c| *c += v).or_insert(v);
666 }
667 Ok(acc)
668 },
669 );
670
671 self.words = words?;
672 Ok(())
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::{BpeTrainer, Pair, BPE};
679 use std::collections::HashMap;
680
681 #[test]
682 fn test_train() {
683 let word_counts: HashMap<String, u64> = [
684 ("roses".into(), 1),
685 ("are".into(), 2),
686 ("red".into(), 1),
687 ("voilets".into(), 1),
688 ("blue".into(), 1),
689 ("BERT".into(), 1),
690 ("is".into(), 2),
691 ("big".into(), 1),
692 ("and".into(), 1),
693 ("so".into(), 1),
694 ("GPT-2".into(), 1),
695 ]
696 .iter()
697 .cloned()
698 .collect();
699 let trainer = BpeTrainer::builder()
700 .show_progress(false)
701 .min_frequency(2)
702 .build();
703 let mut model = BPE::default();
704 trainer.do_train(&word_counts, &mut model).unwrap();
705
706 let expected_vocab: HashMap<String, u32> = [
709 ("-".into(), 0),
710 ("2".into(), 1),
711 ("B".into(), 2),
712 ("E".into(), 3),
713 ("G".into(), 4),
714 ("P".into(), 5),
715 ("R".into(), 6),
716 ("T".into(), 7),
717 ("a".into(), 8),
718 ("b".into(), 9),
719 ("d".into(), 10),
720 ("e".into(), 11),
721 ("g".into(), 12),
722 ("i".into(), 13),
723 ("l".into(), 14),
724 ("n".into(), 15),
725 ("o".into(), 16),
726 ("r".into(), 17),
727 ("s".into(), 18),
728 ("t".into(), 19),
729 ("u".into(), 20),
730 ("v".into(), 21),
731 ("re".into(), 22),
732 ("are".into(), 23),
733 ("is".into(), 24),
734 ]
735 .iter()
736 .cloned()
737 .collect();
738 assert_eq!(model.vocab, expected_vocab);
739
740 let expected_merges: HashMap<Pair, (u32, u32)> = [
745 ((17, 11), (0, 22)), ((8, 22), (1, 23)), ((13, 18), (2, 24)), ]
749 .iter()
750 .cloned()
751 .collect();
752 assert_eq!(model.merges, expected_merges);
753 }
754 #[test]
755 fn bpe_test_max_token_length_16() {
756 let max_token_length = 16;
762 let long_word_counts: HashMap<String, u64> = [
763 ("singlelongtokenwithoutcasechange", 2),
764 ("singleLongTokenWithCamelCaseChange", 2),
765 ("Longsingletokenwithpunctu@t!onwithin", 2),
766 ("Anotherlongsingletokenwithnumberw1th1n", 2),
767 ("짧은한글문자열짧은한", 2), ("긴한글문자열긴한글문자열긴한글문", 2), ("短字符串短字符串短字", 2), ("长字符串长字符串长字符串长字符串", 2), ("短い文字列短い文字列", 2), ("長い文字列長い文字列長い文字列長", 2), ("so", 2),
774 ("GPT-2", 2),
775 ]
776 .iter()
777 .map(|(key, value)| (key.to_string(), *value))
778 .collect();
779 let trainer = BpeTrainer::builder()
780 .max_token_length(Some(max_token_length))
781 .show_progress(false)
782 .min_frequency(0)
783 .build();
784 let mut model = BPE::default();
785 trainer.do_train(&long_word_counts, &mut model).unwrap();
786 let vocab = model.get_vocab();
787 for token in vocab.keys() {
788 assert!(
789 token.chars().count() <= max_token_length,
790 "token too long : {} , chars().count() = {}",
791 token,
792 token.chars().count()
793 )
794 }
795 }
796 #[test]
797 fn bpe_test_max_token_length_direct_assert() {
798 let long_word_counts: HashMap<String, u64> = [
803 ("sin", 2),
804 ("Sin", 2),
805 ("Lon", 2),
806 ("Ano", 2),
807 ("짧은한", 2),
808 ("긴한글", 2),
809 ("短字符", 2),
810 ("长字符", 2),
811 ("短い文", 2),
812 ("長い文", 2),
813 ("so", 2),
814 ("GP", 2),
815 ]
816 .iter()
817 .map(|(key, value)| (key.to_string(), *value))
818 .collect();
819 let trainer = BpeTrainer::builder()
820 .max_token_length(Some(2))
821 .show_progress(false)
822 .min_frequency(0)
823 .build();
824 let mut model = BPE::default();
825 trainer.do_train(&long_word_counts, &mut model).unwrap();
826 let trained_vocab: HashMap<String, u32> = model.get_vocab();
827 let expected_vocab: HashMap<String, u32> = [
828 ("短", 12),
829 ("n", 6),
830 ("i", 5),
831 ("s", 8),
832 ("字符", 23),
833 ("長", 14),
834 ("긴", 17),
835 ("い文", 22),
836 ("L", 2),
837 ("in", 21),
838 ("o", 7),
839 ("은한", 29),
840 ("S", 4),
841 ("P", 3),
842 ("so", 27),
843 ("符", 13),
844 ("文", 11),
845 ("字", 10),
846 ("짧", 19),
847 ("GP", 25),
848 ("글", 16),
849 ("G", 1),
850 ("An", 24),
851 ("长", 15),
852 ("A", 0),
853 ("Lo", 26),
854 ("긴한", 28),
855 ("い", 9),
856 ("한", 20),
857 ("은", 18),
858 ]
859 .iter()
860 .cloned()
861 .map(|(k, v)| (k.to_string(), v))
862 .collect();
863 assert_eq!(trained_vocab, expected_vocab)
864 }
865}