use crate::HashMap;
use anyhow::{anyhow, bail, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use toktrie::TokTrie;
#[derive(Debug, Deserialize, Serialize)]
struct AddedToken {
id: usize,
content: String,
special: bool,
}
fn add_bytes(tokens: &mut Vec<Vec<u8>>, idx: usize, bytes: Vec<u8>) {
if tokens.len() <= idx {
tokens.resize(idx + 1, vec![]);
}
tokens[idx] = bytes;
}
fn is_self_mapped(c: char) -> bool {
match c {
'!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}' => true,
_ => false,
}
}
fn build_char_map() -> HashMap<char, u8> {
let mut res = HashMap::default();
let mut k = 0x100u32;
for byte in 0..=255u8 {
let c = byte as char;
if is_self_mapped(c) {
res.insert(c, byte);
} else {
res.insert(char::from_u32(k).unwrap(), byte);
k += 1;
}
}
res
}
pub fn token_bytes_from_tokenizer_json(tokenizer_json: &Value) -> Result<Vec<Vec<u8>>> {
let mut is_byte_level = false;
let mut is_byte_fallback = false;
let mut space_ch = ' ';
let decoder = &tokenizer_json["decoder"];
if decoder["type"].as_str() == Some("ByteLevel") {
is_byte_level = true;
} else if decoder["type"].as_str() == Some("Sequence") {
if let Some(decoders) = decoder["decoders"].as_array() {
for decoder in decoders {
if decoder["type"].as_str() == Some("ByteFallback") {
is_byte_fallback = true;
} else if decoder["type"].as_str() == Some("Replace")
&& decoder["content"].as_str() == Some(" ")
{
if let Some(s) = decoder["pattern"]["String"].as_str() {
let s: Vec<char> = s.chars().collect();
if s.len() == 1 {
space_ch = s[0];
}
}
}
}
}
}
if !is_byte_fallback && !is_byte_level {
bail!("can't determine decoder type: {:?}", decoder["type"]);
}
let mut token_bytes = vec![];
let added_tokens: Vec<AddedToken> =
serde_json::from_value(tokenizer_json["added_tokens"].clone())
.map_err(|e| anyhow!("error parsing added_tokens: {}", e))?;
for info in added_tokens.iter() {
let mut bytes = info.content.as_bytes().to_vec();
if info.special {
bytes.insert(0, TokTrie::SPECIAL_TOKEN_MARKER);
}
add_bytes(&mut token_bytes, info.id, bytes);
}
let char_map = build_char_map();
let vocab: HashMap<String, usize> =
serde_json::from_value(tokenizer_json["model"]["vocab"].clone())
.map_err(|e| anyhow!("error parsing vocab: {}", e))?;
for (tok_name, &tok_id) in vocab.iter() {
if tok_id < token_bytes.len() && token_bytes[tok_id].len() > 0 {
continue; }
let bytes = if is_byte_fallback {
if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">") {
let hex_str = &tok_name[3..5];
let byte = u8::from_str_radix(hex_str, 16).unwrap();
vec![byte]
} else {
assert!(!tok_name.starts_with("<0x"));
let tok_name = tok_name.replace(space_ch, " ");
tok_name.as_bytes().to_vec()
}
} else if is_byte_level {
let bytes: Result<Vec<u8>> = tok_name
.chars()
.map(|c| {
char_map
.get(&c)
.map(|c| *c)
.ok_or_else(|| anyhow!("missing char: {}", c))
})
.collect();
match bytes {
Ok(b) => b,
Err(e) => {
bail!("error: {} decoding {:?}", e, tok_name);
}
}
} else {
panic!();
};
add_bytes(&mut token_bytes, tok_id, bytes);
}
Ok(token_bytes)
}