#![cfg(feature = "tokenizer-config")]
use std::io::Write;
use mlxrs::tokenizer::{EncodeOptions, Encoded, Tokenizer};
const TOKENIZER_JSON: &str = include_str!("fixtures/tokenizer.json");
const TOKENIZER_CONFIG_JSON: &str = include_str!("fixtures/tokenizer_config.json");
fn fixture_dir() -> std::path::PathBuf {
static FIXTURE: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
FIXTURE
.get_or_init(|| {
let dir = std::env::temp_dir().join(format!(
"mlxrs-tok-encode-opts-fixture-{}",
std::process::id()
));
std::fs::create_dir_all(&dir).unwrap();
let mut f = std::fs::File::create(dir.join("tokenizer.json")).unwrap();
f.write_all(TOKENIZER_JSON.as_bytes()).unwrap();
let mut c = std::fs::File::create(dir.join("tokenizer_config.json")).unwrap();
c.write_all(TOKENIZER_CONFIG_JSON.as_bytes()).unwrap();
dir
})
.clone()
}
#[test]
fn encode_with_defaults_matches_encode_true() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let text = "hello world the quick brown fox";
let legacy = tok.encode(text, true).unwrap();
let out = tok.encode_with(text, &EncodeOptions::default()).unwrap();
assert_eq!(out.ids(), legacy);
assert!(out.attention_mask().is_empty()); }
#[test]
fn encode_with_add_special_false_matches_legacy_false() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let text = "hello world";
let legacy = tok.encode(text, false).unwrap();
let out = tok
.encode_with(text, &EncodeOptions::new().with_add_special(false))
.unwrap();
assert_eq!(out.ids(), legacy);
}
#[test]
fn encode_with_add_eos_uses_primary_not_smallest_id() {
let tok = Tokenizer::from_path(fixture_dir(), Some(&[2, 0])).unwrap();
assert!(tok.contains_eos_id(0));
assert!(tok.contains_eos_id(2));
let out = tok
.encode_with(
"hello world",
&EncodeOptions::new()
.with_add_special(false)
.with_add_eos(true),
)
.unwrap();
assert_eq!(out.ids().last().copied(), Some(2u32));
assert!(
!out.ids()[..out.ids().len() - 1].contains(&2),
"primary eos should appear only at the tail; got {:?}",
out.ids()
);
}
#[test]
fn encode_with_add_eos_appends_eos_id() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let eos = tok.eos_token_id().expect("fixture has an eos token");
let text = "hello world";
let base = tok
.encode_with(text, &EncodeOptions::new().with_add_special(false))
.unwrap();
let with_eos = tok
.encode_with(
text,
&EncodeOptions::new()
.with_add_special(false)
.with_add_eos(true),
)
.unwrap();
assert_eq!(with_eos.ids().last().copied(), Some(eos));
assert_eq!(with_eos.ids().len(), base.ids().len() + 1);
assert_eq!(&with_eos.ids()[..base.ids().len()], base.ids());
}
#[test]
fn encode_with_add_eos_errors_when_no_eos_configured() {
let dir = std::env::temp_dir().join(format!(
"mlxrs-tok-encode-opts-noeos-{}",
std::process::id()
));
std::fs::create_dir_all(&dir).unwrap();
let mut f = std::fs::File::create(dir.join("tokenizer.json")).unwrap();
f.write_all(TOKENIZER_JSON.as_bytes()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(tok.eos_token_ids_iter().next().is_none());
let err = tok
.encode_with("hi", &EncodeOptions::new().with_add_eos(true))
.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("eos"),
"expected eos-related error, got: {msg}"
);
}
#[test]
fn encode_with_truncate_to_caps_length() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let text = "hello world the quick brown fox";
let full = tok.encode(text, false).unwrap();
assert!(
full.len() > 3,
"fixture must encode to >3 ids for this test"
);
let out = tok
.encode_with(
text,
&EncodeOptions::new()
.with_add_special(false)
.with_truncate_to(Some(3)),
)
.unwrap();
assert_eq!(out.ids().len(), 3);
assert_eq!(out.ids(), &full[..3]);
}
#[test]
fn encode_with_truncate_to_above_length_is_noop() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let text = "hello world";
let full = tok.encode(text, false).unwrap();
let out = tok
.encode_with(
text,
&EncodeOptions::new()
.with_add_special(false)
.with_truncate_to(Some(full.len() + 100)),
)
.unwrap();
assert_eq!(out.ids(), full);
}
#[test]
fn encode_with_return_attention_mask_matches_ids_len() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let text = "hello world the quick brown fox";
let out = tok
.encode_with(text, &EncodeOptions::new().with_return_attention_mask(true))
.unwrap();
let mask = out.attention_mask();
assert_eq!(mask.len(), out.ids().len());
assert!(!mask.is_empty(), "mask requested");
assert!(mask.iter().all(|&m| m == 1), "non-padded mask is all 1s");
}
#[test]
fn encode_with_truncate_zero_yields_empty_ids_and_mask() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let out = tok
.encode_with(
"hello world",
&EncodeOptions::new()
.with_truncate_to(Some(0))
.with_return_attention_mask(true),
)
.unwrap();
assert!(out.ids().is_empty());
assert!(
out.attention_mask().is_empty(),
"mask requested + empty in lock-step with ids"
);
}
#[test]
fn encode_with_truncate_zero_dominates_add_eos() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let out = tok
.encode_with(
"hello world",
&EncodeOptions::new()
.with_add_eos(true)
.with_truncate_to(Some(0))
.with_return_attention_mask(true),
)
.expect("eos is configured in the fixture, so add_eos does not error");
assert!(
out.ids().is_empty(),
"n=0 cap dominates add_eos → empty ids"
);
assert!(
out.attention_mask().is_empty(),
"mask empty in lock-step with ids (mask requested)"
);
}
#[cfg(feature = "tokenizer-stream")]
#[test]
fn encode_with_padded_tokenizer_strips_pads_and_eos_lands_after_real() {
use tokenizers::{
Tokenizer as HfTokenizer,
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
};
let mut hf = HfTokenizer::from_file(fixture_dir().join("tokenizer.json")).unwrap();
hf.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(16),
direction: PaddingDirection::Right,
pad_to_multiple_of: None,
pad_id: 0,
pad_type_id: 0,
pad_token: "<unk>".into(),
}));
let cfg_bytes = std::fs::read(fixture_dir().join("tokenizer_config.json")).unwrap();
let cfg: serde_json::Value = serde_json::from_slice(&cfg_bytes).unwrap();
let tok =
Tokenizer::from_loaded(hf, cfg, mlxrs::tokenizer::DetokenizerClass::Naive, None).unwrap();
let eos = tok.eos_token_id().expect("fixture has eos");
let bare = tok
.encode_with(
"hello world",
&EncodeOptions::new()
.with_add_special(false)
.with_return_attention_mask(true),
)
.unwrap();
assert_eq!(bare.ids().len(), 2);
let bare_mask = bare.attention_mask();
assert!(!bare_mask.is_empty(), "mask requested");
assert!(bare_mask.iter().all(|&m| m == 1));
let with_eos = tok
.encode_with(
"hello world",
&EncodeOptions::new()
.with_add_special(false)
.with_add_eos(true)
.with_return_attention_mask(true),
)
.unwrap();
assert_eq!(with_eos.ids().len(), bare.ids().len() + 1);
assert_eq!(with_eos.ids().last().copied(), Some(eos));
assert!(
!with_eos.ids().contains(&0),
"result must not contain pad id 0; got {:?}",
with_eos.ids()
);
let mask = with_eos.attention_mask();
assert!(!mask.is_empty(), "mask requested");
assert_eq!(mask.len(), with_eos.ids().len());
assert!(mask.iter().all(|&m| m == 1));
}
#[cfg(feature = "tokenizer-stream")]
#[test]
fn encode_with_left_padded_tokenizer_drops_leading_pads() {
use tokenizers::{
Tokenizer as HfTokenizer,
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
};
let mut hf = HfTokenizer::from_file(fixture_dir().join("tokenizer.json")).unwrap();
hf.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(8),
direction: PaddingDirection::Left,
pad_to_multiple_of: None,
pad_id: 0,
pad_type_id: 0,
pad_token: "<unk>".into(),
}));
let cfg_bytes = std::fs::read(fixture_dir().join("tokenizer_config.json")).unwrap();
let cfg: serde_json::Value = serde_json::from_slice(&cfg_bytes).unwrap();
let tok =
Tokenizer::from_loaded(hf, cfg, mlxrs::tokenizer::DetokenizerClass::Naive, None).unwrap();
let bare = tok
.encode_with(
"hello world",
&EncodeOptions::new()
.with_add_special(false)
.with_return_attention_mask(true),
)
.unwrap();
assert_eq!(bare.ids(), &[3u32, 4]);
let bare_mask = bare.attention_mask();
assert!(!bare_mask.is_empty(), "mask requested");
assert!(bare_mask.iter().all(|&m| m == 1));
let eos = tok.eos_token_id().expect("fixture has eos");
let with_eos = tok
.encode_with(
"hello world",
&EncodeOptions::new()
.with_add_special(false)
.with_add_eos(true),
)
.unwrap();
assert_eq!(with_eos.ids(), &[3u32, 4, eos]);
}
#[cfg(feature = "tokenizer-stream")]
#[test]
fn legacy_encode_preserves_hf_padding_layout() {
use tokenizers::{
Tokenizer as HfTokenizer,
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
};
let mut hf = HfTokenizer::from_file(fixture_dir().join("tokenizer.json")).unwrap();
hf.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(8),
direction: PaddingDirection::Right,
pad_to_multiple_of: None,
pad_id: 0,
pad_type_id: 0,
pad_token: "<unk>".into(),
}));
let cfg_bytes = std::fs::read(fixture_dir().join("tokenizer_config.json")).unwrap();
let cfg: serde_json::Value = serde_json::from_slice(&cfg_bytes).unwrap();
let tok =
Tokenizer::from_loaded(hf, cfg, mlxrs::tokenizer::DetokenizerClass::Naive, None).unwrap();
let ids = tok.encode("hello world", false).unwrap();
assert_eq!(ids.len(), 8);
assert_eq!(&ids[..2], &[3u32, 4][..]);
assert!(
ids[2..].iter().all(|&id| id == 0),
"trailing pads must be id=0; got {ids:?}"
);
}
#[test]
fn encode_with_add_eos_errors_without_calling_hf_encode() {
let dir = std::env::temp_dir().join(format!(
"mlxrs-tok-encode-opts-noeos-fastfail-{}",
std::process::id()
));
std::fs::create_dir_all(&dir).unwrap();
let mut f = std::fs::File::create(dir.join("tokenizer.json")).unwrap();
f.write_all(TOKENIZER_JSON.as_bytes()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(tok.eos_token_ids_iter().next().is_none());
let big = "hello ".repeat(1024);
let err = tok
.encode_with(&big, &EncodeOptions::new().with_add_eos(true))
.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("eos"), "{msg}");
}
#[test]
fn encode_with_truncate_far_below_input_is_bounded_alloc() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let mut buf = String::with_capacity(20_000);
for i in 0..10_000 {
if i > 0 {
buf.push(' ');
}
buf.push_str("hello");
}
let out = tok
.encode_with(
&buf,
&EncodeOptions::new()
.with_add_special(false)
.with_truncate_to(Some(8)),
)
.unwrap();
assert_eq!(out.ids().len(), 8);
assert!(out.ids().iter().all(|&id| id == 3));
}
#[test]
fn encode_with_add_eos_then_truncate_caps_including_eos() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let text = "hello world the quick brown fox";
let base = tok
.encode_with(text, &EncodeOptions::new().with_add_special(false))
.unwrap();
assert!(base.ids().len() >= 2);
let eos = tok.eos_token_id().expect("fixture has eos");
let out = tok
.encode_with(
text,
&EncodeOptions::new()
.with_add_special(false)
.with_add_eos(true)
.with_truncate_to(Some(2)),
)
.unwrap();
assert_eq!(out.ids().len(), 2);
assert_eq!(out.ids(), &[base.ids()[0], eos]);
}
#[test]
fn encoded_and_options_are_debug_clone() {
fn assert_debug_clone<T: std::fmt::Debug + Clone>() {}
assert_debug_clone::<EncodeOptions>();
assert_debug_clone::<Encoded>();
let opts = EncodeOptions::new()
.with_add_special(false)
.with_add_eos(true)
.with_truncate_to(Some(128))
.with_return_attention_mask(true);
let opts_cloned = opts.clone().with_add_eos(false);
assert!(opts.add_eos() && !opts_cloned.add_eos());
let s = format!("{opts:?}");
assert!(s.contains("add_eos"));
assert!(s.contains("truncate_to"));
let encoded = Encoded::new(vec![1, 2, 3], vec![1, 1, 1]);
let mut encoded_cloned_ids = encoded.ids().to_vec();
encoded_cloned_ids.push(4);
assert_eq!(encoded.ids().len(), 3);
assert_eq!(encoded_cloned_ids.len(), 4);
let es = format!("{encoded:?}");
assert!(es.contains("ids"));
}
#[test]
fn encode_batch_with_matches_encode_with_per_item() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let texts = vec![
"hello world".to_owned(),
"hello world the quick brown fox".to_owned(),
"hello".to_owned(),
];
let opts = EncodeOptions::new()
.with_add_special(false)
.with_add_eos(true)
.with_truncate_to(Some(3))
.with_return_attention_mask(true);
let batched = tok.encode_batch_with(texts.clone(), &opts).unwrap();
assert_eq!(batched.len(), texts.len());
for (i, text) in texts.iter().enumerate() {
let single = tok.encode_with(text, &opts).unwrap();
assert_eq!(
batched[i].ids(),
single.ids(),
"item {i} ids must match encode_with"
);
assert_eq!(
batched[i].attention_mask(),
single.attention_mask(),
"item {i} mask must match encode_with"
);
let eos = tok.eos_token_id().expect("fixture has eos");
assert_eq!(batched[i].ids().last().copied(), Some(eos));
assert!(batched[i].ids().len() <= 3, "truncate cap honored");
}
}
#[test]
fn encode_batch_with_add_eos_errors_without_calling_hf_encode_batch() {
let dir = std::env::temp_dir().join(format!(
"mlxrs-tok-encode-opts-batch-noeos-{}",
std::process::id()
));
std::fs::create_dir_all(&dir).unwrap();
let mut f = std::fs::File::create(dir.join("tokenizer.json")).unwrap();
f.write_all(TOKENIZER_JSON.as_bytes()).unwrap();
let tok = Tokenizer::from_path(&dir, None).unwrap();
assert!(tok.eos_token_ids_iter().next().is_none());
let big = "hello ".repeat(1024);
let texts: Vec<String> = (0..32).map(|_| big.clone()).collect();
let err = tok
.encode_batch_with(texts, &EncodeOptions::new().with_add_eos(true))
.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("eos"),
"expected eos-related error, got: {msg}"
);
}
#[test]
fn encode_batch_with_empty_input_is_empty_output() {
let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
let out = tok
.encode_batch_with(Vec::new(), &EncodeOptions::new())
.unwrap();
assert!(out.is_empty());
}