tokenizers/tokenizer/
mod.rs

1//! Represents a tokenization pipeline.
2//!
3//! A [`Tokenizer`](struct.Tokenizer.html) is composed of some of the following parts.
4//!   - [`Normalizer`](trait.Normalizer.html): Takes care of the text normalization (like unicode normalization).
5//!   - [`PreTokenizer`](trait.PreTokenizer.html): Takes care of the pre tokenization (ie. How to split tokens and pre-process
6//!     them.
7//!   - [`Model`](trait.Model.html): A model encapsulates the tokenization algorithm (like BPE, Word base, character
8//!     based, ...).
9//!   - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding,
10//!     ...).
11
12use std::{
13    collections::HashMap,
14    fs::{read_to_string, File},
15    io::{prelude::*, BufReader},
16    ops::{Deref, DerefMut},
17    path::{Path, PathBuf},
18};
19
20use serde::de::DeserializeOwned;
21use serde::{Deserialize, Serialize};
22
23use crate::utils::iter::ResultShunt;
24use crate::utils::parallelism::*;
25use crate::utils::progress::{ProgressBar, ProgressStyle};
26
27mod added_vocabulary;
28mod encoding;
29pub mod normalizer;
30pub mod pattern;
31pub mod pre_tokenizer;
32mod serialization;
33
34// Re-export wrappers
35pub use crate::decoders::DecoderWrapper;
36pub use crate::models::ModelWrapper;
37pub use crate::normalizers::NormalizerWrapper;
38pub use crate::pre_tokenizers::PreTokenizerWrapper;
39pub use crate::processors::PostProcessorWrapper;
40// And some other types
41pub use crate::utils::iter::LinesWithEnding;
42pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
43pub use crate::utils::truncation::{
44    truncate_encodings, TruncationDirection, TruncationParams, TruncationStrategy,
45};
46pub use added_vocabulary::*;
47pub use encoding::*;
48pub use normalizer::{NormalizedString, OffsetReferential, SplitDelimiterBehavior};
49pub use pre_tokenizer::*;
50
51pub type Error = Box<dyn std::error::Error + Send + Sync>;
52pub type Result<T> = std::result::Result<T, Error>;
53pub type Offsets = (usize, usize);
54
55/// Takes care of pre-processing strings.
56pub trait Normalizer {
57    fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
58}
59
60/// The `PreTokenizer` is in charge of doing the pre-segmentation step. It splits the given string
61/// in multiple substrings, keeping track of the offsets of said substrings from the
62/// `NormalizedString`. In some occasions, the `PreTokenizer` might need to modify the given
63/// `NormalizedString` to ensure we can entirely keep track of the offsets and the mapping with
64/// the original string.
65pub trait PreTokenizer {
66    fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()>;
67}
68
69/// Represents a model used during Tokenization (like BPE or Word or Unigram).
70pub trait Model {
71    type Trainer: Trainer + Sync;
72    /// Tokenize the given sequence into multiple underlying `Token`. The `offsets` on the `Token`
73    /// are expected to be relative to the given sequence.
74    fn tokenize(&self, sequence: &str) -> Result<Vec<Token>>;
75    /// Find the ID associated to a string token
76    fn token_to_id(&self, token: &str) -> Option<u32>;
77    /// Find the string token associated to an ID
78    fn id_to_token(&self, id: u32) -> Option<String>;
79    /// Retrieve the entire vocabulary mapping (token -> ID)
80    fn get_vocab(&self) -> HashMap<String, u32>;
81    /// Retrieve the size of the vocabulary
82    fn get_vocab_size(&self) -> usize;
83    /// Save the current `Model` in the given folder, using the given `prefix` for the various
84    /// files that need to be saved.
85    fn save(&self, folder: &Path, prefix: Option<&str>) -> Result<Vec<PathBuf>>;
86    /// Get an instance of a Trainer capable of training this Model
87    fn get_trainer(&self) -> <Self as Model>::Trainer;
88}
89
90/// A `PostProcessor` has the responsibility to post process an encoded output of the `Tokenizer`.
91/// It adds any special tokens that a language model would require.
92pub trait PostProcessor {
93    /// Returns the number of tokens that will be added during the processing step
94    fn added_tokens(&self, is_pair: bool) -> usize;
95    /// Process both encodings and returns a new merged one
96    fn process(
97        &self,
98        encoding: Encoding,
99        pair_encoding: Option<Encoding>,
100        add_special_tokens: bool,
101    ) -> Result<Encoding> {
102        let mut encodings = if let Some(pair_encoding) = pair_encoding {
103            vec![encoding, pair_encoding]
104        } else {
105            vec![encoding]
106        };
107        encodings.iter_mut().enumerate().for_each(|(i, encoding)| {
108            encoding.set_sequence_id(i);
109            encoding
110                .get_overflowing_mut()
111                .iter_mut()
112                .for_each(|encoding| encoding.set_sequence_id(i));
113            encoding.set_type_ids(vec![i as u32; encoding.len()]);
114        });
115
116        let encodings = self.process_encodings(encodings, add_special_tokens)?;
117        Ok(Encoding::merge(encodings, false))
118    }
119
120    /// Process any amount of encodings and returns a series of encoding (might merge them)
121    fn process_encodings(
122        &self,
123        encodings: Vec<Encoding>,
124        add_special_tokens: bool,
125    ) -> Result<Vec<Encoding>>;
126}
127impl dyn PostProcessor {
128    pub fn default_process(
129        encodings: Vec<Encoding>,
130        _add_special_tokens: bool,
131    ) -> Result<Vec<Encoding>> {
132        match encodings.len() {
133            1 => Ok(encodings),
134            _ => {
135                let mut final_encoding = Encoding::default();
136                for (i, mut encoding) in encodings.into_iter().enumerate() {
137                    encoding.set_sequence_id(i);
138                    final_encoding.merge_with(encoding, false);
139                }
140                Ok(vec![final_encoding])
141            }
142        }
143    }
144}
145
146#[derive(thiserror::Error, Debug)]
147pub enum ProcessorError {
148    #[error("encodings vector length must be either 1 or 2")]
149    InvalidEncodingsVecLength,
150}
151
152/// A `Decoder` changes the raw tokens into its more readable form.
153pub trait Decoder {
154    fn decode(&self, tokens: Vec<String>) -> Result<String> {
155        let results = self.decode_chain(tokens)?;
156        Ok(results.join(""))
157    }
158    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>>;
159}
160
161/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
162/// and then it can train the given `Model`.
163pub trait Trainer {
164    type Model: Model + Sized;
165    /// Whether we should show progress during the training.
166    fn should_show_progress(&self) -> bool;
167    /// The actual training method. This will return a new trained Model as well as a list
168    /// of `special_tokens` to be added directly to the tokenizer along with the model.
169    fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>;
170    /// Process an iterator of sequences, calling `process` for each of them in order to
171    /// pre-process the said sequence as relevant.
172    fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
173    where
174        I: Iterator<Item = S> + Send,
175        S: AsRef<str> + Send,
176        F: Fn(&str) -> Result<Vec<String>> + Sync;
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
180pub struct Token {
181    pub id: u32,
182    pub value: String,
183    pub offsets: (usize, usize),
184}
185impl Token {
186    pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self {
187        Self { id, value, offsets }
188    }
189}
190
191use std::borrow::Cow;
192#[derive(Debug, Clone)]
193pub enum InputSequence<'s> {
194    Raw(Cow<'s, str>),
195    PreTokenized(Cow<'s, [&'s str]>),
196    PreTokenizedOwned(Cow<'s, [String]>),
197    PreTokenizedCow(Cow<'s, [Cow<'s, str>]>),
198}
199
200impl<'s> From<Cow<'s, str>> for InputSequence<'s> {
201    fn from(input: Cow<'s, str>) -> Self {
202        Self::Raw(input)
203    }
204}
205
206impl<'s> From<&'s str> for InputSequence<'s> {
207    fn from(input: &'s str) -> Self {
208        Self::Raw(Cow::Borrowed(input))
209    }
210}
211
212impl From<String> for InputSequence<'_> {
213    fn from(input: String) -> Self {
214        Self::Raw(Cow::Owned(input))
215    }
216}
217
218impl<'s> From<&'s [&'s str]> for InputSequence<'s> {
219    fn from(input: &'s [&'s str]) -> Self {
220        Self::PreTokenized(Cow::Borrowed(input))
221    }
222}
223
224impl<'s> From<Vec<&'s str>> for InputSequence<'s> {
225    fn from(input: Vec<&'s str>) -> Self {
226        Self::PreTokenized(Cow::Owned(input))
227    }
228}
229
230impl<'s> From<&'s [String]> for InputSequence<'s> {
231    fn from(input: &'s [String]) -> Self {
232        Self::PreTokenizedOwned(Cow::Borrowed(input))
233    }
234}
235
236impl<'s> From<Vec<String>> for InputSequence<'s> {
237    fn from(input: Vec<String>) -> Self {
238        Self::PreTokenizedOwned(Cow::Owned(input))
239    }
240}
241
242impl<'s> From<Vec<Cow<'s, str>>> for InputSequence<'s> {
243    fn from(input: Vec<Cow<'s, str>>) -> Self {
244        Self::PreTokenizedCow(Cow::Owned(input))
245    }
246}
247
248impl<'s> From<&'s [Cow<'s, str>]> for InputSequence<'s> {
249    fn from(input: &'s [Cow<'s, str>]) -> Self {
250        Self::PreTokenizedCow(Cow::Borrowed(input))
251    }
252}
253
254#[derive(Debug, Clone)]
255pub enum EncodeInput<'s> {
256    Single(InputSequence<'s>),
257    Dual(InputSequence<'s>, InputSequence<'s>),
258}
259
260impl<'s, I: Into<InputSequence<'s>>> From<I> for EncodeInput<'s> {
261    fn from(input: I) -> Self {
262        Self::Single(input.into())
263    }
264}
265
266impl<'s, I1, I2> From<(I1, I2)> for EncodeInput<'s>
267where
268    I1: Into<InputSequence<'s>>,
269    I2: Into<InputSequence<'s>>,
270{
271    fn from(input: (I1, I2)) -> Self {
272        Self::Dual(input.0.into(), input.1.into())
273    }
274}
275
276#[derive(thiserror::Error, Debug)]
277#[error("{0}")]
278pub struct BuilderError(String);
279
280/// Builder for Tokenizer structs.
281///
282/// `build()` fails if the `model` is missing.
283pub struct TokenizerBuilder<M, N, PT, PP, D> {
284    model: Option<M>,
285    normalizer: Option<N>,
286    pre_tokenizer: Option<PT>,
287    post_processor: Option<PP>,
288    decoder: Option<D>,
289
290    added_vocabulary: AddedVocabulary,
291
292    truncation: Option<TruncationParams>,
293    padding: Option<PaddingParams>,
294}
295
296impl<M, N, PT, PP, D> Default for TokenizerBuilder<M, N, PT, PP, D>
297where
298    M: Model,
299    N: Normalizer,
300    PT: PreTokenizer,
301    PP: PostProcessor,
302    D: Decoder,
303{
304    fn default() -> Self {
305        Self::new()
306    }
307}
308
309impl<M, N, PT, PP, D> TokenizerBuilder<M, N, PT, PP, D>
310where
311    M: Model,
312    N: Normalizer,
313    PT: PreTokenizer,
314    PP: PostProcessor,
315    D: Decoder,
316{
317    /// Get an empty TokenizerBuilder.
318    pub fn new() -> Self {
319        Self {
320            model: None,
321            normalizer: None,
322            pre_tokenizer: None,
323            post_processor: None,
324            decoder: None,
325            added_vocabulary: AddedVocabulary::new(),
326            truncation: None,
327            padding: None,
328        }
329    }
330
331    /// Convert the TokenizerBuilder to a Tokenizer.
332    ///
333    /// Conversion fails if the `model` is missing.
334    pub fn build(self) -> Result<TokenizerImpl<M, N, PT, PP, D>> {
335        let model = self
336            .model
337            .ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?;
338        Ok(TokenizerImpl {
339            normalizer: self.normalizer,
340            pre_tokenizer: self.pre_tokenizer,
341            model,
342
343            post_processor: self.post_processor,
344            decoder: self.decoder,
345            added_vocabulary: self.added_vocabulary,
346            truncation: self.truncation,
347            padding: self.padding,
348        })
349    }
350
351    /// Set the model.
352    #[must_use]
353    pub fn with_model(mut self, model: M) -> Self {
354        self.model = Some(model);
355        self
356    }
357
358    /// Set the normalizer.
359    #[must_use]
360    pub fn with_normalizer(mut self, normalizer: Option<N>) -> Self {
361        self.normalizer = normalizer;
362        self
363    }
364
365    /// Set the pre-tokenizer.
366    #[must_use]
367    pub fn with_pre_tokenizer(mut self, pretokenizer: Option<PT>) -> Self {
368        self.pre_tokenizer = pretokenizer;
369        self
370    }
371
372    /// Set the post-processor.
373    #[must_use]
374    pub fn with_post_processor(mut self, post_processor: Option<PP>) -> Self {
375        self.post_processor = post_processor;
376        self
377    }
378
379    /// Set the decoder.
380    #[must_use]
381    pub fn with_decoder(mut self, decoder: Option<D>) -> Self {
382        self.decoder = decoder;
383        self
384    }
385
386    /// Set the added vocabulary.
387    pub fn with_added_vocabulary(mut self, added_vocabulary: AddedVocabulary) -> Self {
388        self.added_vocabulary = added_vocabulary;
389        self
390    }
391
392    /// Set the trunaction parameters.
393    #[must_use]
394    pub fn with_truncation(mut self, trunc: Option<TruncationParams>) -> Self {
395        self.truncation = trunc;
396        self
397    }
398
399    /// Set the padding parameters.
400    #[must_use]
401    pub fn with_padding(mut self, padding: Option<PaddingParams>) -> Self {
402        self.padding = padding;
403        self
404    }
405}
406
407#[derive(Serialize, Deserialize, Debug, Clone)]
408pub struct Tokenizer(
409    TokenizerImpl<
410        ModelWrapper,
411        NormalizerWrapper,
412        PreTokenizerWrapper,
413        PostProcessorWrapper,
414        DecoderWrapper,
415    >,
416);
417
418impl Tokenizer {
419    /// Construct a new Tokenizer based on the model.
420    pub fn new(model: impl Into<ModelWrapper>) -> Self {
421        Self(TokenizerImpl::new(model.into()))
422    }
423
424    /// Unwrap the TokenizerImpl.
425    pub fn into_inner(
426        self,
427    ) -> TokenizerImpl<
428        ModelWrapper,
429        NormalizerWrapper,
430        PreTokenizerWrapper,
431        PostProcessorWrapper,
432        DecoderWrapper,
433    > {
434        self.0
435    }
436    pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
437        let content = read_to_string(file)?;
438        let tokenizer = serde_json::from_str(&content)?;
439        Ok(tokenizer)
440    }
441    pub fn from_bytes<P: AsRef<[u8]>>(bytes: P) -> Result<Self> {
442        let tokenizer = serde_json::from_slice(bytes.as_ref())?;
443        Ok(tokenizer)
444    }
445    #[cfg(feature = "http")]
446    pub fn from_pretrained<S: AsRef<str>>(
447        identifier: S,
448        params: Option<crate::utils::from_pretrained::FromPretrainedParameters>,
449    ) -> Result<Self> {
450        let tokenizer_file = crate::utils::from_pretrained::from_pretrained(identifier, params)?;
451        Tokenizer::from_file(tokenizer_file)
452    }
453}
454
455impl std::str::FromStr for Tokenizer {
456    type Err = Box<dyn std::error::Error + Send + Sync>;
457
458    fn from_str(s: &str) -> Result<Self> {
459        Ok(serde_json::from_str(s)?)
460    }
461}
462
463impl<M, N, PT, PP, D> From<TokenizerImpl<M, N, PT, PP, D>> for Tokenizer
464where
465    M: Into<ModelWrapper>,
466    N: Into<NormalizerWrapper>,
467    PT: Into<PreTokenizerWrapper>,
468    PP: Into<PostProcessorWrapper>,
469    D: Into<DecoderWrapper>,
470{
471    fn from(t: TokenizerImpl<M, N, PT, PP, D>) -> Self {
472        Self(TokenizerImpl {
473            model: t.model.into(),
474            normalizer: t.normalizer.map(Into::into),
475            pre_tokenizer: t.pre_tokenizer.map(Into::into),
476            post_processor: t.post_processor.map(Into::into),
477            decoder: t.decoder.map(Into::into),
478            added_vocabulary: t.added_vocabulary,
479            padding: t.padding,
480            truncation: t.truncation,
481        })
482    }
483}
484
485impl Deref for Tokenizer {
486    type Target = TokenizerImpl<
487        ModelWrapper,
488        NormalizerWrapper,
489        PreTokenizerWrapper,
490        PostProcessorWrapper,
491        DecoderWrapper,
492    >;
493
494    fn deref(&self) -> &Self::Target {
495        &self.0
496    }
497}
498
499impl DerefMut for Tokenizer {
500    fn deref_mut(&mut self) -> &mut Self::Target {
501        &mut self.0
502    }
503}
504
505#[derive(thiserror::Error, Debug)]
506#[error("{0}")]
507pub struct TruncationParamError(String);
508
509/// A `Tokenizer` is capable of encoding/decoding any text.
510#[derive(Clone, Debug)]
511pub struct TokenizerImpl<M, N, PT, PP, D> {
512    // Tokenizer parts
513    normalizer: Option<N>,
514    pre_tokenizer: Option<PT>,
515    model: M,
516    post_processor: Option<PP>,
517    decoder: Option<D>,
518
519    // Added Vocabulary capabilities
520    added_vocabulary: AddedVocabulary,
521
522    // General processing parameters
523    truncation: Option<TruncationParams>,
524    padding: Option<PaddingParams>,
525}
526
527impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
528where
529    M: Model,
530    N: Normalizer,
531    PT: PreTokenizer,
532    PP: PostProcessor,
533    D: Decoder,
534{
535    /// Instantiate a new Tokenizer, with the given Model
536    pub fn new(model: M) -> Self {
537        Self {
538            normalizer: None,
539            pre_tokenizer: None,
540            model,
541            post_processor: None,
542            decoder: None,
543
544            added_vocabulary: AddedVocabulary::new(),
545
546            truncation: None,
547            padding: None,
548        }
549    }
550
551    /// Set the normalizer
552    pub fn with_normalizer(&mut self, normalizer: Option<impl Into<N>>) -> &mut Self {
553        self.normalizer = normalizer.map(|norm| norm.into());
554        self
555    }
556    /// Get the normalizer
557    pub fn get_normalizer(&self) -> Option<&N> {
558        self.normalizer.as_ref()
559    }
560
561    /// Set the pre tokenizer
562    pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Option<impl Into<PT>>) -> &mut Self {
563        self.pre_tokenizer = pre_tokenizer.map(|tok| tok.into());
564        self
565    }
566
567    /// Get the pre tokenizer
568    pub fn get_pre_tokenizer(&self) -> Option<&PT> {
569        self.pre_tokenizer.as_ref()
570    }
571
572    /// Set the post processor
573    pub fn with_post_processor(&mut self, post_processor: Option<impl Into<PP>>) -> &mut Self {
574        self.post_processor = post_processor.map(|post_proc| post_proc.into());
575        self
576    }
577
578    /// Get the post processor
579    pub fn get_post_processor(&self) -> Option<&PP> {
580        self.post_processor.as_ref()
581    }
582
583    /// Set the decoder
584    pub fn with_decoder(&mut self, decoder: Option<impl Into<D>>) -> &mut Self {
585        self.decoder = decoder.map(|dec| dec.into());
586        self
587    }
588
589    /// Get the decoder
590    pub fn get_decoder(&self) -> Option<&D> {
591        self.decoder.as_ref()
592    }
593
594    /// Set the model
595    pub fn with_model(&mut self, model: impl Into<M>) -> &mut Self {
596        self.model = model.into();
597        self
598    }
599
600    /// Get the model
601    pub fn get_model(&self) -> &M {
602        &self.model
603    }
604
605    /// Set the added vocabulary.
606    pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self {
607        self.added_vocabulary = added_vocabulary;
608        self
609    }
610
611    /// Get the added vocabulary
612    pub fn get_added_vocabulary(&self) -> &AddedVocabulary {
613        &self.added_vocabulary
614    }
615
616    /// Set the truncation parameters
617    ///
618    /// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`
619    pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> {
620        if let Some(trunc_params) = &trunc {
621            let n_added_tokens = self.get_n_added_tokens(false);
622            let effective_max_length = trunc_params.max_length - n_added_tokens;
623            if effective_max_length < trunc_params.stride {
624                return Err(Box::new(TruncationParamError(format!(
625                    "tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= {} original max length - {} added special tokens), ",
626                    trunc_params.stride, effective_max_length, trunc_params.max_length, n_added_tokens
627                ))));
628            }
629        }
630        self.truncation = trunc;
631        Ok(self)
632    }
633
634    /// Get the currently set truncation parameters
635    pub fn get_truncation(&self) -> Option<&TruncationParams> {
636        self.truncation.as_ref()
637    }
638
639    /// Get a mutable reference to the currently set truncation parameters
640    pub fn get_truncation_mut(&mut self) -> Option<&mut TruncationParams> {
641        self.truncation.as_mut()
642    }
643
644    /// Set the padding parameters
645    pub fn with_padding(&mut self, padding: Option<PaddingParams>) -> &mut Self {
646        self.padding = padding;
647        self
648    }
649
650    /// Get the currently set padding parameters
651    pub fn get_padding(&self) -> Option<&PaddingParams> {
652        self.padding.as_ref()
653    }
654
655    /// Get a mutable reference to the currently set padding parameters
656    pub fn get_padding_mut(&mut self) -> Option<&mut PaddingParams> {
657        self.padding.as_mut()
658    }
659
660    /// Get the vocabulary
661    pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
662        let mut final_vocab = self.model.get_vocab();
663
664        if with_added_tokens {
665            let added_vocab = self.added_vocabulary.get_vocab();
666            if !added_vocab.is_empty() {
667                final_vocab.reserve(added_vocab.len());
668                for (token, id) in added_vocab {
669                    final_vocab.insert(token.clone(), *id);
670                }
671            }
672        }
673
674        final_vocab
675    }
676
677    /// Get the added tokens decoder
678    pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> {
679        self.added_vocabulary.get_added_tokens_decoder().clone()
680    }
681
682    /// Get the size of the vocabulary
683    pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
684        // TODO ArthurZ THIS IS WRONG! We need to measure the length of the `set` because
685        // now some tokens can be both in the added_tokens_encoder and in the vocab
686        if with_added_tokens {
687            self.get_vocab(true).len()
688        } else {
689            self.model.get_vocab_size()
690        }
691    }
692
693    /// Converts a token in the corresponding id.
694    pub fn token_to_id(&self, token: &str) -> Option<u32> {
695        self.added_vocabulary.token_to_id(token, &self.model)
696    }
697
698    /// Converts an id to the corresponding token.
699    pub fn id_to_token(&self, id: u32) -> Option<String> {
700        self.added_vocabulary
701            .simple_id_to_token(id)
702            .or_else(|| self.model.id_to_token(id))
703    }
704
705    /// set the added bocab's splitting scheme
706    pub fn set_encode_special_tokens(&mut self, value: bool) {
707        self.added_vocabulary.set_encode_special_tokens(value);
708    }
709
710    /// Get added token value
711    pub fn get_encode_special_tokens(&self) -> bool {
712        self.added_vocabulary.get_encode_special_tokens()
713    }
714
715    /// Encode a single sequence
716    fn encode_single_sequence(
717        &self,
718        sequence: InputSequence,
719        type_id: u32,
720        offsets_type: OffsetType,
721    ) -> Result<Encoding> {
722        let encode = |is_pre_tokenized, subseq_idx, subseq| -> Result<Encoding> {
723            let normalized = self
724                .added_vocabulary
725                .extract_and_normalize(self.normalizer.as_ref(), subseq);
726            let pre_tokenized = self.do_pre_tokenize(normalized)?;
727            let subseq_encoding = self.do_tokenize(
728                pre_tokenized,
729                type_id,
730                if is_pre_tokenized {
731                    Some(subseq_idx as u32)
732                } else {
733                    None
734                },
735                offsets_type,
736            )?;
737
738            Ok(subseq_encoding)
739        };
740
741        match sequence {
742            InputSequence::PreTokenized(seq) => seq
743                .iter()
744                .enumerate()
745                .map(|(i, sequence)| encode(true, i, sequence))
746                .collect(),
747            InputSequence::PreTokenizedOwned(seq) => seq
748                .iter()
749                .enumerate()
750                .map(|(i, sequence)| encode(true, i, sequence))
751                .collect(),
752            InputSequence::PreTokenizedCow(seq) => seq
753                .iter()
754                .enumerate()
755                .map(|(i, sequence)| encode(true, i, sequence))
756                .collect(),
757            InputSequence::Raw(seq) => encode(false, 0, seq.as_ref()),
758        }
759    }
760
761    /// Encode the given input. This method accepts both single sequences, as well as pair
762    /// sequences. Also, a sequence can be a string, or already pre-tokenized input directly:
763    /// Contrarily to `encode`, it does not compute offsets
764    /// ```
765    /// # use tokenizers::Tokenizer;
766    /// # use tokenizers::models::bpe::BPE;
767    /// # let mut tokenizer = Tokenizer::new(BPE::default());
768    /// #
769    /// // Sequences:
770    /// tokenizer.encode_fast("Single sequence", false);
771    /// tokenizer.encode_fast(("Sequence A", "Sequence B"), false);
772    ///
773    /// // Pre-tokenized sequences:
774    /// tokenizer.encode_fast(&["Single", "sequence"][..], false);
775    /// tokenizer.encode_fast((
776    ///     &["Sequence", "A"][..],
777    ///     &["Sequence", "B"][..]
778    /// ), false);
779    ///
780    /// // or even both types together:
781    /// tokenizer.encode_fast(("A complete sequence", &["And", "a", "tokenized"][..]), false);
782    /// ```
783    pub fn encode_fast<'s, E>(&self, input: E, add_special_tokens: bool) -> Result<Encoding>
784    where
785        E: Into<EncodeInput<'s>>,
786    {
787        // Extract sequences from the EncodeInput
788        let (sequence, pair) = match input.into() {
789            EncodeInput::Single(s1) => (s1, None),
790            EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
791        };
792
793        // Encode each sequence
794        let encoding = self.encode_single_sequence(sequence, 0, OffsetType::None)?;
795        let pair_encoding = pair
796            .map(|sequence| self.encode_single_sequence(sequence, 1, OffsetType::None))
797            .transpose()?;
798
799        // And finally post process
800        self.post_process(encoding, pair_encoding, add_special_tokens)
801    }
802
803    /// Encode the given input. This method accepts both single sequences, as well as pair
804    /// sequences. Also, a sequence can be a string, or already pre-tokenized input directly:
805    ///
806    /// ```
807    /// # use tokenizers::Tokenizer;
808    /// # use tokenizers::models::bpe::BPE;
809    /// # let mut tokenizer = Tokenizer::new(BPE::default());
810    /// #
811    /// // Sequences:
812    /// tokenizer.encode("Single sequence", false);
813    /// tokenizer.encode(("Sequence A", "Sequence B"), false);
814    ///
815    /// // Pre-tokenized sequences:
816    /// tokenizer.encode(&["Single", "sequence"][..], false);
817    /// tokenizer.encode((
818    ///     &["Sequence", "A"][..],
819    ///     &["Sequence", "B"][..]
820    /// ), false);
821    ///
822    /// // or even both types together:
823    /// tokenizer.encode(("A complete sequence", &["And", "a", "tokenized"][..]), false);
824    /// ```
825    pub fn encode<'s, E>(&self, input: E, add_special_tokens: bool) -> Result<Encoding>
826    where
827        E: Into<EncodeInput<'s>>,
828    {
829        // Extract sequences from the EncodeInput
830        let (sequence, pair) = match input.into() {
831            EncodeInput::Single(s1) => (s1, None),
832            EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
833        };
834
835        // Encode each sequence
836        let encoding = self.encode_single_sequence(sequence, 0, OffsetType::Byte)?;
837        let pair_encoding = pair
838            .map(|sequence| self.encode_single_sequence(sequence, 1, OffsetType::Byte))
839            .transpose()?;
840
841        // And finally post process
842        self.post_process(encoding, pair_encoding, add_special_tokens)
843    }
844
845    /// Encode the given input, using offsets relative to chars instead of bytes.
846    /// This method accepts both single sequences, as well as pair sequences. Also,
847    /// a sequence can be a string, or already pre-tokenized input directly:
848    ///
849    /// ```
850    /// # use tokenizers::Tokenizer;
851    /// # use tokenizers::models::bpe::BPE;
852    /// # let mut tokenizer = Tokenizer::new(BPE::default());
853    /// #
854    /// // Sequences:
855    /// tokenizer.encode("Single sequence", false);
856    /// tokenizer.encode(("Sequence A", "Sequence B"), false);
857    ///
858    /// // Pre-tokenized sequences:
859    /// tokenizer.encode(&["Single", "sequence"][..], false);
860    /// tokenizer.encode((
861    ///     &["Sequence", "A"][..],
862    ///     &["Sequence", "B"][..]
863    /// ), false);
864    ///
865    /// // or even both types together:
866    /// tokenizer.encode(("A complete sequence", &["And", "a", "tokenized"][..]), false);
867    /// ```
868    pub fn encode_char_offsets<'s, E>(&self, input: E, add_special_tokens: bool) -> Result<Encoding>
869    where
870        E: Into<EncodeInput<'s>>,
871    {
872        // Extract sequences from the EncodeInput
873        let (sequence, pair) = match input.into() {
874            EncodeInput::Single(s1) => (s1, None),
875            EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
876        };
877
878        // Encode each sequence
879        let encoding = self.encode_single_sequence(sequence, 0, OffsetType::Char)?;
880        let pair_encoding = pair
881            .map(|sequence| self.encode_single_sequence(sequence, 1, OffsetType::Char))
882            .transpose()?;
883
884        // And finally post process
885        self.post_process(encoding, pair_encoding, add_special_tokens)
886    }
887
888    /// Decode the given ids, back to a String
889    pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
890        let tokens = ids
891            .iter()
892            .filter_map(|id| {
893                self.added_vocabulary
894                    .simple_id_to_token(*id)
895                    .or_else(|| self.model.id_to_token(*id))
896                    .filter(|token| {
897                        !skip_special_tokens || !self.added_vocabulary.is_special_token(token)
898                    })
899            })
900            .collect::<Vec<_>>();
901
902        if let Some(decoder) = &self.decoder {
903            decoder.decode(tokens)
904        } else {
905            Ok(tokens.join(" "))
906        }
907    }
908
909    /// Decode the given ids, back to a String
910    /// See [`DecodeStream`]
911    pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> {
912        DecodeStream::new(self, skip_special_tokens)
913    }
914}
915
916/// DecodeStream will keep the state necessary to produce individual chunks of
917/// strings given an input stream of token_ids.
918///
919/// This is necessary because decoding in general cannot achieve that since strings
920/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces
921///
922/// Example:
923///
924/// ```
925/// # #[cfg(not(target_os = "windows"))]
926/// # {
927/// use tokenizers::Tokenizer;
928/// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();
929///
930/// let mut decode_stream = tokenizer.decode_stream(false);
931/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
932/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
933/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
934/// assert_eq!(
935///     decode_stream.step(1246).unwrap(),
936///     Some(" example".to_string())
937/// );
938/// # }
939/// ```
940///
941/// Returning `None` means the given id is not enough to produce a chunk.
942/// This typically happens with `byte_fallback` options where some tokens do
943/// not represent valid utf-8, and only follow-up token_ids will help produce
944/// a valid chunk.
945/// ```
946/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, decoders::byte_fallback::ByteFallback, pre_tokenizers::byte_level::ByteLevel, normalizers::unicode::NFC};
947/// use std::collections::HashMap;
948/// use std::iter::FromIterator;
949///
950/// let vocab = HashMap::from_iter([
951///     ("<0x20>".to_string(), 0),
952///     ("<0xC3>".to_string(), 1),
953///     ("<0xA9>".to_string(), 2),
954///     (" This".to_string(), 3),
955/// ]);
956/// let merges = vec![];
957/// let bpe = BPE::builder()
958///     .vocab_and_merges(vocab, merges)
959///     .byte_fallback(true)
960///     .build()
961///     .unwrap();
962/// let tokenizer = TokenizerBuilder::default()
963///     .with_model(bpe)
964///     .with_decoder(Some(ByteFallback::default()))
965///     .with_normalizer(Some(NFC))
966///     .with_pre_tokenizer(Some(ByteLevel::default()))
967///     .with_post_processor(Some(ByteLevel::default()))
968///     .build().unwrap();
969///
970/// let mut decode_stream = tokenizer.decode_stream(false);
971/// // Single byte_fallback is valid utf-8
972/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
973/// // Invalid utf-8
974/// assert_eq!(decode_stream.step(1).unwrap(), None);
975/// // Valid utf-8 again, this corresponds to both tokens: [1, 2]
976/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
977/// ```
978///
979/// To see how [`DecodeStream`] is necessary, let's show how using raw [`TokenizerImpl::decode`] would
980/// fail.
981///
982/// ```
983/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, pre_tokenizers::{byte_level::ByteLevel, metaspace::Metaspace}, normalizers::unicode::NFC};
984/// use std::collections::HashMap;
985/// use std::iter::FromIterator;
986///
987/// let vocab = HashMap::from_iter([
988///     ("▁This".to_string(), 0),
989/// ]);
990/// let merges = vec![];
991/// let bpe = BPE::builder()
992///     .vocab_and_merges(vocab, merges)
993///     .byte_fallback(true)
994///     .build()
995///     .unwrap();
996/// let tokenizer = TokenizerBuilder::new()
997///     .with_model(bpe)
998///     .with_decoder(Some(Metaspace::default()))
999///     .with_normalizer(Some(NFC))
1000///     .with_pre_tokenizer(Some(ByteLevel::default()))
1001///     .with_post_processor(Some(ByteLevel::default()))
1002///     .build()
1003///     .unwrap();
1004///
1005/// // Strip decoder removes the extra initial space
1006/// assert_eq!(tokenizer.decode(&[0, 0], false).unwrap(), "This This");
1007/// // Decoding one token at a time would produce "ThisThis"
1008/// assert_eq!(tokenizer.decode(&[0], false).unwrap(), "This");
1009///
1010/// // Using a stream fixes it by keeping the necessary state.
1011/// let mut decode_stream = tokenizer.decode_stream(false);
1012/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string()));
1013/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string()));
1014/// ```
1015pub struct DecodeStream<'tok, M, N, PT, PP, D> {
1016    /// A reference to the tokenizer
1017    tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>,
1018    /// Regular decode option that is kept throughout.
1019    skip_special_tokens: bool,
1020    /// A temporary buffer of the necessary token_ids needed
1021    /// to produce valid string chunks.
1022    /// This typically contains 3 parts:
1023    ///  - read
1024    ///  - prefix
1025    ///  - rest
1026    ///
1027    /// Read is the bit necessary to surround the prefix
1028    /// so decoding the whole ids produces a valid prefix.
1029    /// Prefix is the previously produced string, kept around to trim off of
1030    /// the next valid chunk
1031    ids: Vec<u32>,
1032    /// The previously returned chunk that needs to be discarded from the
1033    /// decoding of the current ids to produce the next chunk
1034    prefix: String,
1035    /// The index within the ids corresponding to the prefix so we can drain
1036    /// correctly
1037    prefix_index: usize,
1038    /// We need to keep 2 prefixes.
1039    /// Prefix is the second one that was already emitted to discard the part
1040    /// of the text of all the ids
1041    /// read is the prefix kept only for starting side effects of the prefix
1042    read_index: usize,
1043}
1044
1045#[derive(thiserror::Error, Debug)]
1046pub enum DecodeStreamError {
1047    #[error("Invalid prefix encountered")]
1048    InvalidPrefix,
1049}
1050
1051impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
1052where
1053    M: Model,
1054    N: Normalizer,
1055    PT: PreTokenizer,
1056    PP: PostProcessor,
1057    D: Decoder,
1058{
1059    fn new(tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>, skip_special_tokens: bool) -> Self {
1060        Self {
1061            tokenizer,
1062            ids: vec![],
1063            skip_special_tokens,
1064            prefix: "".to_string(),
1065            prefix_index: 0,
1066            read_index: 0,
1067        }
1068    }
1069
1070    /// See [`DecodeStream`]
1071    pub fn step(&mut self, id: u32) -> Result<Option<String>> {
1072        step_decode_stream(
1073            self.tokenizer,
1074            id,
1075            self.skip_special_tokens,
1076            &mut self.ids,
1077            &mut self.prefix,
1078            &mut self.prefix_index,
1079            &mut self.read_index,
1080        )
1081    }
1082}
1083
1084/// Internal function exposed only to bypass python limitations
1085pub fn step_decode_stream<M, N, PT, PP, D>(
1086    tokenizer: &TokenizerImpl<M, N, PT, PP, D>,
1087    id: u32,
1088    skip_special_tokens: bool,
1089    ids: &mut Vec<u32>,
1090    prefix: &mut String,
1091    prefix_index: &mut usize,
1092    read_index: &mut usize,
1093) -> Result<Option<String>>
1094where
1095    M: Model,
1096    N: Normalizer,
1097    PT: PreTokenizer,
1098    PP: PostProcessor,
1099    D: Decoder,
1100{
1101    ids.push(id);
1102    let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?;
1103    if string.len() > prefix.len() && !string.ends_with('�') {
1104        if !(string.starts_with(&*prefix)) {
1105            return Err(Box::new(DecodeStreamError::InvalidPrefix));
1106        }
1107        let new_text = &string[prefix.len()..].to_string();
1108        let new_prefix_index = ids.len() - *prefix_index;
1109        *ids = ids.drain(*read_index..).collect();
1110        *prefix = tokenizer.decode(ids, skip_special_tokens)?;
1111        *read_index = *prefix_index;
1112        *prefix_index = new_prefix_index;
1113        Ok(Some(new_text.to_string()))
1114    } else {
1115        Ok(None)
1116    }
1117}
1118
1119impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1120where
1121    M: Model,
1122{
1123    /// Tokenization logic, makes the bridge between the pre-tokenization phase and the real
1124    /// tokenization phase, and converting offsets back to the original referential.
1125    fn do_tokenize<P: Into<PreTokenizedString>>(
1126        &self,
1127        pretokenized: P,
1128        type_id: u32,
1129        word_idx: Option<u32>,
1130        offsets_type: OffsetType,
1131    ) -> Result<Encoding> {
1132        let mut pretokenized: PreTokenizedString = pretokenized.into();
1133        pretokenized.tokenize(|normalized| self.model.tokenize(normalized.get()))?;
1134        pretokenized.into_encoding(word_idx, type_id, offsets_type)
1135    }
1136}
1137
1138impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1139where
1140    N: Normalizer,
1141{
1142    /// Normalization logic, go through all normalizers
1143    fn do_normalize<V: Into<NormalizedString>>(&self, normalized: V) -> Result<NormalizedString> {
1144        let mut normalized: NormalizedString = normalized.into();
1145
1146        if let Some(ref normalizer) = self.normalizer {
1147            normalizer.normalize(&mut normalized)?;
1148        }
1149
1150        Ok(normalized)
1151    }
1152}
1153
1154impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1155where
1156    N: Normalizer,
1157    M: Model,
1158{
1159    /// Register the given tokens as special tokens. This is especially useful for removing
1160    /// these special tokens while decoding
1161    pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
1162        self.added_vocabulary
1163            .add_special_tokens(tokens, &self.model, self.normalizer.as_ref())
1164    }
1165
1166    /// Add the given tokens to the added vocabulary
1167    pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
1168        self.added_vocabulary
1169            .add_tokens(tokens, &self.model, self.normalizer.as_ref())
1170    }
1171}
1172
1173impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1174where
1175    PT: PreTokenizer,
1176{
1177    /// PreTokenization logic, handling the case where there is no PreTokenizer set
1178    fn do_pre_tokenize<P: Into<PreTokenizedString>>(
1179        &self,
1180        pretokenized: P,
1181    ) -> Result<PreTokenizedString> {
1182        let mut pretokenized: PreTokenizedString = pretokenized.into();
1183        if let Some(ref pretok) = self.pre_tokenizer {
1184            pretok.pre_tokenize(&mut pretokenized)?;
1185        }
1186
1187        Ok(pretokenized)
1188    }
1189}
1190
1191impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1192where
1193    PP: PostProcessor,
1194{
1195    /// Post processing logic, handling the case where there is no PostProcessor set
1196    pub fn post_process(
1197        &self,
1198        encoding: Encoding,
1199        pair_encoding: Option<Encoding>,
1200        add_special_tokens: bool,
1201    ) -> Result<Encoding> {
1202        // 1. First we truncate if needed
1203        let (encoding, pair_encoding) = {
1204            if let Some(trunc) = &self.truncation {
1205                let n_added_tokens = self.get_n_added_tokens(pair_encoding.is_some());
1206
1207                if add_special_tokens && n_added_tokens > 0 {
1208                    let params = TruncationParams {
1209                        max_length: trunc.max_length - n_added_tokens,
1210                        ..*trunc
1211                    };
1212                    truncate_encodings(encoding, pair_encoding, &params)?
1213                } else {
1214                    truncate_encodings(encoding, pair_encoding, trunc)?
1215                }
1216            } else {
1217                (encoding, pair_encoding)
1218            }
1219        };
1220
1221        // 2. Then We post process
1222        let final_encoding = if let Some(processor) = &self.post_processor {
1223            processor.process(encoding, pair_encoding, add_special_tokens)?
1224        } else {
1225            let encodings = if let Some(pair_encoding) = pair_encoding {
1226                vec![encoding, pair_encoding]
1227            } else {
1228                vec![encoding]
1229            };
1230            let mut encodings =
1231                <dyn PostProcessor>::default_process(encodings, add_special_tokens)?;
1232            if encodings.len() != 1 {
1233                panic!("We haven't reduced the encodings like we should have");
1234            }
1235            encodings.pop().unwrap()
1236        };
1237
1238        // 3. Then we pad if needed
1239        let [final_encoding] = if let Some(params) = &self.padding {
1240            let mut arr = [final_encoding];
1241            pad_encodings(&mut arr, params)?;
1242            arr
1243        } else {
1244            [final_encoding]
1245        };
1246
1247        Ok(final_encoding)
1248    }
1249
1250    fn get_n_added_tokens(&self, is_pair: bool) -> usize {
1251        if let Some(processor) = &self.post_processor {
1252            processor.added_tokens(is_pair)
1253        } else {
1254            0
1255        }
1256    }
1257}
1258
1259impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1260where
1261    M: Model + Send + Sync,
1262    N: Normalizer + Send + Sync,
1263    PT: PreTokenizer + Send + Sync,
1264    PP: PostProcessor + Send + Sync,
1265    D: Decoder + Send + Sync,
1266{
1267    /// Encode all the sentences in parallel, using multiple threads
1268    pub fn encode_batch<'s, E>(
1269        &self,
1270        inputs: Vec<E>,
1271        add_special_tokens: bool,
1272    ) -> Result<Vec<Encoding>>
1273    where
1274        E: Into<EncodeInput<'s>> + Send,
1275    {
1276        let mut encodings = inputs
1277            .into_maybe_par_iter()
1278            .map(|input| self.encode(input, add_special_tokens))
1279            .collect::<Result<Vec<Encoding>>>()?;
1280
1281        if let Some(params) = &self.padding {
1282            // We do the padding here to make sure we handle the batch padding
1283            pad_encodings(&mut encodings, params)?;
1284        }
1285
1286        Ok(encodings)
1287    }
1288
1289    /// Encode all the sentences in parallel, using multiple threads.
1290    /// The offsets on each `Encoding` will be relative to chars instead of bytes.
1291    pub fn encode_batch_char_offsets<'s, E>(
1292        &self,
1293        inputs: Vec<E>,
1294        add_special_tokens: bool,
1295    ) -> Result<Vec<Encoding>>
1296    where
1297        E: Into<EncodeInput<'s>> + Send,
1298    {
1299        let mut encodings = inputs
1300            .into_maybe_par_iter()
1301            .map(|input| self.encode_char_offsets(input, add_special_tokens))
1302            .collect::<Result<Vec<Encoding>>>()?;
1303
1304        if let Some(params) = &self.padding {
1305            // We do the padding here to make sure we handle the batch padding
1306            pad_encodings(&mut encodings, params)?;
1307        }
1308
1309        Ok(encodings)
1310    }
1311
1312    /// Encode all the sentences in parallel, using multiple threads
1313    pub fn encode_batch_fast<'s, E>(
1314        &self,
1315        inputs: Vec<E>,
1316        add_special_tokens: bool,
1317    ) -> Result<Vec<Encoding>>
1318    where
1319        E: Into<EncodeInput<'s>> + Send,
1320    {
1321        let mut encodings = inputs
1322            .into_maybe_par_iter()
1323            .map(|input| self.encode_fast(input, add_special_tokens))
1324            .collect::<Result<Vec<Encoding>>>()?;
1325
1326        if let Some(params) = &self.padding {
1327            // We do the padding here to make sure we handle the batch padding
1328            pad_encodings(&mut encodings, params)?;
1329        }
1330
1331        Ok(encodings)
1332    }
1333
1334    /// Decode all sentences in parallel
1335    pub fn decode_batch(
1336        &self,
1337        sentences: &[&[u32]],
1338        skip_special_tokens: bool,
1339    ) -> Result<Vec<String>>
1340    where
1341        M: Send + Sync,
1342    {
1343        sentences
1344            .into_maybe_par_iter()
1345            .map(|sentence| self.decode(sentence, skip_special_tokens))
1346            .collect()
1347    }
1348
1349    /// Train our Model from files
1350    pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self>
1351    where
1352        T: Trainer<Model = M> + Sync,
1353    {
1354        let mut len = 0;
1355        for file in files.iter() {
1356            len += File::open(file)
1357                .and_then(|f| f.metadata())
1358                .map(|m| m.len())?;
1359        }
1360
1361        let max_read = 1_000_000;
1362
1363        ResultShunt::process(
1364            files.into_iter().flat_map(|filename| {
1365                match File::open(filename) {
1366                    Ok(file) => {
1367                        let file = BufReader::with_capacity(max_read, file);
1368                        // We read new lines using this API instead of the Lines Iterator
1369                        // on purpose. We want to keep the `\n` and potential `\r` between each lines
1370                        // We use an iterator to be able to chain with par_bridge.
1371                        itertools::Either::Left(file.lines_with_ending())
1372                    }
1373                    Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
1374                }
1375            }),
1376            |sequences| -> Result<()> {
1377                let progress = if trainer.should_show_progress() {
1378                    let progress = ProgressBar::new(len);
1379                    progress.set_style(
1380                        ProgressStyle::default_bar()
1381                            .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {percent:>18!}%")
1382                            .expect("Invalid progress template"),
1383                    );
1384                    progress
1385                        .set_message(format!("Pre-processing files ({:.2} Mo)", len / 1_000_000));
1386                    Some(progress)
1387                } else {
1388                    None
1389                };
1390
1391                trainer.feed(
1392                    sequences.inspect(|s| {
1393                        if let Some(progress) = &progress {
1394                            progress.inc(s.len() as u64)
1395                        }
1396                    }),
1397                    |seq| {
1398                        let normalized = self.do_normalize(seq.as_ref())?;
1399                        let pre_tokenized = self.do_pre_tokenize(normalized)?;
1400                        Ok(pre_tokenized
1401                            .get_splits(OffsetReferential::Original, OffsetType::Byte)
1402                            .into_iter()
1403                            .map(|(s, _, _)| s.to_owned())
1404                            .collect())
1405                    },
1406                )?;
1407
1408                if let Some(pbar) = progress {
1409                    pbar.finish();
1410                }
1411                let special_tokens = trainer.train(&mut self.model)?;
1412                self.add_special_tokens(&special_tokens);
1413
1414                Ok(())
1415            },
1416        )??;
1417        Ok(self)
1418    }
1419
1420    /// Train our Model, using the given Trainer and iterator
1421    pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self>
1422    where
1423        T: Trainer<Model = M> + Sync,
1424        I: Iterator<Item = S> + Send,
1425        S: AsRef<str> + Send,
1426    {
1427        let (lower, upper) = sequences.size_hint();
1428        let len = upper.unwrap_or(lower) as u64;
1429        let progress = if trainer.should_show_progress() {
1430            let progress = ProgressBar::new(len);
1431            progress.set_style(
1432                ProgressStyle::default_bar()
1433                    .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}")
1434                    .expect("Invalid progress template"),
1435            );
1436            progress.set_message("Pre-processing sequences");
1437            Some(progress)
1438        } else {
1439            None
1440        };
1441
1442        trainer.feed(
1443            sequences.inspect(|_s| {
1444                if let Some(progress) = &progress {
1445                    progress.inc(1)
1446                }
1447            }),
1448            |seq| {
1449                let normalized = self.do_normalize(seq.as_ref())?;
1450                let pre_tokenized = self.do_pre_tokenize(normalized)?;
1451                Ok(pre_tokenized
1452                    .get_splits(OffsetReferential::Original, OffsetType::Byte)
1453                    .into_iter()
1454                    .map(|(s, _, _)| s.to_owned())
1455                    .collect())
1456            },
1457        )?;
1458        if let Some(pbar) = progress {
1459            pbar.finish();
1460        }
1461
1462        let special_tokens = trainer.train(&mut self.model)?;
1463        self.add_special_tokens(&special_tokens);
1464
1465        Ok(self)
1466    }
1467}
1468
1469impl<M, N, PT, PP, D> std::str::FromStr for TokenizerImpl<M, N, PT, PP, D>
1470where
1471    M: for<'de> Deserialize<'de> + Model,
1472    N: for<'de> Deserialize<'de> + Normalizer,
1473    PT: for<'de> Deserialize<'de> + PreTokenizer,
1474    PP: for<'de> Deserialize<'de> + PostProcessor,
1475    D: for<'de> Deserialize<'de> + Decoder,
1476{
1477    type Err = Error;
1478
1479    fn from_str(s: &str) -> Result<Self> {
1480        Ok(serde_json::from_str(s)?)
1481    }
1482}
1483
1484impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1485where
1486    M: DeserializeOwned + Model,
1487    N: DeserializeOwned + Normalizer,
1488    PT: DeserializeOwned + PreTokenizer,
1489    PP: DeserializeOwned + PostProcessor,
1490    D: DeserializeOwned + Decoder,
1491{
1492    /// Instantiate a new Tokenizer from the given file
1493    pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
1494        let content = read_to_string(file)?;
1495        let tokenizer = serde_json::from_str(&content)?;
1496        Ok(tokenizer)
1497    }
1498}
1499
1500impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1501where
1502    M: DeserializeOwned + Model,
1503    N: DeserializeOwned + Normalizer,
1504    PT: DeserializeOwned + PreTokenizer,
1505    PP: DeserializeOwned + PostProcessor,
1506    D: DeserializeOwned + Decoder,
1507{
1508    /// Instantiate a new Tokenizer from bytes
1509    pub fn from_bytes<P: AsRef<[u8]>>(bytes: P) -> Result<Self> {
1510        let tokenizer = serde_json::from_slice(bytes.as_ref())?;
1511        Ok(tokenizer)
1512    }
1513}
1514
1515impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1516where
1517    M: DeserializeOwned + Model,
1518    N: DeserializeOwned + Normalizer,
1519    PT: DeserializeOwned + PreTokenizer,
1520    PP: DeserializeOwned + PostProcessor,
1521    D: DeserializeOwned + Decoder,
1522{
1523    #[deprecated(
1524        since = "0.14.0",
1525        note = "Users should download the file separately using https://github.com/huggingface/hf-hub instead, which splits concerns of accessing the web, and should use the new cache layout"
1526    )]
1527    #[cfg(feature = "http")]
1528    /// Instantiate a new Tokenizer from a file hosted on the Hugging Face Hub.
1529    /// It expects the `identifier` of a model that includes a `tokenizer.json` file.
1530    pub fn from_pretrained<S: AsRef<str>>(
1531        identifier: S,
1532        params: Option<crate::utils::from_pretrained::FromPretrainedParameters>,
1533    ) -> Result<Self> {
1534        let tokenizer_file = crate::utils::from_pretrained::from_pretrained(identifier, params)?;
1535        TokenizerImpl::from_file(tokenizer_file)
1536    }
1537}
1538
1539impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1540where
1541    M: Serialize,
1542    N: Serialize,
1543    PT: Serialize,
1544    PP: Serialize,
1545    D: Serialize,
1546{
1547    /// Serialize the current tokenizer as a String
1548    pub fn to_string(&self, pretty: bool) -> Result<String> {
1549        Ok(if pretty {
1550            serde_json::to_string_pretty(self)?
1551        } else {
1552            serde_json::to_string(self)?
1553        })
1554    }
1555
1556    /// Save the current tokenizer at the given path
1557    pub fn save<P: AsRef<Path>>(&self, path: P, pretty: bool) -> Result<()> {
1558        let serialized = self.to_string(pretty)?;
1559
1560        let mut file = File::create(path)?;
1561        file.write_all(serialized.as_bytes())?;
1562
1563        Ok(())
1564    }
1565}
1566
1567#[cfg(test)]
1568mod test {
1569    #[cfg(feature = "http")]
1570    #[test]
1571    fn test_decoding_with_added_bpe() {
1572        use crate::{
1573            normalizers,
1574            pre_tokenizers::split::{Split, SplitPattern},
1575            AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
1576        };
1577
1578        let mut tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
1579        tokenizer.normalizer = Some(NormalizerWrapper::from(normalizers::ByteLevel::new()));
1580        tokenizer.pre_tokenizer = Some(PreTokenizerWrapper::Split(
1581            Split::new(
1582                SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
1583                SplitDelimiterBehavior::Isolated,
1584                false,
1585            )
1586            .unwrap(),
1587        ));
1588        tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]);
1589        let encoded = tokenizer
1590            .encode("Hey! how is this token: 嗎", false)
1591            .unwrap();
1592        assert_eq!(
1593            encoded.get_ids(),
1594            [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
1595        );
1596        assert_eq!(
1597            encoded.get_tokens(),
1598            ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"]
1599        );
1600
1601        let decoded = tokenizer.decode(encoded.get_ids(), false);
1602        assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");
1603
1604        tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
1605        let encoded = tokenizer
1606            .encode("Hey! how is this token: д", false)
1607            .unwrap();
1608        assert_eq!(
1609            encoded.get_ids(),
1610            [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
1611        );
1612        assert_eq!(
1613            encoded.get_tokens(),
1614            ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
1615        );
1616        let decoded = tokenizer.decode(encoded.get_ids(), false);
1617        assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
1618    }
1619}