use std::path::Path;
use crate::encoder::{BacktrackingBytePairEncoder, BytePairEncoder, Encoder, EncoderType, SentencePieceBPE, UnigramEncoder, WordPieceEncoder};
use crate::decoder::Decoder;
use crate::normalizer::Normalizer;
use crate::postprocessor::PostProcessor;
use crate::pretok::{PretokType, Pretokenizer};
use crate::tokenizer::Tokenizer;
use crate::types::TokenId;
#[derive(Debug)]
pub enum JsonLoadError {
Io(std::io::Error),
Json(serde_json::Error),
InvalidFormat(&'static str),
}
impl std::fmt::Display for JsonLoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "IO error: {}", e),
Self::Json(e) => write!(f, "JSON error: {}", e),
Self::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
}
}
}
impl std::error::Error for JsonLoadError {}
impl From<std::io::Error> for JsonLoadError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<serde_json::Error> for JsonLoadError {
fn from(e: serde_json::Error) -> Self {
Self::Json(e)
}
}
pub fn from_json(path: impl AsRef<Path>) -> Result<Tokenizer, JsonLoadError> {
let json_str = std::fs::read_to_string(path)?;
from_json_str(&json_str)
}
pub fn from_json_with_pretokenizer(
path: impl AsRef<Path>,
pretokenizer_type: PretokType,
) -> Result<Tokenizer, JsonLoadError> {
let json_str = std::fs::read_to_string(path)?;
from_json_str_with_pretokenizer(&json_str, pretokenizer_type)
}
pub fn from_json_with_encoder(
path: impl AsRef<Path>,
encoder_type: EncoderType,
) -> Result<Tokenizer, JsonLoadError> {
let json_str = std::fs::read_to_string(path)?;
let data: serde_json::Value = serde_json::from_str(&json_str)?;
let detected = detect_pretokenizer_type(&data);
let mut tok = load_from_json_value_with_encoder(&data, detected.pretok_type, encoder_type)?;
apply_regex_fallback(&mut tok, &detected);
Ok(tok)
}
pub fn from_json_with_options(
path: impl AsRef<Path>,
encoder_type: EncoderType,
pretokenizer_type: PretokType,
) -> Result<Tokenizer, JsonLoadError> {
let json_str = std::fs::read_to_string(path)?;
let data: serde_json::Value = serde_json::from_str(&json_str)?;
load_from_json_value_with_encoder(&data, pretokenizer_type, encoder_type)
}
pub fn from_json_str(json_str: &str) -> Result<Tokenizer, JsonLoadError> {
let data: serde_json::Value = serde_json::from_str(json_str)?;
let detected = detect_pretokenizer_type(&data);
let mut tok = load_from_json_value(&data, detected.pretok_type)?;
apply_regex_fallback(&mut tok, &detected);
Ok(tok)
}
pub fn from_json_str_with_pretokenizer(
json_str: &str,
pretokenizer_type: PretokType,
) -> Result<Tokenizer, JsonLoadError> {
let data: serde_json::Value = serde_json::from_str(json_str)?;
load_from_json_value(&data, pretokenizer_type)
}
fn load_from_json_value(
data: &serde_json::Value,
pretokenizer_type: PretokType,
) -> Result<Tokenizer, JsonLoadError> {
load_from_json_value_with_encoder(data, pretokenizer_type, EncoderType::Backtracking)
}
fn load_from_json_value_with_encoder(
data: &serde_json::Value,
pretokenizer_type: PretokType,
encoder_type: EncoderType,
) -> Result<Tokenizer, JsonLoadError> {
let model = &data["model"];
let normalizer = detect_normalizer(data);
let is_unigram = model["type"].as_str() == Some("Unigram")
|| (model["vocab"].is_array() && model["merges"].is_null());
if is_unigram {
return load_unigram(data, pretokenizer_type, normalizer);
}
let is_wordpiece = model["type"].as_str() == Some("WordPiece")
|| data["decoder"]["type"].as_str() == Some("WordPiece")
|| (model["merges"].is_null() && model["continuing_subword_prefix"].is_string());
if is_wordpiece {
return load_wordpiece(data, pretokenizer_type, normalizer);
}
let vocab_map = model["vocab"]
.as_object()
.ok_or(JsonLoadError::InvalidFormat("vocab should be object"))?;
let merges_arr = model["merges"]
.as_array()
.ok_or(JsonLoadError::InvalidFormat("merges should be array"))?;
let mut vocab: Vec<(String, u32)> = vocab_map
.iter()
.map(|(k, v)| (k.clone(), v.as_u64().unwrap_or(0) as u32))
.collect();
vocab.sort_by_key(|(_, id)| *id);
let num_base_tokens = detect_num_base_tokens(vocab_map, merges_arr);
if are_merges_topological(vocab_map, merges_arr, num_base_tokens) {
load_byte_level_bpe(data, &vocab, vocab_map, merges_arr, pretokenizer_type, encoder_type, normalizer)
} else {
load_vocab_defined_bpe(data, &vocab, vocab_map, merges_arr, pretokenizer_type, encoder_type, normalizer)
}
}
fn detect_num_base_tokens(
vocab_map: &serde_json::Map<String, serde_json::Value>,
merges_arr: &[serde_json::Value],
) -> usize {
if let Some(first_merge) = merges_arr.first() {
let (left_str, right_str) = if let Some(arr) = first_merge.as_array() {
if arr.len() >= 2 {
match (arr[0].as_str(), arr[1].as_str()) {
(Some(l), Some(r)) => (l, r),
_ => return 256,
}
} else {
return 256;
}
} else if let Some(s) = first_merge.as_str() {
let mut parts = s.split(' ');
match (parts.next(), parts.next()) {
(Some(l), Some(r)) => (l, r),
_ => return 256,
}
} else {
return 256;
};
let merged = format!("{}{}", left_str, right_str);
if let Some(id) = vocab_map.get(&merged).and_then(|v| v.as_u64()) {
return id as usize;
}
}
256 }
fn are_merges_topological(
vocab_map: &serde_json::Map<String, serde_json::Value>,
merges_arr: &[serde_json::Value],
num_base_tokens: usize,
) -> bool {
for (merge_idx, merge) in merges_arr.iter().enumerate() {
let (left_str, right_str) = if let Some(arr) = merge.as_array() {
if arr.len() >= 2 {
match (arr[0].as_str(), arr[1].as_str()) {
(Some(l), Some(r)) => (l, r),
_ => continue,
}
} else {
continue;
}
} else if let Some(s) = merge.as_str() {
let mut parts = s.split(' ');
match (parts.next(), parts.next()) {
(Some(l), Some(r)) => (l, r),
_ => continue,
}
} else {
continue;
};
let left_id = vocab_map
.get(left_str)
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let right_id = vocab_map
.get(right_str)
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let max_available = num_base_tokens + merge_idx;
if left_id >= max_available || right_id >= max_available {
return false;
}
}
true
}
fn load_byte_level_bpe(
data: &serde_json::Value,
vocab: &[(String, u32)],
vocab_map: &serde_json::Map<String, serde_json::Value>,
merges_arr: &[serde_json::Value],
pretokenizer_type: PretokType,
encoder_type: EncoderType,
normalizer: Normalizer,
) -> Result<Tokenizer, JsonLoadError> {
let num_base_tokens = detect_num_base_tokens(vocab_map, merges_arr);
let full_vocab: Vec<(u32, Vec<u8>)> = vocab
.iter()
.map(|(s, id)| (*id, decode_bytelevel_token(s)))
.collect();
let merges: Vec<(u32, u32)> = merges_arr
.iter()
.filter_map(|m| {
if let Some(arr) = m.as_array() {
if arr.len() >= 2 {
let left_str = arr[0].as_str()?;
let right_str = arr[1].as_str()?;
let left = vocab_map.get(left_str)?.as_u64()? as u32;
let right = vocab_map.get(right_str)?.as_u64()? as u32;
return Some((left, right));
}
}
let s = m.as_str()?;
let mut parts = s.split(' ');
let left = vocab_map.get(parts.next()?)?.as_u64()? as u32;
let right = vocab_map.get(parts.next()?)?.as_u64()? as u32;
Some((left, right))
})
.collect();
let (encoder, token_bytes) = match encoder_type {
EncoderType::Backtracking | EncoderType::WordPiece | EncoderType::SentencePiece | EncoderType::Unigram => {
let (enc, bytes) = BacktrackingBytePairEncoder::from_vocab_and_merges(
&full_vocab, &merges, num_base_tokens,
);
(Encoder::Backtracking(enc), bytes)
}
EncoderType::Simple => {
let (enc, bytes) =
BytePairEncoder::from_vocab_and_merges(&full_vocab, &merges, num_base_tokens);
(Encoder::Simple(enc), bytes)
}
};
let decoder = Decoder::for_encoder(token_bytes, encoder.encoder_type());
let post_processor = detect_post_processor(data);
let mut tokenizer = Tokenizer::new(encoder, decoder, pretokenizer_type, normalizer, post_processor);
if let Some(pad_id) = extract_pad_token_id(data) {
tokenizer.set_pad_token_id(pad_id);
}
setup_added_tokens(&mut tokenizer, data);
Ok(tokenizer)
}
fn load_vocab_defined_bpe(
data: &serde_json::Value,
vocab: &[(String, u32)],
vocab_map: &serde_json::Map<String, serde_json::Value>,
merges_arr: &[serde_json::Value],
pretokenizer_type: PretokType,
encoder_type: EncoderType,
normalizer: Normalizer,
) -> Result<Tokenizer, JsonLoadError> {
let uses_bytelevel = is_bytelevel_decoder(data);
let mut byte_fallback_ids = foldhash::HashSet::default();
let full_vocab: Vec<(u32, Vec<u8>)> = vocab
.iter()
.map(|(s, id)| {
let bytes = if uses_bytelevel {
decode_bytelevel_token(s)
} else if let Some(byte_val) = parse_byte_fallback_token(s) {
byte_fallback_ids.insert(*id);
vec![byte_val]
} else {
decode_sentencepiece_token(s)
};
(*id, bytes)
})
.collect();
let num_base_tokens = full_vocab
.iter()
.take_while(|(_, bytes)| bytes.len() == 1)
.count()
.max(256);
let merges: Vec<(u32, u32)> = merges_arr
.iter()
.filter_map(|m| {
if let Some(arr) = m.as_array() {
if arr.len() >= 2 {
let left_str = arr[0].as_str()?;
let right_str = arr[1].as_str()?;
let left = vocab_map.get(left_str)?.as_u64()? as u32;
let right = vocab_map.get(right_str)?.as_u64()? as u32;
return Some((left, right));
}
}
let s = m.as_str()?;
let mut parts = s.split(' ');
let left = vocab_map.get(parts.next()?)?.as_u64()? as u32;
let right = vocab_map.get(parts.next()?)?.as_u64()? as u32;
Some((left, right))
})
.collect();
let use_sentencepiece = encoder_type == EncoderType::SentencePiece
|| matches!(normalizer, Normalizer::Metaspace | Normalizer::MetaspaceReplace);
let use_simple = uses_bytelevel && encoder_type != EncoderType::SentencePiece;
let (encoder, token_bytes) = if use_sentencepiece {
let (enc, bytes) =
SentencePieceBPE::from_vocab_and_merges(&full_vocab, &merges, num_base_tokens, &byte_fallback_ids);
(Encoder::SentencePiece(enc), bytes)
} else if use_simple || encoder_type == EncoderType::Simple {
let (enc, bytes) =
BytePairEncoder::from_vocab_and_merges(&full_vocab, &merges, num_base_tokens);
(Encoder::Simple(enc), bytes)
} else {
let (enc, bytes) =
BacktrackingBytePairEncoder::from_vocab_and_merges(&full_vocab, &merges, num_base_tokens);
(Encoder::Backtracking(enc), bytes)
};
let decoder = Decoder::for_encoder(token_bytes, encoder.encoder_type());
let post_processor = detect_post_processor(data);
let mut tokenizer = Tokenizer::new(encoder, decoder, pretokenizer_type, normalizer, post_processor);
if let Some(pad_id) = extract_pad_token_id(data) {
tokenizer.set_pad_token_id(pad_id);
}
setup_added_tokens(&mut tokenizer, data);
Ok(tokenizer)
}
#[inline]
fn parse_byte_fallback_token(s: &str) -> Option<u8> {
if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
u8::from_str_radix(&s[3..5], 16).ok()
} else {
None
}
}
fn apply_regex_fallback(tok: &mut Tokenizer, detected: &DetectedPretokenizer) {
if tok.pretokenizer().is_none() {
if let Some(pattern) = &detected.fallback_pattern {
let patterns = if pattern.contains("(?!\\S)") || pattern.contains("(?!\\s)") {
let main = pattern.replace("\\s+(?!\\S)", "\\s+$")
.replace("\\s+(?!\\s)", "\\s+$");
vec![
(main, false),
("\\s+\\s".to_string(), true),
("\\s+".to_string(), false),
]
} else {
vec![(pattern.clone(), false)]
};
let pat_refs: Vec<(&str, bool)> = patterns.iter()
.map(|(p, l)| (p.as_str(), *l))
.collect();
if let Ok(regex) = pretokie::Regex::new(&pat_refs) {
tok.set_pretokenizer(Some(Pretokenizer::from_regex(regex)));
}
}
}
}
struct DetectedPretokenizer {
pretok_type: PretokType,
fallback_pattern: Option<String>,
}
fn detect_pretokenizer_type(data: &serde_json::Value) -> DetectedPretokenizer {
let pre_tokenizer = &data["pre_tokenizer"];
if let Some(typ) = pre_tokenizer["type"].as_str() {
if typ == "ByteLevel" {
return DetectedPretokenizer { pretok_type: PretokType::Gpt2, fallback_pattern: None };
}
if typ == "Sequence" {
if let Some(pretokenizers) = pre_tokenizer["pretokenizers"].as_array() {
let has_byte_level = pretokenizers
.iter()
.any(|p| p["type"].as_str() == Some("ByteLevel"));
if has_byte_level {
let has_digits = pretokenizers
.iter()
.any(|p| p["type"].as_str() == Some("Digits"));
let mut split_patterns: Vec<String> = Vec::new();
for p in pretokenizers {
if p["type"].as_str() == Some("Split") {
if let Some(pattern) = p["pattern"]["Regex"].as_str() {
split_patterns.push(pattern.to_string());
let is_case_aware = pattern.contains("\\p{Lu}")
|| pattern.contains("\\p{Lt}")
|| pattern.contains("\\p{Ll}");
if is_case_aware {
return DetectedPretokenizer { pretok_type: PretokType::O200k, fallback_pattern: None };
}
if pattern.contains("[\\p{L}\\p{M}]+") {
if pattern.contains("\\p{N}{") {
return DetectedPretokenizer { pretok_type: PretokType::DeepSeek, fallback_pattern: None };
}
return DetectedPretokenizer { pretok_type: PretokType::Qwen35, fallback_pattern: None };
}
if pattern.contains("\\p{L}+") || pattern.contains("(?i:'s|'t|'re") {
if pattern.contains("\\p{N}|") && !pattern.contains("\\p{N}{") {
return DetectedPretokenizer { pretok_type: PretokType::Voyage, fallback_pattern: None };
}
return DetectedPretokenizer { pretok_type: PretokType::Cl100k, fallback_pattern: None };
}
}
}
}
if has_digits {
return DetectedPretokenizer { pretok_type: PretokType::SmolLM, fallback_pattern: None };
}
if !split_patterns.is_empty() {
return DetectedPretokenizer {
pretok_type: PretokType::Gpt2, fallback_pattern: Some(split_patterns.join("|")),
};
}
return DetectedPretokenizer { pretok_type: PretokType::Gpt2, fallback_pattern: None };
}
let mut split_patterns: Vec<String> = Vec::new();
for p in pretokenizers {
if p["type"].as_str() == Some("Split") {
if let Some(pattern) = p["pattern"]["Regex"].as_str() {
split_patterns.push(pattern.to_string());
}
}
}
if !split_patterns.is_empty() {
return DetectedPretokenizer {
pretok_type: PretokType::None,
fallback_pattern: Some(split_patterns.join("|")),
};
}
}
}
if typ == "Metaspace" {
return DetectedPretokenizer { pretok_type: PretokType::None, fallback_pattern: None };
}
if typ == "Split" {
if let Some(pattern) = pre_tokenizer["pattern"]["Regex"].as_str() {
return DetectedPretokenizer {
pretok_type: PretokType::None,
fallback_pattern: Some(pattern.to_string()),
};
}
}
}
DetectedPretokenizer { pretok_type: PretokType::None, fallback_pattern: None }
}
fn detect_normalizer(data: &serde_json::Value) -> Normalizer {
let normalizer = &data["normalizer"];
let has_metaspace = is_metaspace_pretokenizer(data);
let has_whitespace_split = is_whitespace_split_pretokenizer(data);
if normalizer.is_null() {
if has_metaspace {
return Normalizer::Metaspace;
}
return Normalizer::None;
}
if let Some(typ) = normalizer["type"].as_str() {
match typ {
"Precompiled" => {
if has_metaspace {
return Normalizer::SentencePiece;
}
}
"BertNormalizer" => {
let lowercase = normalizer["lowercase"].as_bool().unwrap_or(false);
if lowercase {
return Normalizer::BertUncased;
} else {
return Normalizer::BertCased;
}
}
"NFC" => {
return Normalizer::Nfc;
}
"Sequence" => {
if let Some(normalizers) = normalizer["normalizers"].as_array() {
let has_lowercase = normalizers.iter().any(|n| {
n["type"].as_str() == Some("Lowercase")
});
let has_precompiled = normalizers.iter().any(|n| {
n["type"].as_str() == Some("Precompiled")
});
let has_prepend_metaspace = normalizers.iter().any(|n| {
n["type"].as_str() == Some("Prepend")
&& n["prepend"].as_str() == Some("▁")
});
let has_replace_space_metaspace = normalizers.iter().any(|n| {
n["type"].as_str() == Some("Replace")
&& n["pattern"]["String"].as_str() == Some(" ")
&& n["content"].as_str() == Some("▁")
});
if has_precompiled && has_lowercase && has_metaspace && has_whitespace_split {
return Normalizer::SentencePieceLowercase;
}
if has_precompiled && has_metaspace {
return Normalizer::SentencePiece;
}
if has_prepend_metaspace && has_replace_space_metaspace {
return Normalizer::Metaspace;
}
for n in normalizers {
if let Some(n_type) = n["type"].as_str() {
if n_type == "NFC" {
return Normalizer::Nfc;
}
if n_type == "BertNormalizer" {
let lowercase = n["lowercase"].as_bool().unwrap_or(false);
if lowercase {
return Normalizer::BertUncased;
} else {
return Normalizer::BertCased;
}
}
}
}
}
}
"Lowercase" => {
if has_metaspace && has_whitespace_split {
return Normalizer::SentencePieceLowercase;
}
return Normalizer::BertUncased;
}
"Replace" => {
let pattern = normalizer["pattern"]["String"].as_str();
let content = normalizer["content"].as_str();
if pattern == Some(" ") && content == Some("▁") {
return Normalizer::MetaspaceReplace;
}
}
_ => {}
}
}
if has_metaspace {
return Normalizer::Metaspace;
}
Normalizer::None
}
fn is_whitespace_split_pretokenizer(data: &serde_json::Value) -> bool {
let pre_tokenizer = &data["pre_tokenizer"];
if let Some(typ) = pre_tokenizer["type"].as_str() {
if typ == "WhitespaceSplit" {
return true;
}
}
if let Some(pretokenizers) = pre_tokenizer["pretokenizers"].as_array() {
for p in pretokenizers {
if let Some(typ) = p["type"].as_str() {
if typ == "WhitespaceSplit" {
return true;
}
}
}
}
false
}
fn is_metaspace_pretokenizer(data: &serde_json::Value) -> bool {
let pre_tokenizer = &data["pre_tokenizer"];
if let Some(typ) = pre_tokenizer["type"].as_str() {
if typ == "Metaspace" {
return true;
}
}
if let Some(pretokenizers) = pre_tokenizer["pretokenizers"].as_array() {
for p in pretokenizers {
if let Some(typ) = p["type"].as_str() {
if typ == "Metaspace" {
return true;
}
}
}
}
false
}
fn decode_bytelevel_token(s: &str) -> Vec<u8> {
static NON_PRINTABLE: [u8; 68] = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 173, ];
let mut bytes = Vec::with_capacity(s.len());
for c in s.chars() {
let code = c as u32;
let b = if code >= 256 && code < 256 + NON_PRINTABLE.len() as u32 {
NON_PRINTABLE[(code - 256) as usize]
} else if code <= 255 {
code as u8
} else {
bytes.extend(c.to_string().as_bytes());
continue;
};
bytes.push(b);
}
bytes
}
fn is_bytelevel_decoder(data: &serde_json::Value) -> bool {
let decoder = &data["decoder"];
if let Some(typ) = decoder["type"].as_str() {
if typ == "ByteLevel" {
return true;
}
}
if let Some(decoders) = decoder["decoders"].as_array() {
for d in decoders {
if let Some(typ) = d["type"].as_str() {
if typ == "ByteLevel" {
return true;
}
}
}
}
if let Some(pretoks) = data["pre_tokenizer"]["pretokenizers"].as_array() {
for p in pretoks {
if let Some(typ) = p["type"].as_str() {
if typ == "ByteLevel" {
return true;
}
}
}
}
false
}
fn decode_sentencepiece_token(s: &str) -> Vec<u8> {
if s.starts_with("<0x") && s.ends_with('>') && s.len() == 6 {
if let Ok(byte) = u8::from_str_radix(&s[3..5], 16) {
return vec![byte];
}
}
s.as_bytes().to_vec()
}
fn load_wordpiece(
data: &serde_json::Value,
pretokenizer_type: PretokType,
normalizer: Normalizer,
) -> Result<Tokenizer, JsonLoadError> {
let model = &data["model"];
let vocab_map = model["vocab"]
.as_object()
.ok_or(JsonLoadError::InvalidFormat("vocab should be object"))?;
let unk_token_str = model["unk_token"]
.as_str()
.unwrap_or("[UNK]");
let continuation_prefix = model["continuing_subword_prefix"]
.as_str()
.unwrap_or("##");
let max_input_chars_per_word = model["max_input_chars_per_word"]
.as_u64()
.unwrap_or(100) as usize;
let mut vocab: Vec<(String, u32)> = vocab_map
.iter()
.map(|(k, v)| (k.clone(), v.as_u64().unwrap_or(0) as u32))
.collect();
vocab.sort_by_key(|(_, id)| *id);
let unk_token = vocab_map
.get(unk_token_str)
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
let token_bytes: Vec<Vec<u8>> = vocab
.iter()
.map(|(s, _)| s.as_bytes().to_vec())
.collect();
let vocab_pairs: Vec<(Vec<u8>, u32)> = token_bytes
.iter()
.enumerate()
.map(|(i, bytes)| (bytes.clone(), i as u32))
.collect();
let encoder = WordPieceEncoder::from_vocab(
&vocab_pairs,
unk_token,
continuation_prefix.as_bytes(),
max_input_chars_per_word,
);
let decoder = Decoder::for_encoder(token_bytes, EncoderType::WordPiece);
let pretok = if pretokenizer_type == PretokType::None {
PretokType::Bert
} else {
pretokenizer_type
};
let post_processor = detect_post_processor(data);
let mut tokenizer = Tokenizer::new(Encoder::WordPiece(encoder), decoder, pretok, normalizer, post_processor);
if let Some(pad_id) = extract_pad_token_id(data) {
tokenizer.set_pad_token_id(pad_id);
}
setup_added_tokens(&mut tokenizer, data);
Ok(tokenizer)
}
fn load_unigram(
data: &serde_json::Value,
pretokenizer_type: PretokType,
normalizer: Normalizer,
) -> Result<Tokenizer, JsonLoadError> {
let model = &data["model"];
let vocab_arr = model["vocab"]
.as_array()
.ok_or(JsonLoadError::InvalidFormat("Unigram vocab should be array"))?;
let unk_id = model["unk_id"].as_u64().unwrap_or(0) as u32;
let vocab: Vec<(u32, Vec<u8>, f32)> = vocab_arr
.iter()
.enumerate()
.filter_map(|(id, entry)| {
let arr = entry.as_array()?;
if arr.len() < 2 {
return None;
}
let token_str = arr[0].as_str()?;
let score = arr[1].as_f64()? as f32;
let bytes = decode_sentencepiece_token(token_str);
Some((id as u32, bytes, score))
})
.collect();
if vocab.is_empty() {
return Err(JsonLoadError::InvalidFormat("Unigram vocab is empty"));
}
let (encoder, token_bytes) = UnigramEncoder::from_vocab_with_scores(&vocab, unk_id);
let decoder = Decoder::for_encoder(token_bytes, EncoderType::Unigram);
let pretok = if pretokenizer_type == PretokType::None {
if is_metaspace_pretokenizer(data) {
PretokType::None } else {
PretokType::None
}
} else {
pretokenizer_type
};
let post_processor = detect_post_processor(data);
let mut tokenizer = Tokenizer::new(Encoder::Unigram(encoder), decoder, pretok, normalizer, post_processor);
if let Some(pad_id) = extract_pad_token_id(data) {
tokenizer.set_pad_token_id(pad_id);
}
setup_added_tokens(&mut tokenizer, data);
Ok(tokenizer)
}
fn detect_post_processor(data: &serde_json::Value) -> PostProcessor {
let pp = &data["post_processor"];
let pp_type = pp["type"].as_str().unwrap_or("");
match pp_type {
"TemplateProcessing" => {
if let Some(single) = pp["single"].as_array() {
parse_template_post_processor(data, single)
} else {
PostProcessor::None
}
}
"Sequence" => {
if let Some(processors) = pp["processors"].as_array() {
for processor in processors {
if processor["type"].as_str() == Some("TemplateProcessing") {
if let Some(single) = processor["single"].as_array() {
return parse_template_post_processor(data, single);
}
}
}
}
PostProcessor::None
}
_ => PostProcessor::None,
}
}
fn parse_template_post_processor(data: &serde_json::Value, single: &[serde_json::Value]) -> PostProcessor {
let mut cls_token = None;
let mut sep_token = None;
let mut bos_token = None;
for item in single {
if let Some(special) = item.get("SpecialToken") {
if let Some(id) = special["id"].as_str() {
let token_id = lookup_special_token_id(data, id);
match id {
"[CLS]" => cls_token = token_id,
"[SEP]" => sep_token = token_id,
"<|begin_of_text|>" | "<s>" | "<bos>" => bos_token = token_id,
_ => {}
}
}
}
}
if let (Some(cls), Some(sep)) = (cls_token, sep_token) {
PostProcessor::Bert { cls_token: cls, sep_token: sep }
} else if let Some(bos) = bos_token {
PostProcessor::Prefix { bos_token: bos }
} else {
PostProcessor::None
}
}
fn extract_pad_token_id(data: &serde_json::Value) -> Option<TokenId> {
if let Some(pad_id) = data["padding"]["pad_id"].as_u64() {
return Some(pad_id as TokenId);
}
if let Some(added) = data["added_tokens"].as_array() {
for token in added {
if let Some(content) = token["content"].as_str() {
match content {
"[PAD]" | "<pad>" | "<|pad|>" => {
return token["id"].as_u64().map(|id| id as TokenId);
}
_ => {}
}
}
}
}
None
}
fn extract_added_tokens(data: &serde_json::Value) -> Vec<(TokenId, Vec<u8>)> {
let Some(added) = data["added_tokens"].as_array() else {
return Vec::new();
};
added.iter().filter_map(|token| {
let id = token["id"].as_u64()? as TokenId;
let content = token["content"].as_str()?;
if content.is_empty() {
return None;
}
if content.len() == 1 {
return None;
}
Some((id, content.as_bytes().to_vec()))
}).collect()
}
fn setup_added_tokens(tokenizer: &mut Tokenizer, data: &serde_json::Value) {
let added = extract_added_tokens(data);
if !added.is_empty() {
tokenizer.set_added_tokens(&added);
}
if let Some(added_arr) = data["added_tokens"].as_array() {
let special: Vec<(String, TokenId)> = added_arr.iter().filter_map(|token| {
let special = token["special"].as_bool().unwrap_or(false);
if !special { return None; }
let id = token["id"].as_u64()? as TokenId;
let content = token["content"].as_str()?;
Some((content.to_string(), id))
}).collect();
if !special.is_empty() {
tokenizer.set_special_tokens(special);
}
}
}
fn lookup_special_token_id(data: &serde_json::Value, token_str: &str) -> Option<u32> {
if let Some(added) = data["added_tokens"].as_array() {
for token in added {
if token["content"].as_str() == Some(token_str) {
return token["id"].as_u64().map(|id| id as u32);
}
}
}
if let Some(vocab) = data["model"]["vocab"].as_object() {
if let Some(id) = vocab.get(token_str) {
return id.as_u64().map(|id| id as u32);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decode_bytelevel_token_ascii() {
assert_eq!(decode_bytelevel_token("Hello"), b"Hello".to_vec());
assert_eq!(decode_bytelevel_token("world"), b"world".to_vec());
}
#[test]
fn test_decode_bytelevel_token_space() {
assert_eq!(decode_bytelevel_token("Ġ"), vec![32]);
assert_eq!(decode_bytelevel_token("Ġhello"), vec![32, 104, 101, 108, 108, 111]);
}
#[test]
fn test_decode_bytelevel_token_newline() {
assert_eq!(decode_bytelevel_token("Ċ"), vec![10]);
}
#[test]
fn test_decode_bytelevel_token_tab() {
assert_eq!(decode_bytelevel_token("ĉ"), vec![9]);
}
#[test]
fn test_decode_bytelevel_token_punctuation() {
assert_eq!(decode_bytelevel_token(","), vec![44]);
assert_eq!(decode_bytelevel_token("."), vec![46]);
assert_eq!(decode_bytelevel_token("!"), vec![33]);
}
}