use super::{Tokenizer, HFTokenizer};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct WordpieceTokenizer {
hf_tokenizer: HFTokenizer,
}
impl Tokenizer for WordpieceTokenizer {
fn load() -> Self {
super::hf_tokenizers::utils::parallelism::set_parallelism(true);
use super::hf_tokenizers::pre_tokenizers::whitespace::Whitespace;
use super::hf_tokenizers::models::wordpiece::WordPiece;
use std::collections::HashMap;
let wordpiece_builder = WordPiece::builder();
let lines: Vec<&str> = include_str!("../resources/wordpiece_vocab.txt").split('\n').collect();
let mut hashmap: HashMap<String, u32> = HashMap::new();
for (i, line) in lines.iter().enumerate() {
hashmap.insert(line.to_string(), i as u32);
}
let wordpiece_builder = wordpiece_builder.vocab(hashmap);
let wordpiece = wordpiece_builder
.build().expect("WordPiece Tokenizer failed to build!");
let mut tokenizer = HFTokenizer::new(wordpiece);
tokenizer.with_pre_tokenizer(Whitespace::default());
WordpieceTokenizer{
hf_tokenizer: tokenizer
}
}
fn tokenize(&self, string: &str) -> Vec<String> {
super::hf_tokenizers::utils::parallelism::set_parallelism(true);
let string = string.to_lowercase();
let encoding = self.hf_tokenizer.encode(string, false).expect("BPE tokenization failed!");
encoding.get_tokens().to_vec()
}
fn batch_tokenize(&self, strings: Vec<String>) -> Vec<Vec<String>> {
super::hf_tokenizers::utils::parallelism::set_parallelism(true);
let strings = strings.iter().map(|a| {(*a).to_lowercase()}).collect();
let encodings = self.hf_tokenizer.encode_batch(strings, false).expect("WordPiece tokenization failed!");
let mut tokens: Vec<Vec<String>> = Vec::with_capacity(encodings.len());
for encoding in encodings.iter() {
tokens.push(encoding.get_tokens().to_vec());
}
tokens
}
fn untokenize(&self, tokens: Vec<String>) -> String {
let mut untokenized_string = String::new();
for (i, token) in tokens.iter().enumerate() {
if *token != "[PAD]" && *token != "[EOS]" {
if token.contains("##") || [".", "?", "!", ",", "'", r#"""#].iter().any(|x| *x == token) || i == 0 {
untokenized_string = format!("{}{}", untokenized_string, token.replace("##", ""))
}else {
untokenized_string = format!("{} {}", untokenized_string, token)
}
}
}
untokenized_string
}
fn batch_untokenize(&self, tokens: Vec<Vec<String>>) -> Vec<String> {
let mut untokenized_strings = vec![String::new(); tokens.len()];
for i in 0..tokens.len() {
for x in 0..tokens[i].len() {
if *tokens[i][x] != *"[PAD]" && *tokens[i][x] != *"[EOS]" {
if tokens[i][x].contains("##") || [".", "?", "!", ",", "'", r#"""#].iter().any(|t| **t == tokens[i][x]) || x == 0 {
untokenized_strings[i] = format!("{}{}", untokenized_strings[i], tokens[i][x].replace("##", ""))
}else {
untokenized_strings[i] = format!("{} {}", untokenized_strings[i], tokens[i][x])
}
}
}
}
untokenized_strings
}
}