rust_tokenizers 6.2.3

High performance tokenizers for Rust
Documentation
// Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
// Copyright 2019-2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//     http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::error::TokenizerError;
use crate::tokenizer::base_tokenizer::{
    Mask, Offset, OffsetSize, Token, TokenIdsWithOffsets, TokenIdsWithSpecialTokens, TokenRef,
};
use crate::tokenizer::tokenization_utils::{clean_text, decompose_nfkc, is_whitespace, lowercase};
use crate::tokenizer::{MultiThreadedTokenizer, Tokenizer};
use crate::vocab::{MBart50Vocab, SentencePieceModel, Vocab};

/// # MBart50 tokenizer
/// MBart50 tokenizer performing:
/// - Splitting on special tokens
/// - text cleaning
/// - NFKC decomposition
/// - (optional) lower casing
/// - SentencePiece decomposition
#[allow(clippy::upper_case_acronyms)]
pub struct MBart50Tokenizer {
    model: SentencePieceModel,
    vocab: MBart50Vocab,
    lower_case: bool,
}

impl MBart50Tokenizer {
    /// Create a new instance of a `MBart50Tokenizer`
    /// Expects a json vocab file and a SentencePiece protobuf file as an input.
    ///
    /// # Parameters
    /// - path (`&str`): path to the SentencePiece model file
    /// - lower_case (`bool`): flag indicating if the text should be lower-cased as part of the tokenization
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_tokenizers::tokenizer::{Tokenizer, MBart50Tokenizer};
    /// let lower_case = false;
    /// let tokenizer = MBart50Tokenizer::from_file("path/to/vocab/file", lower_case).unwrap();
    /// ```
    pub fn from_file(path: &str, lower_case: bool) -> Result<MBart50Tokenizer, TokenizerError> {
        let model = SentencePieceModel::from_file(path)?;
        let vocab = MBart50Vocab::from_file(path)?;
        Ok(MBart50Tokenizer {
            model,
            vocab,
            lower_case,
        })
    }

    /// Create a new instance of a `MBart50Tokenizer` from an existing vocabulary and model
    ///
    /// # Parameters
    /// - vocab (`MBart50Vocab`): vocabulary
    /// - model (`SentencePieceModel`): SentencePiece model
    /// - lower_case (`bool`): flag indicating if the text should be lower-cased as part of the tokenization
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_tokenizers::tokenizer::{Tokenizer, MBart50Tokenizer};
    /// use rust_tokenizers::vocab::{SentencePieceModel, Vocab, MBart50Vocab};
    /// let lower_case = false;
    /// let vocab = MBart50Vocab::from_file("path/to/vocab/file").unwrap();
    /// let model = SentencePieceModel::from_file("path/to/model/file").unwrap();
    ///
    /// let tokenizer = MBart50Tokenizer::from_existing_vocab_and_model(vocab, model, lower_case);
    /// ```
    pub fn from_existing_vocab_and_model(
        vocab: MBart50Vocab,
        model: SentencePieceModel,
        lower_case: bool,
    ) -> MBart50Tokenizer {
        MBart50Tokenizer {
            model,
            vocab,
            lower_case,
        }
    }

    fn split_on_language_code<'a>(
        &self,
        token: TokenRef<'a>,
        code_length: usize,
    ) -> Vec<TokenRef<'a>> {
        let mut tokens: Vec<TokenRef<'a>> = Vec::new();
        let mut begin_char: usize = 0usize;
        let mut start_byte: usize = 0usize;
        let mut char_indices = token.text.char_indices();
        while let Some((c_start, c)) = char_indices.next() {
            if !c.is_whitespace() {
                break;
            }
            start_byte = c_start;
            begin_char += 1;
        }
        let leading_bytes = &token.text.as_bytes()[start_byte..start_byte + code_length];
        for language_code in self.vocab.language_codes_bytes.iter() {
            if leading_bytes == language_code {
                tokens.push(TokenRef {
                    text: &token.text[start_byte..start_byte + code_length],
                    offset: Offset::new(
                        token.offset.begin + begin_char as OffsetSize,
                        token.offset.begin + begin_char as OffsetSize + code_length as OffsetSize,
                    ),
                    reference_offsets: &token.reference_offsets
                        [begin_char..begin_char + code_length],
                    mask: Mask::None,
                });
                start_byte += code_length;
                begin_char += code_length;
                for _ in 0..code_length {
                    char_indices.next();
                }
                break;
            }
        }
        for (c_start, c) in char_indices {
            if !c.is_whitespace() {
                break;
            }
            start_byte = c_start;
            begin_char += 1;
        }
        tokens.push(TokenRef {
            text: &token.text[start_byte..],
            offset: Offset::new(
                token.offset.begin + begin_char as OffsetSize,
                token.text.chars().count() as OffsetSize,
            ),
            reference_offsets: &token.reference_offsets[begin_char..],
            mask: Mask::None,
        });
        tokens
    }
}

impl Tokenizer<MBart50Vocab> for MBart50Tokenizer {
    fn vocab(&self) -> &MBart50Vocab {
        &self.vocab
    }

    fn tokenize_to_tokens(&self, text: TokenRef) -> Vec<Token> {
        let tokens = self.split_on_language_code(text, 5);
        let (code_token, mut token) = match tokens.len() {
            0 => {
                return vec![];
            }
            1 => (None, tokens[0].to_owned()),
            _ => (Some(tokens[0].to_owned()), tokens[1].to_owned()),
        };

        clean_text(&mut token, true);
        decompose_nfkc(&mut token);
        if self.lower_case {
            lowercase(&mut token);
        }
        token.text = token.text.replace(|c: char| is_whitespace(&c), "\u{2581}");
        if !token.text.starts_with('\u{2581}') {
            token.text.insert(0, '\u{2581}');
            token
                .reference_offsets
                .insert(0, token.reference_offsets[0]);
        };
        let output = self.model.decode_forward_token_ref(token.as_ref());
        let decoded = self.model.decode_backward(&output);

        let mut output: Vec<Token> = Vec::with_capacity(decoded.len() + 1);
        if let Some(code) = code_token {
            output.push(code);
        };
        output.extend(self.model.parse_nodes_to_tokens(decoded));

        self.model.populate_masks(output.as_mut_slice(), '\u{2581}');

        output
    }

    fn convert_tokens_to_string(&self, tokens: Vec<String>) -> String {
        tokens
            .into_iter()
            .map(|v| v.replace('\u{2581}', " "))
            .collect::<Vec<String>>()
            .join("")
    }

    fn build_input_with_special_tokens(
        &self,
        tokens_ids_with_offsets_1: TokenIdsWithOffsets,
        tokens_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
    ) -> TokenIdsWithSpecialTokens {
        // MBart50 is a special case where it expects the target language to be provided in the input text
        // This is similar to Marian where the target language may be passed before the sentence to translate
        let mut output: Vec<i64> = vec![];
        let mut token_segment_ids: Vec<i8> = vec![];
        let mut special_tokens_mask: Vec<i8> = vec![];
        let mut offsets: Vec<Option<Offset>> = vec![];
        let mut original_offsets: Vec<Vec<OffsetSize>> = vec![];
        let mut mask: Vec<Mask> = vec![];
        special_tokens_mask.extend(vec![0; tokens_ids_with_offsets_1.ids.len()]);
        if !special_tokens_mask.is_empty() {
            special_tokens_mask[0] = 1;
        }
        special_tokens_mask.push(1);
        token_segment_ids.extend(vec![0; tokens_ids_with_offsets_1.ids.len() + 2]);
        output.extend(tokens_ids_with_offsets_1.ids);
        output.push(self.vocab.token_to_id(MBart50Vocab::sep_value()));
        offsets.extend(tokens_ids_with_offsets_1.offsets);
        if !offsets.is_empty() {
            offsets[0] = None;
        }
        offsets.push(None);
        original_offsets.extend(tokens_ids_with_offsets_1.reference_offsets);
        if !original_offsets.is_empty() {
            original_offsets[0] = vec![];
        }
        original_offsets.push(vec![]);
        mask.extend(tokens_ids_with_offsets_1.masks);
        if !mask.is_empty() {
            mask[0] = Mask::Special;
        }
        mask.push(Mask::Special);
        if let Some(tokens_ids_with_offsets_2_value) = tokens_ids_with_offsets_2 {
            let length = tokens_ids_with_offsets_2_value.ids.len();
            special_tokens_mask.extend(vec![0; length]);
            special_tokens_mask.push(1);
            token_segment_ids.extend(vec![1; length + 1]);
            output.extend(tokens_ids_with_offsets_2_value.ids);
            output.push(self.vocab.token_to_id(MBart50Vocab::sep_value()));
            offsets.extend(tokens_ids_with_offsets_2_value.offsets);
            offsets.push(None);
            original_offsets.extend(tokens_ids_with_offsets_2_value.reference_offsets);
            original_offsets.push(vec![]);
            mask.extend(tokens_ids_with_offsets_2_value.masks);
            mask.push(Mask::Special);
        }
        TokenIdsWithSpecialTokens {
            token_ids: output,
            segment_ids: token_segment_ids,
            special_tokens_mask,
            token_offsets: offsets,
            reference_offsets: original_offsets,
            mask,
        }
    }
}

impl MultiThreadedTokenizer<MBart50Vocab> for MBart50Tokenizer {}