use std::collections::HashMap;
use super::{Tokenizer, HFTokenizer};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct BPETokenizer {
hf_tokenizer: HFTokenizer,
}
impl Tokenizer for BPETokenizer {
fn load() -> Self {
use crate::tokenization::hf_tokenizers::models::bpe::BPE;
use serde_json::Value;
let json: HashMap<String, Value> = serde_json::from_str(&include_str!("../resources/bpe_vocab.json").replace("/", "").replace("Ġ", "")).expect("Error parsing BPE vocab file!");
let mut token_vec: Vec<String> = vec![String::from(""); 50265]; for token in json.keys() {
token_vec[json[token].as_u64().unwrap() as usize] = token.clone();
}
let mut token2index = HashMap::with_capacity(token_vec.len());
for token in token_vec {
if !token.is_empty() {
token2index.insert(token.to_string(), token2index.len() as u32);
}
}
let bpe_builder = BPE::builder();
let mut merges: Vec<(String, String)> = Vec::new();
let lines: Vec<&str> = include_str!("../resources/bpe_merges.txt").split('\n').collect();
for line in lines {
let line = String::from(line).replace("Ġ", "").replace("\n", "").replace("##", "");
if line.contains(' ') && !line.contains('#') {
let line: Vec<&str> = line.split(' ').collect();
if token2index.contains_key(&line[0].to_string()) && token2index.contains_key(&line[1].to_string()) && token2index.contains_key(&format!("{}{}", line[0].to_string(), line[1].to_string())) {
merges.push((line[0].to_string(), line[1].to_string()));
}
}
}
let bpe_builder = bpe_builder.vocab_and_merges(token2index, merges);
let bpe = bpe_builder
.unk_token("[UNK]".into())
.build().expect("BPE Tokenizer failed to build!");
BPETokenizer {
hf_tokenizer: HFTokenizer::new(bpe)
}
}
fn tokenize(&self, string: &str) -> Vec<String> {
super::hf_tokenizers::utils::parallelism::set_parallelism(true);
let string = string.to_lowercase();
let encoding = self.hf_tokenizer.encode(string, false).expect("BPE tokenization failed!");
encoding.get_tokens().to_vec()
}
fn batch_tokenize(&self, strings: Vec<String>) -> Vec<Vec<String>> {
super::hf_tokenizers::utils::parallelism::set_parallelism(true);
let strings = strings.iter().map(|a| {a.to_lowercase()}).collect();
let encodings = self.hf_tokenizer.encode_batch(strings, false).expect("BPE tokenization failed!");
let mut tokens: Vec<Vec<String>> = Vec::with_capacity(encodings.len());
for encoding in encodings {
tokens.push(encoding.get_tokens().to_vec());
};
tokens
}
fn untokenize(&self, tokens: Vec<String>) -> String {
tokens.join("")
}
fn batch_untokenize(&self, tokens: Vec<Vec<String>>) -> Vec<String> {
tokens.iter().map(|tokens| {
tokens.join("")
}).collect()
}
}