tokenizers/tokenizer/
added_vocabulary.rs

1use super::{
2    normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token,
3};
4use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
5use regex::Regex;
6use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
7use std::collections::{HashMap, HashSet};
8
9/// Represent a token added by the user on top of the existing Model vocabulary.
10/// AddedToken can be configured to specify the behavior they should have in various situations
11/// like:
12///   - Whether they should only match single words
13///   - Whether to include any whitespace on its left or right
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
15pub struct AddedToken {
16    /// The content of the added token
17    pub content: String,
18    /// Whether this token must be a single word or can break words
19    pub single_word: bool,
20    /// Whether this token should strip whitespaces on its left
21    pub lstrip: bool,
22    /// Whether this token should strip whitespaces on its right
23    pub rstrip: bool,
24    /// Whether this token should be normalized
25    pub normalized: bool,
26    /// Whether this token is special
27    pub special: bool,
28}
29
30impl AddedToken {
31    /// Build this token from the given content, specifying if it is intented to be a
32    /// special token. Special tokens are not normalized by default.
33    pub fn from<S: Into<String>>(content: S, special: bool) -> Self {
34        Self {
35            content: content.into(),
36            normalized: !special,
37            special,
38            ..Default::default()
39        }
40    }
41    /// Specify whether this token should only match on whole single words, and never
42    /// part of a word.
43    #[must_use]
44    pub fn single_word(mut self, single_word: bool) -> Self {
45        self.single_word = single_word;
46        self
47    }
48    /// Specify whether this token should include all the whitespaces on its left, in
49    /// order to strip them out.
50    #[must_use]
51    pub fn lstrip(mut self, lstrip: bool) -> Self {
52        self.lstrip = lstrip;
53        self
54    }
55    /// Specify whether this token should include all the whitespaces on its right, in
56    /// order to strip them out.
57    #[must_use]
58    pub fn rstrip(mut self, rstrip: bool) -> Self {
59        self.rstrip = rstrip;
60        self
61    }
62    /// Specify whether this token should be normalized and match against its normalized
63    /// version in the input text.
64    #[must_use]
65    pub fn normalized(mut self, normalized: bool) -> Self {
66        self.normalized = normalized;
67        self
68    }
69    /// Specify whether this token is special, meaning if it should be skipped when decoding
70    #[must_use]
71    pub fn special(mut self, special: bool) -> Self {
72        self.special = special;
73        self
74    }
75}
76impl Default for AddedToken {
77    fn default() -> Self {
78        Self {
79            content: String::new(),
80            single_word: false,
81            lstrip: false,
82            rstrip: false,
83            normalized: true,
84            special: false,
85        }
86    }
87}
88// AddedTokens can be updated if value changed
89impl std::hash::Hash for AddedToken {
90    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91        self.content.hash(state);
92    }
93}
94
95type MatchingSet = (AhoCorasick, Vec<u32>);
96
97lazy_static! {
98    static ref STARTS_WITH_WORD: Regex = Regex::new(r"^\w").unwrap();
99    static ref ENDS_WITH_WORD: Regex = Regex::new(r"\w$").unwrap();
100    static ref RIGHTMOST_SPACE_AT_START: Regex = Regex::new(r"^\s*").unwrap();
101    static ref LEFTMOST_SPACE_AT_END: Regex = Regex::new(r"\s*$").unwrap();
102}
103
104fn ends_with_word(sentence: &str) -> bool {
105    ENDS_WITH_WORD.is_match(sentence)
106}
107
108fn starts_with_word(sentence: &str) -> bool {
109    STARTS_WITH_WORD.is_match(sentence)
110}
111
112fn space_leftmost_at_end(sentence: &str) -> usize {
113    if let Some(match_) = LEFTMOST_SPACE_AT_END.find(sentence) {
114        match_.start()
115    } else {
116        sentence.len()
117    }
118}
119fn space_rightmost_at_start(sentence: &str) -> usize {
120    if let Some(match_) = RIGHTMOST_SPACE_AT_START.find(sentence) {
121        match_.end()
122    } else {
123        0
124    }
125}
126///
127/// A vocabulary built on top of the Model
128///
129/// This provides a way to add new vocabulary to a Tokenizer that has already been trained,
130/// in a previous process, maybe by someone else. This is especially interesting in the case
131/// of fine-tunings, where we want to finetune a model while adding some new functionalities
132/// using some new special tokens, or maybe add some tokens in the case of unknown tokens, etc.
133///
134/// One of the reasons we need to handle these tokens outside of the model is simply that
135/// for many models, it is not possible to add new tokens after the training process. For example,
136/// using BPE, the training process generates merges pairs along the vocabulary, and any token
137/// in the vocabulary can be decomposed in other tokens, down to the original alphabet. If we
138/// were to add new tokens after this training process, we couldn't make sure the merges pairs
139/// exist as required.
140///
141#[derive(Clone, Debug)]
142pub struct AddedVocabulary {
143    /// Contains the mapping from String (token content) to ID. This map contains both special
144    /// tokens and classic added tokens that were added to the this vocabulary.
145    added_tokens_map: HashMap<String, u32>,
146    /// Contains the mapping from ID to AddedToken for all the added tokens, both special
147    /// and classic.
148    added_tokens_map_r: HashMap<u32, AddedToken>,
149
150    /// Contains only the classic AddedToken, in the specific order the user gave them.
151    added_tokens: Vec<AddedToken>,
152    /// Contains only the special AddedToken, in the specific order the user gave them.
153    special_tokens: Vec<AddedToken>,
154
155    /// A Set, containing all the special token for easy access while decoding. This let's
156    /// us remove them easily with an O(1) complexity.
157    special_tokens_set: HashSet<String>,
158
159    /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
160    split_trie: MatchingSet,
161    /// A RegexSet containing all the normalized patterns used to split on AddedTokens
162    split_normalized_trie: MatchingSet,
163
164    /// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them
165    encode_special_tokens: bool,
166}
167
168impl AddedVocabulary {
169    pub fn new() -> Self {
170        let trie = AhoCorasickBuilder::new()
171            .match_kind(MatchKind::LeftmostLongest)
172            .build::<_, &&[u8]>([])
173            .expect("The trie should build correctly");
174        let normalized_trie = AhoCorasickBuilder::new()
175            .match_kind(MatchKind::LeftmostLongest)
176            .build::<_, &&[u8]>([])
177            .expect("The normalized trie should build correctly");
178        Self {
179            added_tokens_map: HashMap::new(),
180            added_tokens_map_r: HashMap::new(),
181            added_tokens: vec![],
182            special_tokens: vec![],
183            special_tokens_set: HashSet::new(),
184            split_trie: (trie, vec![]),
185            split_normalized_trie: (normalized_trie, vec![]),
186            encode_special_tokens: false,
187        }
188    }
189    /// Size of the additional vocabulary
190    #[allow(dead_code)] // Suppress the "method is never used" warning
191    pub fn len(&self) -> usize {
192        self.added_tokens_map.len()
193    }
194
195    /// Whether or not this vocabulary is empty
196    pub fn is_empty(&self) -> bool {
197        self.added_tokens_map.is_empty()
198    }
199
200    /// Get the additional vocabulary
201    pub fn get_vocab(&self) -> &HashMap<String, u32> {
202        &self.added_tokens_map
203    }
204
205    /// Get the additional vocabulary with the AddedTokens
206    pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
207        &self.added_tokens_map_r
208    }
209
210    /// Get the id matching one of our token if it exists
211    pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
212        self.added_tokens_map
213            .get(token)
214            .copied()
215            .or_else(|| model.token_to_id(token))
216    }
217
218    /// Get the token matching the given id if it exists
219    #[deprecated(
220        since = "0.19.0",
221        note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead"
222    )]
223    pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option<String> {
224        self.added_tokens_map_r
225            .get(&id)
226            .map(|t| t.content.clone())
227            .or_else(|| model.id_to_token(id))
228    }
229
230    pub fn simple_id_to_token(&self, id: u32) -> Option<String> {
231        self.added_tokens_map_r.get(&id).map(|t| t.content.clone())
232    }
233
234    //
235    pub fn set_encode_special_tokens(&mut self, value: bool) {
236        self.encode_special_tokens = value;
237    }
238
239    pub fn get_encode_special_tokens(&self) -> bool {
240        self.encode_special_tokens
241    }
242
243    /// Check if a token is a special token
244    pub fn is_special_token(&self, token: &str) -> bool {
245        self.special_tokens_set.contains(token)
246    }
247
248    /// Add some special tokens to the vocabulary
249    pub fn add_special_tokens<N: Normalizer>(
250        &mut self,
251        tokens: &[AddedToken],
252        model: &impl Model,
253        normalizer: Option<&N>,
254    ) -> usize {
255        self.add_tokens(tokens, model, normalizer)
256    }
257
258    /// Add some tokens to the vocabulary
259    pub fn add_tokens<N: Normalizer>(
260        &mut self,
261        tokens: &[AddedToken],
262        model: &impl Model,
263        normalizer: Option<&N>,
264    ) -> usize {
265        // Handle special tokens (if any)
266        for token in tokens {
267            if token.special
268                && !token.content.is_empty()
269                && !self.special_tokens_set.contains(&token.content)
270            {
271                self.special_tokens.push(token.to_owned());
272                self.special_tokens_set.insert(token.content.clone());
273            }
274        }
275
276        // Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
277        let mut ignored = 0;
278        for token in tokens {
279            if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token)
280            {
281                ignored += 1;
282                continue;
283            }
284            // If a token is already part of the vocabulary, we mark it as added
285            let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
286                new_id
287            } else {
288                self.added_tokens_map.values().cloned().max().map_or(
289                    model.get_vocab_size() as u32,
290                    |max| {
291                        if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
292                            max + 1
293                        } else {
294                            model.get_vocab_size() as u32
295                        }
296                    },
297                )
298            };
299            // Make sure we modify the previous entry
300            self.added_tokens_map
301                .entry(token.content.clone())
302                .and_modify(|old_id| *old_id = new_id)
303                .or_insert_with(|| new_id);
304            // Update the current revert operation
305            self.added_tokens_map_r
306                .entry(new_id)
307                .and_modify(|t| *t = token.clone())
308                .or_insert_with(|| token.clone());
309            // Make sure to remove previous entry (if the token gets a new id)
310
311            // Finally add the token to the classic set if special
312            if !self.special_tokens_set.contains(&token.content) {
313                self.added_tokens.push(token.clone());
314            }
315        }
316
317        self.refresh_added_tokens(model, normalizer);
318
319        // Return the number of added tokens
320        tokens.len() - ignored
321    }
322
323    /// Reconstruct our internal RegexSet when new tokens are added to the vocabulary.
324    ///
325    /// We keep two different RegexSet, one that will take care of matching against the
326    /// non-normalized string, and one matching against the normalized one.
327    fn refresh_added_tokens<N: Normalizer>(&mut self, model: &impl Model, normalizer: Option<&N>) {
328        type TupleTokenId<'a> = (&'a AddedToken, u32);
329        let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) = self
330            .special_tokens
331            .iter()
332            .chain(self.added_tokens.iter())
333            .map(|token| {
334                (
335                    token,
336                    self.token_to_id(&token.content, model)
337                        .expect("Missing additional token"),
338                )
339            })
340            .partition(|(token, _)| token.normalized);
341
342        let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
343        let trie = AhoCorasickBuilder::new()
344            .match_kind(MatchKind::LeftmostLongest)
345            .build(tokens.iter().map(|token| &token.content))
346            .expect("Failed to build tried when refreshing tokens");
347        self.split_trie = (trie, ids);
348
349        let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
350        let patterns: Vec<_> = ntokens
351            .iter()
352            .map(|token| {
353                let mut content = NormalizedString::from(token.content.as_ref());
354                if let Some(n) = normalizer {
355                    n.normalize(&mut content).unwrap();
356                }
357                content
358            })
359            .collect();
360        let normalized_trie = AhoCorasickBuilder::new()
361            .match_kind(MatchKind::LeftmostLongest)
362            .build(patterns.iter().map(|content| content.get()))
363            .expect("Failed to build tried when refreshing tokens (normalized)");
364        self.split_normalized_trie = (normalized_trie, nids);
365    }
366
367    /// Find any AddedToken in the given sentence, using the provided MatchingSet.
368    /// This method returns a list "splits", each of them being a pair of Offsets
369    /// and an optional ID if it is an AddedToken.
370    /// The list of splits cover the entire input string.
371    fn find_matches(&self, sentence: &str, split_re: &MatchingSet) -> Vec<(Option<u32>, Offsets)> {
372        if sentence.is_empty() {
373            return vec![(None, (0, 0))];
374        }
375
376        let mut start_offset = 0;
377        let mut splits = vec![];
378
379        for mat in split_re.0.find_iter(sentence) {
380            let mut start = mat.start();
381            let mut stop = mat.end();
382            let aho_id = mat.pattern();
383            let id = split_re.1[aho_id];
384            let added_token = &self.added_tokens_map_r.get(&id).unwrap();
385
386            if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content)
387            {
388                continue;
389            }
390
391            if added_token.single_word {
392                let start_space = start == 0 || !ends_with_word(&sentence[..start]);
393                let stop_space = stop == sentence.len() || !starts_with_word(&sentence[stop..]);
394
395                if !stop_space || !start_space {
396                    // Discard not single word
397                    continue;
398                }
399            }
400            if added_token.lstrip {
401                // This will be strictly inferior to start and in correct sentence offset
402                let newstart = space_leftmost_at_end(&sentence[..start]);
403
404                // The previous match could have already matched those spaces
405                // Ignore them if it's already matched
406                start = std::cmp::max(newstart, start_offset);
407            }
408            if added_token.rstrip {
409                // This will starting a the stop+1 character, so we need
410                // to add the previous stop value
411                stop += space_rightmost_at_start(&sentence[stop..])
412            }
413            if start_offset < start {
414                splits.push((None, (start_offset, start)));
415            }
416            splits.push((Some(id), (start, stop)));
417            start_offset = stop;
418        }
419
420        let total_byte_len = sentence.len();
421        if start_offset != total_byte_len {
422            splits.push((None, (start_offset, total_byte_len)));
423        }
424
425        splits
426    }
427
428    /// Split the input sentence to extract anything we found from the `MatchingSet`, as well as
429    /// the list of corresponding IDs
430    /// The list of IDs have the exact same number of elements than the Iterator.
431    fn split_with_indices(
432        &self,
433        sentence: NormalizedString,
434        split_re: &MatchingSet,
435    ) -> Vec<(NormalizedString, Option<Vec<Token>>)> {
436        self.find_matches(sentence.get(), split_re)
437            .into_iter()
438            .map(|(id, byte_offsets)| {
439                let slice = sentence
440                    .slice(Range::Normalized(byte_offsets.0..byte_offsets.1))
441                    .expect("AddedVocabulary bad split");
442                if let Some(id) = id {
443                    let value = slice.get().to_owned();
444                    let len = value.len();
445                    (slice, Some(vec![Token::new(id, value, (0, len))]))
446                } else {
447                    (slice, None)
448                }
449            })
450            .collect()
451    }
452
453    /// Extract the additional vocabulary from the given sentence, normalizing it along the way.
454    ///
455    /// Some tokens should match against their normalized representation, as well as the
456    /// non-normalized one. For example, when we expect to extract the token `yesterday` in the
457    /// input sentence `I read a book Yesterday`, if the normalizer is supposed to lowercase
458    /// everything, we expect a match.
459    pub fn extract_and_normalize<N: Normalizer>(
460        &self,
461        normalizer: Option<&N>,
462        sequence: &str,
463    ) -> PreTokenizedString {
464        let mut pretokenized: PreTokenizedString = sequence.into();
465
466        // 1. We extract all the non-normalized tokens from the non-normalized string
467        pretokenized
468            .split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
469            .expect("AddedVocabulary bad split");
470
471        // <s> normalized = False
472        // "I read a book   <s>Hey" -> "I read a book", "   <s>", "Hey"
473
474        // </s> normalized = True -> "▁</s>"
475        // "I read a book</s>Hey" -> "I read a book</s>Hey"
476
477        // Day normalized = True -> "Day"
478        // "I read a book monday" -> "I read a book monday"
479
480        // [DAY] normalized = False -> "Day"
481        // "I read a [DAY] monday" -> "I read a " "[DAY]", "book monday"
482        //                                         320055
483        // 2. Then extract the normalized tokens from the normalized pieces of the string
484        pretokenized
485            .split(|_, mut sequence| {
486                normalizer.map(|n| n.normalize(&mut sequence));
487                Ok(self.split_with_indices(sequence, &self.split_normalized_trie))
488            })
489            .expect("AddedVocabulary bad split");
490
491        // ["I read a book", "   <s>", "Hey"] -> ["▁I read a book", "▁   <s>", "▁Hey"]
492        // ["▁I read a book", "▁   <s>", "▁Hey"] -> [.., "▁   ", "<s>", "▁Hey"]
493
494        // </s> normalized = True -> "▁</s>"
495        // "I read a book</s>Hey" -> ["▁I read a book", "<","/","s",">", "Hey"]
496
497        // "I read a " "[DAY]", "book monday" -> "i read a " "[day]", "book monday"
498
499        pretokenized
500    }
501}
502
503impl Default for AddedVocabulary {
504    fn default() -> Self {
505        Self::new()
506    }
507}
508
509#[derive(Debug, Serialize, Deserialize)]
510pub(super) struct AddedTokenWithId {
511    /// The id assigned to this token
512    pub id: u32,
513    #[serde(flatten)]
514    /// The target AddedToken
515    pub token: AddedToken,
516}
517
518impl Serialize for AddedVocabulary {
519    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
520    where
521        S: Serializer,
522    {
523        let mut added_tokens = self
524            .added_tokens_map_r
525            .iter()
526            .map(|(id, token)| AddedTokenWithId {
527                id: *id,
528                token: token.clone(),
529            })
530            .collect::<Vec<_>>();
531        // We need to have these added tokens ordered by ascending ID
532        added_tokens.sort_unstable_by_key(|o| o.id);
533
534        let mut vocabulary = serializer.serialize_seq(Some(added_tokens.len()))?;
535        for token in added_tokens {
536            vocabulary.serialize_element(&token)?;
537        }
538
539        vocabulary.end()
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use crate::normalizers::byte_level::ByteLevel as ByteLevelNormalizer;
547    use crate::normalizers::utils::Lowercase;
548    use crate::normalizers::NormalizerWrapper;
549    use crate::{OffsetReferential, OffsetType, Result, Token, Trainer};
550    use std::path::{Path, PathBuf};
551
552    #[derive(Serialize, Deserialize)]
553    struct ModelMock {
554        vocab: HashMap<String, u32>,
555        vocab_r: HashMap<u32, String>,
556    }
557    impl ModelMock {
558        pub fn new<I>(iter: I) -> Self
559        where
560            I: IntoIterator<Item = &'static (&'static str, u32)>,
561        {
562            let vocab: HashMap<String, u32> = iter
563                .into_iter()
564                .map(|&(tok, id)| (tok.to_string(), id))
565                .collect();
566            Self {
567                vocab_r: vocab
568                    .iter()
569                    .map(|(tok, id)| (*id, tok.to_owned()))
570                    .collect(),
571                vocab,
572            }
573        }
574    }
575
576    fn simplify_output(result: &'_ PreTokenizedString) -> Vec<(&'_ str, Option<Vec<u32>>)> {
577        result
578            .get_splits(OffsetReferential::Original, OffsetType::Byte)
579            .into_iter()
580            .map(|(s, _, tokens)| {
581                (
582                    s,
583                    tokens
584                        .as_ref()
585                        .map(|t| t.iter().map(|t| t.id).collect::<Vec<_>>()),
586                )
587            })
588            .collect::<Vec<_>>()
589    }
590
591    struct TrainerMock;
592    impl Trainer for TrainerMock {
593        type Model = ModelMock;
594        fn should_show_progress(&self) -> bool {
595            true
596        }
597        fn train(&self, _model: &mut ModelMock) -> Result<Vec<AddedToken>> {
598            unimplemented!()
599        }
600        fn feed<I, S, F>(&mut self, _iterator: I, _process: F) -> Result<()>
601        where
602            I: Iterator<Item = S> + Send,
603            S: AsRef<str> + Send,
604            F: Fn(&str) -> Result<Vec<String>> + Sync,
605        {
606            unimplemented!()
607        }
608    }
609
610    impl Model for ModelMock {
611        type Trainer = TrainerMock;
612
613        fn tokenize(&self, _sequence: &str) -> Result<Vec<Token>> {
614            unimplemented!()
615        }
616        fn token_to_id(&self, token: &str) -> Option<u32> {
617            self.vocab.get(token).copied()
618        }
619        fn id_to_token(&self, id: u32) -> Option<String> {
620            self.vocab_r.get(&id).cloned()
621        }
622        fn get_vocab(&self) -> HashMap<String, u32> {
623            self.vocab.clone()
624        }
625        fn get_vocab_size(&self) -> usize {
626            self.vocab.len()
627        }
628        fn save(&self, _folder: &Path, _name: Option<&str>) -> Result<Vec<PathBuf>> {
629            unimplemented!()
630        }
631        fn get_trainer(&self) -> Self::Trainer {
632            TrainerMock
633        }
634    }
635
636    #[test]
637    fn can_add_tokens() {
638        let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
639        let mut vocab = AddedVocabulary::new();
640        let normalizer: Option<&NormalizerWrapper> = None;
641
642        // Add tokens normally
643        assert_eq!(
644            vocab.add_tokens(
645                &[AddedToken::from("added_token_1", false)],
646                &model,
647                normalizer
648            ),
649            1
650        );
651
652        let vocab_len: usize = vocab.len();
653        assert_eq!(vocab_len, 1);
654
655        // Does not add multiple time the same token
656        assert_eq!(
657            vocab.add_tokens(
658                &[
659                    AddedToken::from("added_token_2", false),
660                    AddedToken::from("added_token_2", false)
661                ],
662                &model,
663                normalizer
664            ),
665            1
666        );
667        assert_eq!(vocab.len(), 2);
668
669        // Also adds tokens already covered by the model
670        let added_token = AddedToken::from("test", false);
671        assert_eq!(
672            vocab.add_tokens(&[added_token.clone()], &model, normalizer),
673            1
674        );
675        assert_eq!(vocab.len(), 3);
676
677        assert_eq!(vocab.get_added_tokens_decoder()[&0], added_token);
678    }
679
680    #[test]
681    fn can_add_special_tokens() {
682        let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
683        let mut vocab = AddedVocabulary::new();
684        let normalizer: Option<&NormalizerWrapper> = None;
685        // Add tokens normally
686        assert_eq!(
687            vocab.add_special_tokens(
688                &[AddedToken::from("added_token_1", true)],
689                &model,
690                normalizer
691            ),
692            1
693        );
694        assert_eq!(vocab.len(), 1);
695
696        // Does not add multiple time the same token
697        assert_eq!(
698            vocab.add_special_tokens(
699                &[
700                    AddedToken::from("added_token_2", true),
701                    AddedToken::from("added_token_2", true)
702                ],
703                &model,
704                normalizer
705            ),
706            1
707        );
708        assert_eq!(vocab.len(), 2);
709
710        // Can add tokens already covered by the model
711        assert_eq!(
712            vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer),
713            1
714        );
715        assert_eq!(vocab.len(), 3); // New token was added
716        assert!(vocab.is_special_token("test"));
717        assert_eq!(
718            *vocab.get_added_tokens_decoder(),
719            HashMap::from([
720                (0, AddedToken::from("test", true)),
721                (2, AddedToken::from("added_token_1", true)),
722                (3, AddedToken::from("added_token_2", true)),
723            ])
724        );
725        assert!(vocab.added_tokens_map.contains_key("test"));
726        assert!(vocab.added_tokens_map_r.contains_key(&0));
727
728        vocab.add_tokens(
729            &[
730                AddedToken::from("tost", true),
731                AddedToken::from("another_two", false),
732            ],
733            &model,
734            normalizer,
735        );
736        assert_eq!(vocab.len(), 5); // New token was added
737        assert_eq!(vocab.get_vocab()["another_two"], 4); // New token was added, but the index is not the length of the vocab
738
739        // Let's add an already added token again
740        assert_eq!(
741            vocab.add_special_tokens(&[AddedToken::from("another_two", true)], &model, normalizer),
742            1
743        );
744        assert_eq!(vocab.len(), 5); // Token was already there
745        assert_eq!(vocab.get_vocab()["another_two"], 4); // Token idx not changed
746
747        // Just checking that we can set the content of the string in rust
748        let mut token: AddedToken = AddedToken::from("Hey", false);
749        token.content = "hey".to_string();
750        assert_eq!(token.content, "hey"); // Token was already there
751
752        token.special = true;
753        assert!(token.special); // Token was already there
754    }
755
756    #[test]
757    fn can_extract_added_tokens() {
758        // Is able to extract both normal and special tokens
759        let model = ModelMock::new(&[]);
760        let mut vocab = AddedVocabulary::new();
761        let normalizer: Option<&NormalizerWrapper> = None;
762
763        vocab.add_tokens(
764            &[
765                AddedToken::from("my", false),
766                AddedToken::from("name", false),
767            ],
768            &model,
769            normalizer,
770        );
771        vocab.add_special_tokens(
772            &[
773                AddedToken::from("[CLS]", true),
774                AddedToken::from("[SEP]", true),
775            ],
776            &model,
777            normalizer,
778        );
779
780        let result = vocab.extract_and_normalize(normalizer, "[CLS] My name is Anthony [SEP]");
781        assert_eq!(
782            result
783                .get_splits(OffsetReferential::Original, OffsetType::Byte)
784                .into_iter()
785                .map(|(s, _, tokens)| (
786                    s,
787                    tokens
788                        .as_ref()
789                        .map(|t| t.iter().map(|t| t.id).collect::<Vec<_>>())
790                ))
791                .collect::<Vec<_>>(),
792            vec![
793                ("[CLS]", Some(vec![2])),
794                (" My ", None),
795                ("name", Some(vec![1])),
796                (" is Anthony ", None),
797                ("[SEP]", Some(vec![3]))
798            ]
799        );
800    }
801
802    #[test]
803    fn options_use_cases() {
804        // Is able to extract both normal and special tokens, with various options (lstrip, rstrip,
805        // single_word, normalized)
806        let model = ModelMock::new(&[]);
807        let normalizer = Lowercase;
808        let mut vocab = AddedVocabulary::new();
809
810        vocab.add_tokens(
811            &[
812                AddedToken::from("my", false).lstrip(true).rstrip(true),
813                AddedToken::from("name", false),
814                AddedToken::from("ony", false).single_word(true),
815            ],
816            &model,
817            Some(&normalizer),
818        );
819        vocab.add_special_tokens(
820            &[
821                AddedToken::from("[CLS]", true),
822                AddedToken::from("[SEP]", true),
823            ],
824            &model,
825            Some(&normalizer),
826        );
827
828        let result =
829            vocab.extract_and_normalize(Some(&normalizer), "[CLS] My name is Anthony [SEP]");
830
831        assert_eq!(
832            simplify_output(&result),
833            vec![
834                ("[CLS]", Some(vec![3])),
835                // This one includes both spaces because of the lstrip & rstrip
836                // And it matches because normalized == true
837                (" my ", Some(vec![0])),
838                ("name", Some(vec![1])),
839                // `ony` is not extracted here thanks to single_word
840                (" is anthony ", None),
841                ("[SEP]", Some(vec![4])),
842            ]
843        );
844    }
845
846    #[test]
847    fn empty_matches() {
848        let vocab = AddedVocabulary::new();
849        let matches = vocab.find_matches("", &vocab.split_trie);
850        assert_eq!(matches, vec![(None, (0, 0))]);
851    }
852
853    #[test]
854    fn test_single_word_is_correct() {
855        // Is able to extract both normal and special tokens, with various options (lstrip, rstrip,
856        // single_word, normalized)
857        let model = ModelMock::new(&[]);
858        let mut vocab = AddedVocabulary::new();
859        let normalizer = Lowercase;
860
861        vocab.add_tokens(
862            &[AddedToken::from("<mask>", false).single_word(true)],
863            &model,
864            Some(&normalizer),
865        );
866        // Left, in the middle, non single world left, non single word right, end of sentence valid
867        let result = vocab.extract_and_normalize(
868            Some(&normalizer),
869            "<mask> My name <mask> A<mask> <mask>ony <mask>",
870        );
871        assert_eq!(
872            simplify_output(&result),
873            vec![
874                ("<mask>", Some(vec![0])),
875                (" my name ", None),
876                ("<mask>", Some(vec![0])),
877                (" a<mask> <mask>ony ", None),
878                ("<mask>", Some(vec![0]))
879            ]
880        );
881    }
882
883    #[test]
884    fn test_single_word_is_unicode_correct() {
885        let model = ModelMock::new(&[]);
886        let mut vocab = AddedVocabulary::new();
887        let normalizer = Lowercase;
888
889        assert_eq!(vocab.len(), 0);
890
891        vocab.add_tokens(
892            &[AddedToken::from("<mask>", false).single_word(true)],
893            &model,
894            Some(&normalizer),
895        );
896        let result = vocab.extract_and_normalize(Some(&normalizer), "<mask>, <mask>- ◌̰<mask>");
897        assert_eq!(
898            simplify_output(&result),
899            vec![
900                // Punctuation is not word
901                ("<mask>", Some(vec![0])),
902                (", ", None),
903                // dash is not word
904                ("<mask>", Some(vec![0])),
905                // This is unicode combining mark character and is word: https://en.wikipedia.org/wiki/Combining_Diacritical_Marks
906                ("- ◌̰<mask>", None),
907            ]
908        );
909    }
910
911    #[test]
912    fn test_lstrip_unicode_space() {
913        let model = ModelMock::new(&[]);
914        let mut vocab = AddedVocabulary::new();
915        let normalizer = Lowercase;
916
917        vocab.add_tokens(
918            &[AddedToken::from("<mask>", false)
919                .lstrip(true)
920                .rstrip(true)
921                .single_word(true)],
922            &model,
923            Some(&normalizer),
924        );
925        let result = vocab
926            .extract_and_normalize(Some(&normalizer), "Hi <mask> there\t<mask>\t<mask>\u{2000}");
927        assert_eq!(
928            simplify_output(&result),
929            vec![
930                ("hi", None),
931                // Regular space
932                (" <mask> ", Some(vec![0])),
933                ("there", None),
934                // \t is a spacing character
935                ("\t<mask>\t", Some(vec![0])),
936                // Non overlapping
937                // \u{2000} is mongolian vowel separator: https://jkorpela.fi/chars/spaces.html
938                ("<mask>\u{2000}", Some(vec![0])),
939            ]
940        );
941    }
942
943    #[test]
944    fn test_encode_special_tokens() {
945        let model = ModelMock::new(&[]);
946        let mut vocab = AddedVocabulary::new();
947        let normalizer = Lowercase;
948
949        vocab.add_tokens(
950            &[
951                AddedToken::from("<mask>", true)
952                    .lstrip(true)
953                    .rstrip(true)
954                    .single_word(true),
955                AddedToken::from("ask>", false),
956                AddedToken::from("<pad>", true),
957            ],
958            &model,
959            Some(&normalizer),
960        );
961        vocab.set_encode_special_tokens(true);
962
963        let result = vocab.extract_and_normalize(
964            Some(&normalizer),
965            "Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
966        );
967
968        assert_eq!(
969            simplify_output(&result),
970            vec![
971                ("hi <m", None),
972                ("ask>", Some(vec![1])),
973                (" there\t<m", None),
974                ("ask>", Some(vec![1])),
975                ("\t<m", None),
976                ("ask>", Some(vec![1])),
977                ("\u{2000} <pad> <m", None),
978                ("ask>", Some(vec![1])),
979                ("<pad><pad>", None)
980            ]
981        );
982
983        vocab.set_encode_special_tokens(false);
984
985        let result = vocab.extract_and_normalize(
986            Some(&normalizer),
987            "Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
988        );
989        assert_eq!(
990            simplify_output(&result),
991            vec![
992                ("hi", None),
993                (" <mask> ", Some(vec![0])),
994                ("there", None),
995                ("\t<mask>\t", Some(vec![0])),
996                ("<mask>\u{2000} ", Some(vec![0])),
997                ("<pad>", Some(vec![2])),
998                (" <mask>", Some(vec![0])),
999                ("<pad>", Some(vec![2])),
1000                ("<pad>", Some(vec![2]))
1001            ]
1002        );
1003    }
1004    #[test]
1005    fn byte_level_normalizer() {
1006        // Is able to extract both normal and special tokens
1007        let model = ModelMock::new(&[]);
1008        let mut vocab = AddedVocabulary::new();
1009        let from = NormalizerWrapper::from(ByteLevelNormalizer::new());
1010        let normalizer: Option<&NormalizerWrapper> = Some(&from);
1011
1012        vocab.add_tokens(
1013            &[AddedToken::from("my", false), AddedToken::from("今", false)],
1014            &model,
1015            normalizer,
1016        );
1017        let result = vocab.extract_and_normalize(normalizer, "my今");
1018        assert_eq!(
1019            result
1020                .get_splits(OffsetReferential::Original, OffsetType::Byte)
1021                .into_iter()
1022                .map(|(s, _, tokens)| (
1023                    s,
1024                    tokens
1025                        .as_ref()
1026                        .map(|t| t.iter().map(|t| t.id).collect::<Vec<_>>())
1027                ))
1028                .collect::<Vec<_>>(),
1029            vec![("my", Some(vec![0])), ("ä»Ĭ", Some(vec![1])),]
1030        );
1031    }
1032}