1use std::collections::HashMap;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, thiserror::Error)]
11pub enum TokenizerMapError {
12 #[error("TokenizerMap validation failed: {0}")]
13 Validation(String),
14 #[error("TokenizerMap parse failed: {0}")]
15 Parse(#[from] serde_json::Error),
16}
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
30pub struct TokenizerMap {
31 #[serde(default)]
33 pub id: String,
34 #[serde(default = "default_version")]
36 pub version: String,
37 #[serde(default, rename = "vocab_size")]
39 pub vocab_size: i64,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub vocab: Option<HashMap<String, u32>>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
45 pub tokens: Option<HashMap<String, String>>,
46 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub encoder: Option<String>,
49 #[serde(default, skip_serializing_if = "Option::is_none")]
51 pub merges: Option<Vec<String>>,
52 #[serde(default, skip_serializing_if = "Option::is_none", rename = "pre_tokenizer_pattern")]
55 pub pre_tokenizer_pattern: Option<String>,
56 #[serde(default, skip_serializing_if = "Option::is_none", rename = "pre_tokenizer_program")]
63 pub pre_tokenizer_program: Option<crate::pretok_program::PreTokProgram>,
64 #[serde(default, skip_serializing_if = "Option::is_none", rename = "byte_fallback_start")]
66 pub byte_fallback_start: Option<i64>,
67 #[serde(default, skip_serializing_if = "Option::is_none", rename = "byte_fallback_end")]
69 pub byte_fallback_end: Option<i64>,
70 #[serde(default, skip_serializing_if = "Option::is_none", rename = "special_tokens")]
72 pub special_tokens: Option<HashMap<String, u32>>,
73 #[serde(default, skip_serializing_if = "Option::is_none", rename = "tool_calling")]
79 pub tool_calling: Option<ToolCallingBlock>,
80 #[serde(default, skip_serializing_if = "Option::is_none", rename = "published_at")]
82 pub published_at: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ToolCallingBlock {
91 pub convention: ToolCallingConvention,
95 pub markers: ToolCallingMarkers,
98 pub args_format: ToolCallingArgsFormat,
101 pub result_format: ToolCallingResultFormat,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct ToolCallingMarkers {
108 pub start: String,
109 pub end: String,
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum ToolCallingConvention {
117 Llama3,
118 Qwen25,
119 Phi4,
120 MistralNemo,
121 DeepseekV3,
122 DeepseekR1,
123 Custom,
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
128#[serde(rename_all = "snake_case")]
129pub enum ToolCallingArgsFormat {
130 Json,
132 PythonArgs,
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
137#[serde(rename_all = "snake_case")]
138pub enum ToolCallingResultFormat {
139 Text,
141 Json,
143}
144
145fn default_version() -> String {
146 "2".to_string()
147}
148
149impl TokenizerMap {
150 pub fn from_json(json: &[u8]) -> Result<Self, TokenizerMapError> {
152 let map: TokenizerMap = serde_json::from_slice(json)?;
153 Self::validate(&map)?;
154 Ok(map)
155 }
156
157 pub fn from_json_str(json: &str) -> Result<Self, TokenizerMapError> {
159 Self::from_json(json.as_bytes())
160 }
161
162 pub fn verify_sha256(bytes: &[u8], expected: &str) -> Result<String, (String, String)> {
165 use sha2::{Digest, Sha256};
166 let mut hasher = Sha256::new();
167 hasher.update(bytes);
168 let actual = hex::encode(hasher.finalize());
169 let want = parse_hash(expected);
170 if actual.eq_ignore_ascii_case(&want) {
171 Ok(actual)
172 } else {
173 Err((want, actual))
174 }
175 }
176
177 pub fn validate(map: &Self) -> Result<(), TokenizerMapError> {
179 if map.id.is_empty() {
180 return Err(TokenizerMapError::Validation(
181 "id must be a non-empty string".into(),
182 ));
183 }
184 if map.version.is_empty() {
185 return Err(TokenizerMapError::Validation(
186 "version must be a non-empty string".into(),
187 ));
188 }
189 if map.vocab_size < 1 {
190 return Err(TokenizerMapError::Validation(
191 "vocab_size must be a positive integer".into(),
192 ));
193 }
194 let has_vocab = map.vocab.as_ref().is_some_and(|v| !v.is_empty());
195 let has_tokens = map.tokens.as_ref().is_some_and(|v| !v.is_empty());
196 if !has_vocab && !has_tokens {
197 return Err(TokenizerMapError::Validation(
198 "one of `vocab` (v2) or `tokens` (v1) is required".into(),
199 ));
200 }
201 match map.encoder.as_deref() {
202 None | Some("byte_level") | Some("metaspace") => {}
203 Some(other) => {
204 return Err(TokenizerMapError::Validation(format!(
205 "encoder must be \"byte_level\" or \"metaspace\" if present, got \"{other}\""
206 )));
207 }
208 }
209 if map.byte_fallback_start.is_some() != map.byte_fallback_end.is_some() {
210 return Err(TokenizerMapError::Validation(
211 "byte_fallback_start and byte_fallback_end must both be set or both omitted"
212 .into(),
213 ));
214 }
215 if let Some(tc) = &map.tool_calling {
216 if tc.markers.start.is_empty() || tc.markers.end.is_empty() {
217 return Err(TokenizerMapError::Validation(
218 "tool_calling.markers.start/.end must both be non-empty strings".into(),
219 ));
220 }
221 let st = map.special_tokens.as_ref();
223 let in_st = |name: &str| st.is_some_and(|m| m.contains_key(name));
224 if !in_st(&tc.markers.start) || !in_st(&tc.markers.end) {
225 return Err(TokenizerMapError::Validation(format!(
226 "tool_calling.markers.start (\"{}\") and .end (\"{}\") must both exist as keys in special_tokens",
227 tc.markers.start, tc.markers.end,
228 )));
229 }
230 }
231 Ok(())
232 }
233}
234
235pub(crate) fn parse_hash(hash: &str) -> String {
237 if let Some((algo, hex)) = hash.split_once(':') {
238 if !algo.eq_ignore_ascii_case("sha256") {
239 }
242 hex.to_ascii_lowercase()
243 } else {
244 hash.to_ascii_lowercase()
245 }
246}