use std::{
path::{Path, PathBuf},
sync::atomic::{AtomicU64, Ordering},
};
use super::*;
fn fresh_dir(tag: &str) -> PathBuf {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!(
"mlxrs-tokenizer-wrapper-{tag}-{}-{n}",
std::process::id()
));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
dir
}
fn write_wordlevel(dir: &Path, vocab: &[(&str, u32)], unk: &str, with_whitespace: bool) {
use tokenizers::{
Tokenizer as HfTokenizer, models::wordlevel::WordLevel, pre_tokenizers::whitespace::Whitespace,
};
let map = vocab.iter().map(|(w, i)| ((*w).to_string(), *i)).collect();
let wl = WordLevel::builder()
.vocab(map)
.unk_token(unk.to_string())
.build()
.unwrap();
let mut hf = HfTokenizer::new(wl);
if with_whitespace {
hf.with_pre_tokenizer(Some(Whitespace {}));
}
hf.save(dir.join("tokenizer.json"), false).unwrap();
}
const BASIC_VOCAB: &[(&str, u32)] = &[
("<unk>", 0),
("hello", 1),
("world", 2),
("<tool_call>", 3),
("</tool_call>", 4),
];
#[cfg(feature = "tokenizer-config")]
#[test]
fn cfg_str_reads_object_content_form() {
let dir = fresh_dir("cfg_str_obj");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let cfg = serde_json::json!({
"bos_token": {"content": "<s>", "lstrip": false},
"eos_token": "</s>",
"unk_token": {"not_content": "x"},
});
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert_eq!(tok.bos_token(), Some("<s>"));
assert_eq!(tok.eos_token(), Some("</s>"));
assert_eq!(tok.unk_token(), None);
}
#[cfg(all(
feature = "tokenizer-stream",
not(any(feature = "tokenizer-spm", feature = "tokenizer-bpe"))
))]
#[test]
fn from_path_naive_class_without_spm_bpe() {
let dir = fresh_dir("naive_noinfer");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert_eq!(tok.detokenizer_class(), DetokenizerClass::Naive);
assert!(tok.detokenizer().is_naive());
}
#[cfg(feature = "tokenizer-tools")]
#[test]
fn tool_parser_from_config_sets_delimiters_and_accessors() {
let dir = fresh_dir("tool_parser_cfg");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let cfg = serde_json::json!({ "tool_parser_type": "json_tools" });
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(tok.has_tool_calling());
assert_eq!(tok.tool_call_start(), Some("<tool_call>"));
assert_eq!(tok.tool_call_end(), Some("</tool_call>"));
assert_eq!(tok.tool_parser().map(|p| p.name()), Some("json_tools"));
}
#[cfg(feature = "tokenizer-tools")]
#[test]
fn tool_parser_absent_accessors_are_none() {
let dir = fresh_dir("tool_parser_none");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let cfg = serde_json::json!({ "chat_template": "{{ messages }}" });
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(!tok.has_tool_calling());
assert_eq!(tok.tool_call_start(), None);
assert_eq!(tok.tool_call_end(), None);
assert!(tok.tool_parser().is_none());
}
#[cfg(feature = "tokenizer-tools")]
#[test]
fn parse_tool_call_with_configured_parser() {
let dir = fresh_dir("parse_tc_ok");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let cfg = serde_json::json!({ "tool_parser_type": "json_tools" });
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
let calls = tok
.parse_tool_call(r#"{"name": "get_time", "arguments": {}}"#, None)
.unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name(), "get_time");
assert_eq!(*calls[0].arguments(), serde_json::json!({}));
}
#[cfg(feature = "tokenizer-tools")]
#[test]
fn parse_tool_call_without_parser_errors() {
let dir = fresh_dir("parse_tc_err");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
std::fs::write(dir.join("tokenizer_config.json"), "{}").unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
let err = tok.parse_tool_call("{}", None).unwrap_err();
match err {
Error::Tokenizer(m) => assert!(m.contains("no tool parser"), "msg: {m}"),
other => panic!("expected Error::Tokenizer, got {other:?}"),
}
}
#[cfg(all(feature = "tokenizer-config", feature = "tokenizer-stream"))]
#[test]
fn from_parts_forwards_to_from_loaded() {
use tokenizers::models::wordlevel::WordLevel;
let map = BASIC_VOCAB
.iter()
.map(|(w, i)| ((*w).to_string(), *i))
.collect();
let wl = WordLevel::builder()
.vocab(map)
.unk_token("<unk>".to_string())
.build()
.unwrap();
let hf = HfTokenizer::new(wl);
let raw = serde_json::json!({ "decoder": null });
let config = serde_json::json!({ "eos_token": "</tool_call>" });
let tok = Tokenizer::from_parts(
hf,
raw,
config,
DetokenizerClass::Naive,
Some(&[4u32, 3u32]),
)
.unwrap();
assert_eq!(tok.detokenizer_class(), DetokenizerClass::Naive);
let eos: Vec<u32> = tok.eos_token_ids_iter().collect();
assert_eq!(eos, vec![3, 4]);
let out = tok
.encode_with(
"hello",
&EncodeOptions::new()
.with_add_eos(true)
.with_add_special(false),
)
.unwrap();
assert_eq!(out.ids().last(), Some(&4u32));
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn additional_special_token_ids_non_array_is_empty() {
let dir = fresh_dir("addl_non_array");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let cfg = serde_json::json!({ "additional_special_tokens": "<tool_call>" });
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(tok.additional_special_token_ids().is_empty());
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn additional_special_token_ids_mixed_entries() {
let dir = fresh_dir("addl_mixed");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let cfg = serde_json::json!({
"additional_special_tokens": [
"<tool_call>", {"content": "</tool_call>"}, 42, "unknown_tok", {"no_content": true}, ]
});
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert_eq!(tok.additional_special_token_ids(), vec![3, 4]);
}
#[cfg(feature = "tokenizer-spm")]
#[test]
fn detokenizer_spm_branch() {
let dir = fresh_dir("spm_detok");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let path = dir.join("tokenizer.json");
let bytes = std::fs::read(&path).unwrap();
let mut v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
v["decoder"] = serde_json::json!({
"type": "Sequence",
"decoders": [
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
{"type": "ByteFallback"},
{"type": "Fuse"},
{"type": "Strip", "content": " ", "start": 1, "stop": 0}
]
});
std::fs::write(&path, v.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert_eq!(tok.detokenizer_class(), DetokenizerClass::Spm);
assert!(tok.detokenizer().is_spm());
}
#[cfg(feature = "tokenizer-deepseek-v32")]
#[test]
fn apply_chat_template_uses_registered_override() {
let dir = fresh_dir("ds_override");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let cfg = serde_json::json!({
"chat_template_type": "deepseek_v32",
"chat_template": "JINJA-SHOULD-NOT-BE-USED",
});
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(tok.has_chat_template());
assert!(!tok.has_thinking());
let messages = serde_json::json!([{ "role": "user", "content": "hello" }]);
let out = tok
.apply_chat_template(&messages, None, true, false, None)
.unwrap();
assert_eq!(
out,
"<|begin▁of▁sentence|><|User|>hello<|Assistant|></think>"
);
}
#[cfg(feature = "tokenizer-deepseek-v32")]
#[test]
fn apply_chat_template_override_rejects_non_list_messages() {
let dir = fresh_dir("ds_override_err");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let cfg = serde_json::json!({ "chat_template_type": "deepseek_v32" });
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
let messages = serde_json::json!({ "role": "user", "content": "hello" });
let err = tok
.apply_chat_template(&messages, None, true, false, None)
.unwrap_err();
match err {
Error::Tokenizer(m) => assert!(m.contains("messages must be a list"), "msg: {m}"),
other => panic!("expected Error::Tokenizer, got {other:?}"),
}
}
#[cfg(feature = "tokenizer-config")]
#[test]
fn config_accessor_returns_parsed_config() {
let dir = fresh_dir("config_accessor");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let cfg = serde_json::json!({ "model_max_length": 4096, "eos_token": "</tool_call>" });
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
let c = tok.config();
assert_eq!(
c.get("model_max_length").and_then(|v| v.as_u64()),
Some(4096)
);
assert_eq!(
c.get("eos_token").and_then(|v| v.as_str()),
Some("</tool_call>")
);
}
#[test]
fn infer_thinking_channel_branch() {
let dir = fresh_dir("channel_think");
let vocab: &[(&str, u32)] = &[
("<unk>", 0),
("hello", 1),
("<|channel>", 10),
("<channel|>", 11),
("<|channel>thought", 12),
];
write_wordlevel(dir.as_path(), vocab, "<unk>", false);
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(tok.has_thinking());
assert_eq!(tok.think_start(), Some("<|channel>thought"));
assert_eq!(tok.think_end(), Some("<channel|>"));
assert_eq!(tok.think_start_tokens(), Some(&[12u32][..]));
assert_eq!(tok.think_end_tokens(), Some(&[11u32][..]));
}
#[test]
fn infer_thinking_absent_when_no_markers() {
let dir = fresh_dir("no_think");
write_wordlevel(dir.as_path(), BASIC_VOCAB, "<unk>", true);
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(!tok.has_thinking());
assert_eq!(tok.think_start(), None);
assert_eq!(tok.think_end(), None);
assert_eq!(tok.think_start_tokens(), None);
assert_eq!(tok.think_end_tokens(), None);
}