use std::collections::HashMap;
use serde_json::Value;
use crate::{
bpe::BpeMerges,
error::{TokenizerError, TokenizerResult},
tokenizer::{OxiTokenizer, TokenizerConfig},
vocab::Vocabulary,
};
fn build_bytes_to_unicode() -> [char; 256] {
let mut printable: Vec<u8> = Vec::with_capacity(188);
for b in 0x21u8..=0x7Eu8 {
printable.push(b);
}
for b in 0xA1u8..=0xACu8 {
printable.push(b);
}
for b in 0xAEu8..=0xFFu8 {
printable.push(b);
}
let mut table: [char; 256] = ['\0'; 256];
let mut n: u32 = 0;
for b in 0u16..=255u16 {
if printable.contains(&(b as u8)) {
table[b as usize] = char::from_u32(b as u32).unwrap_or('\u{FFFD}');
} else {
let cp = 0x100u32 + n;
table[b as usize] = char::from_u32(cp).unwrap_or('\u{FFFD}');
n += 1;
}
}
table
}
pub fn bytes_to_unicode_map() -> [char; 256] {
build_bytes_to_unicode()
}
pub fn byte_to_unicode(b: u8) -> char {
bytes_to_unicode_map()[b as usize]
}
pub fn unicode_to_byte(ch: char) -> Option<u8> {
let table = bytes_to_unicode_map();
for (idx, &c) in table.iter().enumerate() {
if c == ch {
return Some(idx as u8);
}
}
None
}
pub fn bytes_to_unicode_inverse() -> HashMap<char, u8> {
let table = build_bytes_to_unicode();
let mut out = HashMap::with_capacity(256);
for (idx, &ch) in table.iter().enumerate() {
out.insert(ch, idx as u8);
}
out
}
pub fn bytes_to_unicode_string(s: &str) -> String {
let table = bytes_to_unicode_map();
let mut out = String::with_capacity(s.len());
for b in s.as_bytes() {
out.push(table[*b as usize]);
}
out
}
#[derive(Debug, Clone)]
pub struct HfTokenizerJson {
pub vocab: HashMap<String, u32>,
pub merges: Vec<(String, String)>,
pub special_tokens: HashMap<String, u32>,
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub unk_token: Option<String>,
pub pad_token: Option<String>,
pub byte_level: bool,
}
impl HfTokenizerJson {
pub fn parse(json: &str) -> TokenizerResult<Self> {
let root: Value = serde_json::from_str(json)
.map_err(|e| TokenizerError::HfFormat(format!("invalid JSON: {e}")))?;
let model = root
.get("model")
.ok_or_else(|| TokenizerError::HfFormat("missing `model` field".to_owned()))?;
let vocab_val = model
.get("vocab")
.ok_or_else(|| TokenizerError::HfFormat("missing `model.vocab` field".to_owned()))?;
let mut vocab: HashMap<String, u32> = HashMap::new();
match vocab_val {
Value::Object(map) => {
for (token, id_val) in map {
let id = id_val.as_u64().ok_or_else(|| {
TokenizerError::HfFormat(format!(
"vocab entry {token:?} has non-integer id"
))
})? as u32;
vocab.insert(token.clone(), id);
}
}
_ => {
return Err(TokenizerError::HfFormat(
"`model.vocab` must be an object".to_owned(),
));
}
}
let merges_val = model
.get("merges")
.ok_or_else(|| TokenizerError::HfFormat("missing `model.merges` field".to_owned()))?;
let mut merges: Vec<(String, String)> = Vec::new();
match merges_val {
Value::Array(list) => {
for (idx, entry) in list.iter().enumerate() {
let pair = parse_merge_entry(entry).ok_or_else(|| {
TokenizerError::HfFormat(format!("malformed merge entry #{idx}: {entry:?}"))
})?;
merges.push(pair);
}
}
_ => {
return Err(TokenizerError::HfFormat(
"`model.merges` must be an array".to_owned(),
));
}
}
let mut special_tokens: HashMap<String, u32> = HashMap::new();
if let Some(added) = root.get("added_tokens").and_then(Value::as_array) {
for token_obj in added {
let content = token_obj
.get("content")
.and_then(Value::as_str)
.map(|s| s.to_owned());
let id = token_obj
.get("id")
.and_then(Value::as_u64)
.map(|n| n as u32);
let is_special = token_obj
.get("special")
.and_then(Value::as_bool)
.unwrap_or(false);
if let (Some(content), Some(id)) = (content, id) {
vocab.entry(content.clone()).or_insert(id);
if is_special {
special_tokens.insert(content, id);
}
}
}
}
let bos_token = extract_special_token(&root, "bos_token");
let eos_token = extract_special_token(&root, "eos_token");
let unk_token = extract_special_token(&root, "unk_token").or_else(|| {
model
.get("unk_token")
.and_then(Value::as_str)
.map(str::to_owned)
});
let pad_token = extract_special_token(&root, "pad_token");
let byte_level = detect_byte_level(&root);
Ok(Self {
vocab,
merges,
special_tokens,
bos_token,
eos_token,
unk_token,
pad_token,
byte_level,
})
}
pub fn into_tokenizer(self) -> TokenizerResult<OxiTokenizer> {
let mut vocabulary = Vocabulary::new();
for (token, id) in &self.vocab {
if self.special_tokens.contains_key(token) {
vocabulary.add_special(token, *id);
} else {
vocabulary.insert(token, *id);
}
}
let mut merges = BpeMerges::new();
for (a, b) in &self.merges {
let merged = format!("{a}{b}");
let merged_id = match self.vocab.get(&merged) {
Some(&id) => id,
None => {
continue;
}
};
merges.add_merge(a, b, merged_id);
}
let bos_id = self
.bos_token
.as_ref()
.and_then(|t| self.vocab.get(t).copied());
let eos_id = self
.eos_token
.as_ref()
.and_then(|t| self.vocab.get(t).copied());
let unk_id = self
.unk_token
.as_ref()
.and_then(|t| self.vocab.get(t).copied());
let pad_id = self
.pad_token
.as_ref()
.and_then(|t| self.vocab.get(t).copied());
let mut config = TokenizerConfig {
byte_level_decode: self.byte_level,
..Default::default()
};
if let Some(id) = bos_id {
config.bos_token_id = id;
}
if let Some(id) = eos_id {
config.eos_token_id = id;
}
if let Some(id) = unk_id {
config.unk_token_id = id;
}
if let Some(id) = pad_id {
config.pad_token_id = id;
}
Ok(OxiTokenizer::new(vocabulary, merges, config))
}
}
fn parse_merge_entry(entry: &Value) -> Option<(String, String)> {
match entry {
Value::String(s) => {
let mut parts = s.splitn(2, ' ');
let a = parts.next()?.to_owned();
let b = parts.next()?.to_owned();
Some((a, b))
}
Value::Array(arr) if arr.len() == 2 => {
let a = arr[0].as_str()?.to_owned();
let b = arr[1].as_str()?.to_owned();
Some((a, b))
}
_ => None,
}
}
fn extract_special_token(root: &Value, key: &str) -> Option<String> {
if let Some(v) = root.get(key) {
if let Some(s) = v.as_str() {
return Some(s.to_owned());
}
if let Some(inner) = v.get("content").and_then(Value::as_str) {
return Some(inner.to_owned());
}
}
None
}
fn detect_byte_level(root: &Value) -> bool {
let has_bl = |field: &str| -> bool {
match root.get(field) {
Some(Value::Object(map)) => map
.get("type")
.and_then(Value::as_str)
.map(|t| t == "ByteLevel")
.unwrap_or(false),
Some(Value::Array(list)) => list.iter().any(|entry| {
entry
.get("type")
.and_then(Value::as_str)
.map(|t| t == "ByteLevel")
.unwrap_or(false)
}),
_ => false,
}
};
has_bl("pre_tokenizer") || has_bl("decoder")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn map_is_256_entries() {
let table = bytes_to_unicode_map();
let mut seen = std::collections::HashSet::new();
for &ch in table.iter() {
assert_ne!(ch, '\0', "map entry must not be NUL");
assert!(seen.insert(ch), "map entries must be distinct");
}
assert_eq!(seen.len(), 256);
}
#[test]
fn map_printable_ascii_passthrough() {
for b in 0x21u8..=0x7Eu8 {
assert_eq!(byte_to_unicode(b), char::from(b));
}
}
#[test]
fn map_latin1_passthrough() {
for b in 0xA1u8..=0xACu8 {
assert_eq!(byte_to_unicode(b), char::from(b));
}
for b in 0xAEu8..=0xFFu8 {
assert_eq!(byte_to_unicode(b), char::from(b));
}
}
#[test]
fn map_space_remapped() {
assert_eq!(byte_to_unicode(0x20), '\u{0120}');
}
#[test]
fn map_newline_remapped() {
assert_eq!(byte_to_unicode(0x0A), '\u{010A}');
}
#[test]
fn map_inverse_roundtrip() {
for b in 0u16..=255u16 {
let b = b as u8;
let ch = byte_to_unicode(b);
assert_eq!(
unicode_to_byte(ch),
Some(b),
"roundtrip failed for byte {b:#x}"
);
}
}
#[test]
fn bytes_to_unicode_string_basic() {
let out = bytes_to_unicode_string(" hello");
assert!(out.starts_with('\u{0120}'));
assert!(out.contains("hello"));
}
#[test]
fn parse_minimal_tokenizer_json() {
let json = r#"{
"model": {
"type": "BPE",
"vocab": {"<unk>": 0, "a": 1, "b": 2, "ab": 3},
"merges": ["a b"]
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("minimal parse ok");
assert_eq!(parsed.vocab.len(), 4);
assert_eq!(parsed.merges.len(), 1);
assert_eq!(parsed.merges[0], ("a".to_owned(), "b".to_owned()));
}
#[test]
fn parse_array_merges() {
let json = r#"{
"model": {
"vocab": {"a": 0, "b": 1, "ab": 2},
"merges": [["a", "b"]]
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("array merges ok");
assert_eq!(parsed.merges[0], ("a".to_owned(), "b".to_owned()));
}
#[test]
fn parse_detects_byte_level() {
let json = r#"{
"pre_tokenizer": {"type": "ByteLevel"},
"model": {
"vocab": {"a": 0},
"merges": []
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("parse ok");
assert!(parsed.byte_level);
}
#[test]
fn parse_missing_model_errors() {
let json = r#"{"foo": "bar"}"#;
let err = HfTokenizerJson::parse(json).expect_err("should fail");
match err {
TokenizerError::HfFormat(msg) => assert!(msg.contains("model")),
other => panic!("expected HfFormat, got {other:?}"),
}
}
#[test]
fn parse_picks_up_special_tokens() {
let json = r#"{
"added_tokens": [
{"id": 100, "content": "<|im_start|>", "special": true},
{"id": 101, "content": "foo", "special": false}
],
"model": {
"vocab": {"a": 0},
"merges": []
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("parse ok");
assert!(parsed.special_tokens.contains_key("<|im_start|>"));
assert!(!parsed.special_tokens.contains_key("foo"));
assert_eq!(parsed.vocab.get("foo"), Some(&101));
}
#[test]
fn into_tokenizer_roundtrip() {
let json = r#"{
"pre_tokenizer": {"type": "ByteLevel"},
"model": {
"vocab": {"a": 0, "b": 1, "ab": 2, "c": 3},
"merges": ["a b"]
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("parse ok");
let tok = parsed.into_tokenizer().expect("to tokenizer ok");
assert!(tok.vocab_size() >= 4);
}
#[test]
fn malformed_merge_entry_errors() {
let json = r#"{
"model": {
"vocab": {"a": 0},
"merges": [{"not": "a pair"}]
}
}"#;
let err = HfTokenizerJson::parse(json).expect_err("should fail");
assert!(matches!(err, TokenizerError::HfFormat(_)));
}
#[test]
fn vocab_non_integer_id_errors() {
let json = r#"{
"model": {
"vocab": {"a": "not an int"},
"merges": []
}
}"#;
let err = HfTokenizerJson::parse(json).expect_err("should fail");
assert!(matches!(err, TokenizerError::HfFormat(_)));
}
#[test]
fn inverse_map_len() {
let inv = bytes_to_unicode_inverse();
assert_eq!(inv.len(), 256);
}
}