rten_text/
tokenizer.rs

1//! Defines the [`Tokenizer`] type that implements the tokenization pipeline.
2//!
3//! There are two ways to construct a tokenizer:
4//!
5//! 1. Load a preconfigured tokenizer from JSON, using [`Tokenizer::from_json`].
6//!    This crate supports a subset of the `tokenizer.json` format that
7//!    Hugging Face Tokenizers generates.
8//!
9//! 2. Manually configure a [`Tokenizer`] by creating an [`Model`] implementation,
10//!    such as [`WordPiece`] and then wrap it with a tokenizer using
11//!    [`Tokenizer::new`].
12
13use std::borrow::Cow;
14use std::error::Error;
15use std::fmt;
16use std::iter::repeat;
17use std::ops::Range;
18use std::path::Path;
19
20use rustc_hash::FxHashMap;
21
22use crate::models::{
23    Bpe, BpeError, BpeOptions, DecodeError, EncodeError, Model, WordPiece, merge_pairs_from_lines,
24};
25use crate::normalizers::{NormalizeError, Normalizer};
26use crate::pre_tokenizers::{PreTokenizeError, PreTokenizer};
27use crate::split::SliceExt;
28use crate::{normalizers, pre_tokenizers};
29
30mod json;
31
32/// Input sequences for [`Tokenizer::encode`].
33#[derive(Copy, Clone, Debug, PartialEq)]
34pub enum EncoderInput<'a> {
35    /// Input with a single sequence.
36    Item(&'a str),
37
38    /// Input with a pair of sequences. Used in tasks such as extractive
39    /// question answering, where the sequence is `(query, context)`.
40    Pair((&'a str, &'a str)),
41}
42
43/// Construct a tokenizer input with a single sequence.
44impl<'a> From<&'a str> for EncoderInput<'a> {
45    fn from(val: &'a str) -> EncoderInput<'a> {
46        EncoderInput::Item(val)
47    }
48}
49
50impl<'a> From<&'a String> for EncoderInput<'a> {
51    fn from(val: &'a String) -> EncoderInput<'a> {
52        EncoderInput::Item(val)
53    }
54}
55
56/// Construct a tokenizer input with a pair of sequences.
57impl<'a> From<(&'a str, &'a str)> for EncoderInput<'a> {
58    fn from(val: (&'a str, &'a str)) -> EncoderInput<'a> {
59        EncoderInput::Pair(val)
60    }
61}
62
63/// Integer type used to represent token IDs.
64pub type TokenId = u32;
65
66/// Output produced by a [`Tokenizer::encode`] implementation.
67///
68/// Use [`Encoded::token_ids`] to get the token IDs to feed to a model, and
69/// [`Encoded::text_for_token_range`] to map token ID ranges back to the
70/// corresponding input text.
71#[derive(Debug)]
72pub struct Encoded<'a> {
73    input: EncoderInput<'a>,
74    token_ids: Vec<TokenId>,
75
76    /// Number of tokens in `token_ids` that were generated from the first
77    /// sequence in the input. This includes the `[CLS]` and `[SEP]` tokens
78    /// which come before and after the sequence respectively.
79    first_seq_tokens: usize,
80
81    /// Offsets of text corresponding to tokens in the input string. When the
82    /// input contains two sentences, the offsets are relative to the string
83    /// that a particular input that a token comes from.
84    token_offsets: Vec<usize>,
85}
86
87impl<'a> Encoded<'a> {
88    fn new(
89        input: EncoderInput<'a>,
90        ids: Vec<TokenId>,
91        offsets: Vec<usize>,
92        first_seq_tokens: usize,
93    ) -> Encoded<'a> {
94        Encoded {
95            input,
96            token_ids: ids,
97            token_offsets: offsets,
98            first_seq_tokens,
99        }
100    }
101
102    /// Return the sequence of token IDs that the input was tokenized into.
103    pub fn token_ids(&self) -> &[TokenId] {
104        &self.token_ids
105    }
106
107    /// Consume `self` and return a list of token IDs.
108    ///
109    /// This is a convenient way to discard other information from the encoded
110    /// output and get the token IDs as an owned vector.
111    pub fn into_token_ids(self) -> Vec<TokenId> {
112        self.token_ids
113    }
114
115    /// Return the byte offsets of the start of each token in the input
116    /// sequence. If the input contained two sequences, the offsets are assigned
117    /// as if the two sequences were concatenated.
118    pub fn token_offsets(&self) -> &[usize] {
119        &self.token_offsets
120    }
121
122    /// Return an iterator of the inputs for the `token_type_ids` input field
123    /// in the model, if it has one.
124    pub fn token_type_ids(&self) -> impl Iterator<Item = usize> {
125        let second_seq_tokens = self.token_ids.len() - self.first_seq_tokens;
126        repeat(0)
127            .take(self.first_seq_tokens)
128            .chain(repeat(1).take(second_seq_tokens))
129    }
130
131    /// Return the text from the input sequence(s) that corresponds to a range
132    /// of token indices. If the input contained two sequences, the range must
133    /// lie entirely within one of them.
134    pub fn text_for_token_range(&self, range: Range<usize>) -> Option<&'a str> {
135        let start_offset = self.token_offsets.get(range.start).copied()?;
136        let input_len = match self.input {
137            EncoderInput::Item(item) => item.len(),
138            EncoderInput::Pair((query, context)) => query.len() + context.len(),
139        };
140
141        let end_offset = if range.end == self.token_offsets.len() {
142            input_len
143        } else {
144            self.token_offsets.get(range.end).copied()?
145        };
146
147        match self.input {
148            EncoderInput::Item(item) => item.get(start_offset..end_offset),
149            EncoderInput::Pair((query, context)) => {
150                if end_offset <= query.len() {
151                    query.get(start_offset..end_offset)
152                } else {
153                    let offset = query.len();
154                    context.get(start_offset - offset..end_offset - offset)
155                }
156            }
157        }
158    }
159}
160
161/// Options that control chunking and truncation by [`Tokenizer::encode`] and
162/// [`Tokenizer::encode_chunks`].
163#[derive(Clone, Default)]
164pub struct EncodeOptions {
165    /// Maximum number of tokens in each chunk, including any special tokens
166    /// (eg. `[CLS]`, `[SEP]`) that are added.
167    pub max_chunk_len: Option<usize>,
168
169    /// The number of tokens that a chunk will overlap with the previous chunk.
170    pub overlap: usize,
171}
172
173/// Errors returned by [`Tokenizer::from_json`].
174#[derive(Debug)]
175pub enum FromJsonError {
176    /// There was an error reading the JSON data from a file.
177    IoError(std::io::Error),
178    /// There was an error decoding the JSON data.
179    JsonError(serde_json::Error),
180    /// Could not instantiate a normalizer.
181    NormalizerError(NormalizeError),
182    /// Could not instantiate a pre-tokenizer.
183    PreTokenizerError(PreTokenizeError),
184    /// There was an error loading a BPE tokenizer.
185    BpeError(BpeError),
186    /// The model type isn't supported by this crate.
187    UnsupportedModel,
188}
189
190impl fmt::Display for FromJsonError {
191    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192        match self {
193            Self::IoError(err) => fmt::Display::fmt(err, f),
194            Self::JsonError(err) => write!(f, "JSON error {}", err),
195            Self::NormalizerError(err) => write!(f, "failed to construct normalizer: {}", err),
196            Self::PreTokenizerError(err) => write!(f, "failed to construct pre-tokenizer: {}", err),
197            Self::BpeError(err) => write!(f, "BPE tokenizer error: {}", err),
198            Self::UnsupportedModel => write!(f, "unsupported model type"),
199        }
200    }
201}
202
203impl From<NormalizeError> for FromJsonError {
204    fn from(val: NormalizeError) -> Self {
205        FromJsonError::NormalizerError(val)
206    }
207}
208
209impl From<PreTokenizeError> for FromJsonError {
210    fn from(val: PreTokenizeError) -> Self {
211        FromJsonError::PreTokenizerError(val)
212    }
213}
214
215impl Error for FromJsonError {
216    fn source(&self) -> Option<&(dyn Error + 'static)> {
217        match self {
218            Self::IoError(err) => Some(err),
219            Self::JsonError(err) => Some(err),
220            Self::NormalizerError(err) => Some(err),
221            Self::PreTokenizerError(err) => Some(err),
222            Self::BpeError(err) => Some(err),
223            Self::UnsupportedModel => None,
224        }
225    }
226}
227
228/// Configuration for a [`Tokenizer`].
229#[derive(Clone, Default)]
230pub struct TokenizerOptions<'a> {
231    /// Token added at the start of the output. For BERT models, this is the
232    /// `[CLS]` token.
233    pub cls_token: Option<&'a str>,
234
235    /// Token added after each encoded sequence in the output. For BERT models,
236    /// this is the `[SEP]` token.
237    pub sep_token: Option<&'a str>,
238}
239
240/// Tokenizes text inputs into sequences of token IDs that can be fed to a
241/// machine learning model.
242///
243/// `Tokenizer` wraps a [`Model`] which handles specific methods of encoding of
244/// individual sequences (eg. WordPiece, Byte Pair Encoding, Unigram) and adds
245/// common functionality such as injecting special tokens, splitting sequences
246/// into overlapping chunks and truncating long sequences.
247pub struct Tokenizer {
248    normalizer: Option<Box<dyn Normalizer>>,
249    pre_tokenizer: Option<Box<dyn PreTokenizer>>,
250    model: Box<dyn Model>,
251
252    /// Token added at start of output.
253    cls_token: Option<String>,
254
255    /// Token added after end of each sequence.
256    sep_token: Option<String>,
257}
258
259impl Tokenizer {
260    /// Create a new tokenizer which wraps the given model.
261    pub fn new<M: Model + 'static>(model: M, options: TokenizerOptions) -> Tokenizer {
262        Tokenizer {
263            model: Box::new(model),
264            pre_tokenizer: None,
265            normalizer: None,
266            cls_token: options.cls_token.map(|t| t.to_string()),
267            sep_token: options.sep_token.map(|t| t.to_string()),
268        }
269    }
270
271    /// Configure the normalizer used by this tokenizer.
272    pub fn with_normalizer(mut self, normalizer: Box<dyn Normalizer>) -> Self {
273        self.normalizer = Some(normalizer);
274        self
275    }
276
277    /// Configure the pre-tokenizer used by this tokenizer.
278    pub fn with_pre_tokenizer(mut self, pre_tokenizer: Box<dyn PreTokenizer>) -> Self {
279        self.pre_tokenizer = Some(pre_tokenizer);
280        self
281    }
282
283    /// Load a tokenizer from the contents of a Hugging Face `tokenizer.json`
284    /// file.
285    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Tokenizer, FromJsonError> {
286        let content = std::fs::read_to_string(path).map_err(FromJsonError::IoError)?;
287        Self::from_json(&content)
288    }
289
290    /// Load a tokenizer from the contents of a Hugging Face `tokenizer.json`
291    /// file.
292    pub fn from_json(json: &str) -> Result<Tokenizer, FromJsonError> {
293        let tokenizer_json = json::from_json(json).map_err(FromJsonError::JsonError)?;
294        Self::from_parsed_json(tokenizer_json)
295    }
296
297    fn from_parsed_json(json: json::Tokenizer) -> Result<Tokenizer, FromJsonError> {
298        fn regex_pattern(pattern: &json::Pattern) -> Cow<'_, str> {
299            match pattern {
300                json::Pattern::Regex(pat) => Cow::Borrowed(pat.as_str()),
301                json::Pattern::String(delim) => fancy_regex::escape(delim),
302            }
303        }
304
305        fn create_normalizer(
306            config: json::Normalizer,
307        ) -> Result<Box<dyn Normalizer>, FromJsonError> {
308            let normalizer: Box<dyn Normalizer> = match config {
309                json::Normalizer::Bert(bert_norm) => {
310                    Box::new(normalizers::Bert::new(normalizers::BertOptions {
311                        lowercase: bert_norm.lowercase,
312                        strip_accents: bert_norm.strip_accents.unwrap_or(bert_norm.lowercase),
313                    }))
314                }
315                json::Normalizer::Lowercase => {
316                    Box::new(normalizers::Bert::new(normalizers::BertOptions {
317                        lowercase: true,
318                        strip_accents: false,
319                    }))
320                }
321                json::Normalizer::Nfc => Box::new(normalizers::Unicode::Nfc),
322                json::Normalizer::Nfd => Box::new(normalizers::Unicode::Nfd),
323                json::Normalizer::Nfkc => Box::new(normalizers::Unicode::Nfkc),
324                json::Normalizer::Nfkd => Box::new(normalizers::Unicode::Nfkd),
325                json::Normalizer::Replace(replace) => {
326                    let pattern = regex_pattern(&replace.pattern);
327                    Box::new(normalizers::Replace::new(&pattern, replace.content)?)
328                }
329                json::Normalizer::Sequence(seq) => {
330                    let normalizers = seq
331                        .normalizers
332                        .into_iter()
333                        .map(create_normalizer)
334                        .collect::<Result<Vec<_>, _>>()?;
335                    Box::new(normalizers::Sequence::from_vec(normalizers))
336                }
337            };
338            Ok(normalizer)
339        }
340
341        let normalizer: Option<Box<dyn Normalizer>> =
342            json.normalizer.map(create_normalizer).transpose()?;
343
344        fn create_pre_tokenizer(
345            config: json::PreTokenizer,
346        ) -> Result<Box<dyn PreTokenizer>, FromJsonError> {
347            let pre_tokenizer: Box<dyn PreTokenizer> = match config {
348                json::PreTokenizer::Bert => Box::new(pre_tokenizers::Bert::new()),
349                json::PreTokenizer::ByteLevel(byte_level) => {
350                    if byte_level.use_regex {
351                        Box::new(pre_tokenizers::Split::gpt2())
352                    } else {
353                        let noop_split = pre_tokenizers::SplitOptions {
354                            pattern: r".*",
355                            invert: true,
356                            ..Default::default()
357                        };
358                        Box::new(pre_tokenizers::Split::new(noop_split)?)
359                    }
360                }
361                json::PreTokenizer::Digits(digits) => {
362                    Box::new(pre_tokenizers::Digits::new(digits.individual_digits))
363                }
364                json::PreTokenizer::Sequence(seq) => {
365                    let pre_tokenizers = seq
366                        .pretokenizers
367                        .into_iter()
368                        .map(create_pre_tokenizer)
369                        .collect::<Result<Vec<_>, _>>()?;
370                    Box::new(pre_tokenizers::Sequence::from_vec(pre_tokenizers))
371                }
372                json::PreTokenizer::Split(split) => {
373                    let pattern = regex_pattern(&split.pattern);
374                    let opts = pre_tokenizers::SplitOptions {
375                        pattern: &pattern,
376                        invert: split.invert,
377                        delimiter: match split.behavior {
378                            json::pre_tokenizers::SplitDelimiter::Isolated => {
379                                pre_tokenizers::SplitDelimiterBehavior::Isolate
380                            }
381                            json::pre_tokenizers::SplitDelimiter::Removed => {
382                                pre_tokenizers::SplitDelimiterBehavior::Remove
383                            }
384                        },
385                    };
386                    Box::new(pre_tokenizers::Split::new(opts)?)
387                }
388            };
389            Ok(pre_tokenizer)
390        }
391
392        let pre_tokenizer: Option<Box<dyn PreTokenizer>> =
393            json.pre_tokenizer.map(create_pre_tokenizer).transpose()?;
394
395        let mut tokenizer = match json.model {
396            json::Model::Bpe(model) => {
397                let added_tokens: FxHashMap<TokenId, String> = json
398                    .added_tokens
399                    .as_ref()
400                    .map(|tokens| {
401                        tokens
402                            .iter()
403                            .map(|token| (token.id, token.content.clone()))
404                            .collect()
405                    })
406                    .unwrap_or_default();
407                let merges: Vec<(Cow<str>, Cow<str>)> = match model.merges {
408                    json::models::MergeList::Legacy(lines) => merge_pairs_from_lines(&lines),
409                    json::models::MergeList::Tuple(pairs) => {
410                        pairs.into_iter().map(|(a, b)| (a.0, b.0)).collect()
411                    }
412                };
413                let bpe_opts = BpeOptions {
414                    merges: &merges,
415                    vocab: Some(model.vocab),
416                    added_tokens,
417                    end_of_word_suffix: model.end_of_word_suffix,
418                    ignore_merges: model.ignore_merges,
419                };
420                let model = Bpe::new(bpe_opts).map_err(FromJsonError::BpeError)?;
421
422                let tokenizer = Tokenizer::new(
423                    model,
424                    TokenizerOptions {
425                        cls_token: None,
426                        sep_token: None,
427                    },
428                );
429
430                Ok::<_, FromJsonError>(tokenizer)
431            }
432            json::Model::WordPiece(model) => {
433                let model = WordPiece::from_vocab(model.vocab, Default::default());
434                let tokenizer = Tokenizer::new(
435                    model,
436                    TokenizerOptions {
437                        cls_token: Some("[CLS]"),
438                        sep_token: Some("[SEP]"),
439                    },
440                );
441
442                Ok::<_, FromJsonError>(tokenizer)
443            }
444        }?;
445
446        if let Some(normalizer) = normalizer {
447            tokenizer = tokenizer.with_normalizer(normalizer);
448        }
449
450        if let Some(pre_tokenizer) = pre_tokenizer {
451            tokenizer = tokenizer.with_pre_tokenizer(pre_tokenizer);
452        }
453
454        Ok(tokenizer)
455    }
456
457    #[deprecated = "`encoder` was renamed to `model`"]
458    pub fn encoder(&self) -> &dyn Model {
459        self.model()
460    }
461
462    /// Return the model used to convert string pieces to token IDs.
463    pub fn model(&self) -> &dyn Model {
464        self.model.as_ref()
465    }
466
467    /// Return the ID of a token given its canonical string representation.
468    ///
469    /// This is usually used for looking up the IDs of special/added tokens.
470    ///
471    /// This wraps [`Model::get_token_id`] but returns a `Result` rather than
472    /// an `Option`, assuming the token is expected to be valid.
473    pub fn get_token_id(&self, text: &str) -> Result<TokenId, TokenizerError> {
474        self.model
475            .get_token_id(text)
476            .ok_or(TokenizerError::EncodeError(EncodeError::TokenIdNotFound(
477                text.to_string(),
478            )))
479    }
480
481    fn cls_token(&self) -> Result<Option<TokenId>, TokenizerError> {
482        self.cls_token
483            .as_deref()
484            .map(|cls| self.get_token_id(cls))
485            .transpose()
486    }
487
488    fn sep_token(&self) -> Result<Option<TokenId>, TokenizerError> {
489        self.sep_token
490            .as_deref()
491            .map(|sep| self.get_token_id(sep))
492            .transpose()
493    }
494
495    /// Encode one or two sequences into a sequence of tokens.
496    ///
497    /// The input can be an `&str` or tuple of `(&str, &str)`.
498    ///
499    /// In addition to token IDs, the result also includes information about
500    /// the corresponding offsets in the source text.
501    pub fn encode<'a, I: Into<EncoderInput<'a>>>(
502        &self,
503        input: I,
504        options: Option<EncodeOptions>,
505    ) -> Result<Encoded<'a>, TokenizerError> {
506        let options = options.unwrap_or_default();
507        let input: EncoderInput = input.into();
508
509        let cls_token = self.cls_token()?;
510        let sep_token = self.sep_token()?;
511
512        // To simplify the implementation, we tokenize the whole input and
513        // just discard all chunks except the first. This could be optimized
514        // to only generate one chunk.
515        let chunks = self.encode_chunks(input, options)?;
516
517        let chunk = chunks.into_iter().next().unwrap_or_else(|| {
518            // If the input is empty after tokenization, generate a single
519            // empty chunk.
520            let mut tokens = Vec::new();
521            let mut offsets = Vec::new();
522            let mut first_seq_tokens = 0;
523
524            if let Some(cls_token) = cls_token {
525                tokens.push(cls_token);
526                offsets.push(0);
527                first_seq_tokens += 1;
528            }
529            if let Some(sep_token) = sep_token {
530                tokens.push(sep_token);
531                offsets.push(0);
532                first_seq_tokens += 1;
533
534                if matches!(input, EncoderInput::Pair(_)) {
535                    tokens.push(sep_token);
536                    offsets.push(0);
537                }
538            }
539
540            Encoded::new(input, tokens, offsets, first_seq_tokens)
541        });
542
543        Ok(chunk)
544    }
545
546    /// Encode a single string into tokens and return a `(tokens, offsets)`
547    /// tuple.
548    fn encode_str(
549        &self,
550        text: &str,
551        start_offset: usize,
552    ) -> Result<(Vec<TokenId>, Vec<usize>), TokenizerError> {
553        let (normalized, offset_map) = match &self.normalizer {
554            None => (text.to_string(), None),
555            Some(normalizer) => {
556                let (normalized_text, offsets) = normalizer.normalize(text)?;
557                (normalized_text, Some(offsets))
558            }
559        };
560
561        let chunks = self
562            .pre_tokenizer
563            .as_ref()
564            .map(|pt| pt.pre_tokenize(&normalized))
565            .transpose()
566            .map_err(TokenizerError::PreTokenizeError)?
567            .unwrap_or(Vec::from([normalized.as_str()]));
568
569        // Map an offset into the normalized string into an offset in the source
570        // string.
571        let map_offset = |offset: usize| {
572            if let Some(mappings) = &offset_map {
573                mappings
574                    .get(offset)
575                    .copied()
576                    .expect("invalid normalized offset")
577            } else {
578                offset
579            }
580        };
581
582        let mut tokens = Vec::new();
583        let mut offsets = Vec::new();
584
585        for chunk in chunks {
586            let base_offset = normalized
587                .as_bytes()
588                .subslice_offsets(chunk.as_bytes())
589                .expect("should be a subslice")
590                .start;
591            self.model
592                .encode_with_offsets(chunk, &mut |offset, token| {
593                    offsets.push(start_offset + base_offset + map_offset(offset));
594                    tokens.push(token);
595                })?;
596        }
597
598        Ok((tokens, offsets))
599    }
600
601    /// Encode one or two sequences into a sequence of tokens.
602    ///
603    /// The output is split into chunks such that the number of tokens in
604    /// each chunk is less than the limit specified in [`EncodeOptions`].
605    pub fn encode_chunks<'a>(
606        &self,
607        input: EncoderInput<'a>,
608        options: EncodeOptions,
609    ) -> Result<Vec<Encoded<'a>>, TokenizerError> {
610        let cls_token = self.cls_token()?;
611        let sep_token = self.sep_token()?;
612
613        let has_cls = cls_token.is_some() as usize;
614        let has_sep = sep_token.is_some() as usize;
615
616        // Number of non-content tokens added to each chunk.
617        let non_content_tokens_per_chunk = has_cls
618            + match input {
619                EncoderInput::Item(_) => has_sep,     // [CLS] .. [SEP]
620                EncoderInput::Pair(_) => has_sep * 2, // [CLS] .. [SEP] .. [SEP]
621            };
622
623        // Encode the full input sequences.
624        let mut tokens = Vec::new();
625        let mut offsets = Vec::new();
626        let (first_seq, second_seq) = match input {
627            EncoderInput::Item(first) => (first, None),
628            EncoderInput::Pair((first, second)) => (first, Some(second)),
629        };
630
631        let (first_seq_tokens, first_seq_offsets) = self.encode_str(first_seq, 0)?;
632        tokens.extend(first_seq_tokens);
633        offsets.extend(first_seq_offsets);
634        let first_seq_tokens = tokens.len();
635
636        if let Some(second_seq) = second_seq {
637            let (second_seq_tokens, second_seq_offsets) =
638                self.encode_str(second_seq, first_seq.len())?;
639            tokens.extend(second_seq_tokens);
640            offsets.extend(second_seq_offsets);
641        }
642
643        let max_tokens_per_chunk = options
644            .max_chunk_len
645            .unwrap_or(tokens.len() + non_content_tokens_per_chunk)
646            .saturating_sub(non_content_tokens_per_chunk);
647
648        if max_tokens_per_chunk == 0 {
649            // We can't "consume" tokens from the input in each chunk, so just
650            // return an empty output.
651            return Ok(vec![]);
652        }
653
654        // Split into chunks.
655        let mut chunks = Vec::new();
656
657        match input {
658            // For single sequence inputs, create chunks with a maximum of
659            // `max_seq_len` tokens each.
660            EncoderInput::Item(item) => {
661                let all_offsets = &offsets;
662                for (chunk_idx, (tokens_chunk, offsets_chunk)) in tokens
663                    .chunks_with_overlap(max_tokens_per_chunk, options.overlap)
664                    .zip(offsets.chunks_with_overlap(max_tokens_per_chunk, options.overlap))
665                    .enumerate()
666                {
667                    let mut tokens = Vec::new();
668                    let mut offsets = Vec::new();
669
670                    if let Some(cls_token) = cls_token {
671                        tokens.push(cls_token);
672                        offsets.push(offsets_chunk.first().copied().unwrap());
673                    }
674
675                    tokens.extend_from_slice(tokens_chunk);
676                    offsets.extend_from_slice(offsets_chunk);
677
678                    if let Some(sep_token) = sep_token {
679                        tokens.push(sep_token);
680                    }
681
682                    // The offset for the final token is the offset of the first
683                    // token in the next chunk, or the input length if this
684                    // is the final chunk.
685                    let chunk_start = chunk_idx * max_tokens_per_chunk;
686                    offsets.push(
687                        all_offsets
688                            .get(chunk_start + offsets_chunk.len())
689                            .copied()
690                            .unwrap_or(item.len()),
691                    );
692
693                    let n_tokens = tokens.len();
694                    chunks.push(Encoded::new(input, tokens, offsets, n_tokens));
695                }
696            }
697
698            // For input sequence pairs, create chunks where the first part is
699            // the same for each chunk and has a maximum of `max_seq_len` tokens,
700            // and the second part contains chunks of the second sequence,
701            // taking up the remaining available space in the chunk.
702            EncoderInput::Pair((first, second)) => {
703                let (first_tokens, second_tokens) = tokens.split_at(first_seq_tokens);
704                let (first_offsets, second_offsets) = offsets.split_at(first_seq_tokens);
705
706                let first_len = first_tokens.len().min(max_tokens_per_chunk);
707                let second_len = second_tokens.len().min(max_tokens_per_chunk - first_len);
708
709                if second_len == 0 {
710                    // We can't "consume" tokens from the second sequence in
711                    // each chunk, so just return an empty output.
712                    return Ok(vec![]);
713                }
714
715                for (chunk_idx, (tokens_chunk, offsets_chunk)) in second_tokens
716                    .chunks_with_overlap(second_len, options.overlap)
717                    .zip(second_offsets.chunks_with_overlap(second_len, options.overlap))
718                    .enumerate()
719                {
720                    let mut tokens = Vec::new();
721                    let mut offsets = Vec::new();
722
723                    // Add the first sequence. This is the same for every chunk.
724                    if let Some(cls_token) = cls_token {
725                        tokens.push(cls_token);
726                        offsets.push(0);
727                    }
728
729                    tokens.extend_from_slice(&first_tokens[..first_len]);
730                    offsets.extend_from_slice(&first_offsets[..first_len]);
731
732                    if let Some(sep_token) = sep_token {
733                        tokens.push(sep_token);
734                        offsets.push(first.len());
735                    }
736
737                    let first_seq_len = tokens.len();
738
739                    // Add the second sequence, which changes in each chunk.
740                    tokens.extend_from_slice(tokens_chunk);
741                    offsets.extend_from_slice(offsets_chunk);
742
743                    // The offset for the final token is the offset of the first
744                    // token from the second sequence in the next chunk, or
745                    // the concatenated input length if this is the final chunk.
746                    if let Some(sep_token) = sep_token {
747                        tokens.push(sep_token);
748                    }
749                    let chunk_start = chunk_idx * second_len;
750                    offsets.push(
751                        second_offsets
752                            .get(chunk_start + offsets_chunk.len())
753                            .copied()
754                            .unwrap_or(first.len() + second.len()),
755                    );
756
757                    chunks.push(Encoded::new(input, tokens, offsets, first_seq_len));
758                }
759            }
760        }
761
762        Ok(chunks)
763    }
764
765    /// Decode a sequence of token IDs to a text string.
766    ///
767    /// For tokenizers which operate on byte sequences (eg. [`Bpe`]) this can
768    /// fail if the token IDs don't correspond to a complete UTF-8 sequence.
769    /// In that case the solution is to accumulate more token IDs and then
770    /// retry decoding.
771    ///
772    /// Special tokens are decoded into their canonical string representations
773    /// as returned by [`Model::get_token_str`].
774    pub fn decode(&self, ids: &[TokenId]) -> Result<String, TokenizerError> {
775        self.model.decode(ids).map_err(TokenizerError::DecodeError)
776    }
777}
778
779/// Error type returned when tokenizing a string.
780#[derive(Clone, Debug)]
781pub enum TokenizerError {
782    NormalizeError(NormalizeError),
783
784    /// An error occurred while performing pre-tokenization to split the input.
785    PreTokenizeError(PreTokenizeError),
786
787    /// Encoding of text pieces after pre-tokenization failed.
788    EncodeError(EncodeError),
789
790    /// Decoding token IDs into text failed.
791    DecodeError(DecodeError),
792}
793
794impl fmt::Display for TokenizerError {
795    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
796        match self {
797            Self::NormalizeError(err) => write!(f, "normalization error: {}", err),
798            Self::PreTokenizeError(err) => write!(f, "pretokenization error: {}", err),
799            Self::EncodeError(err) => write!(f, "encoding with model failed: {}", err),
800            Self::DecodeError(err) => write!(f, "decoding failed: {}", err),
801        }
802    }
803}
804
805impl From<NormalizeError> for TokenizerError {
806    fn from(err: NormalizeError) -> Self {
807        TokenizerError::NormalizeError(err)
808    }
809}
810
811impl From<EncodeError> for TokenizerError {
812    fn from(err: EncodeError) -> Self {
813        TokenizerError::EncodeError(err)
814    }
815}
816
817impl Error for TokenizerError {
818    fn source(&self) -> Option<&(dyn Error + 'static)> {
819        match self {
820            Self::NormalizeError(e) => Some(e),
821            Self::PreTokenizeError(e) => Some(e),
822            Self::EncodeError(e) => Some(e),
823            Self::DecodeError(e) => Some(e),
824        }
825    }
826}
827
828#[cfg(test)]
829mod tests {
830    use std::collections::HashMap;
831    use std::error::Error;
832    use std::fs::read_to_string;
833    use std::ops::Range;
834    use std::path::PathBuf;
835
836    use rten_testing::TestCases;
837
838    use super::{EncodeOptions, EncoderInput, TokenId, Tokenizer, TokenizerOptions, WordPiece};
839    use crate::normalizers::Normalizer;
840    use crate::{normalizers, pre_tokenizers};
841    use serde_derive::Deserialize;
842
843    fn make_wordpiece(vocab: &[&str]) -> WordPiece {
844        let vocab: HashMap<_, _> = vocab
845            .iter()
846            .enumerate()
847            .map(|(i, token)| (token.to_string(), i as u32))
848            .collect();
849        WordPiece::from_vocab(vocab, Default::default())
850    }
851
852    fn lowercase_normalizer() -> Box<dyn Normalizer> {
853        Box::new(normalizers::Bert::new(normalizers::BertOptions {
854            lowercase: true,
855            ..Default::default()
856        }))
857    }
858
859    // The tests below use the WordPiece model to exercise common Tokenizer
860    // functionality. This is convenient as WordPiece is simple.
861
862    #[test]
863    fn test_encode_two_sequences() {
864        let vocab = &[
865            "[CLS]", "[SEP]", "[UNK]", "This", "is", "a", "test", "sequence",
866        ];
867        let model = make_wordpiece(vocab);
868        let tokenizer = Tokenizer::new(
869            model,
870            TokenizerOptions {
871                cls_token: Some("[CLS]"),
872                sep_token: Some("[SEP]"),
873            },
874        )
875        .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
876
877        // Two sequences, no subwords.
878        let encoded = tokenizer
879            .encode(("This is", "a test sequence"), None)
880            .unwrap();
881        assert_eq!(
882            tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
883            &[
884                "[CLS]", "This", "is", "[SEP]", "a", "test", "sequence", "[SEP]"
885            ]
886        );
887
888        let token_type_ids: Vec<_> = encoded.token_type_ids().collect();
889        assert_eq!(token_type_ids, &[0, 0, 0, 0, 1, 1, 1, 1]);
890    }
891
892    #[test]
893    fn test_text_for_token_range() {
894        #[derive(Debug)]
895        struct Case<'a> {
896            input: EncoderInput<'a>,
897            range: Range<usize>,
898            expected: Option<&'a str>,
899        }
900
901        let vocab = &[
902            "[CLS]", "[SEP]", "[UNK]", "This", "is", "a", "test", "sequence", "Word", "##Piece",
903            "Piece", "of", "pie", ".", "!", "?", "Hey", "Hello",
904        ];
905
906        let cases = [
907            // Part of a single sequence
908            Case {
909                input: "This is a test sequence".into(),
910                range: 4..6,
911                expected: Some("test sequence"),
912            },
913            // Whole of a single sentence
914            Case {
915                input: "This is a test sequence".into(),
916                range: 1..6,
917                expected: Some("This is a test sequence"),
918            },
919            // Part of first item in a pair
920            Case {
921                input: ("This is a test sequence", "Hey Hello").into(),
922                range: 4..6,
923                expected: Some("test sequence"),
924            },
925            // Whole of first item in a pair
926            Case {
927                input: "This is a test sequence".into(),
928                range: 1..6,
929                expected: Some("This is a test sequence"),
930            },
931            // Part of second item in a pair
932            Case {
933                input: ("This is a test sequence", "Hey Hello").into(),
934                range: 8..9,
935                expected: Some("Hello"),
936            },
937            // Whole of second item in a pair
938            Case {
939                input: ("This is a test sequence", "Hey Hello").into(),
940                range: 7..9,
941                expected: Some("Hey Hello"),
942            },
943            // Out of bounds range for a single sequence
944            Case {
945                input: "This is a test sequence".into(),
946                range: 4..8,
947                expected: None,
948            },
949            // Out of bounds range for a pair
950            Case {
951                input: ("This is a test sequence", "Hey Hello").into(),
952                range: 7..12,
953                expected: None,
954            },
955            // Range that spans first and second sequences in a pair
956            Case {
957                input: "This is a test sequence".into(),
958                range: 1..8,
959                expected: None,
960            },
961            // Range that intersects special tokens
962            Case {
963                input: "This is a test sequence".into(),
964                range: 0..7,
965                expected: Some("This is a test sequence"),
966            },
967        ];
968
969        cases.test_each(|case| {
970            let model = make_wordpiece(vocab);
971            let tokenizer = Tokenizer::new(
972                model,
973                TokenizerOptions {
974                    cls_token: Some("[CLS]"),
975                    sep_token: Some("[SEP]"),
976                },
977            )
978            .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
979
980            let encoded = tokenizer.encode(case.input, None).unwrap();
981            let text = encoded.text_for_token_range(case.range.clone());
982            assert_eq!(
983                text, case.expected,
984                "mismatch for input {:?} with range {:?}",
985                case.input, case.range
986            );
987        })
988    }
989
990    #[test]
991    fn test_encode_chunks_single_sequence() {
992        let vocab = &[
993            "[CLS]", "[SEP]", "[UNK]", "This", "is", "a", "test", "sequence",
994        ];
995
996        #[derive(Debug)]
997        struct Case<'a> {
998            text: &'a str,
999            max_chunk_len: Option<usize>,
1000            overlap: usize,
1001            tokens: Vec<&'a [&'a str]>,
1002            use_cls_sep: bool,
1003            lowercase: bool,
1004        }
1005
1006        let cases = [
1007            // Unbounded chunk size
1008            Case {
1009                text: "This is a test sequence",
1010                max_chunk_len: None,
1011                overlap: 0,
1012                tokens: vec![&["[CLS]", "This", "is", "a", "test", "sequence", "[SEP]"]],
1013                use_cls_sep: true,
1014                lowercase: false,
1015            },
1016            // Encode with a normalizer
1017            Case {
1018                text: "A TEST SEQUENCE",
1019                max_chunk_len: None,
1020                overlap: 0,
1021                tokens: vec![&["[CLS]", "a", "test", "sequence", "[SEP]"]],
1022                use_cls_sep: true,
1023                lowercase: true,
1024            },
1025            // Two chunks
1026            Case {
1027                text: "This is a test sequence",
1028                max_chunk_len: Some(5),
1029                overlap: 0,
1030                tokens: vec![
1031                    &["[CLS]", "This", "is", "a", "[SEP]"],
1032                    &["[CLS]", "test", "sequence", "[SEP]"],
1033                ],
1034                use_cls_sep: true,
1035                lowercase: false,
1036            },
1037            // Three chunks
1038            Case {
1039                text: "This is a test sequence",
1040                max_chunk_len: Some(4),
1041                overlap: 0,
1042                tokens: vec![
1043                    &["[CLS]", "This", "is", "[SEP]"],
1044                    &["[CLS]", "a", "test", "[SEP]"],
1045                    &["[CLS]", "sequence", "[SEP]"],
1046                ],
1047                use_cls_sep: true,
1048                lowercase: false,
1049            },
1050            // Chunk size that is small enough that there is no room for
1051            // any content tokens in each chunk.
1052            Case {
1053                text: "This is a test sequence",
1054                max_chunk_len: Some(0),
1055                overlap: 0,
1056                tokens: vec![],
1057                use_cls_sep: true,
1058                lowercase: false,
1059            },
1060            // Overlap between chunks
1061            Case {
1062                text: "This is a test sequence",
1063                max_chunk_len: Some(5),
1064                overlap: 2,
1065                tokens: vec![
1066                    &["[CLS]", "This", "is", "a", "[SEP]"],
1067                    &["[CLS]", "is", "a", "test", "[SEP]"],
1068                    &["[CLS]", "a", "test", "sequence", "[SEP]"],
1069                ],
1070                use_cls_sep: true,
1071                lowercase: false,
1072            },
1073            // No special tokens
1074            Case {
1075                text: "This is a test sequence",
1076                max_chunk_len: None,
1077                overlap: 0,
1078                tokens: vec![&["This", "is", "a", "test", "sequence"]],
1079                use_cls_sep: false,
1080                lowercase: false,
1081            },
1082        ];
1083
1084        let model = make_wordpiece(vocab);
1085
1086        cases.test_each(|case| {
1087            let Case {
1088                text,
1089                max_chunk_len,
1090                overlap,
1091                tokens,
1092                use_cls_sep,
1093                lowercase,
1094            } = case;
1095
1096            let mut tokenizer = Tokenizer::new(
1097                model.clone(),
1098                TokenizerOptions {
1099                    cls_token: use_cls_sep.then_some("[CLS]"),
1100                    sep_token: use_cls_sep.then_some("[SEP]"),
1101                },
1102            )
1103            .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
1104
1105            if *lowercase {
1106                tokenizer = tokenizer.with_normalizer(lowercase_normalizer());
1107            }
1108
1109            let options = EncodeOptions {
1110                max_chunk_len: *max_chunk_len,
1111                overlap: *overlap,
1112            };
1113            let chunks = tokenizer.encode_chunks((*text).into(), options).unwrap();
1114            let chunk_tokens: Vec<_> = chunks
1115                .into_iter()
1116                .map(|c| tokenizer.model().get_tokens(c.token_ids()).unwrap())
1117                .collect();
1118            assert_eq!(chunk_tokens, *tokens);
1119        })
1120    }
1121
1122    #[test]
1123    fn test_encode_chunks_sequence_pair() {
1124        let vocab = &[
1125            "[CLS]",
1126            "[SEP]",
1127            "[UNK]",
1128            "What",
1129            "is",
1130            "Rust",
1131            "?",
1132            "a",
1133            "programming",
1134            "language",
1135            ".",
1136            "Its",
1137            "mascot",
1138            "is",
1139            "Ferris",
1140        ];
1141
1142        let model = make_wordpiece(vocab);
1143
1144        #[derive(Debug)]
1145        struct Case<'a> {
1146            query: &'a str,
1147            context: &'a str,
1148            max_chunk_len: Option<usize>,
1149            overlap: usize,
1150            tokens: Vec<&'a [&'a str]>,
1151            use_sep_cls: bool,
1152            lowercase: bool,
1153        }
1154
1155        let cases = [
1156            // Unbounded chunk size
1157            Case {
1158                query: "What is Rust?",
1159                context: "Rust is a programming language",
1160                max_chunk_len: None,
1161                overlap: 0,
1162                use_sep_cls: true,
1163                tokens: vec![&[
1164                    "[CLS]",
1165                    "What",
1166                    "is",
1167                    "Rust",
1168                    "?",
1169                    "[SEP]",
1170                    "Rust",
1171                    "is",
1172                    "a",
1173                    "programming",
1174                    "language",
1175                    "[SEP]",
1176                ]],
1177                lowercase: false,
1178            },
1179            // Apply normalization to both sequences
1180            Case {
1181                query: "PROGRAMMING",
1182                context: "LANGUAGE",
1183                max_chunk_len: None,
1184                overlap: 0,
1185                use_sep_cls: true,
1186                tokens: vec![&["[CLS]", "programming", "[SEP]", "language", "[SEP]"]],
1187                lowercase: true,
1188            },
1189            // Multiple chunks, no overlap
1190            Case {
1191                query: "What is Rust?",
1192                context: "Rust is a programming language. Its mascot is Ferris.",
1193                max_chunk_len: Some(13),
1194                overlap: 0,
1195                use_sep_cls: true,
1196                tokens: vec![
1197                    &[
1198                        "[CLS]",
1199                        "What",
1200                        "is",
1201                        "Rust",
1202                        "?",
1203                        "[SEP]",
1204                        "Rust",
1205                        "is",
1206                        "a",
1207                        "programming",
1208                        "language",
1209                        ".",
1210                        "[SEP]",
1211                    ],
1212                    &[
1213                        "[CLS]", "What", "is", "Rust", "?", "[SEP]", "Its", "mascot", "is",
1214                        "Ferris", ".", "[SEP]",
1215                    ],
1216                ],
1217                lowercase: false,
1218            },
1219            // Multiple chunks with overlap
1220            Case {
1221                query: "What is Rust?",
1222                context: "Rust is a programming language. Its mascot is Ferris",
1223                max_chunk_len: Some(13),
1224                overlap: 2,
1225                use_sep_cls: true,
1226                tokens: vec![
1227                    &[
1228                        "[CLS]",
1229                        "What",
1230                        "is",
1231                        "Rust",
1232                        "?",
1233                        "[SEP]",
1234                        "Rust",
1235                        "is",
1236                        "a",
1237                        "programming",
1238                        "language",
1239                        ".",
1240                        "[SEP]",
1241                    ],
1242                    &[
1243                        "[CLS]", "What", "is", "Rust", "?", "[SEP]", "language", ".", "Its",
1244                        "mascot", "is", "Ferris", "[SEP]",
1245                    ],
1246                ],
1247                lowercase: false,
1248            },
1249            // Chunk size too small for any tokens from the second sequence
1250            Case {
1251                query: "What is Rust?",
1252                context: "Rust is a programming language",
1253                max_chunk_len: Some(7), // Tokens in query + special tokens (3)
1254                overlap: 0,
1255                use_sep_cls: true,
1256                tokens: vec![],
1257                lowercase: false,
1258            },
1259            // No special tokens
1260            Case {
1261                query: "What is Rust?",
1262                context: "Rust is a programming language",
1263                max_chunk_len: None,
1264                overlap: 0,
1265                use_sep_cls: false,
1266                tokens: vec![&[
1267                    "What",
1268                    "is",
1269                    "Rust",
1270                    "?",
1271                    "Rust",
1272                    "is",
1273                    "a",
1274                    "programming",
1275                    "language",
1276                ]],
1277                lowercase: false,
1278            },
1279        ];
1280
1281        cases.test_each(|case| {
1282            let Case {
1283                query,
1284                context,
1285                max_chunk_len,
1286                overlap,
1287                tokens,
1288                use_sep_cls,
1289                lowercase,
1290            } = case;
1291
1292            let mut tokenizer = Tokenizer::new(
1293                model.clone(),
1294                TokenizerOptions {
1295                    cls_token: use_sep_cls.then_some("[CLS]"),
1296                    sep_token: use_sep_cls.then_some("[SEP]"),
1297                },
1298            )
1299            .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
1300
1301            if *lowercase {
1302                tokenizer = tokenizer.with_normalizer(lowercase_normalizer());
1303            }
1304
1305            let options = EncodeOptions {
1306                max_chunk_len: *max_chunk_len,
1307                overlap: *overlap,
1308                ..Default::default()
1309            };
1310            let chunks = tokenizer
1311                .encode_chunks((*query, *context).into(), options)
1312                .unwrap();
1313            let chunk_tokens: Vec<_> = chunks
1314                .iter()
1315                .map(|c| tokenizer.model().get_tokens(c.token_ids()).unwrap())
1316                .collect();
1317            assert_eq!(chunk_tokens, *tokens);
1318
1319            // Check that the generated offsets are correct. Since none of the
1320            // tokens are subwords, and no normalization is being applied, the
1321            // source text for every token index should be the same as the
1322            // token's canonical string.
1323            for (chunk, chunk_tokens) in chunks.iter().zip(chunk_tokens.into_iter()) {
1324                for (i, token) in chunk_tokens.into_iter().enumerate() {
1325                    if !token.starts_with("[") {
1326                        let text = chunk
1327                            .text_for_token_range(i..i + 1)
1328                            .map(|t| t.trim())
1329                            .unwrap();
1330                        let text = if *lowercase {
1331                            text.to_lowercase()
1332                        } else {
1333                            text.to_string()
1334                        };
1335                        assert_eq!(text, token);
1336                    }
1337                }
1338            }
1339        })
1340    }
1341
1342    #[derive(Deserialize)]
1343    struct TokenizerJsonCase {
1344        text: String,
1345        token_ids: Vec<TokenId>,
1346    }
1347
1348    #[derive(Deserialize)]
1349    struct TokenizerJsonTest<'a> {
1350        #[serde(borrow)]
1351        tokenizer: super::json::Tokenizer<'a>,
1352        cases: Vec<TokenizerJsonCase>,
1353    }
1354
1355    fn read_test_json(path: &str) -> Result<String, Box<dyn Error>> {
1356        let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1357        abs_path.push("test-data/tokenizer-json/");
1358        abs_path.push(path);
1359        let content = read_to_string(abs_path)?;
1360        Ok(content)
1361    }
1362
1363    #[test]
1364    fn test_from_json() {
1365        let paths = ["wordpiece.json", "wordpiece-lower.json"];
1366
1367        for path in paths.iter() {
1368            let json = read_test_json(path).unwrap();
1369            let config: TokenizerJsonTest = serde_json::from_str(&json).unwrap();
1370
1371            let tokenizer = Tokenizer::from_parsed_json(config.tokenizer).unwrap();
1372            for case in config.cases {
1373                let encoded = tokenizer.encode(case.text.as_str(), None).unwrap();
1374                assert_eq!(encoded.token_ids(), case.token_ids);
1375            }
1376        }
1377    }
1378}