#![cfg(all(feature = "lm", feature = "llguidance"))]
use std::{fs, io::Write, path::PathBuf, process};
use mlxrs::{
Array,
lm::{generate::LogitsProcessor, structured},
};
use serde_json::json;
fn build_byte_level_tokenizer_json() -> String {
let mut vocab_entries: Vec<String> = Vec::new();
vocab_entries.push("\"<unk>\": 0".to_string());
vocab_entries.push("\"<s>\": 1".to_string());
vocab_entries.push("\"</s>\": 2".to_string());
let mut next_id: u32 = 3;
let mut k: u32 = 0x100;
let mut char_map: Vec<char> = Vec::with_capacity(256);
for byte in 0..=255u8 {
let c = byte as char;
let mapped = match c {
'!'..='~' => c,
'\u{00A1}'..='\u{00AC}' => c,
'\u{00AE}'..='\u{00FF}' => c,
_ => {
let m = char::from_u32(k).unwrap();
k += 1;
m
}
};
char_map.push(mapped);
}
for byte in 0x20u8..=0x7Eu8 {
let glyph = char_map[byte as usize];
let escaped = match glyph {
'"' => "\\\"".to_string(),
'\\' => "\\\\".to_string(),
c => c.to_string(),
};
vocab_entries.push(format!("\"{}\": {}", escaped, next_id));
next_id += 1;
}
format!(
r#"{{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{{"id": 0, "content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}},
{{"id": 1, "content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}},
{{"id": 2, "content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}}
],
"normalizer": null,
"pre_tokenizer": {{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true
}},
"post_processor": null,
"decoder": {{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true
}},
"model": {{
"type": "BPE",
"dropout": null,
"unk_token": "<unk>",
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"vocab": {{
{}
}},
"merges": []
}}
}}"#,
vocab_entries.join(",\n ")
)
}
const TOKENIZER_CONFIG_JSON: &str = r#"{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"model_max_length": 2048
}"#;
fn temp_dir(name: &str) -> PathBuf {
let dir = std::env::temp_dir().join(format!("mlxrs_lm_structured_{}_{}", process::id(), name));
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
dir
}
fn fixture_tokenizer(name: &str) -> mlxrs::tokenizer::Tokenizer {
let dir = temp_dir(name);
let tj_path = dir.join("tokenizer.json");
let mut tj = fs::File::create(&tj_path).unwrap();
tj.write_all(build_byte_level_tokenizer_json().as_bytes())
.unwrap();
let mut tc = fs::File::create(dir.join("tokenizer_config.json")).unwrap();
tc.write_all(TOKENIZER_CONFIG_JSON.as_bytes()).unwrap();
mlxrs::tokenizer::Tokenizer::from_path(&dir, None)
.unwrap_or_else(|e| panic!("fixture tokenizer load failed: {e}"))
}
fn id_for_byte(byte: u8) -> u32 {
assert!(
(0x20..=0x7E).contains(&byte),
"byte {byte:#x} not in fixture vocab"
);
3 + (byte - 0x20) as u32
}
#[test]
fn build_json_schema_logits_processor_constructs() {
let tok = fixture_tokenizer("build_json_schema_constructs");
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
}
});
let _proc = structured::build_json_schema_logits_processor(schema, &tok, None)
.expect("processor construction should succeed for a simple schema");
}
#[test]
fn json_schema_processor_masks_invalid_first_tokens() {
let tok = fixture_tokenizer("masks_invalid_first");
let schema = json!({ "type": "object" });
let proc = structured::build_json_schema_logits_processor(schema, &tok, None)
.expect("processor construction should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, vocab)).unwrap();
let mut out = proc.apply(&[], &logits).expect("apply should succeed");
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), vocab);
let open_brace = id_for_byte(b'{') as usize;
assert!(
out_v[open_brace].is_finite(),
"`{{` token (id {open_brace}) must remain finite, got {}",
out_v[open_brace]
);
let a = id_for_byte(b'a') as usize;
assert!(
out_v[a].is_infinite() && out_v[a] < 0.0,
"`a` token (id {a}) must be masked to -inf, got {}",
out_v[a]
);
let close_brace = id_for_byte(b'}') as usize;
assert!(
out_v[close_brace].is_infinite() && out_v[close_brace] < 0.0,
"`}}` token (id {close_brace}) must be masked to -inf, got {}",
out_v[close_brace]
);
}
#[test]
fn llguidance_regex_grammar_constructs() {
let tok = fixture_tokenizer("regex_constructs");
let grammar = structured::GrammarSpec::Regex(r"[0-9]+".to_string());
let _proc = structured::LLGuidanceLogitsProcessor::new(grammar, &tok, None)
.expect("regex grammar processor construction should succeed");
}
#[test]
fn llguidance_lark_grammar_constructs() {
let tok = fixture_tokenizer("lark_constructs");
let lark = r#"start: DIGITS
DIGITS: /[0-9]+/
"#;
let grammar = structured::GrammarSpec::Lark(lark.to_string());
let _proc = structured::LLGuidanceLogitsProcessor::new(grammar, &tok, None)
.expect("lark grammar processor construction should succeed");
}
#[test]
fn llguidance_processor_implements_logits_processor_trait() {
let tok = fixture_tokenizer("plug_into_chain");
let proc = structured::build_json_schema_logits_processor(json!({"type": "object"}), &tok, None)
.expect("processor construction should succeed");
let boxed: LogitsProcessor = proc.into_logits_processor();
assert!(matches!(boxed, LogitsProcessor::Custom(_)));
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, vocab)).unwrap();
let _out = boxed
.apply(&[], &logits)
.expect("Custom processor apply should succeed");
}
const FIXTURE_EOS_ID: usize = 2;
#[test]
fn llguidance_terminal_regex_grammar_returns_eos_only_mask_after_consume() {
let tok = fixture_tokenizer("terminal_regex_eos_only");
let grammar = structured::GrammarSpec::Regex("a".to_string());
let proc = structured::LLGuidanceLogitsProcessor::new(grammar, &tok, None)
.expect("terminal regex processor construction should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, vocab)).unwrap();
let mut first = proc
.apply(&[], &logits)
.expect("first apply should succeed (initial mask, not yet stopped)");
let first_v = first.to_vec::<f32>().unwrap();
let a_id = id_for_byte(b'a') as usize;
assert!(
first_v[a_id].is_finite(),
"first-step `a` token (id {a_id}) must remain finite, got {}",
first_v[a_id]
);
let a_token_id = a_id as u32;
let mut out = proc
.apply(&[a_token_id], &logits)
.expect("post-consume apply should succeed and return EOS-only mask");
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), vocab);
assert!(
out_v[FIXTURE_EOS_ID].is_finite(),
"EOS-only mask: eos token (id {FIXTURE_EOS_ID}) must remain finite, got {}",
out_v[FIXTURE_EOS_ID]
);
for (i, v) in out_v.iter().enumerate() {
if i == FIXTURE_EOS_ID {
continue;
}
assert!(
v.is_infinite() && *v < 0.0,
"EOS-only mask: non-eos token id {i} must be -inf, got {v}"
);
}
}
#[test]
fn llguidance_terminal_lark_grammar_returns_eos_only_after_close() {
let tok = fixture_tokenizer("terminal_lark_eos_only");
let lark = r#"start: "x"
"#;
let grammar = structured::GrammarSpec::Lark(lark.to_string());
let proc = structured::LLGuidanceLogitsProcessor::new(grammar, &tok, None)
.expect("terminal lark processor construction should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, vocab)).unwrap();
let _ = proc
.apply(&[], &logits)
.expect("first apply should succeed (initial mask, not yet stopped)");
let x_token_id = id_for_byte(b'x');
let mut out = proc
.apply(&[x_token_id], &logits)
.expect("post-close apply should succeed and return EOS-only mask");
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), vocab);
assert!(
out_v[FIXTURE_EOS_ID].is_finite(),
"EOS-only lark mask: eos (id {FIXTURE_EOS_ID}) must remain finite, got {}",
out_v[FIXTURE_EOS_ID]
);
for (i, v) in out_v.iter().enumerate() {
if i == FIXTURE_EOS_ID {
continue;
}
assert!(
v.is_infinite() && *v < 0.0,
"EOS-only lark mask: non-eos token id {i} must be -inf, got {v}"
);
}
}
fn build_byte_level_tokenizer_json_with_extras(extra_added: &[(u32, &str)]) -> String {
let mut vocab_entries: Vec<String> = Vec::new();
vocab_entries.push("\"<unk>\": 0".to_string());
vocab_entries.push("\"<s>\": 1".to_string());
vocab_entries.push("\"</s>\": 2".to_string());
let mut next_id: u32 = 3;
let mut k: u32 = 0x100;
let mut char_map: Vec<char> = Vec::with_capacity(256);
for byte in 0..=255u8 {
let c = byte as char;
let mapped = match c {
'!'..='~' => c,
'\u{00A1}'..='\u{00AC}' => c,
'\u{00AE}'..='\u{00FF}' => c,
_ => {
let m = char::from_u32(k).unwrap();
k += 1;
m
}
};
char_map.push(mapped);
}
for byte in 0x20u8..=0x7Eu8 {
let glyph = char_map[byte as usize];
let escaped = match glyph {
'"' => "\\\"".to_string(),
'\\' => "\\\\".to_string(),
c => c.to_string(),
};
vocab_entries.push(format!("\"{}\": {}", escaped, next_id));
next_id += 1;
}
let mut added_entries: Vec<String> = vec![
"{\"id\": 0, \"content\": \"<unk>\", \"single_word\": false, \"lstrip\": false, \"rstrip\": false, \"normalized\": false, \"special\": true}".to_string(),
"{\"id\": 1, \"content\": \"<s>\", \"single_word\": false, \"lstrip\": false, \"rstrip\": false, \"normalized\": false, \"special\": true}".to_string(),
"{\"id\": 2, \"content\": \"</s>\", \"single_word\": false, \"lstrip\": false, \"rstrip\": false, \"normalized\": false, \"special\": true}".to_string(),
];
for &(id, content) in extra_added {
assert!(
id >= next_id,
"extra-added id {id} must be > base vocab top id {}",
next_id - 1
);
vocab_entries.push(format!("\"{}\": {}", content, id));
added_entries.push(format!(
"{{\"id\": {}, \"content\": \"{}\", \"single_word\": false, \"lstrip\": false, \"rstrip\": false, \"normalized\": false, \"special\": true}}",
id, content
));
}
format!(
r#"{{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{}
],
"normalizer": null,
"pre_tokenizer": {{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true
}},
"post_processor": null,
"decoder": {{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true
}},
"model": {{
"type": "BPE",
"dropout": null,
"unk_token": "<unk>",
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"vocab": {{
{}
}},
"merges": []
}}
}}"#,
added_entries.join(",\n "),
vocab_entries.join(",\n ")
)
}
fn fixture_tokenizer_with_eos_override(
name: &str,
extra_added: &[(u32, &str)],
eos_override: &[u32],
) -> mlxrs::tokenizer::Tokenizer {
let dir = temp_dir(name);
let tj_path = dir.join("tokenizer.json");
let mut tj = fs::File::create(&tj_path).unwrap();
tj.write_all(build_byte_level_tokenizer_json_with_extras(extra_added).as_bytes())
.unwrap();
let mut tc = fs::File::create(dir.join("tokenizer_config.json")).unwrap();
tc.write_all(TOKENIZER_CONFIG_JSON.as_bytes()).unwrap();
mlxrs::tokenizer::Tokenizer::from_path(&dir, Some(eos_override))
.unwrap_or_else(|e| panic!("fixture tokenizer load failed: {e}"))
}
const CUSTOM_EOS_ID: u32 = 98;
#[test]
fn llguidance_terminal_grammar_uses_mlxrs_configured_custom_eos_id() {
let tok = fixture_tokenizer_with_eos_override(
"terminal_custom_eos",
&[(CUSTOM_EOS_ID, "<|im_end|>")],
&[CUSTOM_EOS_ID],
);
let eos_vec: Vec<u32> = tok.eos_token_ids_iter().collect();
assert_eq!(
eos_vec,
vec![CUSTOM_EOS_ID],
"mlxrs Tokenizer::eos_token_ids() must reflect the from_path override"
);
let grammar = structured::GrammarSpec::Regex("a".to_string());
let proc = structured::LLGuidanceLogitsProcessor::new(grammar, &tok, None)
.expect("processor construction with custom eos override should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, vocab)).unwrap();
let _ = proc
.apply(&[], &logits)
.expect("first apply should succeed");
let a_token_id = id_for_byte(b'a');
let mut out = proc
.apply(&[a_token_id], &logits)
.expect("post-consume apply should succeed");
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), vocab);
assert!(
out_v[CUSTOM_EOS_ID as usize].is_finite(),
"EOS-only mask: custom eos id {CUSTOM_EOS_ID} must remain finite, got {}",
out_v[CUSTOM_EOS_ID as usize]
);
assert!(
out_v[0].is_infinite() && out_v[0] < 0.0,
"EOS-only mask: id 0 (upstream's default `tok_eos`) must be -inf, got {}",
out_v[0]
);
assert!(
out_v[FIXTURE_EOS_ID].is_infinite() && out_v[FIXTURE_EOS_ID] < 0.0,
"EOS-only mask: hardcoded `</s>` id {FIXTURE_EOS_ID} must be -inf (override replaces), got {}",
out_v[FIXTURE_EOS_ID]
);
for (i, v) in out_v.iter().enumerate() {
if i as u32 == CUSTOM_EOS_ID {
continue;
}
assert!(
v.is_infinite() && *v < 0.0,
"EOS-only mask: non-eos id {i} must be -inf, got {v}"
);
}
}
#[test]
fn llguidance_terminal_grammar_multi_eos_unmasks_all_configured_ids() {
let tok = fixture_tokenizer_with_eos_override(
"terminal_multi_eos",
&[(CUSTOM_EOS_ID, "<|im_end|>")],
&[1, 2, CUSTOM_EOS_ID],
);
let eos_vec: Vec<u32> = tok.eos_token_ids_iter().collect();
assert_eq!(
eos_vec,
vec![1, 2, CUSTOM_EOS_ID],
"mlxrs Tokenizer::eos_token_ids() must hold all three caller-supplied ids"
);
let grammar = structured::GrammarSpec::Regex("a".to_string());
let proc = structured::LLGuidanceLogitsProcessor::new(grammar, &tok, None)
.expect("processor construction with multi-eos override should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, vocab)).unwrap();
let _ = proc
.apply(&[], &logits)
.expect("first apply should succeed");
let a_token_id = id_for_byte(b'a');
let mut out = proc
.apply(&[a_token_id], &logits)
.expect("post-consume apply should succeed");
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), vocab);
for &eos in &[1u32, 2u32, CUSTOM_EOS_ID] {
assert!(
out_v[eos as usize].is_finite(),
"EOS-only mask: configured eos id {eos} must remain finite, got {}",
out_v[eos as usize]
);
}
for (i, v) in out_v.iter().enumerate() {
let id = i as u32;
if id == 1 || id == 2 || id == CUSTOM_EOS_ID {
continue;
}
assert!(
v.is_infinite() && *v < 0.0,
"EOS-only mask: non-eos id {i} must be -inf, got {v}"
);
}
}
#[test]
fn llguidance_processor_accepts_padded_model_vocab() {
let tok = fixture_tokenizer("padded_model_vocab");
let tok_vocab = tok.hf().get_vocab_size(true);
let model_vocab = tok_vocab + 8;
let proc = structured::build_json_schema_logits_processor(
json!({ "type": "object" }),
&tok,
Some(model_vocab),
)
.expect("processor construction should succeed with padded model vocab");
let zeros = vec![0.0f32; model_vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, model_vocab)).unwrap();
let mut out = proc
.apply(&[], &logits)
.expect("apply should succeed when model vocab is padded via Some(n)");
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), model_vocab);
let open_brace = id_for_byte(b'{') as usize;
assert!(
out_v[open_brace].is_finite(),
"padded-vocab mask: `{{` (id {open_brace}) must remain finite, got {}",
out_v[open_brace]
);
for (i, v) in out_v.iter().enumerate().take(model_vocab).skip(tok_vocab) {
assert!(
v.is_infinite() && *v < 0.0,
"padded placeholder id {i} must be -inf, got {v}"
);
}
}
#[test]
fn llguidance_terminal_grammar_rejects_out_of_range_eos_id_without_panic() {
let tok = fixture_tokenizer_with_eos_override(
"out_of_range_eos",
&[(CUSTOM_EOS_ID, "<|im_end|>")],
&[4242],
);
let bound = tok.hf().get_vocab_size(true);
let grammar = structured::GrammarSpec::Regex("a".to_string());
let result = structured::LLGuidanceLogitsProcessor::new(grammar, &tok, None);
match result {
Err(mlxrs::Error::OutOfRange(p)) => {
assert!(
p.context().contains("EOS token id") || p.context().contains("llguidance"),
"context names the offending parameter: {}",
p.context()
);
assert!(
p.value().contains("4242"),
"value carries the offending id 4242, got: {}",
p.value()
);
assert!(
p.value().contains(&bound.to_string()),
"value carries the vocab bound {bound}, got: {}",
p.value()
);
assert!(
p.requirement().contains("vocab"),
"requirement names the vocab-bound rule: {}",
p.requirement()
);
}
Err(other) => panic!("expected Error::OutOfRange, got different Err: {other:?}"),
Ok(_) => panic!("out-of-range eos id must yield Err, not Ok"),
}
}
#[test]
fn llguidance_terminal_grammar_accepts_padded_model_eos_id() {
let tok_padded_only = fixture_tokenizer_with_eos_override(
"padded_eos_accepted_padded_range",
&[(CUSTOM_EOS_ID, "<|im_end|>")],
&[120],
);
let tok_vocab_a = tok_padded_only.hf().get_vocab_size(true);
assert!(
120 >= tok_vocab_a as u32,
"test premise: eos id 120 must be in the PADDED range \
(above the tokenizer vocab {tok_vocab_a})"
);
let model_vocab = 128usize;
assert!(
120usize < model_vocab,
"test premise: eos id 120 must be inside model_vocab {model_vocab}"
);
let grammar_a = structured::GrammarSpec::Regex("a".to_string());
let proc_a =
structured::LLGuidanceLogitsProcessor::new(grammar_a, &tok_padded_only, Some(model_vocab))
.expect(
"padded-model EOS id (in [tokenizer_vocab, model_vocab)) must be ACCEPTED — \
constructor must not panic and must not return Err",
);
let zeros_a = vec![0.0f32; model_vocab];
let logits_a = Array::from_slice::<f32>(&zeros_a, &(1, model_vocab)).unwrap();
let _ = proc_a
.apply(&[], &logits_a)
.expect("first apply should succeed");
let a_token_id_a = id_for_byte(b'a');
let mut out_a = proc_a
.apply(&[a_token_id_a], &logits_a)
.expect("post-consume apply should succeed and return EOS-only mask");
let out_va = out_a.to_vec::<f32>().unwrap();
assert_eq!(out_va.len(), model_vocab);
assert!(
out_va[120].is_finite(),
"padded-range EOS-only mask: padded eos id 120 must remain finite, got {}",
out_va[120]
);
assert!(
out_va[0].is_infinite() && out_va[0] < 0.0,
"padded-range EOS-only mask: id 0 (upstream fallback default) must be -inf, got {}",
out_va[0]
);
for (i, v) in out_va.iter().enumerate() {
if i == 120 {
continue;
}
assert!(
v.is_infinite() && *v < 0.0,
"padded-range EOS-only mask: non-eos id {i} must be -inf, got {v}"
);
}
let tok_unpadded = fixture_tokenizer_with_eos_override(
"padded_eos_accepted_unpadded_range",
&[(CUSTOM_EOS_ID, "<|im_end|>")],
&[CUSTOM_EOS_ID],
);
let tok_vocab_b = tok_unpadded.hf().get_vocab_size(true);
assert!(
(CUSTOM_EOS_ID as usize) < tok_vocab_b,
"test premise: CUSTOM_EOS_ID ({CUSTOM_EOS_ID}) must be in the unpadded \
tokenizer vocab {tok_vocab_b}"
);
let grammar_b = structured::GrammarSpec::Regex("a".to_string());
let proc_b =
structured::LLGuidanceLogitsProcessor::new(grammar_b, &tok_unpadded, Some(model_vocab))
.expect("padded model_vocab with in-range EOS must be accepted");
let zeros = vec![0.0f32; model_vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, model_vocab)).unwrap();
let _ = proc_b
.apply(&[], &logits)
.expect("first apply should succeed");
let a_token_id = id_for_byte(b'a');
let mut out = proc_b
.apply(&[a_token_id], &logits)
.expect("post-consume apply should succeed and return EOS-only mask");
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), model_vocab);
assert!(
out_v[CUSTOM_EOS_ID as usize].is_finite(),
"padded-vocab EOS-only mask: id {CUSTOM_EOS_ID} must remain finite, got {}",
out_v[CUSTOM_EOS_ID as usize]
);
for (i, v) in out_v.iter().enumerate() {
if i as u32 == CUSTOM_EOS_ID {
continue;
}
assert!(
v.is_infinite() && *v < 0.0,
"padded-vocab EOS-only mask: non-eos id {i} must be -inf, got {v}"
);
}
}
#[test]
fn llguidance_terminal_grammar_padded_eos_id_actually_unmasks_at_runtime() {
let tok = fixture_tokenizer_with_eos_override(
"r4_padded_eos_runtime_mask",
&[(CUSTOM_EOS_ID, "<|im_end|>")],
&[120],
);
let bt_vocab = tok.hf().get_vocab_size(true);
assert!(
120u32 >= bt_vocab as u32,
"test premise: eos id 120 must be padded-range \
(above the unpadded tokenizer vocab {bt_vocab})"
);
let model_vocab = 128usize;
assert!(
120usize < model_vocab,
"test premise: eos id 120 must be inside model_vocab {model_vocab}"
);
let grammar = structured::GrammarSpec::Regex("a".to_string());
let proc = structured::LLGuidanceLogitsProcessor::new(grammar, &tok, Some(model_vocab))
.expect("padded-only EOS construction must succeed");
let zeros = vec![0.0f32; model_vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, model_vocab)).unwrap();
let _ = proc
.apply(&[], &logits)
.expect("first apply should succeed");
let a_token_id = id_for_byte(b'a');
let mut out = proc
.apply(&[a_token_id], &logits)
.expect("post-consume apply should succeed and return EOS-only mask");
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), model_vocab);
assert!(
out_v[120].is_finite(),
"padded-range EOS-only mask: padded eos id 120 must remain finite, got {}",
out_v[120]
);
assert!(
out_v[0].is_infinite() && out_v[0] < 0.0,
"padded-range EOS-only mask: id 0 (upstream's fallback default) \
must be -inf, got {}",
out_v[0]
);
for (i, v) in out_v.iter().enumerate() {
if i == 120 {
continue;
}
assert!(
v.is_infinite() && *v < 0.0,
"padded-range EOS-only mask: non-eos id {i} must be -inf, got {v}"
);
}
}
#[test]
fn llguidance_terminal_grammar_mixed_in_range_plus_padded_eos() {
let tok = fixture_tokenizer_with_eos_override(
"r4_mixed_eos_in_range_plus_padded",
&[(CUSTOM_EOS_ID, "<|im_end|>")],
&[CUSTOM_EOS_ID, 120],
);
let bt_vocab = tok.hf().get_vocab_size(true);
assert!(
(CUSTOM_EOS_ID as usize) < bt_vocab,
"test premise: CUSTOM_EOS_ID ({CUSTOM_EOS_ID}) must be in the unpadded \
tokenizer vocab {bt_vocab}"
);
assert!(
120u32 >= bt_vocab as u32,
"test premise: eos id 120 must be padded-range \
(above the unpadded tokenizer vocab {bt_vocab})"
);
let model_vocab = 128usize;
let grammar = structured::GrammarSpec::Regex("a".to_string());
let proc = structured::LLGuidanceLogitsProcessor::new(grammar, &tok, Some(model_vocab))
.expect("mixed in-range + padded EOS construction must succeed");
let zeros = vec![0.0f32; model_vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, model_vocab)).unwrap();
let _ = proc
.apply(&[], &logits)
.expect("first apply should succeed");
let a_token_id = id_for_byte(b'a');
let mut out = proc
.apply(&[a_token_id], &logits)
.expect("post-consume apply should succeed and return EOS-only mask");
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), model_vocab);
assert!(
out_v[CUSTOM_EOS_ID as usize].is_finite(),
"mixed EOS-only mask: in-range eos id {CUSTOM_EOS_ID} must remain finite, got {}",
out_v[CUSTOM_EOS_ID as usize]
);
assert!(
out_v[120].is_finite(),
"mixed EOS-only mask: padded-range eos id 120 must remain finite, got {}",
out_v[120]
);
for (i, v) in out_v.iter().enumerate() {
if i as u32 == CUSTOM_EOS_ID || i == 120 {
continue;
}
assert!(
v.is_infinite() && *v < 0.0,
"mixed EOS-only mask: non-eos id {i} must be -inf, got {v}"
);
}
}