use std::{collections::HashMap, path::Path};
use ahash::AHashMap;
use anyhow::{anyhow, Result};
use base64::Engine;
use serde::Deserialize;
use serde_json::Value;
use tokenizers::{
decoders::byte_level::ByteLevel as ByteLevelDecoder,
models::bpe::BpeBuilder,
pre_tokenizers::{
byte_level::ByteLevel,
sequence::Sequence,
split::{Split, SplitPattern},
PreTokenizerWrapper,
},
tokenizer::{self, normalizer::SplitDelimiterBehavior, Tokenizer},
};
use super::tiktoken::token_bytes_to_string;
#[derive(Deserialize)]
struct AddedToken {
id: usize,
content: String,
}
#[derive(Deserialize)]
struct TekkenVocabEntry {
rank: usize,
token_bytes: String,
}
#[derive(Deserialize)]
struct TekkenSpecialToken {
rank: usize,
token_str: String,
}
#[derive(Deserialize)]
struct TekkenConfig {
pattern: String,
default_vocab_size: usize,
default_num_special_tokens: usize,
}
#[derive(Deserialize)]
struct TekkenTokenizer {
config: TekkenConfig,
vocab: Vec<TekkenVocabEntry>,
#[serde(default)]
special_tokens: Vec<TekkenSpecialToken>,
}
fn load_tekken_tokenizer<P: AsRef<Path>>(path: P) -> Result<Tokenizer> {
let raw = std::fs::read(path)?;
let tekken: TekkenTokenizer = serde_json::from_slice(&raw)?;
let num_special = tekken.config.default_num_special_tokens;
let inner_vocab_size = tekken.config.default_vocab_size - num_special;
let mut bpe_ranks: HashMap<Vec<u8>, u32> = HashMap::new();
for entry in &tekken.vocab {
if entry.rank >= inner_vocab_size {
continue;
}
let token_bytes = base64::engine::general_purpose::STANDARD.decode(&entry.token_bytes)?;
#[allow(clippy::cast_possible_truncation)]
bpe_ranks.insert(token_bytes, entry.rank as u32);
}
let mut vocab = AHashMap::new();
for (token_bytes, rank) in &bpe_ranks {
let token_str = token_bytes_to_string(token_bytes);
#[allow(clippy::cast_possible_truncation)]
vocab.insert(token_str, *rank + num_special as u32);
}
let mut merges = Vec::new();
for (token, rank) in &bpe_ranks {
if token.len() == 1 {
continue;
}
let mut local = Vec::new();
for index in 1..token.len() {
let piece_l = &token[..index];
let piece_r = &token[index..];
if let (Some(&rank_l), Some(&rank_r)) = (bpe_ranks.get(piece_l), bpe_ranks.get(piece_r))
{
local.push((piece_l.to_vec(), piece_r.to_vec(), *rank, rank_l, rank_r));
}
}
local.sort_by_key(|(_, _, _, rank_l, rank_r)| (*rank_l, *rank_r));
for (piece_l, piece_r, rank, _, _) in local {
merges.push((piece_l, piece_r, rank));
}
}
merges.sort_by_key(|(_, _, rank)| *rank);
let merges: Vec<(String, String)> = merges
.into_iter()
.map(|(l, r, _)| (token_bytes_to_string(&l), token_bytes_to_string(&r)))
.collect();
let mut special_token_map: HashMap<usize, String> = HashMap::new();
for st in &tekken.special_tokens {
special_token_map.insert(st.rank, st.token_str.clone());
}
#[allow(clippy::cast_possible_truncation)]
let special_token_names: Vec<String> = (0..num_special)
.map(|id| {
special_token_map
.get(&id)
.cloned()
.unwrap_or_else(|| format!("<SPECIAL_{id}>"))
})
.collect();
for (id, name) in special_token_names.iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
vocab.insert(name.clone(), id as u32);
}
let bpe_model = BpeBuilder::new()
.vocab_and_merges(vocab, merges)
.build()
.map_err(|e| anyhow!("Failed to build BPE model: {}", e))?;
let mut tokenizer = Tokenizer::new(bpe_model);
let split = Split::new(
SplitPattern::Regex(tekken.config.pattern),
SplitDelimiterBehavior::Isolated,
false,
)
.map_err(|e| anyhow!("Failed to create split pre-tokenizer: {}", e))?;
let byte_level = ByteLevel::new(false, true, true);
let pre_tokenizer = Sequence::new(vec![
PreTokenizerWrapper::Split(split),
PreTokenizerWrapper::ByteLevel(byte_level),
]);
tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
let decoder = ByteLevelDecoder::new(true, false, false);
tokenizer.with_decoder(Some(decoder));
let special_tokens: Vec<tokenizer::AddedToken> = special_token_names
.into_iter()
.map(|name| tokenizer::AddedToken::from(name, true))
.collect();
tokenizer.add_special_tokens(&special_tokens);
Ok(tokenizer)
}
pub(crate) fn get_tokenizer<P: AsRef<Path> + Clone>(
p: P,
processor_added_tokens: Option<&[&str]>,
) -> Result<Tokenizer> {
if p.as_ref()
.file_name()
.and_then(|name| name.to_str())
.map(|name| name == "tekken.json")
.unwrap_or(false)
{
let mut tokenizer = load_tekken_tokenizer(p)?;
if let Some(added_tokens) = processor_added_tokens {
tokenizer.add_special_tokens(
&added_tokens
.iter()
.map(|x| tokenizer::AddedToken::from(x.to_string(), true))
.collect::<Vec<_>>(),
);
}
return Ok(tokenizer);
}
let mut tokenizer = {
let raw = std::fs::read(p.clone()).map_err(anyhow::Error::msg)?;
let mut tokenizer: Value = serde_json::from_slice(&raw).unwrap();
let added_tokens: Vec<AddedToken> =
serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap();
let vocab: HashMap<String, usize> =
serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap();
for token in added_tokens {
if !vocab.contains_key(&token.content) {
tokenizer["model"]["vocab"]
.as_object_mut()
.unwrap()
.insert(token.content, token.id.into())
.ok_or(())
.unwrap_err();
}
}
let raw_fixed = serde_json::to_vec_pretty(&tokenizer).unwrap();
Tokenizer::from_bytes(&raw_fixed).map_err(anyhow::Error::msg)?
};
if let Some(added_tokens) = processor_added_tokens {
tokenizer.add_special_tokens(
&added_tokens
.iter()
.map(|x| tokenizer::AddedToken::from(x.to_string(), true))
.collect::<Vec<_>>(),
);
}
Ok(tokenizer)
}