use std::collections::HashMap;
use crate::error::{Result, TextError};
use crate::tokenization::wordpiece::WordPieceTokenizer;
use crate::gpt_bpe::Gpt2BpeTokenizer;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum HfModelType {
WordPiece,
Bpe,
Unigram,
Unknown(String),
}
impl HfModelType {
pub fn as_str(&self) -> &str {
match self {
HfModelType::WordPiece => "WordPiece",
HfModelType::Bpe => "BPE",
HfModelType::Unigram => "Unigram",
HfModelType::Unknown(s) => s.as_str(),
}
}
pub fn from_str(s: &str) -> Self {
match s {
"WordPiece" | "wordpiece" | "WORDPIECE" => HfModelType::WordPiece,
"BPE" | "Bpe" | "bpe" => HfModelType::Bpe,
"Unigram" | "unigram" | "UNIGRAM" => HfModelType::Unigram,
other => HfModelType::Unknown(other.to_string()),
}
}
}
#[derive(Debug, Clone)]
pub struct HfAddedToken {
pub id: u32,
pub content: String,
pub special: bool,
pub single_word: bool,
pub lstrip: bool,
pub rstrip: bool,
pub normalized: bool,
}
impl HfAddedToken {
pub fn special(id: u32, content: impl Into<String>) -> Self {
HfAddedToken {
id,
content: content.into(),
special: true,
single_word: false,
lstrip: false,
rstrip: false,
normalized: false,
}
}
fn to_json_object(&self) -> String {
format!(
r#"{{"id":{},"content":{},"single_word":{},"lstrip":{},"rstrip":{},"normalized":{},"special":{}}}"#,
self.id,
json_string(&self.content),
self.single_word,
self.lstrip,
self.rstrip,
self.normalized,
self.special,
)
}
fn from_json_object(obj: &str) -> Option<Self> {
let id = parse_u32_field(obj, "id")?;
let content = parse_string_field(obj, "content")?;
let special = parse_bool_field(obj, "special").unwrap_or(false);
let single_word = parse_bool_field(obj, "single_word").unwrap_or(false);
let lstrip = parse_bool_field(obj, "lstrip").unwrap_or(false);
let rstrip = parse_bool_field(obj, "rstrip").unwrap_or(false);
let normalized = parse_bool_field(obj, "normalized").unwrap_or(false);
Some(HfAddedToken {
id,
content,
special,
single_word,
lstrip,
rstrip,
normalized,
})
}
}
#[derive(Debug, Clone)]
pub struct HfModel {
pub model_type: String,
pub vocab: HashMap<String, u32>,
pub merges: Option<Vec<String>>,
pub unk_token: Option<String>,
pub continuing_subword_prefix: Option<String>,
pub max_input_chars_per_word: Option<u32>,
}
impl HfModel {
fn to_json_string(&self) -> String {
let mut parts: Vec<String> = Vec::new();
parts.push(format!(r#""type":{}"#, json_string(&self.model_type)));
if let Some(ref unk) = self.unk_token {
parts.push(format!(r#""unk_token":{}"#, json_string(unk)));
}
if let Some(ref pfx) = self.continuing_subword_prefix {
parts.push(format!(
r#""continuing_subword_prefix":{}"#,
json_string(pfx)
));
}
if let Some(max_chars) = self.max_input_chars_per_word {
parts.push(format!(r#""max_input_chars_per_word":{}"#, max_chars));
}
let vocab_entries = {
let mut sorted: Vec<(&String, &u32)> = self.vocab.iter().collect();
sorted.sort_by_key(|(_, &id)| id);
sorted
.iter()
.map(|(tok, id)| format!("{}:{}", json_string(tok), id))
.collect::<Vec<_>>()
.join(",")
};
parts.push(format!(r#""vocab":{{{}}}"#, vocab_entries));
if let Some(ref merges) = self.merges {
let merge_strs = merges
.iter()
.map(|m| json_string(m))
.collect::<Vec<_>>()
.join(",");
parts.push(format!(r#""merges":[{}]"#, merge_strs));
}
format!("{{{}}}", parts.join(","))
}
fn from_json_str(s: &str) -> Result<Self> {
let model_type = parse_string_field(s, "type").ok_or_else(|| {
TextError::InvalidInput("HF JSON: missing model.type field".to_string())
})?;
let unk_token = parse_string_field(s, "unk_token");
let continuing_subword_prefix = parse_string_field(s, "continuing_subword_prefix");
let max_input_chars_per_word = parse_u32_field(s, "max_input_chars_per_word");
let vocab = parse_vocab_object(s)?;
let merges = parse_string_array_field(s, "merges");
Ok(HfModel {
model_type,
vocab,
merges,
unk_token,
continuing_subword_prefix,
max_input_chars_per_word,
})
}
}
#[derive(Debug, Clone)]
pub struct HfTokenizerJson {
pub version: String,
pub model: HfModel,
pub added_tokens: Vec<HfAddedToken>,
pub normalizer_json: Option<String>,
pub pre_tokenizer_json: Option<String>,
pub post_processor_json: Option<String>,
pub decoder_json: Option<String>,
}
impl HfTokenizerJson {
pub fn from_wordpiece(wp: &WordPieceTokenizer) -> Self {
let vocab: HashMap<String, u32> = wp.vocab_snapshot();
let get = |tok: &str, fallback: u32| -> u32 {
vocab.get(tok).copied().unwrap_or(fallback)
};
let added_tokens = vec![
HfAddedToken::special(get("[PAD]", 0), "[PAD]"),
HfAddedToken::special(get("[UNK]", 1), "[UNK]"),
HfAddedToken::special(get("[CLS]", 101), "[CLS]"),
HfAddedToken::special(get("[SEP]", 102), "[SEP]"),
HfAddedToken::special(get("[MASK]", 103), "[MASK]"),
];
let model = HfModel {
model_type: "WordPiece".to_string(),
vocab,
merges: None,
unk_token: Some("[UNK]".to_string()),
continuing_subword_prefix: Some("##".to_string()),
max_input_chars_per_word: Some(100),
};
HfTokenizerJson {
version: "1.0".to_string(),
model,
added_tokens,
normalizer_json: None,
pre_tokenizer_json: None,
post_processor_json: None,
decoder_json: None,
}
}
pub fn from_gpt2_bpe(bpe: &Gpt2BpeTokenizer) -> Self {
let vocab: HashMap<String, u32> = bpe.vocab_snapshot();
let merges: Vec<String> = bpe
.merges()
.iter()
.map(|(a, b)| format!("{} {}", a, b))
.collect();
let model = HfModel {
model_type: "BPE".to_string(),
vocab,
merges: Some(merges),
unk_token: None,
continuing_subword_prefix: None,
max_input_chars_per_word: None,
};
HfTokenizerJson {
version: "1.0".to_string(),
model,
added_tokens: vec![],
normalizer_json: None,
pre_tokenizer_json: None,
post_processor_json: None,
decoder_json: None,
}
}
pub fn to_json_string(&self) -> String {
let added_tokens_str = self
.added_tokens
.iter()
.map(|t| t.to_json_object())
.collect::<Vec<_>>()
.join(",");
let null_or = |opt: &Option<String>| -> String {
opt.as_deref().unwrap_or("null").to_string()
};
format!(
r#"{{"version":{},"truncation":null,"padding":null,"added_tokens":[{}],"normalizer":{},"pre_tokenizer":{},"post_processor":{},"decoder":{},"model":{}}}"#,
json_string(&self.version),
added_tokens_str,
null_or(&self.normalizer_json),
null_or(&self.pre_tokenizer_json),
null_or(&self.post_processor_json),
null_or(&self.decoder_json),
self.model.to_json_string(),
)
}
pub fn from_json_str(s: &str) -> Result<Self> {
let version = parse_string_field(s, "version").unwrap_or_else(|| "1.0".to_string());
let model_str = extract_object_field(s, "model").ok_or_else(|| {
TextError::InvalidInput("HF JSON: missing 'model' object".to_string())
})?;
let model = HfModel::from_json_str(&model_str)?;
let added_tokens = extract_array_field(s, "added_tokens")
.unwrap_or_default()
.iter()
.filter_map(|obj| HfAddedToken::from_json_object(obj))
.collect();
let normalizer_json = extract_object_field(s, "normalizer").map(|o| o.to_string());
let pre_tokenizer_json = extract_object_field(s, "pre_tokenizer").map(|o| o.to_string());
let post_processor_json = extract_object_field(s, "post_processor").map(|o| o.to_string());
let decoder_json = extract_object_field(s, "decoder").map(|o| o.to_string());
Ok(HfTokenizerJson {
version,
model,
added_tokens,
normalizer_json,
pre_tokenizer_json,
post_processor_json,
decoder_json,
})
}
pub fn wordpiece_roundtrip_check(wp: &WordPieceTokenizer) -> bool {
let original = Self::from_wordpiece(wp);
let json = original.to_json_string();
match Self::from_json_str(&json) {
Ok(restored) => {
restored.model.vocab.len() == original.model.vocab.len()
&& restored.model.model_type == original.model.model_type
}
Err(_) => false,
}
}
}
pub fn detect_model_type(json: &str) -> Result<HfModelType> {
let model_str = extract_object_field(json, "model").ok_or_else(|| {
TextError::InvalidInput("HF JSON: could not locate 'model' object".to_string())
})?;
let type_str = parse_string_field(&model_str, "type").ok_or_else(|| {
TextError::InvalidInput("HF JSON: missing model.type field".to_string())
})?;
Ok(HfModelType::from_str(&type_str))
}
fn json_string(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 2);
out.push('"');
for ch in s.chars() {
match ch {
'"' => out.push_str(r#"\""#),
'\\' => out.push_str(r"\\"),
'\n' => out.push_str(r"\n"),
'\r' => out.push_str(r"\r"),
'\t' => out.push_str(r"\t"),
c if (c as u32) < 0x20 => {
out.push_str(&format!("\\u{:04x}", c as u32));
}
c => out.push(c),
}
}
out.push('"');
out
}
fn extract_json_value<'a>(json: &'a str, key: &str) -> Option<&'a str> {
let needle = format!("\"{}\":", key);
let pos = json.find(needle.as_str())?;
let after_key = json[pos + needle.len()..].trim_start();
if after_key.starts_with("null") {
return None;
}
Some(after_key)
}
fn parse_string_field(json: &str, key: &str) -> Option<String> {
let raw = extract_json_value(json, key)?;
if !raw.starts_with('"') {
return None;
}
let mut chars = raw.char_indices().skip(1); let mut result = String::new();
loop {
match chars.next() {
None => return None,
Some((_, '"')) => break,
Some((_, '\\')) => {
match chars.next() {
Some((_, '"')) => result.push('"'),
Some((_, '\\')) => result.push('\\'),
Some((_, 'n')) => result.push('\n'),
Some((_, 'r')) => result.push('\r'),
Some((_, 't')) => result.push('\t'),
Some((_, 'u')) => {
let mut hex = String::new();
for _ in 0..4 {
if let Some((_, c)) = chars.next() {
hex.push(c);
}
}
if let Ok(n) = u32::from_str_radix(&hex, 16) {
if let Some(c) = char::from_u32(n) {
result.push(c);
}
}
}
Some((_, c)) => result.push(c),
None => return None,
}
}
Some((_, c)) => result.push(c),
}
}
Some(result)
}
fn parse_bool_field(json: &str, key: &str) -> Option<bool> {
let raw = extract_json_value(json, key)?;
if raw.starts_with("true") {
Some(true)
} else if raw.starts_with("false") {
Some(false)
} else {
None
}
}
fn parse_u32_field(json: &str, key: &str) -> Option<u32> {
let raw = extract_json_value(json, key)?;
let num: String = raw.chars().take_while(|c| c.is_ascii_digit()).collect();
num.parse().ok()
}
fn extract_object_field<'a>(json: &'a str, key: &str) -> Option<&'a str> {
let raw = extract_json_value(json, key)?;
if !raw.starts_with('{') {
return None;
}
let end = find_matching_brace(raw, '{', '}')?;
Some(&raw[..=end])
}
fn extract_array_field(json: &str, key: &str) -> Option<Vec<String>> {
let raw = extract_json_value(json, key)?;
if !raw.starts_with('[') {
return None;
}
let end = find_matching_brace(raw, '[', ']')?;
let inner = &raw[1..end]; Some(split_json_array_objects(inner))
}
fn parse_string_array_field(json: &str, key: &str) -> Option<Vec<String>> {
let raw = extract_json_value(json, key)?;
if !raw.starts_with('[') {
return None;
}
let end = find_matching_brace(raw, '[', ']')?;
let inner = &raw[1..end];
let mut result = Vec::new();
let mut remainder = inner.trim();
while !remainder.is_empty() {
if remainder.starts_with('"') {
let mut chars = remainder.char_indices().skip(1);
let mut s = String::new();
let mut end_pos = 0;
let mut found = false;
loop {
match chars.next() {
None => break,
Some((i, '"')) => {
end_pos = i;
found = true;
break;
}
Some((_, '\\')) => {
match chars.next() {
Some((_, c)) => s.push(c),
None => break,
}
}
Some((_, c)) => s.push(c),
}
}
if found {
result.push(s);
remainder = remainder[end_pos + 1..].trim_start_matches(',').trim();
} else {
break;
}
} else {
let skip = remainder.find(',').map(|i| i + 1).unwrap_or(remainder.len());
remainder = &remainder[skip..];
}
}
Some(result)
}
fn parse_vocab_object(json: &str) -> Result<HashMap<String, u32>> {
let vocab_raw = extract_object_field(json, "vocab").ok_or_else(|| {
TextError::InvalidInput("HF JSON: missing model.vocab object".to_string())
})?;
let inner = &vocab_raw[1..vocab_raw.len() - 1]; let mut map = HashMap::new();
let mut remainder = inner.trim();
while !remainder.is_empty() {
if remainder.starts_with('"') {
let key = match parse_json_string_at_start(remainder) {
Some((s, consumed)) => {
remainder = &remainder[consumed..];
s
}
None => break,
};
remainder = remainder.trim_start();
if !remainder.starts_with(':') {
break;
}
remainder = remainder[1..].trim_start();
let num_str: String = remainder
.chars()
.take_while(|c| c.is_ascii_digit())
.collect();
if num_str.is_empty() {
break;
}
if let Ok(id) = num_str.parse::<u32>() {
map.insert(key, id);
}
remainder = &remainder[num_str.len()..];
remainder = remainder.trim_start();
if remainder.starts_with(',') {
remainder = &remainder[1..].trim_start();
}
} else {
remainder = &remainder[1..];
}
}
Ok(map)
}
fn parse_json_string_at_start(s: &str) -> Option<(String, usize)> {
if !s.starts_with('"') {
return None;
}
let mut result = String::new();
let mut chars = s.char_indices().skip(1);
loop {
match chars.next() {
None => return None,
Some((i, '"')) => return Some((result, i + '"'.len_utf8())),
Some((_, '\\')) => match chars.next() {
Some((_, '"')) => result.push('"'),
Some((_, '\\')) => result.push('\\'),
Some((_, 'n')) => result.push('\n'),
Some((_, 'r')) => result.push('\r'),
Some((_, 't')) => result.push('\t'),
Some((_, 'u')) => {
let mut hex = String::new();
for _ in 0..4 {
if let Some((_, c)) = chars.next() {
hex.push(c);
}
}
if let Ok(n) = u32::from_str_radix(&hex, 16) {
if let Some(c) = char::from_u32(n) {
result.push(c);
}
}
}
Some((_, c)) => result.push(c),
None => return None,
},
Some((_, c)) => result.push(c),
}
}
}
fn find_matching_brace(s: &str, open: char, close: char) -> Option<usize> {
let mut depth = 0i32;
let mut in_string = false;
let mut prev_escape = false;
for (i, ch) in s.char_indices() {
if prev_escape {
prev_escape = false;
continue;
}
if in_string {
if ch == '\\' {
prev_escape = true;
} else if ch == '"' {
in_string = false;
}
continue;
}
if ch == '"' {
in_string = true;
} else if ch == open {
depth += 1;
} else if ch == close {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
}
None
}
fn split_json_array_objects(inner: &str) -> Vec<String> {
let mut result = Vec::new();
let mut remainder = inner.trim();
while !remainder.is_empty() {
if remainder.starts_with('{') {
match find_matching_brace(remainder, '{', '}') {
Some(end) => {
result.push(remainder[..=end].to_string());
remainder = remainder[end + 1..].trim_start_matches(',').trim();
}
None => break,
}
} else {
let skip = remainder.find('{').unwrap_or(remainder.len());
if skip == remainder.len() {
break;
}
remainder = &remainder[skip..];
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenization::wordpiece::WordPieceTokenizer;
fn minimal_wp() -> WordPieceTokenizer {
let tokens = vec![
"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]",
"hello", "world", "##ing", "foo",
];
WordPieceTokenizer::from_vocab_list(&tokens)
}
#[test]
fn from_wordpiece_model_type() {
let wp = minimal_wp();
let hf = HfTokenizerJson::from_wordpiece(&wp);
assert_eq!(hf.model.model_type, "WordPiece");
}
#[test]
fn to_json_string_contains_vocab() {
let wp = minimal_wp();
let hf = HfTokenizerJson::from_wordpiece(&wp);
let s = hf.to_json_string();
assert!(s.contains("\"vocab\""), "JSON must contain vocab key");
}
#[test]
fn roundtrip_from_json_str() {
let wp = minimal_wp();
let hf = HfTokenizerJson::from_wordpiece(&wp);
let json = hf.to_json_string();
let restored = HfTokenizerJson::from_json_str(&json).expect("parse failed");
assert_eq!(restored.model.model_type, "WordPiece");
}
#[test]
fn detect_model_type_wordpiece() {
let wp = minimal_wp();
let hf = HfTokenizerJson::from_wordpiece(&wp);
let json = hf.to_json_string();
let mt = detect_model_type(&json).expect("detect failed");
assert_eq!(mt, HfModelType::WordPiece);
}
#[test]
fn detect_model_type_bpe() {
let json = r#"{"version":"1.0","model":{"type":"BPE","vocab":{"hello":0},"merges":["h e"]},"added_tokens":[]}"#;
let mt = detect_model_type(json).expect("detect failed");
assert_eq!(mt, HfModelType::Bpe);
}
#[test]
fn added_tokens_contains_cls() {
let wp = minimal_wp();
let hf = HfTokenizerJson::from_wordpiece(&wp);
let has_cls = hf.added_tokens.iter().any(|t| t.content == "[CLS]");
assert!(has_cls, "added_tokens must contain [CLS]");
}
#[test]
fn vocab_size_matches_input() {
let wp = minimal_wp();
let hf = HfTokenizerJson::from_wordpiece(&wp);
assert_eq!(hf.model.vocab.len(), wp.vocab_size());
}
#[test]
fn empty_vocab_serialises_without_panic() {
let tokens: &[&str] = &[];
let wp = WordPieceTokenizer::from_vocab_list(tokens);
let hf = HfTokenizerJson::from_wordpiece(&wp);
let json = hf.to_json_string();
assert!(json.contains("WordPiece"));
}
#[test]
fn hf_model_type_variants_accessible() {
let _ = HfModelType::WordPiece;
let _ = HfModelType::Bpe;
let _ = HfModelType::Unigram;
let _ = HfModelType::Unknown("X".to_string());
}
#[test]
fn invalid_json_returns_err() {
let result = HfTokenizerJson::from_json_str("not json at all }{");
assert!(result.is_err());
}
#[test]
fn roundtrip_check_helper() {
let wp = minimal_wp();
assert!(HfTokenizerJson::wordpiece_roundtrip_check(&wp));
}
#[test]
fn version_field_preserved() {
let wp = minimal_wp();
let hf = HfTokenizerJson::from_wordpiece(&wp);
let json = hf.to_json_string();
let restored = HfTokenizerJson::from_json_str(&json).unwrap();
assert_eq!(restored.version, "1.0");
}
}