tokenizers/models/bpe/
trainer.rs

1#![allow(clippy::map_entry)]
2
3use super::{Pair, WithFirstLastIterator, Word, BPE};
4use crate::parallelism::*;
5use crate::tokenizer::{AddedToken, Result, Trainer};
6use crate::utils::progress::{ProgressBar, ProgressStyle};
7use serde::{Deserialize, Serialize};
8use std::cmp::Ordering;
9use std::collections::{BinaryHeap, HashMap, HashSet};
10
11#[derive(Debug, Eq)]
12struct Merge {
13    pair: Pair,
14    count: u64,
15    pos: HashSet<usize>,
16}
17impl PartialEq for Merge {
18    fn eq(&self, other: &Self) -> bool {
19        self.count == other.count && self.pair == other.pair
20    }
21}
22impl PartialOrd for Merge {
23    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
24        Some(self.cmp(other))
25    }
26}
27impl Ord for Merge {
28    fn cmp(&self, other: &Self) -> Ordering {
29        if self.count != other.count {
30            self.count.cmp(&other.count)
31        } else {
32            // Here we want ascending order
33            other.pair.cmp(&self.pair)
34        }
35    }
36}
37
38struct Config {
39    min_frequency: u64,
40    vocab_size: usize,
41    show_progress: bool,
42    special_tokens: Vec<AddedToken>,
43    limit_alphabet: Option<usize>,
44    initial_alphabet: HashSet<char>,
45    continuing_subword_prefix: Option<String>,
46    end_of_word_suffix: Option<String>,
47    max_token_length: Option<usize>,
48}
49
50/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom
51/// configuration.
52pub struct BpeTrainerBuilder {
53    config: Config,
54}
55
56impl Default for BpeTrainerBuilder {
57    fn default() -> Self {
58        Self {
59            config: Config {
60                min_frequency: 0,
61                vocab_size: 30000,
62                show_progress: true,
63                special_tokens: vec![],
64                limit_alphabet: None,
65                initial_alphabet: HashSet::new(),
66                continuing_subword_prefix: None,
67                end_of_word_suffix: None,
68                max_token_length: None,
69            },
70        }
71    }
72}
73
74impl BpeTrainerBuilder {
75    /// Constructs a new `BpeTrainerBuilder`
76    pub fn new() -> Self {
77        Self::default()
78    }
79
80    /// Set the expected minimum frequency
81    #[must_use]
82    pub fn min_frequency(mut self, frequency: u64) -> Self {
83        self.config.min_frequency = frequency;
84        self
85    }
86
87    /// Set the vocabulary size
88    #[must_use]
89    pub fn vocab_size(mut self, size: usize) -> Self {
90        self.config.vocab_size = size;
91        self
92    }
93
94    /// Set whether to show progress
95    #[must_use]
96    pub fn show_progress(mut self, show: bool) -> Self {
97        self.config.show_progress = show;
98        self
99    }
100
101    /// Set the special tokens
102    #[must_use]
103    pub fn special_tokens(mut self, tokens: Vec<AddedToken>) -> Self {
104        self.config.special_tokens = tokens;
105        self
106    }
107
108    /// Set whether to limit the alphabet
109    #[must_use]
110    pub fn limit_alphabet(mut self, limit: usize) -> Self {
111        self.config.limit_alphabet = Some(limit);
112        self
113    }
114
115    /// Set the initial alphabet
116    #[must_use]
117    pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
118        self.config.initial_alphabet = alphabet;
119        self
120    }
121
122    /// Set the continuing_subword_prefix
123    #[must_use]
124    pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
125        self.config.continuing_subword_prefix = Some(prefix);
126        self
127    }
128
129    /// Set the end_of_word_suffix
130    #[must_use]
131    pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
132        self.config.end_of_word_suffix = Some(suffix);
133        self
134    }
135    /// Set max_token_length
136    #[must_use]
137    pub fn max_token_length(mut self, max_token_length: Option<usize>) -> Self {
138        self.config.max_token_length = max_token_length;
139        self
140    }
141
142    /// Constructs the final BpeTrainer
143    pub fn build(self) -> BpeTrainer {
144        BpeTrainer {
145            min_frequency: self.config.min_frequency,
146            vocab_size: self.config.vocab_size,
147            show_progress: self.config.show_progress,
148            special_tokens: self.config.special_tokens,
149            limit_alphabet: self.config.limit_alphabet,
150            initial_alphabet: self.config.initial_alphabet,
151            continuing_subword_prefix: self.config.continuing_subword_prefix,
152            end_of_word_suffix: self.config.end_of_word_suffix,
153            max_token_length: self.config.max_token_length,
154            words: HashMap::new(),
155        }
156    }
157}
158
159/// In charge of training a `BPE` model
160///
161/// # Examples
162///
163/// ```
164/// use tokenizers::tokenizer::Trainer;
165/// use tokenizers::models::bpe::{BPE, BpeTrainer};
166///
167/// let sequences = vec![ "Hello", "World" ];
168///
169/// let mut trainer = BpeTrainer::default();
170/// trainer.feed(sequences.iter(), |s| Ok(vec![s.to_owned()]));
171///
172/// let mut model = BPE::default();
173/// let special_tokens = trainer.train(&mut model).unwrap();
174/// ```
175#[non_exhaustive]
176#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
177pub struct BpeTrainer {
178    /// The minimum frequency a pair must have to produce a merge operation
179    pub min_frequency: u64,
180    /// The target vocabulary size
181    pub vocab_size: usize,
182    /// Whether to show progress while training
183    pub show_progress: bool,
184    /// A list of special tokens that the model should know of
185    pub special_tokens: Vec<AddedToken>,
186    /// Whether to limit the number of initial tokens that can be kept before computing merges
187    pub limit_alphabet: Option<usize>,
188    /// The initial alphabet we want absolutely to include. This allows to cover
189    /// some characters that are not necessarily in the training set
190    pub initial_alphabet: HashSet<char>,
191    /// An optional prefix to use on any subword that exist only behind another one
192    pub continuing_subword_prefix: Option<String>,
193    /// An optional suffix to caracterize and end-of-word subword
194    pub end_of_word_suffix: Option<String>,
195    /// An optional parameter to limit the max length of any single token
196    pub max_token_length: Option<usize>,
197
198    words: HashMap<String, u64>,
199}
200
201impl Default for BpeTrainer {
202    fn default() -> Self {
203        Self::builder().build()
204    }
205}
206
207impl BpeTrainer {
208    pub fn new(min_frequency: u64, vocab_size: usize) -> Self {
209        Self {
210            min_frequency,
211            vocab_size,
212            ..Default::default()
213        }
214    }
215
216    pub fn builder() -> BpeTrainerBuilder {
217        BpeTrainerBuilder::new()
218    }
219
220    /// Setup a progress bar if asked to show progress
221    fn setup_progress(&self) -> Option<ProgressBar> {
222        if self.show_progress {
223            let p = ProgressBar::new(0);
224            p.set_style(
225                ProgressStyle::default_bar()
226                    .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}")
227                    .expect("Invalid progress template"),
228            );
229            Some(p)
230        } else {
231            None
232        }
233    }
234
235    /// Set the progress bar in the finish state
236    fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
237        if let Some(p) = p {
238            p.set_length(final_len as u64);
239            p.finish();
240            println!();
241        }
242    }
243
244    /// Update the progress bar with the new provided length and message
245    fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &'static str) {
246        if let Some(p) = p {
247            p.set_message(message);
248            p.set_length(len as u64);
249            p.reset();
250        }
251    }
252
253    /// Add the provided special tokens to the initial vocabulary
254    fn add_special_tokens(&self, w2id: &mut HashMap<String, u32>, id2w: &mut Vec<String>) {
255        for token in &self.special_tokens {
256            if !w2id.contains_key(&token.content) {
257                id2w.push(token.content.to_owned());
258                w2id.insert(token.content.to_owned(), (id2w.len() - 1) as u32);
259            }
260        }
261    }
262
263    /// Compute the initial alphabet and limit it if relevant
264    fn compute_alphabet(
265        &self,
266        wc: &HashMap<String, u64>,
267        w2id: &mut HashMap<String, u32>,
268        id2w: &mut Vec<String>,
269    ) {
270        // Compute the alphabet from seen words
271        let mut alphabet: HashMap<char, usize> = HashMap::new();
272        for (word, count) in wc {
273            for c in word.chars() {
274                alphabet
275                    .entry(c)
276                    .and_modify(|cnt| *cnt += *count as usize)
277                    .or_insert(*count as usize);
278            }
279        }
280
281        // Also include anything from the provided initial alphabet
282        for c in &self.initial_alphabet {
283            alphabet
284                .entry(*c)
285                .and_modify(|cnt| *cnt = usize::MAX)
286                .or_insert(usize::MAX);
287        }
288
289        let mut kept = alphabet.iter().collect::<Vec<_>>();
290
291        // Compute the number of chars to remove from the alphabet
292        // If `limit_alphabet < initial_alphabet.len()`, some of these initial characters
293        // will be removed
294        let to_remove = self
295            .limit_alphabet
296            .map(|limit| {
297                if alphabet.len() > limit {
298                    alphabet.len() - limit
299                } else {
300                    0
301                }
302            })
303            .unwrap_or(0);
304
305        // Remove the unwanted chars
306        if to_remove > 0 {
307            kept.sort_unstable_by_key(|k| *k.1);
308            kept.drain(..to_remove);
309        }
310
311        // Keep the initial alphabet (sorted for determinism)
312        kept.sort_unstable_by_key(|k| (*k.0) as u32);
313        kept.into_iter().for_each(|(c, _)| {
314            let s = c.to_string();
315            if !w2id.contains_key(&s) {
316                id2w.push(s.clone());
317                w2id.insert(s, (id2w.len() - 1) as u32);
318            }
319        });
320    }
321
322    /// Tokenize words and add subwords to the vocabulary when relevant
323    fn tokenize_words(
324        &self,
325        wc: &HashMap<String, u64>,
326        w2id: &mut HashMap<String, u32>,
327        id2w: &mut Vec<String>,
328        p: &Option<ProgressBar>,
329    ) -> (Vec<Word>, Vec<u64>) {
330        let mut words: Vec<Word> = Vec::with_capacity(wc.len());
331        let mut counts: Vec<u64> = Vec::with_capacity(wc.len());
332
333        for (word, count) in wc {
334            let mut current_word = Word::new();
335            counts.push(*count);
336
337            for (is_first, is_last, c) in word.chars().with_first_and_last() {
338                let mut s = c.to_string();
339                if w2id.contains_key(&s) {
340                    // Found the initial char in the authorized alphabet
341
342                    // Add the `continuing_subword_prefix` if relevant
343                    if !is_first {
344                        if let Some(prefix) = &self.continuing_subword_prefix {
345                            s = format!("{prefix}{s}");
346                        }
347                    }
348                    // Add the `end_of_word_suffix` if relevant
349                    if is_last {
350                        if let Some(suffix) = &self.end_of_word_suffix {
351                            s = format!("{s}{suffix}");
352                        }
353                    }
354
355                    // Insert the new formed string if necessary
356                    if !w2id.contains_key(&s) {
357                        id2w.push(s.clone());
358                        w2id.insert(s.clone(), (id2w.len() - 1) as u32);
359                    }
360                    current_word.add(w2id[&s], 1); // We do not care about the len here
361                }
362            }
363            words.push(current_word);
364
365            if let Some(p) = p {
366                p.inc(1);
367            }
368        }
369
370        (words, counts)
371    }
372
373    fn count_pairs(
374        &self,
375        words: &[Word],
376        counts: &[u64],
377        p: &Option<ProgressBar>,
378    ) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
379        words
380            .maybe_par_iter()
381            .enumerate()
382            .map(|(i, word)| {
383                let mut pair_counts = HashMap::new();
384                let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new();
385
386                for window in word.get_chars().windows(2) {
387                    let cur_pair: Pair = (window[0], window[1]);
388
389                    // Initialize pair_counts and where_to_update for this pair if we just saw it
390                    if !pair_counts.contains_key(&cur_pair) {
391                        pair_counts.insert(cur_pair, 0);
392                    }
393
394                    // Then update counts
395                    let count = counts[i];
396                    where_to_update
397                        .entry(cur_pair)
398                        .and_modify(|h| {
399                            h.insert(i);
400                        })
401                        .or_insert_with(|| {
402                            let mut h = HashSet::new();
403                            h.insert(i);
404                            h
405                        });
406                    *pair_counts.get_mut(&cur_pair).unwrap() += count as i32;
407                }
408
409                if let Some(p) = &p {
410                    p.inc(1);
411                }
412
413                (pair_counts, where_to_update)
414            })
415            .reduce(
416                || (HashMap::new(), HashMap::new()),
417                |(mut pair_counts, mut where_to_update), (pc, wtu)| {
418                    for (k, v) in pc {
419                        pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v);
420                    }
421                    for (k, v) in wtu {
422                        where_to_update
423                            .entry(k)
424                            .and_modify(|set| *set = set.union(&v).copied().collect())
425                            .or_insert(v);
426                    }
427                    (pair_counts, where_to_update)
428                },
429            )
430    }
431
432    pub fn do_train(
433        &self,
434        word_counts: &HashMap<String, u64>,
435        model: &mut BPE,
436    ) -> Result<Vec<AddedToken>> {
437        let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
438        let mut id_to_word: Vec<String> = Vec::with_capacity(self.vocab_size);
439        let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX);
440
441        let progress = self.setup_progress();
442
443        //
444        // 1. Add all special tokens to the vocabulary
445        //
446        self.add_special_tokens(&mut word_to_id, &mut id_to_word);
447
448        //
449        // 2. Compute the initial alphabet
450        //
451        self.compute_alphabet(word_counts, &mut word_to_id, &mut id_to_word);
452
453        //
454        // 3. Tokenize words
455        //
456        self.update_progress(&progress, word_counts.len(), "Tokenize words");
457        let (mut words, counts) =
458            self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
459        self.finalize_progress(&progress, words.len());
460
461        //
462        // 4. Count pairs in words
463        //
464        self.update_progress(&progress, words.len(), "Count pairs");
465        let (mut pair_counts, mut where_to_update) = self.count_pairs(&words, &counts, &progress);
466        // Insert them in the queue
467        let mut queue = BinaryHeap::with_capacity(pair_counts.len());
468        where_to_update.drain().for_each(|(pair, pos)| {
469            let count = pair_counts[&pair];
470            if count > 0 {
471                queue.push(Merge {
472                    pair,
473                    count: count as u64,
474                    pos,
475                });
476            }
477        });
478        self.finalize_progress(&progress, words.len());
479
480        //
481        // 5. Do merges
482        //
483        self.update_progress(&progress, self.vocab_size, "Compute merges");
484        let mut merges: Vec<(Pair, u32)> = vec![];
485        loop {
486            // Stop as soon as we have a big enough vocabulary
487            if word_to_id.len() >= self.vocab_size {
488                break;
489            }
490
491            if queue.is_empty() {
492                break;
493            }
494
495            let mut top = queue.pop().unwrap();
496            if top.count != pair_counts[&top.pair] as u64 {
497                top.count = pair_counts[&top.pair] as u64;
498                queue.push(top);
499                continue;
500            }
501
502            if top.count < 1 || self.min_frequency > top.count {
503                break;
504            }
505
506            let part_a = &id_to_word[top.pair.0 as usize];
507            let mut part_b = id_to_word[top.pair.1 as usize].to_owned();
508
509            // Build new token
510            if let Some(prefix) = &self.continuing_subword_prefix {
511                if part_b.starts_with(prefix) {
512                    let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum();
513                    part_b = part_b[prefix_byte_len..].to_string();
514                }
515            }
516            let new_token = format!("{part_a}{part_b}");
517            // implement sentencepiece-like merge.
518            // if this code were to be merged, integrate a way in the python bindings to communicate this variable
519            // default should be 0/None to maintain previous behavior. 16 is the spm default.
520
521            // Insert new token if it does not already exist
522            let new_token_id = word_to_id
523                .get(&new_token)
524                .copied()
525                .unwrap_or(id_to_word.len() as u32);
526            if !word_to_id.contains_key(&new_token) {
527                id_to_word.push(new_token.clone());
528                word_to_id.insert(new_token.clone(), new_token_id);
529            }
530            merges.push((top.pair, new_token_id));
531
532            // Merge the new pair in every words
533            // Safety: This is just a type assertion, the code below may no longer be safe
534            // if the type of `pos` changes
535            let pos: &HashSet<usize> = &top.pos;
536
537            let words_len = words.len();
538            struct WordPtr(*mut Word);
539            // Safety: We do not actually use this for concurrent access to the same memory,
540            // only to different chunks within the same allocation.
541            unsafe impl Sync for WordPtr {}
542            let word_start = WordPtr(words.as_mut_ptr());
543
544            let changes = pos
545                .maybe_par_iter()
546                .flat_map(|&i| {
547                    // Safety:
548                    // We are producing a valid pointer since we are indexing in bounds
549                    //
550                    // We can access each `word` here in parallel because each position
551                    // can be there only once (pos is a HashSet).
552                    unsafe {
553                        assert!(i < words_len);
554                        // This is words[i], but avoids needing to go through &T (which triggers UB)
555                        let word = word_start.0.add(i);
556                        // let word: &mut Word = &mut (*word);
557                        (*word)
558                            .merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
559                            .into_iter()
560                            .map(|c| (c, i))
561                            .collect::<Vec<_>>()
562                    }
563                })
564                .collect::<Vec<_>>();
565
566            // Introduce new formed pairs
567            for ((pair, change), iw) in changes {
568                let count = change * counts[iw] as i32;
569                pair_counts
570                    .entry(pair)
571                    .and_modify(|c| *c += count)
572                    .or_insert(count);
573                if change > 0 {
574                    where_to_update
575                        .entry(pair)
576                        .and_modify(|h| {
577                            h.insert(iw);
578                        })
579                        .or_insert_with(|| {
580                            let mut h = HashSet::new();
581                            h.insert(iw);
582                            h
583                        });
584                }
585            }
586            where_to_update.drain().for_each(|(pair, pos)| {
587                let count = pair_counts[&pair];
588                if count > 0 {
589                    queue.push(Merge {
590                        pair,
591                        count: count as u64,
592                        pos,
593                    });
594                }
595            });
596
597            if let Some(p) = &progress {
598                p.inc(1);
599            }
600        }
601        self.finalize_progress(&progress, merges.len());
602
603        // Transfer new vocab & options to model
604        model.vocab = word_to_id;
605        model.vocab_r = model
606            .vocab
607            .iter()
608            .map(|(key, val)| (*val, key.to_owned()))
609            .collect();
610        model.merges = merges
611            .into_iter()
612            .enumerate()
613            .map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id)))
614            .collect();
615
616        if let Some(prefix) = &self.continuing_subword_prefix {
617            model.continuing_subword_prefix = Some(prefix.to_owned());
618        } else {
619            model.continuing_subword_prefix = None;
620        }
621        if let Some(suffix) = &self.end_of_word_suffix {
622            model.end_of_word_suffix = Some(suffix.to_owned());
623        } else {
624            model.end_of_word_suffix = None;
625        }
626
627        Ok(self.special_tokens.clone())
628    }
629}
630
631impl Trainer for BpeTrainer {
632    type Model = BPE;
633
634    /// Train a BPE model
635    fn train(&self, model: &mut BPE) -> Result<Vec<AddedToken>> {
636        self.do_train(&self.words, model)
637    }
638
639    /// Whether we should show progress
640    fn should_show_progress(&self) -> bool {
641        self.show_progress
642    }
643
644    fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
645    where
646        I: Iterator<Item = S> + Send,
647        S: AsRef<str> + Send,
648        F: Fn(&str) -> Result<Vec<String>> + Sync,
649    {
650        let words: Result<HashMap<String, u64>> = iterator
651            .maybe_par_bridge()
652            .map(|sequence| {
653                let words = process(sequence.as_ref())?;
654                let mut map = HashMap::new();
655                for word in words {
656                    map.entry(word).and_modify(|c| *c += 1).or_insert(1);
657                }
658                Ok(map)
659            })
660            .reduce(
661                || Ok(HashMap::new()),
662                |acc, ws| {
663                    let mut acc = acc?;
664                    for (k, v) in ws? {
665                        acc.entry(k).and_modify(|c| *c += v).or_insert(v);
666                    }
667                    Ok(acc)
668                },
669            );
670
671        self.words = words?;
672        Ok(())
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use super::{BpeTrainer, Pair, BPE};
679    use std::collections::HashMap;
680
681    #[test]
682    fn test_train() {
683        let word_counts: HashMap<String, u64> = [
684            ("roses".into(), 1),
685            ("are".into(), 2),
686            ("red".into(), 1),
687            ("voilets".into(), 1),
688            ("blue".into(), 1),
689            ("BERT".into(), 1),
690            ("is".into(), 2),
691            ("big".into(), 1),
692            ("and".into(), 1),
693            ("so".into(), 1),
694            ("GPT-2".into(), 1),
695        ]
696        .iter()
697        .cloned()
698        .collect();
699        let trainer = BpeTrainer::builder()
700            .show_progress(false)
701            .min_frequency(2)
702            .build();
703        let mut model = BPE::default();
704        trainer.do_train(&word_counts, &mut model).unwrap();
705
706        // Vocab should contain all of the characters from the `word_counts` mapping
707        // as well as three merges: 're', 'are', and 'is'.
708        let expected_vocab: HashMap<String, u32> = [
709            ("-".into(), 0),
710            ("2".into(), 1),
711            ("B".into(), 2),
712            ("E".into(), 3),
713            ("G".into(), 4),
714            ("P".into(), 5),
715            ("R".into(), 6),
716            ("T".into(), 7),
717            ("a".into(), 8),
718            ("b".into(), 9),
719            ("d".into(), 10),
720            ("e".into(), 11),
721            ("g".into(), 12),
722            ("i".into(), 13),
723            ("l".into(), 14),
724            ("n".into(), 15),
725            ("o".into(), 16),
726            ("r".into(), 17),
727            ("s".into(), 18),
728            ("t".into(), 19),
729            ("u".into(), 20),
730            ("v".into(), 21),
731            ("re".into(), 22),
732            ("are".into(), 23),
733            ("is".into(), 24),
734        ]
735        .iter()
736        .cloned()
737        .collect();
738        assert_eq!(model.vocab, expected_vocab);
739
740        // The keys in `merges` are pairs of symbols, the values are tuples of (rank, id),
741        // where 'rank' determines the order in which this merge will be applied during
742        // tokenization, and 'id' is the vocab id of the symbol resulting from merging
743        // the pair of symbols in the corresponding key.
744        let expected_merges: HashMap<Pair, (u32, u32)> = [
745            ((17, 11), (0, 22)), // 'r' + 'e'  -> 're'
746            ((8, 22), (1, 23)),  // 'a' + 're' -> 'are'
747            ((13, 18), (2, 24)), // 'i' + 's'  -> 'is'
748        ]
749        .iter()
750        .cloned()
751        .collect();
752        assert_eq!(model.merges, expected_merges);
753    }
754    #[test]
755    fn bpe_test_max_token_length_16() {
756        /* bpe_test_max_token_length series of tests test the max_token_length flag of bpetrainer
757        // this is the more robust version that only tests max length of learned tokens
758        // (pre) tokenizer settings or vocab can be easily modified when necessary
759         */
760
761        let max_token_length = 16;
762        let long_word_counts: HashMap<String, u64> = [
763            ("singlelongtokenwithoutcasechange", 2),
764            ("singleLongTokenWithCamelCaseChange", 2),
765            ("Longsingletokenwithpunctu@t!onwithin", 2),
766            ("Anotherlongsingletokenwithnumberw1th1n", 2),
767            ("짧은한글문자열짧은한", 2),             // korean 10 char
768            ("긴한글문자열긴한글문자열긴한글문", 2), // korean 16 char
769            ("短字符串短字符串短字", 2),             //simplified chinese 10 char
770            ("长字符串长字符串长字符串长字符串", 2), // simp. chinese 16 char
771            ("短い文字列短い文字列", 2),             // japanese 10 char
772            ("長い文字列長い文字列長い文字列長", 2), // japanese 16 char
773            ("so", 2),
774            ("GPT-2", 2),
775        ]
776        .iter()
777        .map(|(key, value)| (key.to_string(), *value))
778        .collect();
779        let trainer = BpeTrainer::builder()
780            .max_token_length(Some(max_token_length))
781            .show_progress(false)
782            .min_frequency(0)
783            .build();
784        let mut model = BPE::default();
785        trainer.do_train(&long_word_counts, &mut model).unwrap();
786        let vocab = model.get_vocab();
787        for token in vocab.keys() {
788            assert!(
789                token.chars().count() <= max_token_length,
790                "token too long : {} , chars().count() = {}",
791                token,
792                token.chars().count()
793            )
794        }
795    }
796    #[test]
797    fn bpe_test_max_token_length_direct_assert() {
798        /* more direct version of bpe_test_max_token_length test
799        // directly compares tokens with known expected values.
800        // maybe unstable depending on specific settings or changes.
801         */
802        let long_word_counts: HashMap<String, u64> = [
803            ("sin", 2),
804            ("Sin", 2),
805            ("Lon", 2),
806            ("Ano", 2),
807            ("짧은한", 2),
808            ("긴한글", 2),
809            ("短字符", 2),
810            ("长字符", 2),
811            ("短い文", 2),
812            ("長い文", 2),
813            ("so", 2),
814            ("GP", 2),
815        ]
816        .iter()
817        .map(|(key, value)| (key.to_string(), *value))
818        .collect();
819        let trainer = BpeTrainer::builder()
820            .max_token_length(Some(2))
821            .show_progress(false)
822            .min_frequency(0)
823            .build();
824        let mut model = BPE::default();
825        trainer.do_train(&long_word_counts, &mut model).unwrap();
826        let trained_vocab: HashMap<String, u32> = model.get_vocab();
827        let expected_vocab: HashMap<String, u32> = [
828            ("短", 12),
829            ("n", 6),
830            ("i", 5),
831            ("s", 8),
832            ("字符", 23),
833            ("長", 14),
834            ("긴", 17),
835            ("い文", 22),
836            ("L", 2),
837            ("in", 21),
838            ("o", 7),
839            ("은한", 29),
840            ("S", 4),
841            ("P", 3),
842            ("so", 27),
843            ("符", 13),
844            ("文", 11),
845            ("字", 10),
846            ("짧", 19),
847            ("GP", 25),
848            ("글", 16),
849            ("G", 1),
850            ("An", 24),
851            ("长", 15),
852            ("A", 0),
853            ("Lo", 26),
854            ("긴한", 28),
855            ("い", 9),
856            ("한", 20),
857            ("은", 18),
858        ]
859        .iter()
860        .cloned()
861        .map(|(k, v)| (k.to_string(), v))
862        .collect();
863        assert_eq!(trained_vocab, expected_vocab)
864    }
865}