use std::path::Path;
use crate::vocab::EXTENDED_FAIRSEQ_LANGUAGE_CODES;
use crate::{
error::TokenizerError,
vocab::{NLLBVocab, SentencePieceBpeModel, Vocab},
Mask, Offset, OffsetSize, Token, TokenIdsWithOffsets, TokenIdsWithSpecialTokens,
};
use super::{
tokenization_utils::{clean_text, decompose_nfkc, is_whitespace, split_on_language_code},
MultiThreadedTokenizer, Tokenizer,
};
pub struct NLLBTokenizer {
model: SentencePieceBpeModel,
vocab: NLLBVocab,
src_lang: String,
}
impl NLLBTokenizer {
pub fn from_files_with_special_token_map<V: AsRef<Path>, M: AsRef<Path>, S: AsRef<Path>>(
vocab_path: V,
model_path: M,
special_tokens: S,
) -> Result<Self, TokenizerError> {
let model = SentencePieceBpeModel::from_file(model_path)?;
let vocab = NLLBVocab::from_file_with_special_token_mapping(vocab_path, special_tokens)?;
let src_lang = String::from("eng_Latn");
Ok(Self {
model,
vocab,
src_lang,
})
}
pub fn from_files<V: AsRef<Path>, M: AsRef<Path>>(
vocab_path: V,
model_path: M,
) -> Result<Self, TokenizerError> {
let model = SentencePieceBpeModel::from_file(model_path)?;
let vocab = NLLBVocab::from_file(vocab_path)?;
let src_lang = String::from("eng_Latn");
Ok(Self {
model,
vocab,
src_lang,
})
}
pub fn set_src_lang(&mut self, src_lang: &str) -> Result<(), TokenizerError> {
if !EXTENDED_FAIRSEQ_LANGUAGE_CODES.contains(&src_lang) {
Err(TokenizerError::TokenNotFound(format!(
"{src_lang} is not a valid language tag."
)))
} else {
self.src_lang = src_lang.to_string();
Ok(())
}
}
}
impl Tokenizer<NLLBVocab> for NLLBTokenizer {
fn vocab(&self) -> &NLLBVocab {
&self.vocab
}
fn tokenize_to_tokens(&self, text: crate::TokenRef) -> Vec<crate::Token> {
let tokens = split_on_language_code(text, 8, &self.vocab.language_codes_bytes);
let (code_token, mut token) = match tokens.len() {
0 => {
return vec![];
}
1 => (None, tokens[0].to_owned()),
_ => (Some(tokens[0].to_owned()), tokens[1].to_owned()),
};
clean_text(&mut token, true);
decompose_nfkc(&mut token);
token.text = token.text.replace(|c: char| is_whitespace(&c), "\u{2581}");
if !token.text.starts_with('\u{2581}') {
token.text.insert(0, '\u{2581}');
token
.reference_offsets
.insert(0, token.reference_offsets[0]);
};
let mut output: Vec<Token> = Vec::new();
if let Some(code) = code_token {
output.push(code);
};
output.extend(self.model.tokenize_to_tokens(token.as_ref()));
output
}
fn convert_tokens_to_string(&self, tokens: Vec<String>) -> String {
tokens
.into_iter()
.map(|v| v.replace('\u{2581}', " "))
.collect::<Vec<String>>()
.join("")
}
fn build_input_with_special_tokens(
&self,
tokens_ids_with_offsets_1: TokenIdsWithOffsets,
tokens_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
) -> TokenIdsWithSpecialTokens {
let mut output: Vec<i64> = vec![];
let mut token_segment_ids: Vec<i8> = vec![];
let mut special_tokens_mask: Vec<i8> = vec![];
let mut offsets: Vec<Option<Offset>> = vec![];
let mut original_offsets: Vec<Vec<OffsetSize>> = vec![];
let mut mask: Vec<Mask> = vec![];
special_tokens_mask.extend(vec![0; tokens_ids_with_offsets_1.ids.len()]);
token_segment_ids.extend(vec![0; tokens_ids_with_offsets_1.ids.len()]);
output.extend(tokens_ids_with_offsets_1.ids);
offsets.extend(tokens_ids_with_offsets_1.offsets);
original_offsets.extend(tokens_ids_with_offsets_1.reference_offsets);
mask.extend(tokens_ids_with_offsets_1.masks);
if let Some(tokens_ids_with_offsets_2_value) = tokens_ids_with_offsets_2 {
let length = tokens_ids_with_offsets_2_value.ids.len();
special_tokens_mask.extend(vec![0; length]);
token_segment_ids.extend(vec![1; length + 1]);
output.extend(tokens_ids_with_offsets_2_value.ids);
offsets.extend(tokens_ids_with_offsets_2_value.offsets);
original_offsets.extend(tokens_ids_with_offsets_2_value.reference_offsets);
mask.extend(tokens_ids_with_offsets_2_value.masks);
} else {
token_segment_ids.push(0);
}
special_tokens_mask.push(1);
special_tokens_mask.push(1);
output.push(self.vocab.token_to_id(self.vocab.get_eos_value()));
output.push(self.vocab.token_to_id(&self.src_lang));
offsets.push(None);
offsets.push(None);
original_offsets.push(vec![]);
original_offsets.push(vec![]);
mask.push(Mask::Special);
mask.push(Mask::Special);
TokenIdsWithSpecialTokens {
token_ids: output,
segment_ids: token_segment_ids,
special_tokens_mask,
token_offsets: offsets,
reference_offsets: original_offsets,
mask,
}
}
}
impl MultiThreadedTokenizer<NLLBVocab> for NLLBTokenizer {}