use std::path::Path;
#[cfg(any(
feature = "tokenizer-config",
feature = "tokenizer-spm",
feature = "tokenizer-bpe"
))]
use serde_json::Value;
use tokenizers::Tokenizer as HfTokenizer;
use super::encode_options::{EncodeOptions, Encoded};
use crate::Error;
#[cfg(feature = "tokenizer-chat")]
use super::chat;
#[cfg(feature = "tokenizer-deepseek-v32")]
use super::chat::ChatTemplateOverride;
#[cfg(feature = "tokenizer-bpe")]
use super::stream::BpeStreamingDetokenizer;
#[cfg(feature = "tokenizer-spm")]
use super::stream::SpmStreamingDetokenizer;
#[cfg(any(feature = "tokenizer-spm", feature = "tokenizer-bpe"))]
use super::stream::infer_detokenizer_class;
#[cfg(feature = "tokenizer-stream")]
use super::stream::{Detokenizer, DetokenizerClass, NaiveHfDetokenizer};
#[cfg(feature = "tokenizer-tools")]
use super::tools::{self, ToolParser};
#[cfg(feature = "tokenizer-stream")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-stream")))]
pub type BoxedDetokenizer = Detokenizer;
#[derive(Debug, Clone, Default)]
struct Thinking {
start: Option<String>,
end: Option<String>,
start_tokens: Option<Vec<u32>>,
end_tokens: Option<Vec<u32>>,
}
pub struct Tokenizer {
hf: HfTokenizer,
#[cfg(feature = "tokenizer-config")]
config: Value,
#[cfg(feature = "tokenizer-stream")]
detok_class: DetokenizerClass,
#[cfg(all(feature = "tokenizer-config", feature = "tokenizer-stream"))]
clean_up_spaces: bool,
eos_token_ids: std::collections::BTreeSet<u32>,
primary_eos: Option<u32>,
#[cfg(feature = "tokenizer-chat")]
chat_template: Option<String>,
#[cfg(feature = "tokenizer-config")]
has_chat_template: bool,
#[cfg(feature = "tokenizer-deepseek-v32")]
chat_override: Option<Box<dyn ChatTemplateOverride>>,
#[cfg(feature = "tokenizer-tools")]
tool_parser: Option<Box<dyn ToolParser>>,
#[cfg(feature = "tokenizer-tools")]
tool_call_start: Option<String>,
#[cfg(feature = "tokenizer-tools")]
tool_call_end: Option<String>,
thinking: Thinking,
#[cfg(feature = "tokenizer-config")]
bos_token: Option<String>,
#[cfg(feature = "tokenizer-config")]
eos_token: Option<String>,
#[cfg(feature = "tokenizer-config")]
unk_token: Option<String>,
#[cfg(feature = "tokenizer-config")]
pad_token: Option<String>,
}
#[cfg(feature = "tokenizer-config")]
fn cfg_str(cfg: &Value, key: &str) -> Option<String> {
match cfg.get(key) {
Some(Value::String(s)) => Some(s.clone()),
Some(Value::Object(o)) => o.get("content").and_then(Value::as_str).map(str::to_owned),
_ => None,
}
}
impl Tokenizer {
pub fn from_path(
model_path: impl AsRef<Path>,
eos_token_ids: Option<&[u32]>,
) -> Result<Self, Error> {
let dir = model_path.as_ref();
let tok_file = dir.join("tokenizer.json");
let hf = HfTokenizer::from_file(&tok_file)
.map_err(|e| Error::tokenizer(format!("load tokenizer.json: {e}")))?;
#[cfg(any(feature = "tokenizer-spm", feature = "tokenizer-bpe"))]
let detok_class = {
let bytes = std::fs::read(&tok_file)
.map_err(|e| Error::tokenizer(format!("read tokenizer.json: {e}")))?;
let raw: Value = serde_json::from_slice(&bytes)
.map_err(|e| Error::tokenizer(format!("parse tokenizer.json: {e}")))?;
infer_detokenizer_class(raw.get("decoder"))
};
#[cfg(all(
feature = "tokenizer-stream",
not(any(feature = "tokenizer-spm", feature = "tokenizer-bpe"))
))]
let detok_class = DetokenizerClass::Naive;
#[cfg(feature = "tokenizer-config")]
let config: Value = {
let cfg_file = dir.join("tokenizer_config.json");
if cfg_file.exists() {
let bytes = std::fs::read(&cfg_file)
.map_err(|e| Error::tokenizer(format!("read tokenizer_config.json: {e}")))?;
serde_json::from_slice(&bytes)
.map_err(|e| Error::tokenizer(format!("parse tokenizer_config.json: {e}")))?
} else {
Value::Object(Default::default())
}
};
Self::from_loaded(
hf,
#[cfg(feature = "tokenizer-config")]
config,
#[cfg(feature = "tokenizer-stream")]
detok_class,
eos_token_ids,
)
}
pub fn from_loaded(
hf: HfTokenizer,
#[cfg(feature = "tokenizer-config")] config: Value,
#[cfg(feature = "tokenizer-stream")] detok_class: DetokenizerClass,
eos_token_ids: Option<&[u32]>,
) -> Result<Self, Error> {
#[cfg(all(feature = "tokenizer-config", feature = "tokenizer-stream"))]
let clean_up_spaces = config
.get("clean_up_tokenization_spaces")
.and_then(Value::as_bool)
.unwrap_or(true);
#[cfg(feature = "tokenizer-config")]
let bos_token = cfg_str(&config, "bos_token");
#[cfg(feature = "tokenizer-config")]
let eos_token = cfg_str(&config, "eos_token");
#[cfg(feature = "tokenizer-config")]
let unk_token = cfg_str(&config, "unk_token");
#[cfg(feature = "tokenizer-config")]
let pad_token = cfg_str(&config, "pad_token");
let mut eos_set = std::collections::BTreeSet::new();
let mut primary_eos: Option<u32> = None;
if let Some(ids) = eos_token_ids {
if let Some(&first) = ids.first() {
primary_eos = Some(first);
}
eos_set.extend(ids.iter().copied());
}
#[cfg(feature = "tokenizer-config")]
if eos_token_ids.is_none()
&& let Some(ref e) = eos_token
&& let Some(id) = hf.token_to_id(e)
{
primary_eos = Some(id);
eos_set.insert(id);
}
#[cfg(feature = "tokenizer-config")]
let chat_template = match config.get("chat_template") {
Some(Value::String(s)) => Some(s.clone()),
_ => None,
};
#[cfg(feature = "tokenizer-deepseek-v32")]
let chat_override = config
.get("chat_template_type")
.and_then(Value::as_str)
.and_then(chat::override_by_name);
#[cfg(all(feature = "tokenizer-config", feature = "tokenizer-deepseek-v32"))]
let has_chat_template = chat_template.is_some() || chat_override.is_some();
#[cfg(all(feature = "tokenizer-config", not(feature = "tokenizer-deepseek-v32")))]
let has_chat_template = chat_template.is_some();
#[cfg(feature = "tokenizer-tools")]
let (tool_parser, tool_call_start, tool_call_end) = {
let parser_name = config
.get("tool_parser_type")
.and_then(Value::as_str)
.map(str::to_owned)
.or_else(|| tools::infer_tool_parser(chat_template.as_deref()).map(str::to_owned));
let tool_parser = parser_name.as_deref().and_then(tools::parser_by_name);
let (s, e) = match &tool_parser {
Some(p) => (
Some(p.tool_call_start().to_owned()),
Some(p.tool_call_end().to_owned()),
),
None => (None, None),
};
(tool_parser, s, e)
};
let thinking = infer_thinking(&hf);
Ok(Self {
hf,
#[cfg(feature = "tokenizer-config")]
config,
#[cfg(feature = "tokenizer-stream")]
detok_class,
#[cfg(all(feature = "tokenizer-config", feature = "tokenizer-stream"))]
clean_up_spaces,
eos_token_ids: eos_set,
primary_eos,
#[cfg(feature = "tokenizer-chat")]
chat_template,
#[cfg(feature = "tokenizer-config")]
has_chat_template,
#[cfg(feature = "tokenizer-deepseek-v32")]
chat_override,
#[cfg(feature = "tokenizer-tools")]
tool_parser,
#[cfg(feature = "tokenizer-tools")]
tool_call_start,
#[cfg(feature = "tokenizer-tools")]
tool_call_end,
thinking,
#[cfg(feature = "tokenizer-config")]
bos_token,
#[cfg(feature = "tokenizer-config")]
eos_token,
#[cfg(feature = "tokenizer-config")]
unk_token,
#[cfg(feature = "tokenizer-config")]
pad_token,
})
}
#[cfg(all(feature = "tokenizer-config", feature = "tokenizer-stream"))]
#[cfg_attr(
docsrs,
doc(cfg(all(feature = "tokenizer-config", feature = "tokenizer-stream")))
)]
pub fn from_parts(
hf: HfTokenizer,
_raw: Value,
config: Value,
detok_class: DetokenizerClass,
eos_token_ids: Option<&[u32]>,
) -> Result<Self, Error> {
Self::from_loaded(hf, config, detok_class, eos_token_ids)
}
pub fn encode(&self, text: &str, add_special_tokens: bool) -> Result<Vec<u32>, Error> {
let enc = self
.hf
.encode(text, add_special_tokens)
.map_err(|e| Error::tokenizer(format!("encode: {e}")))?;
Ok(enc.get_ids().to_vec())
}
pub fn encode_with(&self, text: &str, opts: &EncodeOptions) -> Result<Encoded, Error> {
let eos = Self::resolve_eos(opts.add_eos(), self.primary_eos)?;
let enc = self
.hf
.encode(text, opts.add_special())
.map_err(|e| Error::tokenizer(format!("hf.encode: {e}")))?;
finalize_encoding(&enc, opts, eos)
}
pub fn encode_batch(
&self,
texts: Vec<String>,
add_special_tokens: bool,
) -> Result<Vec<Vec<u32>>, Error> {
let encs = self
.hf
.encode_batch(texts, add_special_tokens)
.map_err(|e| Error::tokenizer(format!("encode_batch: {e}")))?;
Ok(encs.iter().map(|e| e.get_ids().to_vec()).collect())
}
pub fn encode_batch_with(
&self,
texts: Vec<String>,
opts: &EncodeOptions,
) -> Result<Vec<Encoded>, Error> {
let eos = Self::resolve_eos(opts.add_eos(), self.primary_eos)?;
let encs = self
.hf
.encode_batch(texts, opts.add_special())
.map_err(|e| Error::tokenizer(format!("hf.encode_batch: {e}")))?;
let mut out = Vec::with_capacity(encs.len());
for enc in &encs {
out.push(finalize_encoding(enc, opts, eos)?);
}
Ok(out)
}
fn resolve_eos(add_eos: bool, primary_eos: Option<u32>) -> Result<Option<u32>, Error> {
if add_eos {
Ok(Some(primary_eos.ok_or_else(|| {
Error::tokenizer("encode_with(add_eos=true) requires a configured eos token id")
})?))
} else {
Ok(None)
}
}
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String, Error> {
self
.hf
.decode(ids, skip_special_tokens)
.map_err(|e| Error::tokenizer(format!("decode: {e}")))
}
pub fn decode_batch(
&self,
sequences: &[&[u32]],
skip_special_tokens: bool,
) -> Result<Vec<String>, Error> {
self
.hf
.decode_batch(sequences, skip_special_tokens)
.map_err(|e| Error::tokenizer(format!("decode_batch: {e}")))
}
pub fn convert_token_to_id(&self, token: &str) -> Option<u32> {
self.hf.token_to_id(token)
}
pub fn convert_id_to_token(&self, id: u32) -> Option<String> {
self.hf.id_to_token(id)
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn bos_token(&self) -> Option<&str> {
self.bos_token.as_deref()
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn eos_token(&self) -> Option<&str> {
self.eos_token.as_deref()
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn unk_token(&self) -> Option<&str> {
self.unk_token.as_deref()
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn pad_token(&self) -> Option<&str> {
self.pad_token.as_deref()
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn bos_token_id(&self) -> Option<u32> {
self
.bos_token
.as_deref()
.and_then(|t| self.hf.token_to_id(t))
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn eos_token_id(&self) -> Option<u32> {
self
.eos_token
.as_deref()
.and_then(|t| self.hf.token_to_id(t))
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn unk_token_id(&self) -> Option<u32> {
self
.unk_token
.as_deref()
.and_then(|t| self.hf.token_to_id(t))
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn pad_token_id(&self) -> Option<u32> {
self
.pad_token
.as_deref()
.and_then(|t| self.hf.token_to_id(t))
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn additional_special_token_ids(&self) -> Vec<u32> {
let Some(arr) = self.config.get("additional_special_tokens") else {
return Vec::new();
};
let Some(items) = arr.as_array() else {
return Vec::new();
};
let mut out = Vec::with_capacity(items.len());
for item in items {
let token: Option<&str> = match item {
Value::String(s) => Some(s.as_str()),
Value::Object(o) => o.get("content").and_then(Value::as_str),
_ => None,
};
if let Some(tok) = token
&& let Some(id) = self.hf.token_to_id(tok)
{
out.push(id);
}
}
out
}
pub fn eos_token_ids_iter(&self) -> impl Iterator<Item = u32> + '_ {
self.eos_token_ids.iter().copied()
}
pub fn contains_eos_id(&self, id: u32) -> bool {
self.eos_token_ids.contains(&id)
}
pub fn add_eos_token(&mut self, token: &str) -> Result<(), Error> {
let id = match token.parse::<u32>() {
Ok(i) => Some(i),
Err(_) => self.hf.token_to_id(token),
};
let id = id.ok_or_else(|| Error::tokenizer(format!("'{token}' is not a token")))?;
self.eos_token_ids.insert(id);
if self.primary_eos.is_none() {
self.primary_eos = Some(id);
}
Ok(())
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn has_chat_template(&self) -> bool {
self.has_chat_template
}
#[cfg(feature = "tokenizer-tools")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-tools")))]
pub fn tool_call_start(&self) -> Option<&str> {
self.tool_call_start.as_deref()
}
#[cfg(feature = "tokenizer-tools")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-tools")))]
pub fn tool_call_end(&self) -> Option<&str> {
self.tool_call_end.as_deref()
}
#[cfg(feature = "tokenizer-tools")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-tools")))]
pub fn has_tool_calling(&self) -> bool {
self.tool_parser.is_some()
}
#[cfg(feature = "tokenizer-tools")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-tools")))]
pub fn tool_parser(&self) -> Option<&dyn ToolParser> {
self.tool_parser.as_deref()
}
#[cfg(feature = "tokenizer-tools")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-tools")))]
pub fn parse_tool_call(
&self,
text: &str,
tools: Option<&Value>,
) -> Result<Vec<tools::ToolCall>, Error> {
let p = self
.tool_parser
.as_ref()
.ok_or_else(|| Error::tokenizer("no tool parser configured"))?;
p.parse(text, tools)
}
pub fn has_thinking(&self) -> bool {
self.thinking.start.is_some()
}
pub fn think_start(&self) -> Option<&str> {
self.thinking.start.as_deref()
}
pub fn think_end(&self) -> Option<&str> {
self.thinking.end.as_deref()
}
pub fn think_start_tokens(&self) -> Option<&[u32]> {
self.thinking.start_tokens.as_deref()
}
pub fn think_end_tokens(&self) -> Option<&[u32]> {
self.thinking.end_tokens.as_deref()
}
#[cfg(feature = "tokenizer-stream")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-stream")))]
pub fn detokenizer(&self) -> BoxedDetokenizer {
#[cfg(feature = "tokenizer-config")]
let clean = self.clean_up_spaces;
#[cfg(not(feature = "tokenizer-config"))]
let clean = false;
match self.detok_class {
#[cfg(feature = "tokenizer-spm")]
DetokenizerClass::Spm | DetokenizerClass::SpmNoSpace => {
let vocab = self.hf.get_vocab(true);
let trim = self.detok_class == DetokenizerClass::Spm;
Detokenizer::Spm(SpmStreamingDetokenizer::new(vocab, trim))
}
#[cfg(feature = "tokenizer-bpe")]
DetokenizerClass::Bpe => {
let vocab = self.hf.get_vocab(true);
Detokenizer::Bpe(BpeStreamingDetokenizer::new(vocab, clean))
}
#[cfg(not(feature = "tokenizer-spm"))]
DetokenizerClass::Spm | DetokenizerClass::SpmNoSpace => {
warn_detok_fallback("spm");
self.naive_detokenizer(clean)
}
#[cfg(not(feature = "tokenizer-bpe"))]
DetokenizerClass::Bpe => {
warn_detok_fallback("bpe");
self.naive_detokenizer(clean)
}
DetokenizerClass::Naive => self.naive_detokenizer(clean),
}
}
#[cfg(feature = "tokenizer-stream")]
fn naive_detokenizer(&self, clean: bool) -> BoxedDetokenizer {
Detokenizer::Naive(Box::new(NaiveHfDetokenizer::new(self.hf.clone(), clean)))
}
#[cfg(feature = "tokenizer-stream")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-stream")))]
pub fn detokenizer_class(&self) -> DetokenizerClass {
self.detok_class
}
#[cfg(feature = "tokenizer-chat")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-chat")))]
pub fn apply_chat_template(
&self,
messages: &Value,
tools: Option<&Value>,
add_generation_prompt: bool,
continue_final_message: bool,
additional_context: Option<&Value>,
) -> Result<String, Error> {
if add_generation_prompt && continue_final_message {
return Err(Error::tokenizer(
"continue_final_message is not compatible with add_generation_prompt \
(only one may be set)",
));
}
let enable_thinking = additional_context
.and_then(|c| c.get("enable_thinking"))
.and_then(Value::as_bool)
.unwrap_or(self.has_thinking());
#[cfg(feature = "tokenizer-deepseek-v32")]
if let Some(ovr) = &self.chat_override {
let msgs = messages
.as_array()
.cloned()
.ok_or_else(|| Error::tokenizer("messages must be a list"))?;
return ovr.apply(
&msgs,
tools,
add_generation_prompt,
continue_final_message,
enable_thinking,
);
}
let template = self
.chat_template
.as_deref()
.ok_or_else(|| Error::tokenizer("this tokenizer does not have a chat template"))?;
let extra = additional_context.cloned().unwrap_or(Value::Null);
chat::render_jinja(
template,
messages,
tools,
add_generation_prompt,
continue_final_message,
self.bos_token.as_deref(),
self.eos_token.as_deref(),
enable_thinking,
&extra,
)
}
#[cfg(feature = "tokenizer-chat")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-chat")))]
pub fn apply_chat_template_ids(
&self,
messages: &Value,
tools: Option<&Value>,
add_generation_prompt: bool,
continue_final_message: bool,
additional_context: Option<&Value>,
) -> Result<Vec<u32>, Error> {
let text = self.apply_chat_template(
messages,
tools,
add_generation_prompt,
continue_final_message,
additional_context,
)?;
self.encode(&text, false)
}
#[cfg(feature = "tokenizer-config")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-config")))]
pub fn config(&self) -> &Value {
&self.config
}
pub fn hf(&self) -> &HfTokenizer {
&self.hf
}
}
#[cfg(all(
feature = "tokenizer-stream",
not(all(feature = "tokenizer-spm", feature = "tokenizer-bpe"))
))]
fn warn_detok_fallback(kind: &'static str) {
use std::sync::Once;
static SPM_ONCE: Once = Once::new();
static BPE_ONCE: Once = Once::new();
let once = if kind == "spm" { &SPM_ONCE } else { &BPE_ONCE };
once.call_once(|| {
eprintln!(
"mlxrs: model wants the {kind} streaming detokenizer but the \
`tokenizer-{kind}` feature is disabled; falling back to naive \
(less precise streaming)"
);
});
}
fn infer_thinking(hf: &HfTokenizer) -> Thinking {
let vocab = hf.get_vocab(true);
let pairs = [
("<think>", "</think>"),
("<longcat_think>", "</longcat_think>"),
];
for (ts, te) in pairs {
if let (Some(&sid), Some(&eid)) = (vocab.get(ts), vocab.get(te)) {
return Thinking {
start: Some(ts.to_owned()),
end: Some(te.to_owned()),
start_tokens: Some(vec![sid]),
end_tokens: Some(vec![eid]),
};
}
}
if vocab.contains_key("<|channel>") && vocab.contains_key("<channel|>") {
let ts = "<|channel>thought";
let te = "<channel|>";
let st = hf
.encode(ts, false)
.map(|e| e.get_ids().to_vec())
.unwrap_or_default();
let et = hf
.encode(te, false)
.map(|e| e.get_ids().to_vec())
.unwrap_or_default();
return Thinking {
start: Some(ts.to_owned()),
end: Some(te.to_owned()),
start_tokens: Some(st),
end_tokens: Some(et),
};
}
Thinking::default()
}
pub fn no_bos_or_eos(sequence: &[u32], bos: u32, eos: u32) -> Vec<u32> {
let start = if sequence.first() == Some(&bos) { 1 } else { 0 };
let mut s = sequence[start..].to_vec();
if s.last() == Some(&eos) {
s.pop();
}
s
}
fn finalize_encoding(
enc: &tokenizers::Encoding,
opts: &EncodeOptions,
eos: Option<u32>,
) -> Result<Encoded, Error> {
let hf_ids = enc.get_ids();
let hf_mask = enc.get_attention_mask();
if hf_ids.len() != hf_mask.len() {
return Err(Error::tokenizer(format!(
"HF Encoding shape mismatch: ids.len()={} attention_mask.len()={}",
hf_ids.len(),
hf_mask.len(),
)));
}
let real_len: usize = hf_mask.iter().filter(|&&m| m != 0).count();
let extra = usize::from(eos.is_some());
let pre_trunc_len = real_len + extra;
let final_len = opts
.truncate_to()
.map_or(pre_trunc_len, |n| n.min(pre_trunc_len));
let mut ids: Vec<u32> = Vec::with_capacity(final_len);
let head_cap = final_len.saturating_sub(extra).min(real_len);
if head_cap > 0 {
let mut emitted = 0usize;
for (&id, &m) in hf_ids.iter().zip(hf_mask.iter()) {
if m == 0 {
continue;
}
ids.push(id);
emitted += 1;
if emitted == head_cap {
break;
}
}
}
if let Some(e) = eos
&& ids.len() < final_len
{
ids.push(e);
}
let attention_mask = if opts.return_attention_mask() {
vec![1u8; ids.len()]
} else {
Vec::new()
};
Ok(Encoded::new(ids, attention_mask))
}
#[cfg(test)]
mod tests;