use super::*;
use serde_json::json;
fn byte_level_tokenizer_json() -> String {
let mut vocab_entries: Vec<String> = Vec::new();
vocab_entries.push("\"<unk>\": 0".to_string());
vocab_entries.push("\"<s>\": 1".to_string());
vocab_entries.push("\"</s>\": 2".to_string());
let mut next_id: u32 = 3;
let mut k: u32 = 0x100;
let mut char_map: Vec<char> = Vec::with_capacity(256);
for byte in 0..=255u8 {
let c = byte as char;
let mapped = match c {
'!'..='~' => c,
'\u{00A1}'..='\u{00AC}' => c,
'\u{00AE}'..='\u{00FF}' => c,
_ => {
let m = char::from_u32(k).unwrap();
k += 1;
m
}
};
char_map.push(mapped);
}
for byte in 0x20u8..=0x7Eu8 {
let glyph = char_map[byte as usize];
let escaped = match glyph {
'"' => "\\\"".to_string(),
'\\' => "\\\\".to_string(),
c => c.to_string(),
};
vocab_entries.push(format!("\"{}\": {}", escaped, next_id));
next_id += 1;
}
format!(
r#"{{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{{"id": 0, "content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}},
{{"id": 1, "content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}},
{{"id": 2, "content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}}
],
"normalizer": null,
"pre_tokenizer": {{ "type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true }},
"post_processor": null,
"decoder": {{ "type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true }},
"model": {{
"type": "BPE",
"dropout": null,
"unk_token": "<unk>",
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"vocab": {{ {} }},
"merges": []
}}
}}"#,
vocab_entries.join(",\n ")
)
}
const TOKENIZER_CONFIG_JSON: &str = r#"{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"model_max_length": 2048
}"#;
fn temp_dir(name: &str) -> std::path::PathBuf {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!(
"mlxrs_structured_inline_{}_{}_{n}",
std::process::id(),
name
));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
dir
}
fn fixture_tokenizer(name: &str) -> Tokenizer {
let dir = temp_dir(name);
std::fs::write(dir.join("tokenizer.json"), byte_level_tokenizer_json()).unwrap();
std::fs::write(dir.join("tokenizer_config.json"), TOKENIZER_CONFIG_JSON).unwrap();
Tokenizer::from_path(&dir, None).unwrap_or_else(|e| panic!("fixture tokenizer load failed: {e}"))
}
fn id_for_byte(byte: u8) -> u32 {
assert!(
(0x20..=0x7E).contains(&byte),
"byte {byte:#x} not in fixture vocab"
);
3 + (byte - 0x20) as u32
}
#[test]
fn apply_single_rank_vector_logits_masks_by_grammar() {
let tok = fixture_tokenizer("single_rank_vector");
let proc = build_json_schema_logits_processor(json!({ "type": "object" }), &tok, None)
.expect("processor construction should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(vocab,)).unwrap();
let mut out = proc
.apply(&[], &logits)
.expect("apply should succeed on a rank-1 `[V]` logits row");
assert_eq!(out.shape(), vec![vocab]);
let out_v = out.to_vec::<f32>().unwrap();
assert_eq!(out_v.len(), vocab);
let open_brace = id_for_byte(b'{') as usize;
assert!(
out_v[open_brace].is_finite(),
"`{{` (id {open_brace}) must remain finite in a `[V]` mask, got {}",
out_v[open_brace]
);
let a = id_for_byte(b'a') as usize;
assert!(
out_v[a].is_infinite() && out_v[a] < 0.0,
"`a` (id {a}) must be -inf in a `[V]` mask, got {}",
out_v[a]
);
let close_brace = id_for_byte(b'}') as usize;
assert!(
out_v[close_brace].is_infinite() && out_v[close_brace] < 0.0,
"`}}` (id {close_brace}) must be -inf in a `[V]` mask, got {}",
out_v[close_brace]
);
}
#[test]
fn apply_rejects_batched_rank2_logits_with_rank_mismatch() {
let tok = fixture_tokenizer("rank2_reject");
let proc = build_json_schema_logits_processor(json!({ "type": "object" }), &tok, None)
.expect("processor construction should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; 2 * vocab];
let logits = Array::from_slice::<f32>(&zeros, &(2, vocab)).unwrap();
match proc.apply(&[], &logits) {
Err(Error::RankMismatch(p)) => {
assert_eq!(p.actual(), 2, "observed rank must be 2");
assert_eq!(
p.actual_shape(),
[2usize, vocab].as_slice(),
"payload must carry the full observed shape"
);
assert!(
p.context().contains("[V]") && p.context().contains("[1, V]"),
"context names the accepted shapes: {}",
p.context()
);
}
Err(other) => panic!("expected Error::RankMismatch, got: {other:?}"),
Ok(_) => panic!("rank-2 `[2, V]` logits must be rejected, not accepted"),
}
}
#[test]
fn apply_rejects_rank3_logits_with_rank_mismatch() {
let tok = fixture_tokenizer("rank3_reject");
let proc = build_json_schema_logits_processor(json!({ "type": "object" }), &tok, None)
.expect("processor construction should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, 1, vocab)).unwrap();
match proc.apply(&[], &logits) {
Err(Error::RankMismatch(p)) => {
assert_eq!(p.actual(), 3, "observed rank must be 3");
assert_eq!(p.actual_shape(), [1usize, 1, vocab].as_slice());
}
Err(other) => panic!("expected Error::RankMismatch, got: {other:?}"),
Ok(_) => panic!("rank-3 logits must be rejected"),
}
}
#[test]
fn apply_consume_disallowed_token_surfaces_parse_error() {
let tok = fixture_tokenizer("consume_disallowed");
let proc = LLGuidanceLogitsProcessor::new(GrammarSpec::Regex("a".to_string()), &tok, None)
.expect("regex grammar construction should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, vocab)).unwrap();
let _ = proc
.apply(&[], &logits)
.expect("first apply should succeed");
let b_id = id_for_byte(b'b');
match proc.apply(&[b_id], &logits) {
Err(Error::Parse(p)) => {
assert!(
p.context().contains("consume_token"),
"context must name consume_token: {}",
p.context()
);
assert!(
p.inner().to_string().contains(&b_id.to_string()),
"the inner error should carry the offending token id {b_id}: {}",
p.inner()
);
}
Err(other) => panic!("expected Error::Parse(consume_token), got: {other:?}"),
Ok(_) => panic!("consuming a grammar-disallowed token must error, not succeed"),
}
}
#[test]
fn new_with_uncompilable_grammar_errors_at_construction() {
let tok = fixture_tokenizer("bad_grammar_compile");
let lark = "start: undefined_rule\n".to_string();
match LLGuidanceLogitsProcessor::new(GrammarSpec::Lark(lark), &tok, None) {
Err(Error::Parse(p)) => {
assert!(
p.context().contains("grammar compile"),
"context must name the grammar-compile stage: {}",
p.context()
);
assert_eq!(p.input_kind(), "llguidance grammar");
}
Err(other) => panic!("expected Error::Parse(grammar compile), got: {other:?}"),
Ok(_) => panic!("an uncompilable grammar must yield Err at construction, not Ok"),
}
}
#[test]
fn apply_errors_when_mask_narrower_than_unpadded_logits() {
let tok = fixture_tokenizer("mask_narrower_than_logits");
let tok_vocab = tok.hf().get_vocab_size(true);
let model_vocab = tok_vocab + 64;
let proc = build_json_schema_logits_processor(json!({ "type": "object" }), &tok, None)
.expect("processor construction should succeed");
let zeros = vec![0.0f32; model_vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, model_vocab)).unwrap();
match proc.apply(&[], &logits) {
Err(Error::LengthMismatch(p)) => {
assert_eq!(
p.expected(),
model_vocab,
"expected = logits vocab width ({model_vocab})"
);
assert_eq!(
p.actual(),
tok_vocab,
"actual = matcher mask length ({tok_vocab})"
);
assert!(
p.context().contains("mask vs logits vocab"),
"context names the mask-vs-vocab comparison: {}",
p.context()
);
}
Err(other) => panic!("expected Error::LengthMismatch, got: {other:?}"),
Ok(_) => panic!("a too-narrow matcher mask must error, not silently pass"),
}
}
#[test]
fn reset_returns_to_initial_first_step_state() {
let tok = fixture_tokenizer("reset_first_step");
let proc = LLGuidanceLogitsProcessor::new(GrammarSpec::Regex("a".to_string()), &tok, None)
.expect("regex grammar construction should succeed");
let vocab = tok.hf().get_vocab_size(true);
let zeros = vec![0.0f32; vocab];
let logits = Array::from_slice::<f32>(&zeros, &(1, vocab)).unwrap();
let a_id = id_for_byte(b'a');
let eos_id = 2usize;
let _ = proc
.apply(&[], &logits)
.expect("first apply should succeed");
let mut terminal = proc
.apply(&[a_id], &logits)
.expect("post-consume apply should return the EOS-only mask");
let terminal_v = terminal.to_vec::<f32>().unwrap();
assert!(
terminal_v[eos_id].is_finite(),
"pre-reset terminal mask: eos id {eos_id} must be finite, got {}",
terminal_v[eos_id]
);
assert!(
terminal_v[a_id as usize].is_infinite() && terminal_v[a_id as usize] < 0.0,
"pre-reset terminal mask: `a` (id {a_id}) must be -inf, got {}",
terminal_v[a_id as usize]
);
proc.reset().expect("reset should succeed");
let mut after = proc
.apply(&[], &logits)
.expect("post-reset first apply should succeed");
let after_v = after.to_vec::<f32>().unwrap();
assert_eq!(after_v.len(), vocab);
assert!(
after_v[a_id as usize].is_finite(),
"post-reset initial mask: `a` (id {a_id}) must be finite again, got {}",
after_v[a_id as usize]
);
for (i, v) in after_v.iter().enumerate() {
if i as u32 == a_id {
continue;
}
assert!(
v.is_infinite() && *v < 0.0,
"post-reset initial mask: non-`a` id {i} must be -inf, got {v}"
);
}
}