use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error)]
pub enum TokenizerMapError {
#[error("TokenizerMap validation failed: {0}")]
Validation(String),
#[error("TokenizerMap parse failed: {0}")]
Parse(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenizerMap {
#[serde(default)]
pub id: String,
#[serde(default = "default_version")]
pub version: String,
#[serde(default, rename = "vocab_size")]
pub vocab_size: i64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub vocab: Option<HashMap<String, u32>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tokens: Option<HashMap<String, String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub encoder: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub merges: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none", rename = "pre_tokenizer_pattern")]
pub pre_tokenizer_pattern: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none", rename = "pre_tokenizer_program")]
pub pre_tokenizer_program: Option<crate::pretok_program::PreTokProgram>,
#[serde(default, skip_serializing_if = "Option::is_none", rename = "byte_fallback_start")]
pub byte_fallback_start: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none", rename = "byte_fallback_end")]
pub byte_fallback_end: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none", rename = "special_tokens")]
pub special_tokens: Option<HashMap<String, u32>>,
#[serde(default, skip_serializing_if = "Option::is_none", rename = "tool_calling")]
pub tool_calling: Option<ToolCallingBlock>,
#[serde(default, skip_serializing_if = "Option::is_none", rename = "published_at")]
pub published_at: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallingBlock {
pub convention: ToolCallingConvention,
pub markers: ToolCallingMarkers,
pub args_format: ToolCallingArgsFormat,
pub result_format: ToolCallingResultFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallingMarkers {
pub start: String,
pub end: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolCallingConvention {
Llama3,
Qwen25,
Phi4,
MistralNemo,
DeepseekV3,
DeepseekR1,
Custom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolCallingArgsFormat {
Json,
PythonArgs,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolCallingResultFormat {
Text,
Json,
}
fn default_version() -> String {
"2".to_string()
}
impl TokenizerMap {
pub fn from_json(json: &[u8]) -> Result<Self, TokenizerMapError> {
let map: TokenizerMap = serde_json::from_slice(json)?;
Self::validate(&map)?;
Ok(map)
}
pub fn from_json_str(json: &str) -> Result<Self, TokenizerMapError> {
Self::from_json(json.as_bytes())
}
pub fn verify_sha256(bytes: &[u8], expected: &str) -> Result<String, (String, String)> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(bytes);
let actual = hex::encode(hasher.finalize());
let want = parse_hash(expected);
if actual.eq_ignore_ascii_case(&want) {
Ok(actual)
} else {
Err((want, actual))
}
}
pub fn validate(map: &Self) -> Result<(), TokenizerMapError> {
if map.id.is_empty() {
return Err(TokenizerMapError::Validation(
"id must be a non-empty string".into(),
));
}
if map.version.is_empty() {
return Err(TokenizerMapError::Validation(
"version must be a non-empty string".into(),
));
}
if map.vocab_size < 1 {
return Err(TokenizerMapError::Validation(
"vocab_size must be a positive integer".into(),
));
}
let has_vocab = map.vocab.as_ref().is_some_and(|v| !v.is_empty());
let has_tokens = map.tokens.as_ref().is_some_and(|v| !v.is_empty());
if !has_vocab && !has_tokens {
return Err(TokenizerMapError::Validation(
"one of `vocab` (v2) or `tokens` (v1) is required".into(),
));
}
match map.encoder.as_deref() {
None | Some("byte_level") | Some("metaspace") => {}
Some(other) => {
return Err(TokenizerMapError::Validation(format!(
"encoder must be \"byte_level\" or \"metaspace\" if present, got \"{other}\""
)));
}
}
if map.byte_fallback_start.is_some() != map.byte_fallback_end.is_some() {
return Err(TokenizerMapError::Validation(
"byte_fallback_start and byte_fallback_end must both be set or both omitted"
.into(),
));
}
if let Some(tc) = &map.tool_calling {
if tc.markers.start.is_empty() || tc.markers.end.is_empty() {
return Err(TokenizerMapError::Validation(
"tool_calling.markers.start/.end must both be non-empty strings".into(),
));
}
let st = map.special_tokens.as_ref();
let in_st = |name: &str| st.is_some_and(|m| m.contains_key(name));
if !in_st(&tc.markers.start) || !in_st(&tc.markers.end) {
return Err(TokenizerMapError::Validation(format!(
"tool_calling.markers.start (\"{}\") and .end (\"{}\") must both exist as keys in special_tokens",
tc.markers.start, tc.markers.end,
)));
}
}
Ok(())
}
}
pub(crate) fn parse_hash(hash: &str) -> String {
if let Some((algo, hex)) = hash.split_once(':') {
if !algo.eq_ignore_ascii_case("sha256") {
}
hex.to_ascii_lowercase()
} else {
hash.to_ascii_lowercase()
}
}