#![cfg(feature = "tokenizer-config")]
use std::io::Write;
use mlxrs::{
Error,
tokenizer::{Tokenizer, no_bos_or_eos},
};
const TOKENIZER_JSON: &str = include_str!("fixtures/tokenizer.json");
const TOKENIZER_CONFIG_JSON: &str = include_str!("fixtures/tokenizer_config.json");
const TOKENIZER_TEMPLATE_JSON: &str = include_str!("fixtures/tokenizer_template.json");
fn fixture_dir() -> std::path::PathBuf {
static FIXTURE: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
FIXTURE
.get_or_init(|| {
let dir =
std::env::temp_dir().join(format!("mlxrs-tok-wrapper-fixture-{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
let mut f = std::fs::File::create(dir.join("tokenizer.json")).unwrap();
f.write_all(TOKENIZER_JSON.as_bytes()).unwrap();
let mut c = std::fs::File::create(dir.join("tokenizer_config.json")).unwrap();
c.write_all(TOKENIZER_CONFIG_JSON.as_bytes()).unwrap();
dir
})
.clone()
}
fn template_fixture_dir() -> std::path::PathBuf {
static FIXTURE: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
FIXTURE
.get_or_init(|| {
let dir = std::env::temp_dir().join(format!(
"mlxrs-tok-wrapper-tmpl-fixture-{}",
std::process::id()
));
std::fs::create_dir_all(&dir).unwrap();
let mut f = std::fs::File::create(dir.join("tokenizer.json")).unwrap();
f.write_all(TOKENIZER_TEMPLATE_JSON.as_bytes()).unwrap();
let mut c = std::fs::File::create(dir.join("tokenizer_config.json")).unwrap();
c.write_all(TOKENIZER_CONFIG_JSON.as_bytes()).unwrap();
dir
})
.clone()
}
fn expect_load_err(res: Result<Tokenizer, Error>) -> Error {
match res {
Ok(_) => panic!("expected Tokenizer::from_path to error, but it loaded"),
Err(e) => e,
}
}
fn unique_dir(tag: &str) -> std::path::PathBuf {
let dir = std::env::temp_dir().join(format!(
"mlxrs-tok-wrapper-{tag}-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
));
std::fs::create_dir_all(&dir).unwrap();
dir
}
#[test]
fn no_bos_or_eos_strips_both_ends() {
let seq = [1u32, 3, 4, 5, 2];
assert_eq!(no_bos_or_eos(&seq, 1, 2), vec![3, 4, 5]);
}
#[test]
fn no_bos_or_eos_strips_only_leading_bos() {
let seq = [1u32, 3, 4, 5];
assert_eq!(no_bos_or_eos(&seq, 1, 2), vec![3, 4, 5]);
}
#[test]
fn no_bos_or_eos_strips_only_trailing_eos() {
let seq = [3u32, 4, 5, 2];
assert_eq!(no_bos_or_eos(&seq, 1, 2), vec![3, 4, 5]);
}
#[test]
fn no_bos_or_eos_no_op_when_neither_present() {
let seq = [3u32, 4, 5];
assert_eq!(no_bos_or_eos(&seq, 1, 2), vec![3, 4, 5]);
}
#[test]
fn no_bos_or_eos_strips_each_at_most_once() {
let seq = [1u32, 1, 3, 2, 2];
assert_eq!(no_bos_or_eos(&seq, 1, 2), vec![1, 3, 2]);
}
#[test]
fn no_bos_or_eos_single_bos_only_element() {
let seq = [1u32];
assert_eq!(no_bos_or_eos(&seq, 1, 2), Vec::<u32>::new());
}
#[test]
fn no_bos_or_eos_single_eos_only_element() {
let seq = [2u32];
assert_eq!(no_bos_or_eos(&seq, 2, 2), Vec::<u32>::new());
}
#[test]
fn no_bos_or_eos_empty_input() {
let seq: [u32; 0] = [];
assert_eq!(no_bos_or_eos(&seq, 1, 2), Vec::<u32>::new());
}
#[test]
fn no_bos_or_eos_same_bos_eos_single_element_strips_front_only() {
let seq = [7u32];
assert_eq!(no_bos_or_eos(&seq, 7, 7), Vec::<u32>::new());
}
#[test]
fn no_bos_or_eos_round_trips_with_real_tokenizer_bos_eos() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let bos = tok.bos_token_id().expect("fixture has a bos token");
let eos = tok.eos_token_id().expect("fixture has an eos token");
assert_eq!((bos, eos), (1, 2));
let body = tok.encode("hello world the quick", false).unwrap();
assert_eq!(body, vec![3, 4, 5, 6]);
let mut framed = vec![bos];
framed.extend_from_slice(&body);
framed.push(eos);
assert_eq!(no_bos_or_eos(&framed, bos, eos), body);
}
#[test]
fn encode_add_special_true_injects_bos_eos_false_omits() {
let tok = Tokenizer::from_path(template_fixture_dir(), None).unwrap();
let without = tok.encode("hello world", false).unwrap();
assert_eq!(without, vec![3, 4], "bare body, no specials");
let with_special = tok.encode("hello world", true).unwrap();
assert_eq!(with_special, vec![1, 3, 4, 2]);
assert_eq!(with_special.first(), Some(&1), "starts with BOS");
assert_eq!(with_special.last(), Some(&2), "ends with EOS");
assert!(with_special.len() > without.len());
assert!(!without.contains(&1) && !without.contains(&2));
}
#[test]
fn encode_batch_add_special_true_injects_bos_eos_false_omits() {
let tok = Tokenizer::from_path(template_fixture_dir(), None).unwrap();
let texts = || vec!["hello world".to_string(), "fox".to_string()];
let without = tok.encode_batch(texts(), false).unwrap();
assert_eq!(without, vec![vec![3u32, 4], vec![8]]);
let with_special = tok.encode_batch(texts(), true).unwrap();
assert_eq!(with_special, vec![vec![1u32, 3, 4, 2], vec![1, 8, 2]]);
for (framed, bare) in with_special.iter().zip(without.iter()) {
assert_eq!(framed.first(), Some(&1), "item starts with BOS");
assert_eq!(framed.last(), Some(&2), "item ends with EOS");
assert!(framed.len() > bare.len());
}
}
#[test]
fn encode_add_special_noop_without_post_processor() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let with_special = tok.encode("hello world", true).unwrap();
let without = tok.encode("hello world", false).unwrap();
assert_eq!(without, vec![3, 4]);
assert_eq!(
with_special, without,
"null post_processor: flag is a no-op"
);
assert!(!with_special.contains(&1) && !with_special.contains(&2));
}
#[test]
fn decode_skip_special_tokens_drops_specials() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let ids = [1u32, 3, 4, 2];
let kept = tok.decode(&ids, false).unwrap();
assert!(kept.contains("<s>"), "kept='{kept}'");
assert!(kept.contains("</s>"), "kept='{kept}'");
assert!(kept.contains("hello") && kept.contains("world"));
let skipped = tok.decode(&ids, true).unwrap();
assert!(!skipped.contains("<s>"), "skipped='{skipped}'");
assert!(!skipped.contains("</s>"), "skipped='{skipped}'");
assert!(skipped.contains("hello") && skipped.contains("world"));
}
#[test]
fn encode_decode_round_trip_via_skip_special() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let ids = tok.encode("the quick brown fox", false).unwrap();
assert_eq!(ids, vec![5, 6, 7, 8]);
let text = tok.decode(&ids, true).unwrap();
assert_eq!(tok.encode(&text, false).unwrap(), ids);
}
#[test]
fn encode_unknown_word_maps_to_unk_id() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let ids = tok.encode("hello zzzznotaword world", false).unwrap();
assert_eq!(ids, vec![3, 0, 4]);
}
#[test]
fn convert_token_id_round_trip_and_misses() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
assert_eq!(tok.convert_token_to_id("quick"), Some(6));
assert_eq!(tok.convert_id_to_token(6).as_deref(), Some("quick"));
assert_eq!(tok.convert_token_to_id("</s>"), Some(2));
assert_eq!(tok.convert_token_to_id("zzzznotaword"), None);
assert_eq!(tok.convert_id_to_token(9_999), None);
}
#[test]
fn encode_batch_multi_item_exact_ids() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let batch = tok
.encode_batch(
vec![
"hello world".into(),
"the quick brown fox".into(),
"fox".into(),
],
false,
)
.unwrap();
assert_eq!(batch, vec![vec![3u32, 4], vec![5, 6, 7, 8], vec![8]]);
}
#[test]
fn encode_batch_empty_input_is_empty_output() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let batch = tok.encode_batch(Vec::new(), false).unwrap();
assert!(batch.is_empty());
}
#[test]
fn decode_batch_matches_per_item_decode() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let a: &[u32] = &[3, 4];
let b: &[u32] = &[5, 6, 7, 8];
let seqs: &[&[u32]] = &[a, b];
let out = tok.decode_batch(seqs, true).unwrap();
assert_eq!(out.len(), 2);
assert_eq!(out[0], tok.decode(a, true).unwrap());
assert_eq!(out[1], tok.decode(b, true).unwrap());
assert!(out[0].contains("hello") && out[0].contains("world"));
assert!(out[1].contains("fox"));
}
#[test]
fn decode_batch_skip_special_tokens_drops_specials() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let a: &[u32] = &[1, 3, 2]; let seqs: &[&[u32]] = &[a];
let skipped = tok.decode_batch(seqs, true).unwrap();
assert!(!skipped[0].contains("<s>"), "got '{}'", skipped[0]);
assert!(skipped[0].contains("hello"));
let kept = tok.decode_batch(seqs, false).unwrap();
assert!(kept[0].contains("<s>"), "got '{}'", kept[0]);
}
#[test]
fn eos_set_from_config_contains_primary() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
assert!(tok.contains_eos_id(2));
assert!(!tok.contains_eos_id(3));
let eos_ids: Vec<u32> = tok.eos_token_ids_iter().collect();
assert_eq!(eos_ids, vec![2]);
}
#[test]
fn eos_token_ids_iter_is_sorted_set() {
let tok = Tokenizer::from_path(fixture_dir(), Some(&[8, 2, 8, 0])).unwrap();
let eos_ids: Vec<u32> = tok.eos_token_ids_iter().collect();
assert_eq!(eos_ids, vec![0, 2, 8]);
assert!(tok.contains_eos_id(0));
assert!(tok.contains_eos_id(2));
assert!(tok.contains_eos_id(8));
assert!(!tok.contains_eos_id(5));
}
#[test]
fn add_eos_token_by_string_and_numeric_string() {
let mut tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
tok.add_eos_token("fox").unwrap();
assert!(tok.contains_eos_id(8));
tok.add_eos_token("5").unwrap();
assert!(tok.contains_eos_id(5));
let eos_ids: Vec<u32> = tok.eos_token_ids_iter().collect();
assert_eq!(eos_ids, vec![2, 5, 8]);
}
#[test]
fn add_eos_token_unknown_token_errors() {
let mut tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let err = tok.add_eos_token("definitely-not-a-token").unwrap_err();
assert!(matches!(err, Error::Tokenizer(_)), "got {err:?}");
let msg = format!("{err}");
assert!(msg.contains("is not a token"), "{msg}");
}
#[test]
fn add_eos_token_establishes_primary_when_none_configured() {
use mlxrs::tokenizer::EncodeOptions;
let dir = unique_dir("addeos-primary");
std::fs::write(dir.join("tokenizer.json"), TOKENIZER_JSON).unwrap();
let mut tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(tok.eos_token_ids_iter().next().is_none());
let err = tok
.encode_with("hello", &EncodeOptions::new().with_add_eos(true))
.unwrap_err();
assert!(format!("{err}").contains("eos"));
tok.add_eos_token("fox").unwrap();
let out = tok
.encode_with(
"hello world",
&EncodeOptions::new()
.with_add_special(false)
.with_add_eos(true),
)
.unwrap();
assert_eq!(out.ids(), &[3u32, 4, 8]);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn config_special_token_accessors() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
assert_eq!(tok.bos_token(), Some("<s>"));
assert_eq!(tok.eos_token(), Some("</s>"));
assert_eq!(tok.unk_token(), Some("<unk>"));
assert_eq!(tok.pad_token(), None);
assert_eq!(tok.bos_token_id(), Some(1));
assert_eq!(tok.eos_token_id(), Some(2));
assert_eq!(tok.unk_token_id(), Some(0));
assert_eq!(tok.pad_token_id(), None);
}
#[test]
fn caller_supplied_eos_replaces_config_default() {
let tok = Tokenizer::from_path(fixture_dir(), Some(&[8])).unwrap();
assert!(tok.contains_eos_id(8));
assert!(
!tok.contains_eos_id(2),
"supplied set replaces, not unions, the config eos"
);
assert_eq!(tok.eos_token(), Some("</s>"));
}
#[test]
fn empty_eos_slice_suppresses_fallback() {
let tok = Tokenizer::from_path(fixture_dir(), Some(&[])).unwrap();
assert!(tok.eos_token_ids_iter().next().is_none());
assert!(!tok.contains_eos_id(2));
}
#[test]
fn thinking_inferred_from_vocab() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
assert!(tok.has_thinking());
assert_eq!(tok.think_start(), Some("<think>"));
assert_eq!(tok.think_end(), Some("</think>"));
assert_eq!(tok.think_start_tokens(), Some(&[9u32][..]));
assert_eq!(tok.think_end_tokens(), Some(&[10u32][..]));
}
#[test]
fn hf_accessor_exposes_underlying_tokenizer() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
assert_eq!(tok.hf().get_vocab_size(true), 11);
assert_eq!(tok.hf().token_to_id("hello"), Some(3));
}
#[test]
fn from_path_missing_tokenizer_json_errors() {
let dir = unique_dir("missing");
let err = expect_load_err(Tokenizer::from_path(&dir, None));
assert!(matches!(err, Error::Tokenizer(_)), "got {err:?}");
let msg = format!("{err}");
assert!(
msg.contains("tokenizer.json"),
"expected tokenizer.json in error, got: {msg}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn from_path_malformed_tokenizer_json_errors() {
let dir = unique_dir("malformed-tok");
std::fs::write(dir.join("tokenizer.json"), b"{ this is not valid json").unwrap();
let err = expect_load_err(Tokenizer::from_path(&dir, None));
assert!(matches!(err, Error::Tokenizer(_)), "got {err:?}");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn from_path_malformed_tokenizer_config_errors() {
let dir = unique_dir("malformed-cfg");
std::fs::write(dir.join("tokenizer.json"), TOKENIZER_JSON).unwrap();
std::fs::write(dir.join("tokenizer_config.json"), b"{ not: valid").unwrap();
let err = expect_load_err(Tokenizer::from_path(&dir, None));
assert!(matches!(err, Error::Tokenizer(_)), "got {err:?}");
let msg = format!("{err}");
assert!(
msg.contains("tokenizer_config.json"),
"expected tokenizer_config.json in error, got: {msg}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn from_path_absent_config_uses_empty_defaults() {
let dir = unique_dir("no-config");
std::fs::write(dir.join("tokenizer.json"), TOKENIZER_JSON).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert_eq!(tok.bos_token(), None);
assert_eq!(tok.eos_token(), None);
assert_eq!(tok.unk_token(), None);
assert_eq!(tok.encode("hello world", false).unwrap(), vec![3, 4]);
let _ = std::fs::remove_dir_all(&dir);
}