use std::collections::HashMap;
use serde_json::Value;
use super::schema::{Schema, SchemaError};
use super::state::{JsonGrammar, StepError};
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum GrammarError {
#[error("schema compile error: {0}")]
Schema(#[from] SchemaError),
#[error("grammar step error: {0}")]
Step(#[from] StepError),
#[error("token id {0} out of range")]
InvalidTokenId(u32),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenMask {
pub allow: Vec<u32>,
}
impl TokenMask {
pub fn allow_all(vocab_size: usize) -> Self {
Self {
allow: vec![1; vocab_size],
}
}
pub fn num_allowed(&self) -> usize {
self.allow.iter().filter(|x| **x != 0).count()
}
}
#[derive(Debug, Default, Clone)]
pub struct TokenTransitionCache {
entries: HashMap<(String, usize), bool>,
hits: u64,
misses: u64,
}
impl TokenTransitionCache {
pub fn new() -> Self {
Self::default()
}
pub fn hits(&self) -> u64 {
self.hits
}
pub fn misses(&self) -> u64 {
self.misses
}
pub fn reset_counters(&mut self) {
self.hits = 0;
self.misses = 0;
}
pub fn clear(&mut self) {
self.entries.clear();
self.hits = 0;
self.misses = 0;
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[derive(Debug)]
pub struct JsonSchemaProcessor {
grammar: JsonGrammar,
vocab: Vec<String>,
}
impl JsonSchemaProcessor {
pub fn new(schema: &Value, vocab: Vec<String>) -> Result<Self, GrammarError> {
let schema = Schema::from_json_schema(schema)?;
Ok(Self {
grammar: JsonGrammar::new(schema),
vocab,
})
}
pub fn from_compiled(schema: Schema, vocab: Vec<String>) -> Self {
Self {
grammar: JsonGrammar::new(schema),
vocab,
}
}
pub fn vocab_len(&self) -> usize {
self.vocab.len()
}
pub fn compute_mask(&self) -> TokenMask {
let mut allow = vec![0u32; self.vocab.len()];
if self.grammar.is_complete() {
return TokenMask { allow };
}
for (i, tok) in self.vocab.iter().enumerate() {
if tok.is_empty() {
continue;
}
let mut probe = self.grammar.clone();
let mut ok = true;
for c in tok.chars() {
if probe.step_char(c).is_err() {
ok = false;
break;
}
}
if ok {
allow[i] = 1;
}
}
TokenMask { allow }
}
pub fn step_token(&mut self, token_id: u32) -> Result<(), GrammarError> {
let idx = token_id as usize;
let tok = self
.vocab
.get(idx)
.ok_or(GrammarError::InvalidTokenId(token_id))?;
for c in tok.chars() {
self.grammar.step_char(c)?;
}
Ok(())
}
pub fn is_complete(&self) -> bool {
self.grammar.is_complete()
}
pub fn grammar(&self) -> &JsonGrammar {
&self.grammar
}
pub fn compute_mask_cached(&self, cache: &mut TokenTransitionCache) -> TokenMask {
let mut allow = vec![0u32; self.vocab.len()];
if self.grammar.is_complete() {
return TokenMask { allow };
}
let mut chars = self.grammar.valid_next_chars();
chars.sort_unstable();
chars.dedup();
let signature: String = chars.iter().collect();
for (i, tok) in self.vocab.iter().enumerate() {
if tok.is_empty() {
continue;
}
let key = (signature.clone(), i);
let accept = if let Some(&hit) = cache.entries.get(&key) {
cache.hits += 1;
hit
} else {
cache.misses += 1;
let mut probe = self.grammar.clone();
let mut ok = true;
for c in tok.chars() {
if probe.step_char(c).is_err() {
ok = false;
break;
}
}
cache.entries.insert(key, ok);
ok
};
if accept {
allow[i] = 1;
}
}
TokenMask { allow }
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn ascii_char_vocab() -> Vec<String> {
(0x20u8..=0x7Eu8).map(|b| (b as char).to_string()).collect()
}
#[test]
fn boolean_schema_only_allows_t_or_f_at_start() {
let processor =
JsonSchemaProcessor::new(&json!({"type": "boolean"}), ascii_char_vocab()).unwrap();
let mask = processor.compute_mask();
let allowed_chars: Vec<char> = (0..mask.allow.len())
.filter(|&i| mask.allow[i] != 0)
.map(|i| processor.vocab[i].chars().next().unwrap())
.collect();
assert!(allowed_chars.contains(&'t'));
assert!(allowed_chars.contains(&'f'));
assert!(!allowed_chars.contains(&'a'));
assert!(!allowed_chars.contains(&'1'));
}
#[test]
fn step_token_advances_state() {
let vocab = ascii_char_vocab();
let mut processor =
JsonSchemaProcessor::new(&json!({"type": "boolean"}), vocab.clone()).unwrap();
let t_id = vocab.iter().position(|s| s == "t").unwrap() as u32;
processor.step_token(t_id).unwrap();
let mask = processor.compute_mask();
let r_id = vocab.iter().position(|s| s == "r").unwrap() as u32;
assert_eq!(mask.allow[r_id as usize], 1);
let x_id = vocab.iter().position(|s| s == "x").unwrap() as u32;
assert_eq!(mask.allow[x_id as usize], 0);
}
#[test]
fn invalid_token_id_returns_error() {
let mut processor =
JsonSchemaProcessor::new(&json!({"type": "boolean"}), ascii_char_vocab()).unwrap();
let err = processor.step_token(99999).unwrap_err();
assert!(matches!(err, GrammarError::InvalidTokenId(99999)));
}
#[test]
fn token_transition_cache_byte_equal_and_hits() {
let processor =
JsonSchemaProcessor::new(&json!({"type": "boolean"}), ascii_char_vocab()).unwrap();
let baseline = processor.compute_mask();
let mut cache = TokenTransitionCache::new();
let cached1 = processor.compute_mask_cached(&mut cache);
assert_eq!(baseline.allow, cached1.allow);
let miss_after_first = cache.misses();
let cached2 = processor.compute_mask_cached(&mut cache);
assert_eq!(baseline.allow, cached2.allow);
assert!(cache.hits() > 0);
assert_eq!(cache.misses(), miss_after_first);
}
fn greedy_complete(processor: &mut JsonSchemaProcessor, max_steps: usize) -> String {
let mut emitted = String::new();
for _ in 0..max_steps {
if processor.is_complete() {
break;
}
let mask = processor.compute_mask();
let choice = mask.allow.iter().rposition(|x| *x != 0);
let Some(idx) = choice else { break };
emitted.push_str(&processor.vocab[idx]);
processor.step_token(idx as u32).unwrap();
}
emitted
}
#[test]
fn extraction_response_shaped_schema_step_by_step() {
let schema = json!({
"type": "object",
"properties": {
"value": {"type": "number"},
"confidence": {"enum": ["high", "medium", "low"]},
"notes": {"type": ["string", "null"]}
},
"required": ["value", "confidence"]
});
let vocab = ascii_char_vocab();
let mut p = JsonSchemaProcessor::new(&schema, vocab.clone()).unwrap();
let target = "{\"confidence\":\"high\",\"value\":-3.14}";
for c in target.chars() {
let tok = c.to_string();
let id = vocab.iter().position(|s| s == &tok).unwrap();
let mask = p.compute_mask();
assert_eq!(
mask.allow[id],
1,
"char {c:?} masked at point of emitting {target:?}; \
emitted-so-far valid_next from grammar: {:?}",
p.grammar().valid_next_chars()
);
p.step_token(id as u32).unwrap();
}
assert!(p.is_complete(), "did not complete after {target:?}");
let parsed: serde_json::Value = serde_json::from_str(target).unwrap();
let obj = parsed.as_object().unwrap();
assert_eq!(obj.get("confidence").unwrap().as_str(), Some("high"));
#[allow(clippy::approx_constant)]
let expected_value = -3.14_f64;
assert_eq!(obj.get("value").unwrap().as_f64(), Some(expected_value));
}
#[test]
fn extraction_response_rejects_unknown_key() {
let schema = json!({
"type": "object",
"properties": {
"value": {"type": "number"},
"confidence": {"enum": ["high", "medium", "low"]},
"notes": {"type": ["string", "null"]}
},
"required": ["value", "confidence"]
});
let vocab = ascii_char_vocab();
let mut p = JsonSchemaProcessor::new(&schema, vocab.clone()).unwrap();
for c in "{\"".chars() {
let id = vocab.iter().position(|s| s == &c.to_string()).unwrap();
p.step_token(id as u32).unwrap();
}
let mask = p.compute_mask();
let b_id = vocab.iter().position(|s| s == "b").unwrap();
let c_id = vocab.iter().position(|s| s == "c").unwrap();
let n_id = vocab.iter().position(|s| s == "n").unwrap();
let v_id = vocab.iter().position(|s| s == "v").unwrap();
assert_eq!(mask.allow[b_id], 0, "bogus prefix should be masked");
assert_eq!(mask.allow[c_id], 1);
assert_eq!(mask.allow[n_id], 1);
assert_eq!(mask.allow[v_id], 1);
}
#[test]
fn nested_object_schema_completes() {
let schema = json!({
"type": "object",
"properties": {
"outer": {
"type": "object",
"properties": {"inner": {"type": "boolean"}},
"required": ["inner"]
}
},
"required": ["outer"]
});
let mut p = JsonSchemaProcessor::new(&schema, ascii_char_vocab()).unwrap();
let out = greedy_complete(&mut p, 256);
let parsed: serde_json::Value = serde_json::from_str(&out).expect("valid nested JSON");
let outer = parsed.as_object().unwrap().get("outer").unwrap();
let inner = outer.as_object().unwrap().get("inner").unwrap();
assert!(inner.is_boolean());
}
#[test]
fn array_of_integers_step_by_step() {
let schema = json!({"type": "array", "items": {"type": "integer"}});
let vocab = ascii_char_vocab();
let mut p = JsonSchemaProcessor::new(&schema, vocab.clone()).unwrap();
for c in "[]".chars() {
let id = vocab.iter().position(|s| s == &c.to_string()).unwrap();
assert_eq!(p.compute_mask().allow[id], 1, "char {c:?} masked");
p.step_token(id as u32).unwrap();
}
assert!(p.is_complete());
let mut p = JsonSchemaProcessor::new(&schema, vocab.clone()).unwrap();
for c in "[1,2,3]".chars() {
let id = vocab.iter().position(|s| s == &c.to_string()).unwrap();
assert_eq!(p.compute_mask().allow[id], 1);
p.step_token(id as u32).unwrap();
}
assert!(p.is_complete());
}
#[test]
fn nullable_string_can_emit_null_or_string() {
let schema = json!({"type": ["string", "null"]});
let p = JsonSchemaProcessor::new(&schema, ascii_char_vocab()).unwrap();
let mask = p.compute_mask();
let n_id = p.vocab.iter().position(|s| s == "n").unwrap();
let q_id = p.vocab.iter().position(|s| s == "\"").unwrap();
assert_eq!(mask.allow[n_id], 1);
assert_eq!(mask.allow[q_id], 1);
}
#[test]
fn enum_schema_only_allows_listed_values() {
let schema = json!({"enum": ["high", "low"]});
let mut p = JsonSchemaProcessor::new(&schema, ascii_char_vocab()).unwrap();
let q_id = p.vocab.iter().position(|s| s == "\"").unwrap();
p.step_token(q_id as u32).unwrap();
let mask = p.compute_mask();
let h_id = p.vocab.iter().position(|s| s == "h").unwrap();
let l_id = p.vocab.iter().position(|s| s == "l").unwrap();
let m_id = p.vocab.iter().position(|s| s == "m").unwrap();
assert_eq!(mask.allow[h_id], 1);
assert_eq!(mask.allow[l_id], 1);
assert_eq!(mask.allow[m_id], 0);
}
const SAMPLED_COMPLETIONS_PER_SCHEMA: usize =
if cfg!(debug_assertions) { 1000 } else { 10_000 };
#[test]
fn sampled_completions_always_validate() {
let schemas = [
json!({
"type": "object",
"properties": {
"value": {"type": "number"},
"confidence": {"enum": ["high", "medium", "low"]}
},
"required": ["value", "confidence"]
}),
json!({
"type": "object",
"properties": {
"inner": {
"type": "object",
"properties": {"v": {"type": "boolean"}},
"required": ["v"]
}
},
"required": ["inner"]
}),
json!({"type": "array", "items": {"type": "integer"}}),
json!({"enum": ["red", "green", "blue"]}),
json!({"type": ["string", "null"]}),
];
let vocab = ascii_char_vocab();
for (i, schema) in schemas.iter().enumerate() {
let mut state: u32 = 0x1234_5678 ^ (i as u32);
let mut next = || {
state = state.wrapping_mul(1_103_515_245).wrapping_add(12345);
state
};
for trial in 0..SAMPLED_COMPLETIONS_PER_SCHEMA {
let mut p = JsonSchemaProcessor::new(schema, vocab.clone()).unwrap();
let mut emitted = String::new();
for _ in 0..256 {
if p.is_complete() {
break;
}
let mask = p.compute_mask();
let allowed: Vec<usize> = mask
.allow
.iter()
.enumerate()
.filter_map(|(idx, a)| (*a != 0).then_some(idx))
.collect();
if allowed.is_empty() {
break;
}
let pick = allowed[(next() as usize) % allowed.len()];
emitted.push_str(&p.vocab[pick]);
p.step_token(pick as u32).unwrap();
}
if p.is_complete() {
let parsed: Result<serde_json::Value, _> = serde_json::from_str(&emitted);
if let Err(e) = &parsed {
let msg = e.to_string();
let is_numeric_range_err = msg.contains("number out of range");
assert!(
is_numeric_range_err,
"schema {i} trial {trial}: emitted invalid JSON: {emitted:?} (err: {msg})"
);
}
} else {
let _ = emitted;
}
}
}
}
}