tokenizers/models/bpe/
model.rs

1use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word};
2use crate::tokenizer::{Model, Result, Token};
3use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH};
4use crate::utils::iter::ResultShunt;
5use serde_json::Value;
6use std::borrow::Cow;
7use std::{
8    collections::HashMap,
9    fs::File,
10    io::prelude::*,
11    io::{BufRead, BufReader},
12    path::{Path, PathBuf},
13};
14
15pub type Vocab = HashMap<String, u32>;
16type VocabR = HashMap<u32, String>;
17pub type MergeMap = HashMap<Pair, (u32, u32)>;
18pub type Merges = Vec<(String, String)>;
19
20struct Config {
21    files: Option<(String, String)>,
22    vocab: Vocab,
23    merges: Merges,
24    cache_capacity: usize,
25    dropout: Option<f32>,
26    unk_token: Option<String>,
27    continuing_subword_prefix: Option<String>,
28    end_of_word_suffix: Option<String>,
29    fuse_unk: bool,
30    byte_fallback: bool,
31    ignore_merges: bool,
32}
33
34/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
35pub struct BpeBuilder {
36    config: Config,
37}
38
39impl Default for BpeBuilder {
40    fn default() -> Self {
41        Self {
42            config: Config {
43                files: None,
44                vocab: HashMap::new(),
45                merges: vec![],
46                cache_capacity: DEFAULT_CACHE_CAPACITY,
47                dropout: None,
48                unk_token: None,
49                continuing_subword_prefix: None,
50                end_of_word_suffix: None,
51                fuse_unk: false,
52                byte_fallback: false,
53                ignore_merges: false,
54            },
55        }
56    }
57}
58
59impl BpeBuilder {
60    /// Constructs a new `BpeBuilder`.
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    /// Set the input files.
66    #[must_use]
67    pub fn files(mut self, vocab: String, merges: String) -> Self {
68        self.config.files = Some((vocab, merges));
69        self
70    }
71
72    /// Set the vocab (token -> ID) and merges mappings.
73    #[must_use]
74    pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self {
75        self.config.vocab = vocab;
76        self.config.merges = merges;
77        self
78    }
79
80    /// Set the cache's capacity. Set to 0 if you want to disable caching.
81    #[must_use]
82    pub fn cache_capacity(mut self, capacity: usize) -> Self {
83        self.config.cache_capacity = capacity;
84        self
85    }
86
87    /// Use [dropout](https://arxiv.org/abs/1910.13267) with the model.
88    #[must_use]
89    pub fn dropout(mut self, dropout: f32) -> Self {
90        self.config.dropout = Some(dropout);
91        self
92    }
93
94    /// Set the `UNK` token for the vocab.
95    #[must_use]
96    pub fn unk_token(mut self, unk_token: String) -> Self {
97        self.config.unk_token = Some(unk_token);
98        self
99    }
100
101    /// Set the `continuing_subword_prefix` option.
102    #[must_use]
103    pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
104        self.config.continuing_subword_prefix = Some(prefix);
105        self
106    }
107
108    /// Set the `end_of_word_suffix` option.
109    #[must_use]
110    pub fn end_of_word_suffix(mut self, prefix: String) -> Self {
111        self.config.end_of_word_suffix = Some(prefix);
112        self
113    }
114
115    /// Set the `fuse_unk` option.
116    #[must_use]
117    pub fn fuse_unk(mut self, fuse_unk: bool) -> Self {
118        self.config.fuse_unk = fuse_unk;
119        self
120    }
121
122    /// Set the `byte_fallback` option.
123    #[must_use]
124    pub fn byte_fallback(mut self, byte_fallback: bool) -> Self {
125        self.config.byte_fallback = byte_fallback;
126        self
127    }
128    /// Set the `ignore_merges` option.
129    #[must_use]
130    pub fn ignore_merges(mut self, ignore_merges: bool) -> Self {
131        self.config.ignore_merges = ignore_merges;
132        self
133    }
134
135    /// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
136    pub fn build(mut self) -> Result<BPE> {
137        // Validate dropout.
138        if let Some(p) = self.config.dropout {
139            if !(0.0..=1.0).contains(&p) {
140                return Err(Error::InvalidDropout.into());
141            }
142        }
143
144        // Read files if necessary
145        if let Some((vocab, merges)) = self.config.files {
146            let (v, m) = BPE::read_file(&vocab, &merges)?;
147            self.config.vocab = v;
148            self.config.merges = m;
149        }
150
151        let vocab_r = self
152            .config
153            .vocab
154            .iter()
155            .map(|(key, val)| (*val, key.to_owned()))
156            .collect();
157        let cache = match self.config.cache_capacity {
158            0 => None,
159            capacity => Some(Cache::new(capacity)),
160        };
161
162        let vocab = self.config.vocab;
163        let prefix_len = if let Some(prefix) = &self.config.continuing_subword_prefix {
164            prefix.len()
165        } else {
166            0
167        };
168        let merge_map: MergeMap = self
169            .config
170            .merges
171            .into_iter()
172            .enumerate()
173            .map(|(i, (a, b))| -> Result<(Pair, (u32, u32))> {
174                let a_id = vocab
175                    .get(&a)
176                    .ok_or_else(|| Error::MergeTokenOutOfVocabulary(a.to_owned()))?;
177                let b_id = vocab
178                    .get(&b)
179                    .ok_or_else(|| Error::MergeTokenOutOfVocabulary(b.to_owned()))?;
180                let new_token = format!("{}{}", a, &b[prefix_len..]);
181                let new_id = vocab
182                    .get(&new_token)
183                    .ok_or(Error::MergeTokenOutOfVocabulary(new_token))?;
184                Ok(((*a_id, *b_id), (i as u32, *new_id)))
185            })
186            .collect::<Result<MergeMap>>()?;
187
188        // merges.insert(pair, (rank as u32, *new_id));
189
190        Ok(BPE {
191            vocab,
192            vocab_r,
193            merges: merge_map,
194            cache,
195            dropout: self.config.dropout,
196            unk_token: self.config.unk_token,
197            continuing_subword_prefix: self.config.continuing_subword_prefix,
198            end_of_word_suffix: self.config.end_of_word_suffix,
199            fuse_unk: self.config.fuse_unk,
200            byte_fallback: self.config.byte_fallback,
201            ignore_merges: self.config.ignore_merges,
202        })
203    }
204}
205
206/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
207#[derive(PartialEq)]
208pub struct BPE {
209    /// The vocabulary assigns a number to each token.
210    pub(crate) vocab: Vocab,
211    /// Reversed vocabulary, to rebuild sentences.
212    pub(crate) vocab_r: VocabR,
213    /// Contains the mapping between Pairs and their (rank, new_id).
214    pub(crate) merges: MergeMap,
215    /// Contains the cache for optimizing the encoding step.
216    cache: Option<Cache<String, Word>>,
217    /// Dropout probability for merges. 0.0 = no dropout is the default. At 1.0, tokenization will
218    /// perform no merges, so the result will just be characters.
219    pub dropout: Option<f32>,
220    /// The unknown token to be used when we encounter an unknown char
221    pub unk_token: Option<String>,
222    /// An optional prefix to use on any subword that exist only behind another one
223    pub continuing_subword_prefix: Option<String>,
224    /// An optional suffix to caracterize and end-of-word subword
225    pub end_of_word_suffix: Option<String>,
226    /// Do multiple unk tokens get fused
227    pub fuse_unk: bool,
228    /// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"`
229    /// for each byte in the unk token
230    pub byte_fallback: bool,
231    /// Whether or not to direct output words if they are part of the vocab.
232    pub ignore_merges: bool,
233}
234
235impl std::fmt::Debug for BPE {
236    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
237        fmt.debug_struct("BPE")
238            .field("dropout", &self.dropout)
239            .field("unk_token", &self.unk_token)
240            .field("continuing_subword_prefix", &self.continuing_subword_prefix)
241            .field("end_of_word_suffix", &self.end_of_word_suffix)
242            .field("fuse_unk", &self.fuse_unk)
243            .field("byte_fallback", &self.byte_fallback)
244            .field("vocab", &self.vocab.len())
245            .field("merges", &self.merges.len())
246            .field("ignore_merges", &self.ignore_merges)
247            .finish()
248    }
249}
250
251impl Default for BPE {
252    fn default() -> Self {
253        Self::builder().build().unwrap()
254    }
255}
256
257impl Clone for BPE {
258    // `Clone` can't be derive because it's not implemented for `Cache`.
259    // To keep things simple when we clone, the new BPE will start with a fresh cache.
260    fn clone(&self) -> Self {
261        let fresh_cache = self.cache.as_ref().map(|cache| cache.fresh());
262        Self {
263            vocab: self.vocab.clone(),
264            vocab_r: self.vocab_r.clone(),
265            merges: self.merges.clone(),
266            cache: fresh_cache,
267            dropout: self.dropout,
268            unk_token: self.unk_token.clone(),
269            continuing_subword_prefix: self.continuing_subword_prefix.clone(),
270            end_of_word_suffix: self.end_of_word_suffix.clone(),
271            fuse_unk: self.fuse_unk,
272            byte_fallback: self.byte_fallback,
273            ignore_merges: self.ignore_merges,
274        }
275    }
276}
277
278/// Converts the merges strings (for example from `merges.txt` file) with the format
279/// "{pair_a} {pair_b}" into the format expected by the BPE struct
280pub(crate) fn convert_merges_to_hashmap<I: Iterator<Item = String>>(
281    iter: I,
282    _vocab: &Vocab,
283) -> Result<Merges> {
284    let mut merges = vec![];
285
286    let lines = iter.filter(|l| !l.starts_with("#version"));
287    for (rank, line) in lines.enumerate() {
288        let parts = line.split(' ').collect::<Vec<_>>();
289        if parts.len() != 2 {
290            return Err(Error::BadMerges(rank + 1).into());
291        }
292
293        merges.push((parts[0].to_string(), parts[1].to_string()));
294    }
295
296    Ok(merges)
297}
298
299impl BPE {
300    /// Initialize a `BpeBuilder`.
301    pub fn builder() -> BpeBuilder {
302        BpeBuilder::new()
303    }
304
305    /// Create a new BPE model with the given vocab and merges.
306    pub fn new(vocab: Vocab, merges: Merges) -> Self {
307        Self::builder()
308            .vocab_and_merges(vocab, merges)
309            .build()
310            .unwrap()
311    }
312
313    /// Initialize a BpeBuilder model from vocab and merges files
314    pub fn from_file(vocab: &str, merges: &str) -> BpeBuilder {
315        Self::builder().files(vocab.to_owned(), merges.to_owned())
316    }
317
318    /// Read the given files to extract the vocab and merges
319    pub fn read_file(vocab: &str, merges: &str) -> Result<(Vocab, Merges)> {
320        // Read vocab.json
321        let vocab_file = File::open(vocab)?;
322        let mut vocab_file = BufReader::new(vocab_file);
323
324        let mut buffer = String::new();
325        vocab_file.read_to_string(&mut buffer)?;
326        let json: Value = serde_json::from_str(&buffer)?;
327        let mut vocab = HashMap::new();
328        match json {
329            Value::Object(m) => {
330                for (token, id) in m {
331                    if let Value::Number(id) = id {
332                        let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32;
333                        vocab.insert(token, id);
334                    }
335                }
336            }
337            _ => return Err(Box::new(Error::BadVocabulary)),
338        };
339
340        // Read merges file
341        let merge_file = File::open(merges)?;
342        let merge_file = BufReader::new(merge_file);
343        let merges = ResultShunt::process(merge_file.lines(), |iter| {
344            convert_merges_to_hashmap(iter, &vocab)
345        })??;
346
347        Ok((vocab, merges))
348    }
349
350    /// Reset the cache.
351    pub fn clear_cache(&self) {
352        if let Some(ref cache) = self.cache {
353            cache.clear()
354        }
355    }
356
357    /// Resize the cache
358    pub fn resize_cache(&mut self, capacity: usize) {
359        if let Some(ref mut cache) = self.cache {
360            cache.resize(capacity);
361        }
362    }
363
364    pub fn get_vocab(&self) -> Vocab {
365        self.vocab.clone()
366    }
367
368    pub fn get_unk_token(&self) -> &Option<String> {
369        &self.unk_token
370    }
371
372    pub fn get_continuing_subword_prefix(&self) -> &Option<String> {
373        &self.continuing_subword_prefix
374    }
375
376    fn merge_word(&self, w: &str) -> Result<Word> {
377        let mut indices = w.char_indices().map(|(idx, _)| idx).peekable();
378        let mut word = Word::with_capacity(w.len());
379        let mut unk: Option<(u32, usize)> = None;
380        while let Some(i) = indices.next() {
381            let end = indices.peek();
382            let is_first = i == 0;
383            let is_last = end.is_none();
384
385            let mut s = if let Some(e) = end {
386                Cow::Borrowed(&w[i..*e])
387            } else {
388                Cow::Borrowed(&w[i..])
389            };
390            let byte_len = s.len();
391
392            // Add the `continuing_subword_prefix` if relevant
393            if !is_first {
394                if let Some(ref prefix) = self.continuing_subword_prefix {
395                    s = format!("{prefix}{s}").into()
396                }
397            }
398            // Add the `end_of_word_suffix` if relevant
399            if is_last {
400                if let Some(ref suffix) = self.end_of_word_suffix {
401                    s = format!("{s}{suffix}").into()
402                }
403            }
404
405            if let Some(id) = self.vocab.get(s.as_ref()) {
406                if let Some((unk_id, unk_len)) = unk {
407                    word.add(unk_id, unk_len);
408                    unk = None;
409                }
410                word.add(*id, byte_len);
411            } else {
412                if self.byte_fallback {
413                    let tokens: Option<Vec<_>> = s
414                        .bytes()
415                        .map(|b| -> Option<&u32> {
416                            let code = format!("<{b:#04X}>");
417
418                            self.vocab.get(&code)
419                        })
420                        .collect();
421                    if let Some(tokens) = tokens {
422                        for t in tokens {
423                            word.add(*t, 1);
424                        }
425                        continue;
426                    }
427                }
428                if let Some(unk_token) = &self.unk_token {
429                    unk = match (unk, self.fuse_unk) {
430                        (Some((unk_id, unk_len)), true) => {
431                            // Fuse unk
432                            Some((unk_id, unk_len + byte_len))
433                        }
434                        (Some((unk_id, unk_len)), false) => {
435                            // Do not fuse unk, add the previous one
436                            word.add(unk_id, unk_len);
437                            Some((
438                                *self.vocab.get(unk_token).ok_or_else(|| {
439                                    Error::UnkTokenOutOfVocabulary(unk_token.to_owned())
440                                })?,
441                                byte_len,
442                            ))
443                        }
444                        _ => Some((
445                            *self.vocab.get(unk_token).ok_or_else(|| {
446                                Error::UnkTokenOutOfVocabulary(unk_token.to_owned())
447                            })?,
448                            byte_len,
449                        )),
450                    };
451                }
452            }
453        }
454        if let Some((unk_id, unk_len)) = unk {
455            word.add(unk_id, unk_len);
456        }
457
458        word.merge_all(&self.merges, self.dropout);
459
460        Ok(word)
461    }
462
463    fn word_to_tokens<'a, 'b: 'a>(&'a self, word: &'b Word) -> impl Iterator<Item = Token> + 'a {
464        word.get_chars_iter()
465            .zip(word.get_offsets_iter())
466            .map(move |(id, offsets)| Token::new(id, self.vocab_r[&id].clone(), offsets))
467    }
468
469    fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> {
470        if self.ignore_merges {
471            if let Some(id) = self.vocab.get(sequence) {
472                return Ok(vec![Token::new(
473                    *id,
474                    sequence.to_string().clone(),
475                    (0, sequence.len()),
476                )]);
477            }
478        }
479        if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) {
480            return Ok(self.word_to_tokens(hit).collect());
481        }
482        let word = self.merge_word(sequence)?;
483        let ret = self.word_to_tokens(&word).collect();
484        if let Some(ref cache) = self.cache {
485            if sequence.len() < MAX_LENGTH {
486                cache.set(sequence.to_owned(), word);
487            }
488        }
489        Ok(ret)
490    }
491}
492
493impl Model for BPE {
494    type Trainer = BpeTrainer;
495
496    fn get_vocab(&self) -> HashMap<String, u32> {
497        self.vocab.clone()
498    }
499
500    fn get_vocab_size(&self) -> usize {
501        self.vocab.len()
502    }
503
504    fn tokenize(&self, sequence: &str) -> Result<Vec<Token>> {
505        if sequence.is_empty() {
506            return Ok(vec![]);
507        }
508
509        if self.dropout.is_none() || self.dropout == Some(0.0) {
510            self.tokenize_with_cache(sequence)
511        } else {
512            let word = self.merge_word(sequence)?;
513            Ok(self.word_to_tokens(&word).collect())
514        }
515    }
516
517    fn token_to_id(&self, token: &str) -> Option<u32> {
518        self.vocab.get(token).copied()
519    }
520
521    fn id_to_token(&self, id: u32) -> Option<String> {
522        self.vocab_r.get(&id).cloned()
523    }
524
525    fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
526        let vocab_file_name = match name {
527            Some(name) => format!("{name}-vocab.json"),
528            None => "vocab.json".to_string(),
529        };
530
531        // Write vocab.json
532        let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
533            .iter()
534            .collect();
535        let mut vocab_file = File::create(&vocab_path)?;
536        let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
537        let serialized = serde_json::to_string(&order_vocab_iter)?;
538        vocab_file.write_all(serialized.as_bytes())?;
539
540        // Write merges.txt
541        let merges_file_name = match name {
542            Some(name) => format!("{name}-merges.txt"),
543            None => "merges.txt".to_string(),
544        };
545
546        let merges_path: PathBuf = [folder, Path::new(merges_file_name.as_str())]
547            .iter()
548            .collect();
549        let mut merges_file = File::create(&merges_path)?;
550        let mut merges: Vec<(&Pair, &u32)> = self
551            .merges
552            .iter()
553            .map(|(pair, (rank, _))| (pair, rank))
554            .collect();
555        merges.sort_unstable_by_key(|k| *k.1);
556        merges_file.write_all(b"#version: 0.2\n")?;
557        merges_file.write_all(
558            &merges
559                .into_iter()
560                .flat_map(|(pair, _)| {
561                    format!("{} {}\n", self.vocab_r[&pair.0], self.vocab_r[&pair.1]).into_bytes()
562                })
563                .collect::<Vec<_>>()[..],
564        )?;
565
566        Ok(vec![vocab_path, merges_path])
567    }
568
569    fn get_trainer(&self) -> BpeTrainer {
570        BpeTrainer::default()
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use tempfile::NamedTempFile;
578
579    #[test]
580    fn test_ordered_vocab_iter() {
581        let vocab_r: VocabR = [
582            (0, "a".into()),
583            (1, "b".into()),
584            (2, "c".into()),
585            (3, "ab".into()),
586        ]
587        .iter()
588        .cloned()
589        .collect();
590        let order_vocab_iter = OrderedVocabIter::new(&vocab_r);
591        let serialized = serde_json::to_string(&order_vocab_iter).unwrap();
592        assert_eq!(serialized, "{\"a\":0,\"b\":1,\"c\":2,\"ab\":3}");
593    }
594
595    #[test]
596    fn test_unk_not_fused() {
597        let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
598            .iter()
599            .cloned()
600            .collect();
601        let bpe = BpeBuilder::default()
602            .vocab_and_merges(vocab, vec![])
603            .unk_token("<unk>".to_string())
604            .build()
605            .unwrap();
606        let tokens = bpe.tokenize("c").unwrap();
607        assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
608
609        let tokens = bpe.tokenize("cc").unwrap();
610        assert_eq!(
611            tokens,
612            vec![
613                Token::new(0u32, "<unk>".into(), (0, 1)),
614                Token::new(0u32, "<unk>".into(), (1, 2)),
615            ]
616        );
617
618        let tokens = bpe.tokenize("accb").unwrap();
619        assert_eq!(
620            tokens,
621            vec![
622                Token::new(1u32, "a".into(), (0, 1)),
623                Token::new(0u32, "<unk>".into(), (1, 2)),
624                Token::new(0u32, "<unk>".into(), (2, 3)),
625                Token::new(2u32, "b".into(), (3, 4)),
626            ]
627        );
628    }
629    #[test]
630    fn test_unk_get_fused() {
631        let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
632            .iter()
633            .cloned()
634            .collect();
635        let bpe = BpeBuilder::default()
636            .vocab_and_merges(vocab, vec![])
637            .unk_token("<unk>".to_string())
638            .fuse_unk(true)
639            .build()
640            .unwrap();
641        let tokens = bpe.tokenize("c").unwrap();
642        assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
643
644        let tokens = bpe.tokenize("cc").unwrap();
645        assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 2)),]);
646
647        let tokens = bpe.tokenize("accb").unwrap();
648        assert_eq!(
649            tokens,
650            vec![
651                Token::new(1u32, "a".into(), (0, 1)),
652                Token::new(0u32, "<unk>".into(), (1, 3)),
653                Token::new(2u32, "b".into(), (3, 4)),
654            ]
655        );
656    }
657
658    #[test]
659    // Test tokenization. With dropout set to 0 tokenization is deterministic,
660    // so we know exactly what the result should be.
661    //
662    // To test this, we'll build a simple model to tokenize the word 'unrelated'.
663    fn test_tokenize_with_and_without_dropout() {
664        let vocab: Vocab = [
665            ("u".into(), 0),
666            ("n".into(), 1),
667            ("r".into(), 2),
668            ("e".into(), 3),
669            ("l".into(), 4),
670            ("a".into(), 5),
671            ("t".into(), 6),
672            ("d".into(), 7),
673            ("re".into(), 8),
674            ("at".into(), 9),
675            ("ed".into(), 10),
676            ("un".into(), 11),
677            ("ated".into(), 12),
678            ("rel".into(), 13),
679            ("related".into(), 14),
680            ("unrelated".into(), 15),
681        ]
682        .iter()
683        .cloned()
684        .collect();
685        let merges: Merges = vec![
686            ("r".to_string(), "e".to_string()),
687            ("a".to_string(), "t".to_string()),
688            ("e".to_string(), "d".to_string()),
689            ("u".to_string(), "n".to_string()),
690            ("at".to_string(), "ed".to_string()),
691            ("re".to_string(), "l".to_string()),
692            ("rel".to_string(), "ated".to_string()),
693            ("un".to_string(), "related".to_string()),
694        ];
695        let mut bpe = BPE::new(vocab, merges);
696
697        // With no dropout:
698        let tokens = bpe.tokenize("unrelated").unwrap();
699        assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);
700
701        // With dropout = 0.0 (equivalent to dropout == none)
702        bpe.dropout = Some(0.0);
703        let tokens = bpe.tokenize("unrelated").unwrap();
704        assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);
705
706        // Now set dropout to 1.0. Result should be no merges performed.
707        bpe.dropout = Some(1.0);
708        let tokens = bpe.tokenize("unrelated").unwrap();
709        assert_eq!(
710            tokens,
711            vec![
712                Token::new(0u32, "u".into(), (0, 1)),
713                Token::new(1u32, "n".into(), (1, 2)),
714                Token::new(2u32, "r".into(), (2, 3)),
715                Token::new(3u32, "e".into(), (3, 4)),
716                Token::new(4u32, "l".into(), (4, 5)),
717                Token::new(5u32, "a".into(), (5, 6)),
718                Token::new(6u32, "t".into(), (6, 7)),
719                Token::new(3u32, "e".into(), (7, 8)),
720                Token::new(7u32, "d".into(), (8, 9)),
721            ]
722        );
723
724        // Now try with dropout between 0 and 1.
725        bpe.dropout = Some(0.5);
726        let tokens = bpe.tokenize("unrelated").unwrap();
727        assert!(!tokens.is_empty() && tokens.len() <= 9);
728    }
729
730    #[test]
731    // Ensure `BPE::from_file` works as expected.
732    fn test_bpe_from_file() {
733        // Set up vocab file.
734        let mut vocab_file = NamedTempFile::new().unwrap();
735        vocab_file
736            .write_all(b"{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}")
737            .unwrap();
738
739        // Set up merges file.
740        let mut merges_file = NamedTempFile::new().unwrap();
741        merges_file.write_all(b"#version: 0.2\na b").unwrap();
742
743        // Make sure we can instantiate a BPE model from the files.
744        let builder = BPE::from_file(
745            vocab_file.path().to_str().unwrap(),
746            merges_file.path().to_str().unwrap(),
747        );
748        let bpe = builder.build().unwrap();
749
750        // Check merges.
751        assert_eq!(bpe.merges.get(&(0, 1)).unwrap(), &(0u32, 3u32));
752
753        // Check vocab.
754        assert_eq!(bpe.vocab.get("a").unwrap(), &0u32);
755        assert_eq!(bpe.vocab.get("b").unwrap(), &1u32);
756        assert_eq!(bpe.vocab.get("c").unwrap(), &2u32);
757        assert_eq!(bpe.vocab.get("ab").unwrap(), &3u32);
758    }
759
760    #[test]
761    // Ensure BPEBuilder with dropout = 0.0 doesn't error
762    fn test_bpe_with_dropout_0() {
763        let bpe = BPE::builder().dropout(0.0).build().unwrap();
764        assert_eq!(bpe.dropout, Some(0.0));
765    }
766
767    #[test]
768    // Ensure `BPE::from_file` works as expected.
769    fn test_bpe_with_continuing_subword_prefix() {
770        let vocab: Vocab = vec![
771            ("a".to_string(), 0),
772            ("##b".to_string(), 1),
773            ("##c".to_string(), 2),
774            ("ab".to_string(), 3),
775            ("abc".to_string(), 4),
776        ]
777        .into_iter()
778        .collect();
779
780        let merges = vec![
781            ("a".to_string(), "##b".to_string()),
782            ("ab".to_string(), "##c".to_string()),
783        ];
784
785        let bpe = BPE::builder()
786            .vocab_and_merges(vocab, merges)
787            .unk_token("[UNK]".to_string())
788            .continuing_subword_prefix("##".to_string())
789            .build()
790            .unwrap();
791
792        let res = bpe.tokenize("ab");
793        assert_eq!(
794            res.unwrap(),
795            vec![Token {
796                id: 3,
797                value: "ab".to_string(),
798                offsets: (0, 2)
799            }]
800        );
801        let res = bpe.tokenize("abc");
802        assert_eq!(
803            res.unwrap(),
804            vec![Token {
805                id: 4,
806                value: "abc".to_string(),
807                offsets: (0, 3)
808            }]
809        );
810    }
811
812    #[test]
813    // Ensure `MergeTokenOutOfVocabulary` error is returned when it should be.
814    fn test_bpe_from_file_merge_token_oov() {
815        // Set up vocab file.
816        let mut vocab_file = NamedTempFile::new().unwrap();
817        vocab_file
818            .write_all(b"{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}")
819            .unwrap();
820
821        // Set up merges file.
822        let mut merges_file = NamedTempFile::new().unwrap();
823        merges_file.write_all(b"#version: 0.2\na b\na d").unwrap();
824
825        // Ensure the result of BPE::from_file is a MergeTokenOutOfVocabulary error.
826        match BPE::from_file(
827            vocab_file.path().to_str().unwrap(),
828            merges_file.path().to_str().unwrap(),
829        )
830        .build()
831        {
832            Ok(_) => unreachable!(),
833            Err(err) => match err.downcast_ref::<Error>() {
834                Some(Error::MergeTokenOutOfVocabulary(token)) => {
835                    assert_eq!(*token, String::from("d"))
836                }
837                _ => unreachable!(),
838            },
839        }
840    }
841
842    #[test]
843    // Ensure `BadMerges` error is returned when there is an invalid line in the
844    // merges.txt file.
845    fn test_bpe_from_file_bad_merges() {
846        // Set up vocab file.
847        let mut vocab_file = NamedTempFile::new().unwrap();
848        vocab_file
849            .write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes())
850            .unwrap();
851
852        // Set up merges file with a bad line.
853        let mut merges_file = NamedTempFile::new().unwrap();
854        merges_file.write_all(b"#version: 0.2\na b\nc").unwrap();
855
856        // Ensure the result of BPE::from_file is a BadMerges error.
857        match BPE::from_file(
858            vocab_file.path().to_str().unwrap(),
859            merges_file.path().to_str().unwrap(),
860        )
861        .build()
862        {
863            Ok(_) => unreachable!(),
864            Err(err) => match err.downcast_ref::<Error>() {
865                Some(Error::BadMerges(line)) => assert_eq!(*line, 2),
866                _ => unreachable!(),
867            },
868        }
869    }
870
871    #[test]
872    fn test_bpe_byte_fallback() {
873        // 0x61 == 'a' in bytes
874        let vocab: Vocab = [("<unk>".into(), 0), ("<0x61>".into(), 1)]
875            .iter()
876            .cloned()
877            .collect();
878        let bpe = BpeBuilder::default()
879            .vocab_and_merges(vocab, vec![])
880            .unk_token("<unk>".to_string())
881            .byte_fallback(true)
882            .build()
883            .unwrap();
884        let tokens = bpe.tokenize("c").unwrap();
885        assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
886
887        let tokens = bpe.tokenize("a").unwrap();
888        assert_eq!(tokens, vec![Token::new(1u32, "<0x61>".into(), (0, 1)),]);
889    }
890
891    #[test]
892    fn test_bpe_byte_fallback_newline() {
893        // 0x0A == '\n' in bytes
894        let vocab: Vocab = [("<unk>".into(), 0), ("<0x0A>".into(), 1)]
895            .iter()
896            .cloned()
897            .collect();
898        let bpe = BpeBuilder::default()
899            .vocab_and_merges(vocab, vec![])
900            .unk_token("<unk>".to_string())
901            .byte_fallback(true)
902            .build()
903            .unwrap();
904        let tokens = bpe.tokenize("\n").unwrap();
905        assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]);
906    }
907
908    #[test]
909    fn test_ignore_merges() {
910        // 0x0A == '\n' in bytes
911        let vocab: Vocab = [
912            (".:.:".into(), 0),
913            ("Ġbelirtilen".into(), 1),
914            (".".into(), 2),
915            (":".into(), 3),
916            ("bel".into(), 4),
917            ("irtilen".into(), 5),
918            ("Ġ".into(), 6),
919            (".:".into(), 7),
920            ("belirtilen".into(), 8),
921            (".:.".into(), 9),
922            ("be".into(), 10),
923            ("l".into(), 11),
924            ("ir".into(), 12),
925            ("ti".into(), 13),
926            ("en".into(), 14),
927            ("irtil".into(), 15),
928            ("irti".into(), 16),
929            ("i".into(), 17),
930            ("r".into(), 18),
931            ("t".into(), 19),
932            ("b".into(), 20),
933            ("e".into(), 21),
934            ("n".into(), 22),
935        ]
936        .iter()
937        .cloned()
938        .collect();
939        let mut bpe = BpeBuilder::default()
940            .vocab_and_merges(
941                vocab,
942                vec![
943                    (".".into(), ":".into()),
944                    ("b".into(), "e".into()),
945                    ("be".into(), "l".into()),
946                    ("i".into(), "r".into()),
947                    ("t".into(), "i".into()),
948                    ("ir".into(), "ti".into()),
949                    ("e".into(), "n".into()),
950                    ("irti".into(), "l".into()),
951                ],
952            )
953            .ignore_merges(true)
954            .build()
955            .unwrap();
956        let tokens = bpe.tokenize(".:.:").unwrap();
957        assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 4))]);
958
959        let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
960        assert_eq!(
961            tokens,
962            vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 12))]
963        );
964
965        bpe.ignore_merges = false;
966
967        let tokens = bpe.tokenize(".:.:").unwrap();
968        assert_eq!(
969            tokens,
970            vec![
971                Token::new(7u32, ".:".into(), (0, 2)),
972                Token::new(7u32, ".:".into(), (2, 4))
973            ]
974        );
975
976        let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
977        assert_eq!(
978            tokens,
979            vec![
980                Token {
981                    id: 6,
982                    value: "Ġ".into(),
983                    offsets: (0, 2)
984                },
985                Token {
986                    id: 4,
987                    value: "bel".into(),
988                    offsets: (2, 5)
989                },
990                Token {
991                    id: 15,
992                    value: "irtil".into(),
993                    offsets: (5, 10)
994                },
995                Token {
996                    id: 14,
997                    value: "en".into(),
998                    offsets: (10, 12)
999                }
1000            ]
1001        )
1002    }
1003}