1use std::borrow::Cow;
14use std::error::Error;
15use std::fmt;
16use std::iter::repeat;
17use std::ops::Range;
18use std::path::Path;
19
20use rustc_hash::FxHashMap;
21
22use crate::models::{
23 Bpe, BpeError, BpeOptions, DecodeError, EncodeError, Model, WordPiece, merge_pairs_from_lines,
24};
25use crate::normalizers::{NormalizeError, Normalizer};
26use crate::pre_tokenizers::{PreTokenizeError, PreTokenizer};
27use crate::split::SliceExt;
28use crate::{normalizers, pre_tokenizers};
29
30mod json;
31
32#[derive(Copy, Clone, Debug, PartialEq)]
34pub enum EncoderInput<'a> {
35 Item(&'a str),
37
38 Pair((&'a str, &'a str)),
41}
42
43impl<'a> From<&'a str> for EncoderInput<'a> {
45 fn from(val: &'a str) -> EncoderInput<'a> {
46 EncoderInput::Item(val)
47 }
48}
49
50impl<'a> From<&'a String> for EncoderInput<'a> {
51 fn from(val: &'a String) -> EncoderInput<'a> {
52 EncoderInput::Item(val)
53 }
54}
55
56impl<'a> From<(&'a str, &'a str)> for EncoderInput<'a> {
58 fn from(val: (&'a str, &'a str)) -> EncoderInput<'a> {
59 EncoderInput::Pair(val)
60 }
61}
62
63pub type TokenId = u32;
65
66#[derive(Debug)]
72pub struct Encoded<'a> {
73 input: EncoderInput<'a>,
74 token_ids: Vec<TokenId>,
75
76 first_seq_tokens: usize,
80
81 token_offsets: Vec<usize>,
85}
86
87impl<'a> Encoded<'a> {
88 fn new(
89 input: EncoderInput<'a>,
90 ids: Vec<TokenId>,
91 offsets: Vec<usize>,
92 first_seq_tokens: usize,
93 ) -> Encoded<'a> {
94 Encoded {
95 input,
96 token_ids: ids,
97 token_offsets: offsets,
98 first_seq_tokens,
99 }
100 }
101
102 pub fn token_ids(&self) -> &[TokenId] {
104 &self.token_ids
105 }
106
107 pub fn into_token_ids(self) -> Vec<TokenId> {
112 self.token_ids
113 }
114
115 pub fn token_offsets(&self) -> &[usize] {
119 &self.token_offsets
120 }
121
122 pub fn token_type_ids(&self) -> impl Iterator<Item = usize> {
125 let second_seq_tokens = self.token_ids.len() - self.first_seq_tokens;
126 repeat(0)
127 .take(self.first_seq_tokens)
128 .chain(repeat(1).take(second_seq_tokens))
129 }
130
131 pub fn text_for_token_range(&self, range: Range<usize>) -> Option<&'a str> {
135 let start_offset = self.token_offsets.get(range.start).copied()?;
136 let input_len = match self.input {
137 EncoderInput::Item(item) => item.len(),
138 EncoderInput::Pair((query, context)) => query.len() + context.len(),
139 };
140
141 let end_offset = if range.end == self.token_offsets.len() {
142 input_len
143 } else {
144 self.token_offsets.get(range.end).copied()?
145 };
146
147 match self.input {
148 EncoderInput::Item(item) => item.get(start_offset..end_offset),
149 EncoderInput::Pair((query, context)) => {
150 if end_offset <= query.len() {
151 query.get(start_offset..end_offset)
152 } else {
153 let offset = query.len();
154 context.get(start_offset - offset..end_offset - offset)
155 }
156 }
157 }
158 }
159}
160
161#[derive(Clone, Default)]
164pub struct EncodeOptions {
165 pub max_chunk_len: Option<usize>,
168
169 pub overlap: usize,
171}
172
173#[derive(Debug)]
175pub enum FromJsonError {
176 IoError(std::io::Error),
178 JsonError(serde_json::Error),
180 NormalizerError(NormalizeError),
182 PreTokenizerError(PreTokenizeError),
184 BpeError(BpeError),
186 UnsupportedModel,
188}
189
190impl fmt::Display for FromJsonError {
191 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192 match self {
193 Self::IoError(err) => fmt::Display::fmt(err, f),
194 Self::JsonError(err) => write!(f, "JSON error {}", err),
195 Self::NormalizerError(err) => write!(f, "failed to construct normalizer: {}", err),
196 Self::PreTokenizerError(err) => write!(f, "failed to construct pre-tokenizer: {}", err),
197 Self::BpeError(err) => write!(f, "BPE tokenizer error: {}", err),
198 Self::UnsupportedModel => write!(f, "unsupported model type"),
199 }
200 }
201}
202
203impl From<NormalizeError> for FromJsonError {
204 fn from(val: NormalizeError) -> Self {
205 FromJsonError::NormalizerError(val)
206 }
207}
208
209impl From<PreTokenizeError> for FromJsonError {
210 fn from(val: PreTokenizeError) -> Self {
211 FromJsonError::PreTokenizerError(val)
212 }
213}
214
215impl Error for FromJsonError {
216 fn source(&self) -> Option<&(dyn Error + 'static)> {
217 match self {
218 Self::IoError(err) => Some(err),
219 Self::JsonError(err) => Some(err),
220 Self::NormalizerError(err) => Some(err),
221 Self::PreTokenizerError(err) => Some(err),
222 Self::BpeError(err) => Some(err),
223 Self::UnsupportedModel => None,
224 }
225 }
226}
227
228#[derive(Clone, Default)]
230pub struct TokenizerOptions<'a> {
231 pub cls_token: Option<&'a str>,
234
235 pub sep_token: Option<&'a str>,
238}
239
240pub struct Tokenizer {
248 normalizer: Option<Box<dyn Normalizer>>,
249 pre_tokenizer: Option<Box<dyn PreTokenizer>>,
250 model: Box<dyn Model>,
251
252 cls_token: Option<String>,
254
255 sep_token: Option<String>,
257}
258
259impl Tokenizer {
260 pub fn new<M: Model + 'static>(model: M, options: TokenizerOptions) -> Tokenizer {
262 Tokenizer {
263 model: Box::new(model),
264 pre_tokenizer: None,
265 normalizer: None,
266 cls_token: options.cls_token.map(|t| t.to_string()),
267 sep_token: options.sep_token.map(|t| t.to_string()),
268 }
269 }
270
271 pub fn with_normalizer(mut self, normalizer: Box<dyn Normalizer>) -> Self {
273 self.normalizer = Some(normalizer);
274 self
275 }
276
277 pub fn with_pre_tokenizer(mut self, pre_tokenizer: Box<dyn PreTokenizer>) -> Self {
279 self.pre_tokenizer = Some(pre_tokenizer);
280 self
281 }
282
283 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Tokenizer, FromJsonError> {
286 let content = std::fs::read_to_string(path).map_err(FromJsonError::IoError)?;
287 Self::from_json(&content)
288 }
289
290 pub fn from_json(json: &str) -> Result<Tokenizer, FromJsonError> {
293 let tokenizer_json = json::from_json(json).map_err(FromJsonError::JsonError)?;
294 Self::from_parsed_json(tokenizer_json)
295 }
296
297 fn from_parsed_json(json: json::Tokenizer) -> Result<Tokenizer, FromJsonError> {
298 fn regex_pattern(pattern: &json::Pattern) -> Cow<'_, str> {
299 match pattern {
300 json::Pattern::Regex(pat) => Cow::Borrowed(pat.as_str()),
301 json::Pattern::String(delim) => fancy_regex::escape(delim),
302 }
303 }
304
305 fn create_normalizer(
306 config: json::Normalizer,
307 ) -> Result<Box<dyn Normalizer>, FromJsonError> {
308 let normalizer: Box<dyn Normalizer> = match config {
309 json::Normalizer::Bert(bert_norm) => {
310 Box::new(normalizers::Bert::new(normalizers::BertOptions {
311 lowercase: bert_norm.lowercase,
312 strip_accents: bert_norm.strip_accents.unwrap_or(bert_norm.lowercase),
313 }))
314 }
315 json::Normalizer::Lowercase => {
316 Box::new(normalizers::Bert::new(normalizers::BertOptions {
317 lowercase: true,
318 strip_accents: false,
319 }))
320 }
321 json::Normalizer::Nfc => Box::new(normalizers::Unicode::Nfc),
322 json::Normalizer::Nfd => Box::new(normalizers::Unicode::Nfd),
323 json::Normalizer::Nfkc => Box::new(normalizers::Unicode::Nfkc),
324 json::Normalizer::Nfkd => Box::new(normalizers::Unicode::Nfkd),
325 json::Normalizer::Replace(replace) => {
326 let pattern = regex_pattern(&replace.pattern);
327 Box::new(normalizers::Replace::new(&pattern, replace.content)?)
328 }
329 json::Normalizer::Sequence(seq) => {
330 let normalizers = seq
331 .normalizers
332 .into_iter()
333 .map(create_normalizer)
334 .collect::<Result<Vec<_>, _>>()?;
335 Box::new(normalizers::Sequence::from_vec(normalizers))
336 }
337 };
338 Ok(normalizer)
339 }
340
341 let normalizer: Option<Box<dyn Normalizer>> =
342 json.normalizer.map(create_normalizer).transpose()?;
343
344 fn create_pre_tokenizer(
345 config: json::PreTokenizer,
346 ) -> Result<Box<dyn PreTokenizer>, FromJsonError> {
347 let pre_tokenizer: Box<dyn PreTokenizer> = match config {
348 json::PreTokenizer::Bert => Box::new(pre_tokenizers::Bert::new()),
349 json::PreTokenizer::ByteLevel(byte_level) => {
350 if byte_level.use_regex {
351 Box::new(pre_tokenizers::Split::gpt2())
352 } else {
353 let noop_split = pre_tokenizers::SplitOptions {
354 pattern: r".*",
355 invert: true,
356 ..Default::default()
357 };
358 Box::new(pre_tokenizers::Split::new(noop_split)?)
359 }
360 }
361 json::PreTokenizer::Digits(digits) => {
362 Box::new(pre_tokenizers::Digits::new(digits.individual_digits))
363 }
364 json::PreTokenizer::Sequence(seq) => {
365 let pre_tokenizers = seq
366 .pretokenizers
367 .into_iter()
368 .map(create_pre_tokenizer)
369 .collect::<Result<Vec<_>, _>>()?;
370 Box::new(pre_tokenizers::Sequence::from_vec(pre_tokenizers))
371 }
372 json::PreTokenizer::Split(split) => {
373 let pattern = regex_pattern(&split.pattern);
374 let opts = pre_tokenizers::SplitOptions {
375 pattern: &pattern,
376 invert: split.invert,
377 delimiter: match split.behavior {
378 json::pre_tokenizers::SplitDelimiter::Isolated => {
379 pre_tokenizers::SplitDelimiterBehavior::Isolate
380 }
381 json::pre_tokenizers::SplitDelimiter::Removed => {
382 pre_tokenizers::SplitDelimiterBehavior::Remove
383 }
384 },
385 };
386 Box::new(pre_tokenizers::Split::new(opts)?)
387 }
388 };
389 Ok(pre_tokenizer)
390 }
391
392 let pre_tokenizer: Option<Box<dyn PreTokenizer>> =
393 json.pre_tokenizer.map(create_pre_tokenizer).transpose()?;
394
395 let mut tokenizer = match json.model {
396 json::Model::Bpe(model) => {
397 let added_tokens: FxHashMap<TokenId, String> = json
398 .added_tokens
399 .as_ref()
400 .map(|tokens| {
401 tokens
402 .iter()
403 .map(|token| (token.id, token.content.clone()))
404 .collect()
405 })
406 .unwrap_or_default();
407 let merges: Vec<(Cow<str>, Cow<str>)> = match model.merges {
408 json::models::MergeList::Legacy(lines) => merge_pairs_from_lines(&lines),
409 json::models::MergeList::Tuple(pairs) => {
410 pairs.into_iter().map(|(a, b)| (a.0, b.0)).collect()
411 }
412 };
413 let bpe_opts = BpeOptions {
414 merges: &merges,
415 vocab: Some(model.vocab),
416 added_tokens,
417 end_of_word_suffix: model.end_of_word_suffix,
418 ignore_merges: model.ignore_merges,
419 };
420 let model = Bpe::new(bpe_opts).map_err(FromJsonError::BpeError)?;
421
422 let tokenizer = Tokenizer::new(
423 model,
424 TokenizerOptions {
425 cls_token: None,
426 sep_token: None,
427 },
428 );
429
430 Ok::<_, FromJsonError>(tokenizer)
431 }
432 json::Model::WordPiece(model) => {
433 let model = WordPiece::from_vocab(model.vocab, Default::default());
434 let tokenizer = Tokenizer::new(
435 model,
436 TokenizerOptions {
437 cls_token: Some("[CLS]"),
438 sep_token: Some("[SEP]"),
439 },
440 );
441
442 Ok::<_, FromJsonError>(tokenizer)
443 }
444 }?;
445
446 if let Some(normalizer) = normalizer {
447 tokenizer = tokenizer.with_normalizer(normalizer);
448 }
449
450 if let Some(pre_tokenizer) = pre_tokenizer {
451 tokenizer = tokenizer.with_pre_tokenizer(pre_tokenizer);
452 }
453
454 Ok(tokenizer)
455 }
456
457 #[deprecated = "`encoder` was renamed to `model`"]
458 pub fn encoder(&self) -> &dyn Model {
459 self.model()
460 }
461
462 pub fn model(&self) -> &dyn Model {
464 self.model.as_ref()
465 }
466
467 pub fn get_token_id(&self, text: &str) -> Result<TokenId, TokenizerError> {
474 self.model
475 .get_token_id(text)
476 .ok_or(TokenizerError::EncodeError(EncodeError::TokenIdNotFound(
477 text.to_string(),
478 )))
479 }
480
481 fn cls_token(&self) -> Result<Option<TokenId>, TokenizerError> {
482 self.cls_token
483 .as_deref()
484 .map(|cls| self.get_token_id(cls))
485 .transpose()
486 }
487
488 fn sep_token(&self) -> Result<Option<TokenId>, TokenizerError> {
489 self.sep_token
490 .as_deref()
491 .map(|sep| self.get_token_id(sep))
492 .transpose()
493 }
494
495 pub fn encode<'a, I: Into<EncoderInput<'a>>>(
502 &self,
503 input: I,
504 options: Option<EncodeOptions>,
505 ) -> Result<Encoded<'a>, TokenizerError> {
506 let options = options.unwrap_or_default();
507 let input: EncoderInput = input.into();
508
509 let cls_token = self.cls_token()?;
510 let sep_token = self.sep_token()?;
511
512 let chunks = self.encode_chunks(input, options)?;
516
517 let chunk = chunks.into_iter().next().unwrap_or_else(|| {
518 let mut tokens = Vec::new();
521 let mut offsets = Vec::new();
522 let mut first_seq_tokens = 0;
523
524 if let Some(cls_token) = cls_token {
525 tokens.push(cls_token);
526 offsets.push(0);
527 first_seq_tokens += 1;
528 }
529 if let Some(sep_token) = sep_token {
530 tokens.push(sep_token);
531 offsets.push(0);
532 first_seq_tokens += 1;
533
534 if matches!(input, EncoderInput::Pair(_)) {
535 tokens.push(sep_token);
536 offsets.push(0);
537 }
538 }
539
540 Encoded::new(input, tokens, offsets, first_seq_tokens)
541 });
542
543 Ok(chunk)
544 }
545
546 fn encode_str(
549 &self,
550 text: &str,
551 start_offset: usize,
552 ) -> Result<(Vec<TokenId>, Vec<usize>), TokenizerError> {
553 let (normalized, offset_map) = match &self.normalizer {
554 None => (text.to_string(), None),
555 Some(normalizer) => {
556 let (normalized_text, offsets) = normalizer.normalize(text)?;
557 (normalized_text, Some(offsets))
558 }
559 };
560
561 let chunks = self
562 .pre_tokenizer
563 .as_ref()
564 .map(|pt| pt.pre_tokenize(&normalized))
565 .transpose()
566 .map_err(TokenizerError::PreTokenizeError)?
567 .unwrap_or(Vec::from([normalized.as_str()]));
568
569 let map_offset = |offset: usize| {
572 if let Some(mappings) = &offset_map {
573 mappings
574 .get(offset)
575 .copied()
576 .expect("invalid normalized offset")
577 } else {
578 offset
579 }
580 };
581
582 let mut tokens = Vec::new();
583 let mut offsets = Vec::new();
584
585 for chunk in chunks {
586 let base_offset = normalized
587 .as_bytes()
588 .subslice_offsets(chunk.as_bytes())
589 .expect("should be a subslice")
590 .start;
591 self.model
592 .encode_with_offsets(chunk, &mut |offset, token| {
593 offsets.push(start_offset + base_offset + map_offset(offset));
594 tokens.push(token);
595 })?;
596 }
597
598 Ok((tokens, offsets))
599 }
600
601 pub fn encode_chunks<'a>(
606 &self,
607 input: EncoderInput<'a>,
608 options: EncodeOptions,
609 ) -> Result<Vec<Encoded<'a>>, TokenizerError> {
610 let cls_token = self.cls_token()?;
611 let sep_token = self.sep_token()?;
612
613 let has_cls = cls_token.is_some() as usize;
614 let has_sep = sep_token.is_some() as usize;
615
616 let non_content_tokens_per_chunk = has_cls
618 + match input {
619 EncoderInput::Item(_) => has_sep, EncoderInput::Pair(_) => has_sep * 2, };
622
623 let mut tokens = Vec::new();
625 let mut offsets = Vec::new();
626 let (first_seq, second_seq) = match input {
627 EncoderInput::Item(first) => (first, None),
628 EncoderInput::Pair((first, second)) => (first, Some(second)),
629 };
630
631 let (first_seq_tokens, first_seq_offsets) = self.encode_str(first_seq, 0)?;
632 tokens.extend(first_seq_tokens);
633 offsets.extend(first_seq_offsets);
634 let first_seq_tokens = tokens.len();
635
636 if let Some(second_seq) = second_seq {
637 let (second_seq_tokens, second_seq_offsets) =
638 self.encode_str(second_seq, first_seq.len())?;
639 tokens.extend(second_seq_tokens);
640 offsets.extend(second_seq_offsets);
641 }
642
643 let max_tokens_per_chunk = options
644 .max_chunk_len
645 .unwrap_or(tokens.len() + non_content_tokens_per_chunk)
646 .saturating_sub(non_content_tokens_per_chunk);
647
648 if max_tokens_per_chunk == 0 {
649 return Ok(vec![]);
652 }
653
654 let mut chunks = Vec::new();
656
657 match input {
658 EncoderInput::Item(item) => {
661 let all_offsets = &offsets;
662 for (chunk_idx, (tokens_chunk, offsets_chunk)) in tokens
663 .chunks_with_overlap(max_tokens_per_chunk, options.overlap)
664 .zip(offsets.chunks_with_overlap(max_tokens_per_chunk, options.overlap))
665 .enumerate()
666 {
667 let mut tokens = Vec::new();
668 let mut offsets = Vec::new();
669
670 if let Some(cls_token) = cls_token {
671 tokens.push(cls_token);
672 offsets.push(offsets_chunk.first().copied().unwrap());
673 }
674
675 tokens.extend_from_slice(tokens_chunk);
676 offsets.extend_from_slice(offsets_chunk);
677
678 if let Some(sep_token) = sep_token {
679 tokens.push(sep_token);
680 }
681
682 let chunk_start = chunk_idx * max_tokens_per_chunk;
686 offsets.push(
687 all_offsets
688 .get(chunk_start + offsets_chunk.len())
689 .copied()
690 .unwrap_or(item.len()),
691 );
692
693 let n_tokens = tokens.len();
694 chunks.push(Encoded::new(input, tokens, offsets, n_tokens));
695 }
696 }
697
698 EncoderInput::Pair((first, second)) => {
703 let (first_tokens, second_tokens) = tokens.split_at(first_seq_tokens);
704 let (first_offsets, second_offsets) = offsets.split_at(first_seq_tokens);
705
706 let first_len = first_tokens.len().min(max_tokens_per_chunk);
707 let second_len = second_tokens.len().min(max_tokens_per_chunk - first_len);
708
709 if second_len == 0 {
710 return Ok(vec![]);
713 }
714
715 for (chunk_idx, (tokens_chunk, offsets_chunk)) in second_tokens
716 .chunks_with_overlap(second_len, options.overlap)
717 .zip(second_offsets.chunks_with_overlap(second_len, options.overlap))
718 .enumerate()
719 {
720 let mut tokens = Vec::new();
721 let mut offsets = Vec::new();
722
723 if let Some(cls_token) = cls_token {
725 tokens.push(cls_token);
726 offsets.push(0);
727 }
728
729 tokens.extend_from_slice(&first_tokens[..first_len]);
730 offsets.extend_from_slice(&first_offsets[..first_len]);
731
732 if let Some(sep_token) = sep_token {
733 tokens.push(sep_token);
734 offsets.push(first.len());
735 }
736
737 let first_seq_len = tokens.len();
738
739 tokens.extend_from_slice(tokens_chunk);
741 offsets.extend_from_slice(offsets_chunk);
742
743 if let Some(sep_token) = sep_token {
747 tokens.push(sep_token);
748 }
749 let chunk_start = chunk_idx * second_len;
750 offsets.push(
751 second_offsets
752 .get(chunk_start + offsets_chunk.len())
753 .copied()
754 .unwrap_or(first.len() + second.len()),
755 );
756
757 chunks.push(Encoded::new(input, tokens, offsets, first_seq_len));
758 }
759 }
760 }
761
762 Ok(chunks)
763 }
764
765 pub fn decode(&self, ids: &[TokenId]) -> Result<String, TokenizerError> {
775 self.model.decode(ids).map_err(TokenizerError::DecodeError)
776 }
777}
778
779#[derive(Clone, Debug)]
781pub enum TokenizerError {
782 NormalizeError(NormalizeError),
783
784 PreTokenizeError(PreTokenizeError),
786
787 EncodeError(EncodeError),
789
790 DecodeError(DecodeError),
792}
793
794impl fmt::Display for TokenizerError {
795 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
796 match self {
797 Self::NormalizeError(err) => write!(f, "normalization error: {}", err),
798 Self::PreTokenizeError(err) => write!(f, "pretokenization error: {}", err),
799 Self::EncodeError(err) => write!(f, "encoding with model failed: {}", err),
800 Self::DecodeError(err) => write!(f, "decoding failed: {}", err),
801 }
802 }
803}
804
805impl From<NormalizeError> for TokenizerError {
806 fn from(err: NormalizeError) -> Self {
807 TokenizerError::NormalizeError(err)
808 }
809}
810
811impl From<EncodeError> for TokenizerError {
812 fn from(err: EncodeError) -> Self {
813 TokenizerError::EncodeError(err)
814 }
815}
816
817impl Error for TokenizerError {
818 fn source(&self) -> Option<&(dyn Error + 'static)> {
819 match self {
820 Self::NormalizeError(e) => Some(e),
821 Self::PreTokenizeError(e) => Some(e),
822 Self::EncodeError(e) => Some(e),
823 Self::DecodeError(e) => Some(e),
824 }
825 }
826}
827
828#[cfg(test)]
829mod tests {
830 use std::collections::HashMap;
831 use std::error::Error;
832 use std::fs::read_to_string;
833 use std::ops::Range;
834 use std::path::PathBuf;
835
836 use rten_testing::TestCases;
837
838 use super::{EncodeOptions, EncoderInput, TokenId, Tokenizer, TokenizerOptions, WordPiece};
839 use crate::normalizers::Normalizer;
840 use crate::{normalizers, pre_tokenizers};
841 use serde_derive::Deserialize;
842
843 fn make_wordpiece(vocab: &[&str]) -> WordPiece {
844 let vocab: HashMap<_, _> = vocab
845 .iter()
846 .enumerate()
847 .map(|(i, token)| (token.to_string(), i as u32))
848 .collect();
849 WordPiece::from_vocab(vocab, Default::default())
850 }
851
852 fn lowercase_normalizer() -> Box<dyn Normalizer> {
853 Box::new(normalizers::Bert::new(normalizers::BertOptions {
854 lowercase: true,
855 ..Default::default()
856 }))
857 }
858
859 #[test]
863 fn test_encode_two_sequences() {
864 let vocab = &[
865 "[CLS]", "[SEP]", "[UNK]", "This", "is", "a", "test", "sequence",
866 ];
867 let model = make_wordpiece(vocab);
868 let tokenizer = Tokenizer::new(
869 model,
870 TokenizerOptions {
871 cls_token: Some("[CLS]"),
872 sep_token: Some("[SEP]"),
873 },
874 )
875 .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
876
877 let encoded = tokenizer
879 .encode(("This is", "a test sequence"), None)
880 .unwrap();
881 assert_eq!(
882 tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
883 &[
884 "[CLS]", "This", "is", "[SEP]", "a", "test", "sequence", "[SEP]"
885 ]
886 );
887
888 let token_type_ids: Vec<_> = encoded.token_type_ids().collect();
889 assert_eq!(token_type_ids, &[0, 0, 0, 0, 1, 1, 1, 1]);
890 }
891
892 #[test]
893 fn test_text_for_token_range() {
894 #[derive(Debug)]
895 struct Case<'a> {
896 input: EncoderInput<'a>,
897 range: Range<usize>,
898 expected: Option<&'a str>,
899 }
900
901 let vocab = &[
902 "[CLS]", "[SEP]", "[UNK]", "This", "is", "a", "test", "sequence", "Word", "##Piece",
903 "Piece", "of", "pie", ".", "!", "?", "Hey", "Hello",
904 ];
905
906 let cases = [
907 Case {
909 input: "This is a test sequence".into(),
910 range: 4..6,
911 expected: Some("test sequence"),
912 },
913 Case {
915 input: "This is a test sequence".into(),
916 range: 1..6,
917 expected: Some("This is a test sequence"),
918 },
919 Case {
921 input: ("This is a test sequence", "Hey Hello").into(),
922 range: 4..6,
923 expected: Some("test sequence"),
924 },
925 Case {
927 input: "This is a test sequence".into(),
928 range: 1..6,
929 expected: Some("This is a test sequence"),
930 },
931 Case {
933 input: ("This is a test sequence", "Hey Hello").into(),
934 range: 8..9,
935 expected: Some("Hello"),
936 },
937 Case {
939 input: ("This is a test sequence", "Hey Hello").into(),
940 range: 7..9,
941 expected: Some("Hey Hello"),
942 },
943 Case {
945 input: "This is a test sequence".into(),
946 range: 4..8,
947 expected: None,
948 },
949 Case {
951 input: ("This is a test sequence", "Hey Hello").into(),
952 range: 7..12,
953 expected: None,
954 },
955 Case {
957 input: "This is a test sequence".into(),
958 range: 1..8,
959 expected: None,
960 },
961 Case {
963 input: "This is a test sequence".into(),
964 range: 0..7,
965 expected: Some("This is a test sequence"),
966 },
967 ];
968
969 cases.test_each(|case| {
970 let model = make_wordpiece(vocab);
971 let tokenizer = Tokenizer::new(
972 model,
973 TokenizerOptions {
974 cls_token: Some("[CLS]"),
975 sep_token: Some("[SEP]"),
976 },
977 )
978 .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
979
980 let encoded = tokenizer.encode(case.input, None).unwrap();
981 let text = encoded.text_for_token_range(case.range.clone());
982 assert_eq!(
983 text, case.expected,
984 "mismatch for input {:?} with range {:?}",
985 case.input, case.range
986 );
987 })
988 }
989
990 #[test]
991 fn test_encode_chunks_single_sequence() {
992 let vocab = &[
993 "[CLS]", "[SEP]", "[UNK]", "This", "is", "a", "test", "sequence",
994 ];
995
996 #[derive(Debug)]
997 struct Case<'a> {
998 text: &'a str,
999 max_chunk_len: Option<usize>,
1000 overlap: usize,
1001 tokens: Vec<&'a [&'a str]>,
1002 use_cls_sep: bool,
1003 lowercase: bool,
1004 }
1005
1006 let cases = [
1007 Case {
1009 text: "This is a test sequence",
1010 max_chunk_len: None,
1011 overlap: 0,
1012 tokens: vec![&["[CLS]", "This", "is", "a", "test", "sequence", "[SEP]"]],
1013 use_cls_sep: true,
1014 lowercase: false,
1015 },
1016 Case {
1018 text: "A TEST SEQUENCE",
1019 max_chunk_len: None,
1020 overlap: 0,
1021 tokens: vec![&["[CLS]", "a", "test", "sequence", "[SEP]"]],
1022 use_cls_sep: true,
1023 lowercase: true,
1024 },
1025 Case {
1027 text: "This is a test sequence",
1028 max_chunk_len: Some(5),
1029 overlap: 0,
1030 tokens: vec![
1031 &["[CLS]", "This", "is", "a", "[SEP]"],
1032 &["[CLS]", "test", "sequence", "[SEP]"],
1033 ],
1034 use_cls_sep: true,
1035 lowercase: false,
1036 },
1037 Case {
1039 text: "This is a test sequence",
1040 max_chunk_len: Some(4),
1041 overlap: 0,
1042 tokens: vec![
1043 &["[CLS]", "This", "is", "[SEP]"],
1044 &["[CLS]", "a", "test", "[SEP]"],
1045 &["[CLS]", "sequence", "[SEP]"],
1046 ],
1047 use_cls_sep: true,
1048 lowercase: false,
1049 },
1050 Case {
1053 text: "This is a test sequence",
1054 max_chunk_len: Some(0),
1055 overlap: 0,
1056 tokens: vec![],
1057 use_cls_sep: true,
1058 lowercase: false,
1059 },
1060 Case {
1062 text: "This is a test sequence",
1063 max_chunk_len: Some(5),
1064 overlap: 2,
1065 tokens: vec![
1066 &["[CLS]", "This", "is", "a", "[SEP]"],
1067 &["[CLS]", "is", "a", "test", "[SEP]"],
1068 &["[CLS]", "a", "test", "sequence", "[SEP]"],
1069 ],
1070 use_cls_sep: true,
1071 lowercase: false,
1072 },
1073 Case {
1075 text: "This is a test sequence",
1076 max_chunk_len: None,
1077 overlap: 0,
1078 tokens: vec![&["This", "is", "a", "test", "sequence"]],
1079 use_cls_sep: false,
1080 lowercase: false,
1081 },
1082 ];
1083
1084 let model = make_wordpiece(vocab);
1085
1086 cases.test_each(|case| {
1087 let Case {
1088 text,
1089 max_chunk_len,
1090 overlap,
1091 tokens,
1092 use_cls_sep,
1093 lowercase,
1094 } = case;
1095
1096 let mut tokenizer = Tokenizer::new(
1097 model.clone(),
1098 TokenizerOptions {
1099 cls_token: use_cls_sep.then_some("[CLS]"),
1100 sep_token: use_cls_sep.then_some("[SEP]"),
1101 },
1102 )
1103 .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
1104
1105 if *lowercase {
1106 tokenizer = tokenizer.with_normalizer(lowercase_normalizer());
1107 }
1108
1109 let options = EncodeOptions {
1110 max_chunk_len: *max_chunk_len,
1111 overlap: *overlap,
1112 };
1113 let chunks = tokenizer.encode_chunks((*text).into(), options).unwrap();
1114 let chunk_tokens: Vec<_> = chunks
1115 .into_iter()
1116 .map(|c| tokenizer.model().get_tokens(c.token_ids()).unwrap())
1117 .collect();
1118 assert_eq!(chunk_tokens, *tokens);
1119 })
1120 }
1121
1122 #[test]
1123 fn test_encode_chunks_sequence_pair() {
1124 let vocab = &[
1125 "[CLS]",
1126 "[SEP]",
1127 "[UNK]",
1128 "What",
1129 "is",
1130 "Rust",
1131 "?",
1132 "a",
1133 "programming",
1134 "language",
1135 ".",
1136 "Its",
1137 "mascot",
1138 "is",
1139 "Ferris",
1140 ];
1141
1142 let model = make_wordpiece(vocab);
1143
1144 #[derive(Debug)]
1145 struct Case<'a> {
1146 query: &'a str,
1147 context: &'a str,
1148 max_chunk_len: Option<usize>,
1149 overlap: usize,
1150 tokens: Vec<&'a [&'a str]>,
1151 use_sep_cls: bool,
1152 lowercase: bool,
1153 }
1154
1155 let cases = [
1156 Case {
1158 query: "What is Rust?",
1159 context: "Rust is a programming language",
1160 max_chunk_len: None,
1161 overlap: 0,
1162 use_sep_cls: true,
1163 tokens: vec![&[
1164 "[CLS]",
1165 "What",
1166 "is",
1167 "Rust",
1168 "?",
1169 "[SEP]",
1170 "Rust",
1171 "is",
1172 "a",
1173 "programming",
1174 "language",
1175 "[SEP]",
1176 ]],
1177 lowercase: false,
1178 },
1179 Case {
1181 query: "PROGRAMMING",
1182 context: "LANGUAGE",
1183 max_chunk_len: None,
1184 overlap: 0,
1185 use_sep_cls: true,
1186 tokens: vec![&["[CLS]", "programming", "[SEP]", "language", "[SEP]"]],
1187 lowercase: true,
1188 },
1189 Case {
1191 query: "What is Rust?",
1192 context: "Rust is a programming language. Its mascot is Ferris.",
1193 max_chunk_len: Some(13),
1194 overlap: 0,
1195 use_sep_cls: true,
1196 tokens: vec![
1197 &[
1198 "[CLS]",
1199 "What",
1200 "is",
1201 "Rust",
1202 "?",
1203 "[SEP]",
1204 "Rust",
1205 "is",
1206 "a",
1207 "programming",
1208 "language",
1209 ".",
1210 "[SEP]",
1211 ],
1212 &[
1213 "[CLS]", "What", "is", "Rust", "?", "[SEP]", "Its", "mascot", "is",
1214 "Ferris", ".", "[SEP]",
1215 ],
1216 ],
1217 lowercase: false,
1218 },
1219 Case {
1221 query: "What is Rust?",
1222 context: "Rust is a programming language. Its mascot is Ferris",
1223 max_chunk_len: Some(13),
1224 overlap: 2,
1225 use_sep_cls: true,
1226 tokens: vec![
1227 &[
1228 "[CLS]",
1229 "What",
1230 "is",
1231 "Rust",
1232 "?",
1233 "[SEP]",
1234 "Rust",
1235 "is",
1236 "a",
1237 "programming",
1238 "language",
1239 ".",
1240 "[SEP]",
1241 ],
1242 &[
1243 "[CLS]", "What", "is", "Rust", "?", "[SEP]", "language", ".", "Its",
1244 "mascot", "is", "Ferris", "[SEP]",
1245 ],
1246 ],
1247 lowercase: false,
1248 },
1249 Case {
1251 query: "What is Rust?",
1252 context: "Rust is a programming language",
1253 max_chunk_len: Some(7), overlap: 0,
1255 use_sep_cls: true,
1256 tokens: vec![],
1257 lowercase: false,
1258 },
1259 Case {
1261 query: "What is Rust?",
1262 context: "Rust is a programming language",
1263 max_chunk_len: None,
1264 overlap: 0,
1265 use_sep_cls: false,
1266 tokens: vec![&[
1267 "What",
1268 "is",
1269 "Rust",
1270 "?",
1271 "Rust",
1272 "is",
1273 "a",
1274 "programming",
1275 "language",
1276 ]],
1277 lowercase: false,
1278 },
1279 ];
1280
1281 cases.test_each(|case| {
1282 let Case {
1283 query,
1284 context,
1285 max_chunk_len,
1286 overlap,
1287 tokens,
1288 use_sep_cls,
1289 lowercase,
1290 } = case;
1291
1292 let mut tokenizer = Tokenizer::new(
1293 model.clone(),
1294 TokenizerOptions {
1295 cls_token: use_sep_cls.then_some("[CLS]"),
1296 sep_token: use_sep_cls.then_some("[SEP]"),
1297 },
1298 )
1299 .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
1300
1301 if *lowercase {
1302 tokenizer = tokenizer.with_normalizer(lowercase_normalizer());
1303 }
1304
1305 let options = EncodeOptions {
1306 max_chunk_len: *max_chunk_len,
1307 overlap: *overlap,
1308 ..Default::default()
1309 };
1310 let chunks = tokenizer
1311 .encode_chunks((*query, *context).into(), options)
1312 .unwrap();
1313 let chunk_tokens: Vec<_> = chunks
1314 .iter()
1315 .map(|c| tokenizer.model().get_tokens(c.token_ids()).unwrap())
1316 .collect();
1317 assert_eq!(chunk_tokens, *tokens);
1318
1319 for (chunk, chunk_tokens) in chunks.iter().zip(chunk_tokens.into_iter()) {
1324 for (i, token) in chunk_tokens.into_iter().enumerate() {
1325 if !token.starts_with("[") {
1326 let text = chunk
1327 .text_for_token_range(i..i + 1)
1328 .map(|t| t.trim())
1329 .unwrap();
1330 let text = if *lowercase {
1331 text.to_lowercase()
1332 } else {
1333 text.to_string()
1334 };
1335 assert_eq!(text, token);
1336 }
1337 }
1338 }
1339 })
1340 }
1341
1342 #[derive(Deserialize)]
1343 struct TokenizerJsonCase {
1344 text: String,
1345 token_ids: Vec<TokenId>,
1346 }
1347
1348 #[derive(Deserialize)]
1349 struct TokenizerJsonTest<'a> {
1350 #[serde(borrow)]
1351 tokenizer: super::json::Tokenizer<'a>,
1352 cases: Vec<TokenizerJsonCase>,
1353 }
1354
1355 fn read_test_json(path: &str) -> Result<String, Box<dyn Error>> {
1356 let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1357 abs_path.push("test-data/tokenizer-json/");
1358 abs_path.push(path);
1359 let content = read_to_string(abs_path)?;
1360 Ok(content)
1361 }
1362
1363 #[test]
1364 fn test_from_json() {
1365 let paths = ["wordpiece.json", "wordpiece-lower.json"];
1366
1367 for path in paths.iter() {
1368 let json = read_test_json(path).unwrap();
1369 let config: TokenizerJsonTest = serde_json::from_str(&json).unwrap();
1370
1371 let tokenizer = Tokenizer::from_parsed_json(config.tokenizer).unwrap();
1372 for case in config.cases {
1373 let encoded = tokenizer.encode(case.text.as_str(), None).unwrap();
1374 assert_eq!(encoded.token_ids(), case.token_ids);
1375 }
1376 }
1377 }
1378}