use crate::vocab::Vocab;
use std::collections::HashMap;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::traits::{TokenizedInput, Tokenizer};
use unicode_normalization::UnicodeNormalization;
#[derive(Debug, Clone)]
pub struct WordPieceTokenizer {
vocab: Vocab,
unk_token: String,
sep_token: String,
pad_token: String,
cls_token: String,
#[allow(dead_code)]
mask_token: String,
do_lower_case: bool,
max_input_chars_per_word: usize,
}
impl WordPieceTokenizer {
pub fn new(vocab: HashMap<String, u32>, do_lower_case: bool) -> Self {
Self {
vocab: Vocab::from_map(vocab),
unk_token: "[UNK]".to_string(),
sep_token: "[SEP]".to_string(),
pad_token: "[PAD]".to_string(),
cls_token: "[CLS]".to_string(),
mask_token: "[MASK]".to_string(),
do_lower_case,
max_input_chars_per_word: 100,
}
}
pub fn from_pretrained(model_name: &str) -> Result<Self> {
let potential_paths = vec![
format!("{}/vocab.txt", model_name),
format!("{}-vocab.txt", model_name),
format!("models/{}/vocab.txt", model_name),
format!("./vocab/{}.txt", model_name),
];
for path in potential_paths {
if let Ok(vocab) = Self::load_vocab_from_file(&path) {
return Ok(Self::new(vocab, true));
}
}
let vocab = match model_name {
"bert-base-uncased" | "bert-large-uncased" => Self::create_bert_base_vocab(),
"bert-base-cased" | "bert-large-cased" => Self::create_bert_cased_vocab(),
"distilbert-base-uncased" => Self::create_distilbert_vocab(),
_ => Self::create_basic_vocab(),
};
Ok(Self::new(vocab, model_name.contains("uncased")))
}
pub fn from_vocab_file(vocab_path: &str, do_lower_case: bool) -> Result<Self> {
let vocab = Self::load_vocab_from_file(vocab_path)?;
Ok(Self::new(vocab, do_lower_case))
}
fn load_vocab_from_file(path: &str) -> Result<HashMap<String, u32>> {
use std::fs::File;
use std::io::{BufRead, BufReader};
let file = File::open(path).map_err(|e| {
TrustformersError::io_error(format!("Failed to open vocab file {}: {}", path, e))
})?;
let reader = BufReader::new(file);
let mut vocab = HashMap::new();
for (id, line) in reader.lines().enumerate() {
let token = line
.map_err(|e| TrustformersError::io_error(format!("Failed to read line: {}", e)))?;
let token = token.trim().to_string();
if !token.is_empty() {
vocab.insert(token, id as u32);
}
}
if vocab.is_empty() {
return Err(TrustformersError::other(
"Empty vocabulary file".to_string(),
));
}
Ok(vocab)
}
fn create_bert_base_vocab() -> HashMap<String, u32> {
let mut vocab = HashMap::new();
let special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"];
for (id, token) in special_tokens.iter().enumerate() {
vocab.insert(token.to_string(), id as u32);
}
let common_tokens = vec![
"the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
"on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
"they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
"there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
"go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know", "take",
"people", "into", "year", "your", "good", "some", "could", "them", "see", "other",
"than", "then", "now", "look", "only", "come", "its", "over", "think", "also", "back",
"after", "use", "two", "how", "our", "work", "first", "well", "way", "even", "new",
"want", "because", "any", "these", "give", "day", "most", "us",
];
for (id, token) in (special_tokens.len() as u32..).zip(common_tokens) {
vocab.insert(token.to_string(), id);
}
vocab
}
fn create_bert_cased_vocab() -> HashMap<String, u32> {
let mut vocab = Self::create_bert_base_vocab();
let base_size = vocab.len() as u32;
let capitalized = [
"The", "Be", "To", "Of", "And", "A", "In", "That", "Have", "I",
];
for (i, token) in capitalized.iter().enumerate() {
vocab.insert(token.to_string(), base_size + i as u32);
}
vocab
}
fn create_distilbert_vocab() -> HashMap<String, u32> {
let bert_vocab = Self::create_bert_base_vocab();
let target_size = (bert_vocab.len() * 3) / 4;
bert_vocab.into_iter().take(target_size).collect()
}
fn create_basic_vocab() -> HashMap<String, u32> {
let mut vocab = HashMap::new();
vocab.insert("[UNK]".to_string(), 0);
vocab.insert("[CLS]".to_string(), 1);
vocab.insert("[SEP]".to_string(), 2);
vocab.insert("[PAD]".to_string(), 3);
vocab.insert("[MASK]".to_string(), 4);
let mut id = 5;
for c in 'a'..='z' {
vocab.insert(c.to_string(), id);
id += 1;
}
for punct in [".", ",", "!", "?", ";", ":", "'", "\"", "-"] {
vocab.insert(punct.to_string(), id);
id += 1;
}
vocab
}
fn basic_tokenize(&self, text: &str) -> Vec<String> {
let text = if self.do_lower_case { text.to_lowercase() } else { text.to_string() };
let text = text.nfc().collect::<String>();
text.split_whitespace().map(|s| s.to_string()).collect()
}
fn wordpiece_tokenize(&self, word: &str) -> Vec<String> {
if word.chars().count() > self.max_input_chars_per_word {
return vec![self.unk_token.clone()];
}
let mut output = Vec::new();
let mut start = 0;
let chars: Vec<char> = word.chars().collect();
while start < chars.len() {
let mut end = chars.len();
let mut cur_substr = None;
while start < end {
let substr = if start > 0 {
format!("##{}", chars[start..end].iter().collect::<String>())
} else {
chars[start..end].iter().collect::<String>()
};
if self.vocab.contains(&substr) {
cur_substr = Some(substr);
break;
}
end -= 1;
}
if let Some(substr) = cur_substr {
output.push(substr);
start = end;
} else {
output.push(self.unk_token.clone());
break;
}
}
output
}
}
impl Tokenizer for WordPieceTokenizer {
fn encode(&self, text: &str) -> Result<TokenizedInput> {
let mut tokens = vec![self.cls_token.clone()];
for word in self.basic_tokenize(text) {
tokens.extend(self.wordpiece_tokenize(&word));
}
tokens.push(self.sep_token.clone());
let mut input_ids = Vec::with_capacity(tokens.len());
for token in &tokens {
let id = self.vocab.get_id(token).unwrap_or_else(|| {
self.vocab.get_id(&self.unk_token).expect("UNK token must exist in vocabulary")
});
input_ids.push(id);
}
let attention_mask = vec![1u8; input_ids.len()];
let input_ids_len = input_ids.len();
Ok(TokenizedInput {
input_ids,
attention_mask,
token_type_ids: Some(vec![0u32; input_ids_len]),
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
})
}
fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
let mut tokens = vec![self.cls_token.clone()];
for word in self.basic_tokenize(text) {
tokens.extend(self.wordpiece_tokenize(&word));
}
tokens.push(self.sep_token.clone());
let first_seg_len = tokens.len();
for word in self.basic_tokenize(text2) {
tokens.extend(self.wordpiece_tokenize(&word));
}
tokens.push(self.sep_token.clone());
let mut input_ids = Vec::with_capacity(tokens.len());
for token in &tokens {
let id = self.vocab.get_id(token).unwrap_or_else(|| {
self.vocab.get_id(&self.unk_token).expect("UNK token must exist in vocabulary")
});
input_ids.push(id);
}
let attention_mask = vec![1u8; input_ids.len()];
let mut token_type_ids = vec![0u32; first_seg_len];
token_type_ids.extend(vec![1u32; input_ids.len() - first_seg_len]);
Ok(TokenizedInput {
input_ids,
attention_mask,
token_type_ids: Some(token_type_ids),
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
})
}
fn decode(&self, ids: &[u32]) -> Result<String> {
let tokens: Vec<String> = ids.iter().filter_map(|&id| self.vocab.get_token(id)).collect();
let text = tokens
.join(" ")
.replace(" ##", "")
.replace(&format!(" {} ", self.pad_token), " ")
.replace(&format!(" {} ", self.cls_token), " ")
.replace(&format!(" {} ", self.sep_token), " ")
.trim()
.to_string();
Ok(text)
}
fn vocab_size(&self) -> usize {
self.vocab.size()
}
fn get_vocab(&self) -> HashMap<String, u32> {
self.vocab.get_token_to_id_map().clone()
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get_id(token)
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab.get_token(id)
}
}