use std::collections::HashMap;
use serde_json::Value;
use crate::{
bpe::BpeMerges,
error::{TokenizerError, TokenizerResult},
tokenizer::{OxiTokenizer, TokenizerConfig},
vocab::Vocabulary,
wordpiece::WordPieceVocab,
};
fn build_bytes_to_unicode() -> [char; 256] {
let mut printable: Vec<u8> = Vec::with_capacity(188);
for b in 0x21u8..=0x7Eu8 {
printable.push(b);
}
for b in 0xA1u8..=0xACu8 {
printable.push(b);
}
for b in 0xAEu8..=0xFFu8 {
printable.push(b);
}
let mut table: [char; 256] = ['\0'; 256];
let mut n: u32 = 0;
for b in 0u16..=255u16 {
if printable.contains(&(b as u8)) {
table[b as usize] = char::from_u32(b as u32).unwrap_or('\u{FFFD}');
} else {
let cp = 0x100u32 + n;
table[b as usize] = char::from_u32(cp).unwrap_or('\u{FFFD}');
n += 1;
}
}
table
}
pub fn bytes_to_unicode_map() -> [char; 256] {
build_bytes_to_unicode()
}
pub fn byte_to_unicode(b: u8) -> char {
bytes_to_unicode_map()[b as usize]
}
pub fn unicode_to_byte(ch: char) -> Option<u8> {
let table = bytes_to_unicode_map();
for (idx, &c) in table.iter().enumerate() {
if c == ch {
return Some(idx as u8);
}
}
None
}
pub fn bytes_to_unicode_inverse() -> HashMap<char, u8> {
let table = build_bytes_to_unicode();
let mut out = HashMap::with_capacity(256);
for (idx, &ch) in table.iter().enumerate() {
out.insert(ch, idx as u8);
}
out
}
pub fn bytes_to_unicode_string(s: &str) -> String {
let table = bytes_to_unicode_map();
let mut out = String::with_capacity(s.len());
for b in s.as_bytes() {
out.push(table[*b as usize]);
}
out
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HfModelType {
Bpe,
Unigram,
WordPiece,
Other(String),
}
impl HfModelType {
fn from_str(s: &str) -> Self {
match s {
"BPE" => Self::Bpe,
"Unigram" => Self::Unigram,
"WordPiece" => Self::WordPiece,
other => Self::Other(other.to_owned()),
}
}
}
#[derive(Debug, Clone)]
pub struct HfTokenizerJson {
pub model_type: HfModelType,
pub vocab: HashMap<String, u32>,
pub merges: Vec<(String, String)>,
pub unigram_vocab: Option<Vec<(String, f64)>>,
pub unigram_unk_id: Option<u32>,
pub wordpiece_max_chars: Option<usize>,
pub special_tokens: HashMap<String, u32>,
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub unk_token: Option<String>,
pub pad_token: Option<String>,
pub byte_level: bool,
}
impl HfTokenizerJson {
pub fn parse(json: &str) -> TokenizerResult<Self> {
let root: Value = serde_json::from_str(json)
.map_err(|e| TokenizerError::HfFormat(format!("invalid JSON: {e}")))?;
let model = root
.get("model")
.ok_or_else(|| TokenizerError::HfFormat("missing `model` field".to_owned()))?;
let model_type = model
.get("type")
.and_then(Value::as_str)
.map(HfModelType::from_str)
.unwrap_or(HfModelType::Bpe);
let (mut vocab, merges, unigram_vocab, unigram_unk_id) = match &model_type {
HfModelType::Unigram => {
let (v, m, uv, uid) = parse_unigram_model(model)?;
(v, m, uv, uid)
}
HfModelType::WordPiece => {
let vocab_val = model.get("vocab").ok_or_else(|| {
TokenizerError::HfFormat("WordPiece model.vocab missing".into())
})?;
let mut wp_vocab: HashMap<String, u32> = HashMap::new();
match vocab_val {
Value::Object(map) => {
for (token, id_val) in map {
let id = id_val.as_u64().ok_or_else(|| {
TokenizerError::HfFormat(format!(
"WordPiece vocab id for '{token}' is not an integer"
))
})? as u32;
wp_vocab.insert(token.clone(), id);
}
}
_ => {
return Err(TokenizerError::HfFormat(
"WordPiece model.vocab must be an object".into(),
));
}
}
(wp_vocab, vec![], None, None)
}
HfModelType::Bpe | HfModelType::Other(_) => {
let (v, m) = parse_bpe_model(model)?;
(v, m, None, None)
}
};
let mut special_tokens: HashMap<String, u32> = HashMap::new();
if let Some(added) = root.get("added_tokens").and_then(Value::as_array) {
for token_obj in added {
let content = token_obj
.get("content")
.and_then(Value::as_str)
.map(|s| s.to_owned());
let id = token_obj
.get("id")
.and_then(Value::as_u64)
.map(|n| n as u32);
let is_special = token_obj
.get("special")
.and_then(Value::as_bool)
.unwrap_or(false);
if let (Some(content), Some(id)) = (content, id) {
vocab.entry(content.clone()).or_insert(id);
if is_special {
special_tokens.insert(content, id);
}
}
}
}
let bos_token = extract_special_token(&root, "bos_token");
let eos_token = extract_special_token(&root, "eos_token");
let unk_token = extract_special_token(&root, "unk_token").or_else(|| {
model
.get("unk_token")
.and_then(Value::as_str)
.map(str::to_owned)
});
let pad_token = extract_special_token(&root, "pad_token");
let byte_level = detect_byte_level(&root);
let wordpiece_max_chars = if model_type == HfModelType::WordPiece {
model
.get("max_input_chars_per_word")
.and_then(Value::as_u64)
.map(|n| n as usize)
} else {
None
};
Ok(Self {
model_type,
vocab,
merges,
unigram_vocab,
unigram_unk_id,
wordpiece_max_chars,
special_tokens,
bos_token,
eos_token,
unk_token,
pad_token,
byte_level,
})
}
pub fn into_tokenizer(self) -> TokenizerResult<OxiTokenizer> {
let mut vocabulary = Vocabulary::new();
for (token, id) in &self.vocab {
if self.special_tokens.contains_key(token) {
vocabulary.add_special(token, *id);
} else {
vocabulary.insert(token, *id);
}
}
let bos_id = self
.bos_token
.as_ref()
.and_then(|t| self.vocab.get(t).copied());
let eos_id = self
.eos_token
.as_ref()
.and_then(|t| self.vocab.get(t).copied());
let unk_id_from_token = self
.unk_token
.as_ref()
.and_then(|t| self.vocab.get(t).copied());
let pad_id = self
.pad_token
.as_ref()
.and_then(|t| self.vocab.get(t).copied());
let mut config = TokenizerConfig {
byte_level_decode: self.byte_level,
..Default::default()
};
if let Some(id) = bos_id {
config.bos_token_id = id;
}
if let Some(id) = eos_id {
config.eos_token_id = id;
}
if let Some(id) = unk_id_from_token {
config.unk_token_id = id;
}
if let Some(id) = pad_id {
config.pad_token_id = id;
}
match self.model_type {
HfModelType::Bpe | HfModelType::Other(_) => {
let mut merges = BpeMerges::new();
for (a, b) in &self.merges {
let merged = format!("{a}{b}");
let merged_id = match self.vocab.get(&merged) {
Some(&id) => id,
None => {
continue;
}
};
merges.add_merge(a, b, merged_id);
}
Ok(OxiTokenizer::new(vocabulary, merges, config))
}
HfModelType::Unigram => {
let entries = self.unigram_vocab.ok_or_else(|| {
TokenizerError::HfFormat(
"Unigram model requires `model.vocab` array".to_owned(),
)
})?;
let effective_unk_id = self.unigram_unk_id.unwrap_or(config.unk_token_id);
let unigram_vocab = crate::unigram::UnigramVocab::new(entries, effective_unk_id)
.map_err(|e| TokenizerError::HfFormat(format!("invalid Unigram vocab: {e}")))?;
Ok(OxiTokenizer::with_unigram(
vocabulary,
unigram_vocab,
config,
))
}
HfModelType::WordPiece => {
let wp_vocab = build_wordpiece_vocab_from_map(
&self.vocab,
self.unk_token.as_deref(),
self.wordpiece_max_chars,
config.unk_token_id,
)?;
Ok(OxiTokenizer::with_wordpiece(vocabulary, wp_vocab, config))
}
}
}
}
fn build_wordpiece_vocab_from_map(
vocab_map: &HashMap<String, u32>,
unk_token_str: Option<&str>,
max_chars: Option<usize>,
fallback_unk_id: u32,
) -> TokenizerResult<WordPieceVocab> {
let mut pairs: Vec<(&str, u32)> = vocab_map.iter().map(|(k, &v)| (k.as_str(), v)).collect();
pairs.sort_by_key(|(_, id)| *id);
for (i, (_, id)) in pairs.iter().enumerate() {
if *id as usize != i {
return Err(TokenizerError::HfFormat(format!(
"WordPiece vocab IDs are not contiguous: expected {i}, found {id}"
)));
}
}
let tokens: Vec<String> = pairs.into_iter().map(|(t, _)| t.to_owned()).collect();
let unk_id: u32 = unk_token_str
.and_then(|s| vocab_map.get(s).copied())
.unwrap_or(fallback_unk_id);
let wp = WordPieceVocab::new(tokens, unk_id)
.map_err(|e| TokenizerError::HfFormat(format!("invalid WordPiece vocab: {e}")))?;
Ok(if let Some(max) = max_chars {
wp.with_max_input_chars(max)
} else {
wp
})
}
#[allow(clippy::type_complexity)]
fn parse_bpe_model(
model: &Value,
) -> TokenizerResult<(HashMap<String, u32>, Vec<(String, String)>)> {
let vocab_val = model
.get("vocab")
.ok_or_else(|| TokenizerError::HfFormat("missing `model.vocab` field".to_owned()))?;
let mut vocab: HashMap<String, u32> = HashMap::new();
match vocab_val {
Value::Object(map) => {
for (token, id_val) in map {
let id = id_val.as_u64().ok_or_else(|| {
TokenizerError::HfFormat(format!("vocab entry {token:?} has non-integer id"))
})? as u32;
vocab.insert(token.clone(), id);
}
}
_ => {
return Err(TokenizerError::HfFormat(
"`model.vocab` must be an object".to_owned(),
));
}
}
let merges_val = model
.get("merges")
.ok_or_else(|| TokenizerError::HfFormat("missing `model.merges` field".to_owned()))?;
let mut merges: Vec<(String, String)> = Vec::new();
match merges_val {
Value::Array(list) => {
for (idx, entry) in list.iter().enumerate() {
let pair = parse_merge_entry(entry).ok_or_else(|| {
TokenizerError::HfFormat(format!("malformed merge entry #{idx}: {entry:?}"))
})?;
merges.push(pair);
}
}
_ => {
return Err(TokenizerError::HfFormat(
"`model.merges` must be an array".to_owned(),
));
}
}
Ok((vocab, merges))
}
#[allow(clippy::type_complexity)]
fn parse_unigram_model(
model: &Value,
) -> TokenizerResult<(
HashMap<String, u32>,
Vec<(String, String)>,
Option<Vec<(String, f64)>>,
Option<u32>,
)> {
let vocab_val = model
.get("vocab")
.ok_or_else(|| TokenizerError::HfFormat("missing `model.vocab` field".to_owned()))?;
let arr = vocab_val.as_array().ok_or_else(|| {
TokenizerError::HfFormat(
"Unigram `model.vocab` must be an array of [token, score] pairs".to_owned(),
)
})?;
let mut entries: Vec<(String, f64)> = Vec::with_capacity(arr.len());
let mut vocab_map: HashMap<String, u32> = HashMap::with_capacity(arr.len());
for (idx, item) in arr.iter().enumerate() {
let pair = item.as_array().ok_or_else(|| {
TokenizerError::HfFormat(format!(
"Unigram vocab entry #{idx} must be a [token, score] array"
))
})?;
if pair.len() != 2 {
return Err(TokenizerError::HfFormat(format!(
"Unigram vocab entry #{idx} must have exactly 2 elements, got {}",
pair.len()
)));
}
let token = pair[0].as_str().ok_or_else(|| {
TokenizerError::HfFormat(format!(
"Unigram vocab entry #{idx}: first element must be a string"
))
})?;
let score = pair[1].as_f64().ok_or_else(|| {
TokenizerError::HfFormat(format!(
"Unigram vocab entry #{idx}: second element must be a number"
))
})?;
vocab_map.insert(token.to_owned(), idx as u32);
entries.push((token.to_owned(), score));
}
let unk_id = model
.get("unk_id")
.and_then(Value::as_u64)
.map(|n| n as u32);
Ok((vocab_map, vec![], Some(entries), unk_id))
}
fn parse_merge_entry(entry: &Value) -> Option<(String, String)> {
match entry {
Value::String(s) => {
let mut parts = s.splitn(2, ' ');
let a = parts.next()?.to_owned();
let b = parts.next()?.to_owned();
Some((a, b))
}
Value::Array(arr) if arr.len() == 2 => {
let a = arr[0].as_str()?.to_owned();
let b = arr[1].as_str()?.to_owned();
Some((a, b))
}
_ => None,
}
}
fn extract_special_token(root: &Value, key: &str) -> Option<String> {
if let Some(v) = root.get(key) {
if let Some(s) = v.as_str() {
return Some(s.to_owned());
}
if let Some(inner) = v.get("content").and_then(Value::as_str) {
return Some(inner.to_owned());
}
}
None
}
fn detect_byte_level(root: &Value) -> bool {
let has_bl = |field: &str| -> bool {
match root.get(field) {
Some(Value::Object(map)) => map
.get("type")
.and_then(Value::as_str)
.map(|t| t == "ByteLevel")
.unwrap_or(false),
Some(Value::Array(list)) => list.iter().any(|entry| {
entry
.get("type")
.and_then(Value::as_str)
.map(|t| t == "ByteLevel")
.unwrap_or(false)
}),
_ => false,
}
};
has_bl("pre_tokenizer") || has_bl("decoder")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn map_is_256_entries() {
let table = bytes_to_unicode_map();
let mut seen = std::collections::HashSet::new();
for &ch in table.iter() {
assert_ne!(ch, '\0', "map entry must not be NUL");
assert!(seen.insert(ch), "map entries must be distinct");
}
assert_eq!(seen.len(), 256);
}
#[test]
fn map_printable_ascii_passthrough() {
for b in 0x21u8..=0x7Eu8 {
assert_eq!(byte_to_unicode(b), char::from(b));
}
}
#[test]
fn map_latin1_passthrough() {
for b in 0xA1u8..=0xACu8 {
assert_eq!(byte_to_unicode(b), char::from(b));
}
for b in 0xAEu8..=0xFFu8 {
assert_eq!(byte_to_unicode(b), char::from(b));
}
}
#[test]
fn map_space_remapped() {
assert_eq!(byte_to_unicode(0x20), '\u{0120}');
}
#[test]
fn map_newline_remapped() {
assert_eq!(byte_to_unicode(0x0A), '\u{010A}');
}
#[test]
fn map_inverse_roundtrip() {
for b in 0u16..=255u16 {
let b = b as u8;
let ch = byte_to_unicode(b);
assert_eq!(
unicode_to_byte(ch),
Some(b),
"roundtrip failed for byte {b:#x}"
);
}
}
#[test]
fn bytes_to_unicode_string_basic() {
let out = bytes_to_unicode_string(" hello");
assert!(out.starts_with('\u{0120}'));
assert!(out.contains("hello"));
}
#[test]
fn parse_minimal_tokenizer_json() {
let json = r#"{
"model": {
"type": "BPE",
"vocab": {"<unk>": 0, "a": 1, "b": 2, "ab": 3},
"merges": ["a b"]
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("minimal parse ok");
assert_eq!(parsed.vocab.len(), 4);
assert_eq!(parsed.merges.len(), 1);
assert_eq!(parsed.merges[0], ("a".to_owned(), "b".to_owned()));
}
#[test]
fn parse_array_merges() {
let json = r#"{
"model": {
"vocab": {"a": 0, "b": 1, "ab": 2},
"merges": [["a", "b"]]
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("array merges ok");
assert_eq!(parsed.merges[0], ("a".to_owned(), "b".to_owned()));
}
#[test]
fn parse_detects_byte_level() {
let json = r#"{
"pre_tokenizer": {"type": "ByteLevel"},
"model": {
"vocab": {"a": 0},
"merges": []
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("parse ok");
assert!(parsed.byte_level);
}
#[test]
fn parse_missing_model_errors() {
let json = r#"{"foo": "bar"}"#;
let err = HfTokenizerJson::parse(json).expect_err("should fail");
match err {
TokenizerError::HfFormat(msg) => assert!(msg.contains("model")),
other => panic!("expected HfFormat, got {other:?}"),
}
}
#[test]
fn parse_picks_up_special_tokens() {
let json = r#"{
"added_tokens": [
{"id": 100, "content": "<|im_start|>", "special": true},
{"id": 101, "content": "foo", "special": false}
],
"model": {
"vocab": {"a": 0},
"merges": []
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("parse ok");
assert!(parsed.special_tokens.contains_key("<|im_start|>"));
assert!(!parsed.special_tokens.contains_key("foo"));
assert_eq!(parsed.vocab.get("foo"), Some(&101));
}
#[test]
fn into_tokenizer_roundtrip() {
let json = r#"{
"pre_tokenizer": {"type": "ByteLevel"},
"model": {
"vocab": {"a": 0, "b": 1, "ab": 2, "c": 3},
"merges": ["a b"]
}
}"#;
let parsed = HfTokenizerJson::parse(json).expect("parse ok");
let tok = parsed.into_tokenizer().expect("to tokenizer ok");
assert!(tok.vocab_size() >= 4);
}
#[test]
fn malformed_merge_entry_errors() {
let json = r#"{
"model": {
"vocab": {"a": 0},
"merges": [{"not": "a pair"}]
}
}"#;
let err = HfTokenizerJson::parse(json).expect_err("should fail");
assert!(matches!(err, TokenizerError::HfFormat(_)));
}
#[test]
fn vocab_non_integer_id_errors() {
let json = r#"{
"model": {
"vocab": {"a": "not an int"},
"merges": []
}
}"#;
let err = HfTokenizerJson::parse(json).expect_err("should fail");
assert!(matches!(err, TokenizerError::HfFormat(_)));
}
#[test]
fn inverse_map_len() {
let inv = bytes_to_unicode_inverse();
assert_eq!(inv.len(), 256);
}
}