1use super::{
2 normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token,
3};
4use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
5use regex::Regex;
6use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
15pub struct AddedToken {
16 pub content: String,
18 pub single_word: bool,
20 pub lstrip: bool,
22 pub rstrip: bool,
24 pub normalized: bool,
26 pub special: bool,
28}
29
30impl AddedToken {
31 pub fn from<S: Into<String>>(content: S, special: bool) -> Self {
34 Self {
35 content: content.into(),
36 normalized: !special,
37 special,
38 ..Default::default()
39 }
40 }
41 #[must_use]
44 pub fn single_word(mut self, single_word: bool) -> Self {
45 self.single_word = single_word;
46 self
47 }
48 #[must_use]
51 pub fn lstrip(mut self, lstrip: bool) -> Self {
52 self.lstrip = lstrip;
53 self
54 }
55 #[must_use]
58 pub fn rstrip(mut self, rstrip: bool) -> Self {
59 self.rstrip = rstrip;
60 self
61 }
62 #[must_use]
65 pub fn normalized(mut self, normalized: bool) -> Self {
66 self.normalized = normalized;
67 self
68 }
69 #[must_use]
71 pub fn special(mut self, special: bool) -> Self {
72 self.special = special;
73 self
74 }
75}
76impl Default for AddedToken {
77 fn default() -> Self {
78 Self {
79 content: String::new(),
80 single_word: false,
81 lstrip: false,
82 rstrip: false,
83 normalized: true,
84 special: false,
85 }
86 }
87}
88impl std::hash::Hash for AddedToken {
90 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91 self.content.hash(state);
92 }
93}
94
95type MatchingSet = (AhoCorasick, Vec<u32>);
96
97lazy_static! {
98 static ref STARTS_WITH_WORD: Regex = Regex::new(r"^\w").unwrap();
99 static ref ENDS_WITH_WORD: Regex = Regex::new(r"\w$").unwrap();
100 static ref RIGHTMOST_SPACE_AT_START: Regex = Regex::new(r"^\s*").unwrap();
101 static ref LEFTMOST_SPACE_AT_END: Regex = Regex::new(r"\s*$").unwrap();
102}
103
104fn ends_with_word(sentence: &str) -> bool {
105 ENDS_WITH_WORD.is_match(sentence)
106}
107
108fn starts_with_word(sentence: &str) -> bool {
109 STARTS_WITH_WORD.is_match(sentence)
110}
111
112fn space_leftmost_at_end(sentence: &str) -> usize {
113 if let Some(match_) = LEFTMOST_SPACE_AT_END.find(sentence) {
114 match_.start()
115 } else {
116 sentence.len()
117 }
118}
119fn space_rightmost_at_start(sentence: &str) -> usize {
120 if let Some(match_) = RIGHTMOST_SPACE_AT_START.find(sentence) {
121 match_.end()
122 } else {
123 0
124 }
125}
126#[derive(Clone, Debug)]
142pub struct AddedVocabulary {
143 added_tokens_map: HashMap<String, u32>,
146 added_tokens_map_r: HashMap<u32, AddedToken>,
149
150 added_tokens: Vec<AddedToken>,
152 special_tokens: Vec<AddedToken>,
154
155 special_tokens_set: HashSet<String>,
158
159 split_trie: MatchingSet,
161 split_normalized_trie: MatchingSet,
163
164 encode_special_tokens: bool,
166}
167
168impl AddedVocabulary {
169 pub fn new() -> Self {
170 let trie = AhoCorasickBuilder::new()
171 .match_kind(MatchKind::LeftmostLongest)
172 .build::<_, &&[u8]>([])
173 .expect("The trie should build correctly");
174 let normalized_trie = AhoCorasickBuilder::new()
175 .match_kind(MatchKind::LeftmostLongest)
176 .build::<_, &&[u8]>([])
177 .expect("The normalized trie should build correctly");
178 Self {
179 added_tokens_map: HashMap::new(),
180 added_tokens_map_r: HashMap::new(),
181 added_tokens: vec![],
182 special_tokens: vec![],
183 special_tokens_set: HashSet::new(),
184 split_trie: (trie, vec![]),
185 split_normalized_trie: (normalized_trie, vec![]),
186 encode_special_tokens: false,
187 }
188 }
189 #[allow(dead_code)] pub fn len(&self) -> usize {
192 self.added_tokens_map.len()
193 }
194
195 pub fn is_empty(&self) -> bool {
197 self.added_tokens_map.is_empty()
198 }
199
200 pub fn get_vocab(&self) -> &HashMap<String, u32> {
202 &self.added_tokens_map
203 }
204
205 pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
207 &self.added_tokens_map_r
208 }
209
210 pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
212 self.added_tokens_map
213 .get(token)
214 .copied()
215 .or_else(|| model.token_to_id(token))
216 }
217
218 #[deprecated(
220 since = "0.19.0",
221 note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead"
222 )]
223 pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option<String> {
224 self.added_tokens_map_r
225 .get(&id)
226 .map(|t| t.content.clone())
227 .or_else(|| model.id_to_token(id))
228 }
229
230 pub fn simple_id_to_token(&self, id: u32) -> Option<String> {
231 self.added_tokens_map_r.get(&id).map(|t| t.content.clone())
232 }
233
234 pub fn set_encode_special_tokens(&mut self, value: bool) {
236 self.encode_special_tokens = value;
237 }
238
239 pub fn get_encode_special_tokens(&self) -> bool {
240 self.encode_special_tokens
241 }
242
243 pub fn is_special_token(&self, token: &str) -> bool {
245 self.special_tokens_set.contains(token)
246 }
247
248 pub fn add_special_tokens<N: Normalizer>(
250 &mut self,
251 tokens: &[AddedToken],
252 model: &impl Model,
253 normalizer: Option<&N>,
254 ) -> usize {
255 self.add_tokens(tokens, model, normalizer)
256 }
257
258 pub fn add_tokens<N: Normalizer>(
260 &mut self,
261 tokens: &[AddedToken],
262 model: &impl Model,
263 normalizer: Option<&N>,
264 ) -> usize {
265 for token in tokens {
267 if token.special
268 && !token.content.is_empty()
269 && !self.special_tokens_set.contains(&token.content)
270 {
271 self.special_tokens.push(token.to_owned());
272 self.special_tokens_set.insert(token.content.clone());
273 }
274 }
275
276 let mut ignored = 0;
278 for token in tokens {
279 if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token)
280 {
281 ignored += 1;
282 continue;
283 }
284 let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
286 new_id
287 } else {
288 self.added_tokens_map.values().cloned().max().map_or(
289 model.get_vocab_size() as u32,
290 |max| {
291 if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
292 max + 1
293 } else {
294 model.get_vocab_size() as u32
295 }
296 },
297 )
298 };
299 self.added_tokens_map
301 .entry(token.content.clone())
302 .and_modify(|old_id| *old_id = new_id)
303 .or_insert_with(|| new_id);
304 self.added_tokens_map_r
306 .entry(new_id)
307 .and_modify(|t| *t = token.clone())
308 .or_insert_with(|| token.clone());
309 if !self.special_tokens_set.contains(&token.content) {
313 self.added_tokens.push(token.clone());
314 }
315 }
316
317 self.refresh_added_tokens(model, normalizer);
318
319 tokens.len() - ignored
321 }
322
323 fn refresh_added_tokens<N: Normalizer>(&mut self, model: &impl Model, normalizer: Option<&N>) {
328 type TupleTokenId<'a> = (&'a AddedToken, u32);
329 let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) = self
330 .special_tokens
331 .iter()
332 .chain(self.added_tokens.iter())
333 .map(|token| {
334 (
335 token,
336 self.token_to_id(&token.content, model)
337 .expect("Missing additional token"),
338 )
339 })
340 .partition(|(token, _)| token.normalized);
341
342 let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
343 let trie = AhoCorasickBuilder::new()
344 .match_kind(MatchKind::LeftmostLongest)
345 .build(tokens.iter().map(|token| &token.content))
346 .expect("Failed to build tried when refreshing tokens");
347 self.split_trie = (trie, ids);
348
349 let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
350 let patterns: Vec<_> = ntokens
351 .iter()
352 .map(|token| {
353 let mut content = NormalizedString::from(token.content.as_ref());
354 if let Some(n) = normalizer {
355 n.normalize(&mut content).unwrap();
356 }
357 content
358 })
359 .collect();
360 let normalized_trie = AhoCorasickBuilder::new()
361 .match_kind(MatchKind::LeftmostLongest)
362 .build(patterns.iter().map(|content| content.get()))
363 .expect("Failed to build tried when refreshing tokens (normalized)");
364 self.split_normalized_trie = (normalized_trie, nids);
365 }
366
367 fn find_matches(&self, sentence: &str, split_re: &MatchingSet) -> Vec<(Option<u32>, Offsets)> {
372 if sentence.is_empty() {
373 return vec![(None, (0, 0))];
374 }
375
376 let mut start_offset = 0;
377 let mut splits = vec![];
378
379 for mat in split_re.0.find_iter(sentence) {
380 let mut start = mat.start();
381 let mut stop = mat.end();
382 let aho_id = mat.pattern();
383 let id = split_re.1[aho_id];
384 let added_token = &self.added_tokens_map_r.get(&id).unwrap();
385
386 if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content)
387 {
388 continue;
389 }
390
391 if added_token.single_word {
392 let start_space = start == 0 || !ends_with_word(&sentence[..start]);
393 let stop_space = stop == sentence.len() || !starts_with_word(&sentence[stop..]);
394
395 if !stop_space || !start_space {
396 continue;
398 }
399 }
400 if added_token.lstrip {
401 let newstart = space_leftmost_at_end(&sentence[..start]);
403
404 start = std::cmp::max(newstart, start_offset);
407 }
408 if added_token.rstrip {
409 stop += space_rightmost_at_start(&sentence[stop..])
412 }
413 if start_offset < start {
414 splits.push((None, (start_offset, start)));
415 }
416 splits.push((Some(id), (start, stop)));
417 start_offset = stop;
418 }
419
420 let total_byte_len = sentence.len();
421 if start_offset != total_byte_len {
422 splits.push((None, (start_offset, total_byte_len)));
423 }
424
425 splits
426 }
427
428 fn split_with_indices(
432 &self,
433 sentence: NormalizedString,
434 split_re: &MatchingSet,
435 ) -> Vec<(NormalizedString, Option<Vec<Token>>)> {
436 self.find_matches(sentence.get(), split_re)
437 .into_iter()
438 .map(|(id, byte_offsets)| {
439 let slice = sentence
440 .slice(Range::Normalized(byte_offsets.0..byte_offsets.1))
441 .expect("AddedVocabulary bad split");
442 if let Some(id) = id {
443 let value = slice.get().to_owned();
444 let len = value.len();
445 (slice, Some(vec![Token::new(id, value, (0, len))]))
446 } else {
447 (slice, None)
448 }
449 })
450 .collect()
451 }
452
453 pub fn extract_and_normalize<N: Normalizer>(
460 &self,
461 normalizer: Option<&N>,
462 sequence: &str,
463 ) -> PreTokenizedString {
464 let mut pretokenized: PreTokenizedString = sequence.into();
465
466 pretokenized
468 .split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
469 .expect("AddedVocabulary bad split");
470
471 pretokenized
485 .split(|_, mut sequence| {
486 normalizer.map(|n| n.normalize(&mut sequence));
487 Ok(self.split_with_indices(sequence, &self.split_normalized_trie))
488 })
489 .expect("AddedVocabulary bad split");
490
491 pretokenized
500 }
501}
502
503impl Default for AddedVocabulary {
504 fn default() -> Self {
505 Self::new()
506 }
507}
508
509#[derive(Debug, Serialize, Deserialize)]
510pub(super) struct AddedTokenWithId {
511 pub id: u32,
513 #[serde(flatten)]
514 pub token: AddedToken,
516}
517
518impl Serialize for AddedVocabulary {
519 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
520 where
521 S: Serializer,
522 {
523 let mut added_tokens = self
524 .added_tokens_map_r
525 .iter()
526 .map(|(id, token)| AddedTokenWithId {
527 id: *id,
528 token: token.clone(),
529 })
530 .collect::<Vec<_>>();
531 added_tokens.sort_unstable_by_key(|o| o.id);
533
534 let mut vocabulary = serializer.serialize_seq(Some(added_tokens.len()))?;
535 for token in added_tokens {
536 vocabulary.serialize_element(&token)?;
537 }
538
539 vocabulary.end()
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use crate::normalizers::byte_level::ByteLevel as ByteLevelNormalizer;
547 use crate::normalizers::utils::Lowercase;
548 use crate::normalizers::NormalizerWrapper;
549 use crate::{OffsetReferential, OffsetType, Result, Token, Trainer};
550 use std::path::{Path, PathBuf};
551
552 #[derive(Serialize, Deserialize)]
553 struct ModelMock {
554 vocab: HashMap<String, u32>,
555 vocab_r: HashMap<u32, String>,
556 }
557 impl ModelMock {
558 pub fn new<I>(iter: I) -> Self
559 where
560 I: IntoIterator<Item = &'static (&'static str, u32)>,
561 {
562 let vocab: HashMap<String, u32> = iter
563 .into_iter()
564 .map(|&(tok, id)| (tok.to_string(), id))
565 .collect();
566 Self {
567 vocab_r: vocab
568 .iter()
569 .map(|(tok, id)| (*id, tok.to_owned()))
570 .collect(),
571 vocab,
572 }
573 }
574 }
575
576 fn simplify_output(result: &'_ PreTokenizedString) -> Vec<(&'_ str, Option<Vec<u32>>)> {
577 result
578 .get_splits(OffsetReferential::Original, OffsetType::Byte)
579 .into_iter()
580 .map(|(s, _, tokens)| {
581 (
582 s,
583 tokens
584 .as_ref()
585 .map(|t| t.iter().map(|t| t.id).collect::<Vec<_>>()),
586 )
587 })
588 .collect::<Vec<_>>()
589 }
590
591 struct TrainerMock;
592 impl Trainer for TrainerMock {
593 type Model = ModelMock;
594 fn should_show_progress(&self) -> bool {
595 true
596 }
597 fn train(&self, _model: &mut ModelMock) -> Result<Vec<AddedToken>> {
598 unimplemented!()
599 }
600 fn feed<I, S, F>(&mut self, _iterator: I, _process: F) -> Result<()>
601 where
602 I: Iterator<Item = S> + Send,
603 S: AsRef<str> + Send,
604 F: Fn(&str) -> Result<Vec<String>> + Sync,
605 {
606 unimplemented!()
607 }
608 }
609
610 impl Model for ModelMock {
611 type Trainer = TrainerMock;
612
613 fn tokenize(&self, _sequence: &str) -> Result<Vec<Token>> {
614 unimplemented!()
615 }
616 fn token_to_id(&self, token: &str) -> Option<u32> {
617 self.vocab.get(token).copied()
618 }
619 fn id_to_token(&self, id: u32) -> Option<String> {
620 self.vocab_r.get(&id).cloned()
621 }
622 fn get_vocab(&self) -> HashMap<String, u32> {
623 self.vocab.clone()
624 }
625 fn get_vocab_size(&self) -> usize {
626 self.vocab.len()
627 }
628 fn save(&self, _folder: &Path, _name: Option<&str>) -> Result<Vec<PathBuf>> {
629 unimplemented!()
630 }
631 fn get_trainer(&self) -> Self::Trainer {
632 TrainerMock
633 }
634 }
635
636 #[test]
637 fn can_add_tokens() {
638 let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
639 let mut vocab = AddedVocabulary::new();
640 let normalizer: Option<&NormalizerWrapper> = None;
641
642 assert_eq!(
644 vocab.add_tokens(
645 &[AddedToken::from("added_token_1", false)],
646 &model,
647 normalizer
648 ),
649 1
650 );
651
652 let vocab_len: usize = vocab.len();
653 assert_eq!(vocab_len, 1);
654
655 assert_eq!(
657 vocab.add_tokens(
658 &[
659 AddedToken::from("added_token_2", false),
660 AddedToken::from("added_token_2", false)
661 ],
662 &model,
663 normalizer
664 ),
665 1
666 );
667 assert_eq!(vocab.len(), 2);
668
669 let added_token = AddedToken::from("test", false);
671 assert_eq!(
672 vocab.add_tokens(&[added_token.clone()], &model, normalizer),
673 1
674 );
675 assert_eq!(vocab.len(), 3);
676
677 assert_eq!(vocab.get_added_tokens_decoder()[&0], added_token);
678 }
679
680 #[test]
681 fn can_add_special_tokens() {
682 let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
683 let mut vocab = AddedVocabulary::new();
684 let normalizer: Option<&NormalizerWrapper> = None;
685 assert_eq!(
687 vocab.add_special_tokens(
688 &[AddedToken::from("added_token_1", true)],
689 &model,
690 normalizer
691 ),
692 1
693 );
694 assert_eq!(vocab.len(), 1);
695
696 assert_eq!(
698 vocab.add_special_tokens(
699 &[
700 AddedToken::from("added_token_2", true),
701 AddedToken::from("added_token_2", true)
702 ],
703 &model,
704 normalizer
705 ),
706 1
707 );
708 assert_eq!(vocab.len(), 2);
709
710 assert_eq!(
712 vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer),
713 1
714 );
715 assert_eq!(vocab.len(), 3); assert!(vocab.is_special_token("test"));
717 assert_eq!(
718 *vocab.get_added_tokens_decoder(),
719 HashMap::from([
720 (0, AddedToken::from("test", true)),
721 (2, AddedToken::from("added_token_1", true)),
722 (3, AddedToken::from("added_token_2", true)),
723 ])
724 );
725 assert!(vocab.added_tokens_map.contains_key("test"));
726 assert!(vocab.added_tokens_map_r.contains_key(&0));
727
728 vocab.add_tokens(
729 &[
730 AddedToken::from("tost", true),
731 AddedToken::from("another_two", false),
732 ],
733 &model,
734 normalizer,
735 );
736 assert_eq!(vocab.len(), 5); assert_eq!(vocab.get_vocab()["another_two"], 4); assert_eq!(
741 vocab.add_special_tokens(&[AddedToken::from("another_two", true)], &model, normalizer),
742 1
743 );
744 assert_eq!(vocab.len(), 5); assert_eq!(vocab.get_vocab()["another_two"], 4); let mut token: AddedToken = AddedToken::from("Hey", false);
749 token.content = "hey".to_string();
750 assert_eq!(token.content, "hey"); token.special = true;
753 assert!(token.special); }
755
756 #[test]
757 fn can_extract_added_tokens() {
758 let model = ModelMock::new(&[]);
760 let mut vocab = AddedVocabulary::new();
761 let normalizer: Option<&NormalizerWrapper> = None;
762
763 vocab.add_tokens(
764 &[
765 AddedToken::from("my", false),
766 AddedToken::from("name", false),
767 ],
768 &model,
769 normalizer,
770 );
771 vocab.add_special_tokens(
772 &[
773 AddedToken::from("[CLS]", true),
774 AddedToken::from("[SEP]", true),
775 ],
776 &model,
777 normalizer,
778 );
779
780 let result = vocab.extract_and_normalize(normalizer, "[CLS] My name is Anthony [SEP]");
781 assert_eq!(
782 result
783 .get_splits(OffsetReferential::Original, OffsetType::Byte)
784 .into_iter()
785 .map(|(s, _, tokens)| (
786 s,
787 tokens
788 .as_ref()
789 .map(|t| t.iter().map(|t| t.id).collect::<Vec<_>>())
790 ))
791 .collect::<Vec<_>>(),
792 vec![
793 ("[CLS]", Some(vec![2])),
794 (" My ", None),
795 ("name", Some(vec![1])),
796 (" is Anthony ", None),
797 ("[SEP]", Some(vec![3]))
798 ]
799 );
800 }
801
802 #[test]
803 fn options_use_cases() {
804 let model = ModelMock::new(&[]);
807 let normalizer = Lowercase;
808 let mut vocab = AddedVocabulary::new();
809
810 vocab.add_tokens(
811 &[
812 AddedToken::from("my", false).lstrip(true).rstrip(true),
813 AddedToken::from("name", false),
814 AddedToken::from("ony", false).single_word(true),
815 ],
816 &model,
817 Some(&normalizer),
818 );
819 vocab.add_special_tokens(
820 &[
821 AddedToken::from("[CLS]", true),
822 AddedToken::from("[SEP]", true),
823 ],
824 &model,
825 Some(&normalizer),
826 );
827
828 let result =
829 vocab.extract_and_normalize(Some(&normalizer), "[CLS] My name is Anthony [SEP]");
830
831 assert_eq!(
832 simplify_output(&result),
833 vec![
834 ("[CLS]", Some(vec![3])),
835 (" my ", Some(vec![0])),
838 ("name", Some(vec![1])),
839 (" is anthony ", None),
841 ("[SEP]", Some(vec![4])),
842 ]
843 );
844 }
845
846 #[test]
847 fn empty_matches() {
848 let vocab = AddedVocabulary::new();
849 let matches = vocab.find_matches("", &vocab.split_trie);
850 assert_eq!(matches, vec![(None, (0, 0))]);
851 }
852
853 #[test]
854 fn test_single_word_is_correct() {
855 let model = ModelMock::new(&[]);
858 let mut vocab = AddedVocabulary::new();
859 let normalizer = Lowercase;
860
861 vocab.add_tokens(
862 &[AddedToken::from("<mask>", false).single_word(true)],
863 &model,
864 Some(&normalizer),
865 );
866 let result = vocab.extract_and_normalize(
868 Some(&normalizer),
869 "<mask> My name <mask> A<mask> <mask>ony <mask>",
870 );
871 assert_eq!(
872 simplify_output(&result),
873 vec![
874 ("<mask>", Some(vec![0])),
875 (" my name ", None),
876 ("<mask>", Some(vec![0])),
877 (" a<mask> <mask>ony ", None),
878 ("<mask>", Some(vec![0]))
879 ]
880 );
881 }
882
883 #[test]
884 fn test_single_word_is_unicode_correct() {
885 let model = ModelMock::new(&[]);
886 let mut vocab = AddedVocabulary::new();
887 let normalizer = Lowercase;
888
889 assert_eq!(vocab.len(), 0);
890
891 vocab.add_tokens(
892 &[AddedToken::from("<mask>", false).single_word(true)],
893 &model,
894 Some(&normalizer),
895 );
896 let result = vocab.extract_and_normalize(Some(&normalizer), "<mask>, <mask>- ◌̰<mask>");
897 assert_eq!(
898 simplify_output(&result),
899 vec![
900 ("<mask>", Some(vec![0])),
902 (", ", None),
903 ("<mask>", Some(vec![0])),
905 ("- ◌̰<mask>", None),
907 ]
908 );
909 }
910
911 #[test]
912 fn test_lstrip_unicode_space() {
913 let model = ModelMock::new(&[]);
914 let mut vocab = AddedVocabulary::new();
915 let normalizer = Lowercase;
916
917 vocab.add_tokens(
918 &[AddedToken::from("<mask>", false)
919 .lstrip(true)
920 .rstrip(true)
921 .single_word(true)],
922 &model,
923 Some(&normalizer),
924 );
925 let result = vocab
926 .extract_and_normalize(Some(&normalizer), "Hi <mask> there\t<mask>\t<mask>\u{2000}");
927 assert_eq!(
928 simplify_output(&result),
929 vec![
930 ("hi", None),
931 (" <mask> ", Some(vec![0])),
933 ("there", None),
934 ("\t<mask>\t", Some(vec![0])),
936 ("<mask>\u{2000}", Some(vec![0])),
939 ]
940 );
941 }
942
943 #[test]
944 fn test_encode_special_tokens() {
945 let model = ModelMock::new(&[]);
946 let mut vocab = AddedVocabulary::new();
947 let normalizer = Lowercase;
948
949 vocab.add_tokens(
950 &[
951 AddedToken::from("<mask>", true)
952 .lstrip(true)
953 .rstrip(true)
954 .single_word(true),
955 AddedToken::from("ask>", false),
956 AddedToken::from("<pad>", true),
957 ],
958 &model,
959 Some(&normalizer),
960 );
961 vocab.set_encode_special_tokens(true);
962
963 let result = vocab.extract_and_normalize(
964 Some(&normalizer),
965 "Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
966 );
967
968 assert_eq!(
969 simplify_output(&result),
970 vec![
971 ("hi <m", None),
972 ("ask>", Some(vec![1])),
973 (" there\t<m", None),
974 ("ask>", Some(vec![1])),
975 ("\t<m", None),
976 ("ask>", Some(vec![1])),
977 ("\u{2000} <pad> <m", None),
978 ("ask>", Some(vec![1])),
979 ("<pad><pad>", None)
980 ]
981 );
982
983 vocab.set_encode_special_tokens(false);
984
985 let result = vocab.extract_and_normalize(
986 Some(&normalizer),
987 "Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
988 );
989 assert_eq!(
990 simplify_output(&result),
991 vec![
992 ("hi", None),
993 (" <mask> ", Some(vec![0])),
994 ("there", None),
995 ("\t<mask>\t", Some(vec![0])),
996 ("<mask>\u{2000} ", Some(vec![0])),
997 ("<pad>", Some(vec![2])),
998 (" <mask>", Some(vec![0])),
999 ("<pad>", Some(vec![2])),
1000 ("<pad>", Some(vec![2]))
1001 ]
1002 );
1003 }
1004 #[test]
1005 fn byte_level_normalizer() {
1006 let model = ModelMock::new(&[]);
1008 let mut vocab = AddedVocabulary::new();
1009 let from = NormalizerWrapper::from(ByteLevelNormalizer::new());
1010 let normalizer: Option<&NormalizerWrapper> = Some(&from);
1011
1012 vocab.add_tokens(
1013 &[AddedToken::from("my", false), AddedToken::from("今", false)],
1014 &model,
1015 normalizer,
1016 );
1017 let result = vocab.extract_and_normalize(normalizer, "my今");
1018 assert_eq!(
1019 result
1020 .get_splits(OffsetReferential::Original, OffsetType::Byte)
1021 .into_iter()
1022 .map(|(s, _, tokens)| (
1023 s,
1024 tokens
1025 .as_ref()
1026 .map(|t| t.iter().map(|t| t.id).collect::<Vec<_>>())
1027 ))
1028 .collect::<Vec<_>>(),
1029 vec![("my", Some(vec![0])), ("ä»Ĭ", Some(vec![1])),]
1030 );
1031 }
1032}