#![cfg(feature = "lm")]
use std::io::Write;
use mlxrs::tokenizer::{
StreamingDetokenizer, Tokenizer,
stream::{BpeStreamingDetokenizer, NaiveStreamingDetokenizer, SpmStreamingDetokenizer},
tools::{Glm47, JsonTools, Pythonic, ToolParser},
};
use serde_json::json;
const TOKENIZER_JSON: &str = include_str!("fixtures/tokenizer.json");
const TOKENIZER_CONFIG_JSON: &str = include_str!("fixtures/tokenizer_config.json");
fn fixture_dir() -> std::path::PathBuf {
static FIXTURE: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
FIXTURE
.get_or_init(|| {
let dir = std::env::temp_dir().join(format!("mlxrs-tok-fixture-{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
let mut f = std::fs::File::create(dir.join("tokenizer.json")).unwrap();
f.write_all(TOKENIZER_JSON.as_bytes()).unwrap();
let mut c = std::fs::File::create(dir.join("tokenizer_config.json")).unwrap();
c.write_all(TOKENIZER_CONFIG_JSON.as_bytes()).unwrap();
dir
})
.clone()
}
#[test]
fn encode_decode_round_trip() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let ids = tok
.encode("hello world the quick brown fox", false)
.unwrap();
assert_eq!(ids, vec![3, 4, 5, 6, 7, 8]);
let text = tok.decode(&ids, false).unwrap();
assert!(text.contains("hello"));
assert!(text.contains("fox"));
let reids = tok.encode(&text, false).unwrap();
assert_eq!(reids, ids);
}
#[test]
fn special_token_properties() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
assert_eq!(tok.bos_token(), Some("<s>"));
assert_eq!(tok.eos_token(), Some("</s>"));
assert_eq!(tok.unk_token(), Some("<unk>"));
assert_eq!(tok.bos_token_id(), Some(1));
assert_eq!(tok.eos_token_id(), Some(2));
assert!(tok.contains_eos_id(2));
assert!(tok.has_chat_template());
assert!(tok.has_thinking());
assert_eq!(tok.think_start(), Some("<think>"));
assert_eq!(tok.think_start_tokens(), Some(&[9u32][..]));
}
#[test]
fn naive_detokenizer_reconstructs_full_decode() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let ids = tok.encode("hello world the quick", false).unwrap();
let full = tok.decode(&ids, false).unwrap();
let mut d = tok.detokenizer();
d.reset();
for &t in &ids {
d.add_token(t);
}
d.finalize();
assert_eq!(d.text(), full);
assert_eq!(d.tokens(), ids.as_slice());
}
#[test]
fn spm_detokenizer_incremental_matches_join() {
let vocab = vec![
("\u{2581}Hello".to_string(), 0u32),
("\u{2581}world".to_string(), 1u32),
("!".to_string(), 2u32),
];
let mut d = SpmStreamingDetokenizer::new(vocab, true);
d.reset();
for t in [0u32, 1, 2] {
d.add_token(t);
}
d.finalize();
assert_eq!(d.text(), "Hello world!");
}
#[test]
fn bpe_detokenizer_incremental_matches_join() {
let vocab = vec![
("Hello".to_string(), 0u32),
("\u{0120}world".to_string(), 1u32),
];
let mut d = BpeStreamingDetokenizer::new(vocab, false);
d.reset();
d.add_token(0);
d.add_token(1);
d.finalize();
assert_eq!(d.text(), "Hello world");
}
#[test]
fn naive_detokenizer_standalone_matches() {
let decode = |ids: &[u32]| ids.iter().map(|i| format!("t{i} ")).collect::<String>();
let full = decode(&[1, 2, 3]);
let mut d = NaiveStreamingDetokenizer::new(decode, false);
d.reset();
for t in [1u32, 2, 3] {
d.add_token(t);
}
d.finalize();
assert_eq!(d.text(), full);
}
#[test]
fn chat_template_render_with_generation_prompt_and_tool() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let messages = json!([
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "yo"}
]);
let tools = json!([
{"type": "function", "function": {"name": "get_weather", "parameters": {}}}
]);
let out = tok
.apply_chat_template(&messages, Some(&tools), true, false, None)
.unwrap();
assert_eq!(
out,
"<s><|user|>hi<|assistant|>yo<tool>get_weather</tool><|assistant|>"
);
let out2 = tok
.apply_chat_template(&messages, None, false, false, None)
.unwrap();
assert_eq!(out2, "<s><|user|>hi<|assistant|>yo");
}
#[test]
fn json_tools_parser_parses_assistant_tool_call() {
let p = JsonTools;
let calls = p
.parse(
r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#,
None,
)
.unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name(), "get_weather");
assert_eq!(*calls[0].arguments(), json!({"city": "Paris"}));
}
#[test]
fn pythonic_parser_parses_assistant_tool_call() {
let p = Pythonic;
let calls = p
.parse(
r#"<|tool_call_start|>[get_weather(city="Paris", days=3)]<|tool_call_end|>"#,
None,
)
.unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name(), "get_weather");
assert_eq!(calls[0].arguments()["city"], json!("Paris"));
assert_eq!(calls[0].arguments()["days"], json!(3));
}
#[test]
fn glm47_parser_parses_xml_style() {
let p = Glm47;
let calls = p
.parse(
"get_weather<arg_key>city</arg_key><arg_value>Paris</arg_value>",
None,
)
.unwrap();
assert_eq!(calls[0].name(), "get_weather");
assert_eq!(calls[0].arguments()["city"], json!("Paris"));
}
#[test]
fn additional_special_token_ids_unions_string_and_object_forms() {
let dir = std::env::temp_dir().join(format!(
"mlxrs-tok-addl-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
));
std::fs::create_dir_all(&dir).unwrap();
std::fs::write(dir.join("tokenizer.json"), TOKENIZER_JSON).unwrap();
let cfg = json!({
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"additional_special_tokens": [
"<think>",
{ "content": "</think>", "lstrip": false, "rstrip": false, "single_word": false, "normalized": false, "special": true },
"not-in-vocab" ]
});
std::fs::write(
dir.join("tokenizer_config.json"),
serde_json::to_string(&cfg).unwrap(),
)
.unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
let mut ids = tok.additional_special_token_ids();
ids.sort_unstable();
assert_eq!(ids, vec![9u32, 10u32]);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn infer_tool_parser_selects_correctly() {
use mlxrs::tokenizer::infer_tool_parser;
assert_eq!(
infer_tool_parser(Some("... [TOOL_CALLS] ...")),
Some("mistral")
);
assert_eq!(
infer_tool_parser(Some("uses <minimax:tool_call> here")),
Some("minimax_m2")
);
assert_eq!(
infer_tool_parser(Some("calls <tool_call> with tool_call.name field")),
Some("json_tools")
);
assert_eq!(
infer_tool_parser(Some("<|tool_list_start|>")),
Some("pythonic")
);
assert_eq!(infer_tool_parser(Some("no markers")), None);
assert_eq!(infer_tool_parser(None), None);
}