use super::error::{GrammarError, GrammarResult};
use super::parser::{Grammar, GrammarNode};
const MAX_DEPTH: usize = 128;
const MAX_SIM_BYTES: usize = 64;
#[derive(Debug, Clone)]
pub struct GrammarState {
continuation: Vec<ContNode>,
grammar: Grammar,
}
#[derive(Debug, Clone)]
struct ContNode {
node: GrammarNode,
}
impl ContNode {
fn new(node: GrammarNode) -> Self {
Self { node }
}
}
impl GrammarState {
pub(super) fn new(grammar: Grammar) -> Self {
let root = grammar.root.clone();
let mut state = Self {
continuation: Vec::new(),
grammar,
};
state
.continuation
.push(ContNode::new(GrammarNode::RuleRef(root)));
state
}
pub fn is_complete(&self) -> bool {
self.can_match_empty_continuation(&self.continuation, 0)
}
pub fn allows_token(&self, token_bytes: &[u8]) -> bool {
if token_bytes.is_empty() {
return true;
}
if token_bytes.len() > MAX_SIM_BYTES {
return true;
}
let mut sim = SimState {
grammar: &self.grammar,
depth: 0,
};
sim.simulate_bytes(&self.continuation, token_bytes)
}
pub fn advance(&mut self, token_bytes: &[u8]) -> GrammarResult<()> {
if token_bytes.is_empty() {
return Ok(());
}
let mut sim = SimState {
grammar: &self.grammar,
depth: 0,
};
let new_cont = sim.advance_bytes(&self.continuation, token_bytes)?;
self.continuation = new_cont;
Ok(())
}
fn can_match_empty_continuation(&self, cont: &[ContNode], depth: usize) -> bool {
if depth > MAX_DEPTH {
return false;
}
if cont.is_empty() {
return true;
}
let Some((first, rest)) = cont.split_first() else {
return false;
};
self.node_can_match_empty(&first.node, depth + 1)
&& self.can_match_empty_continuation(rest, depth + 1)
}
fn node_can_match_empty(&self, node: &GrammarNode, depth: usize) -> bool {
if depth > MAX_DEPTH {
return false;
}
match node {
GrammarNode::Literal(bytes) => bytes.is_empty(),
GrammarNode::CharClass { .. } => false,
GrammarNode::RuleRef(name) => {
if let Some(rule_node) = self.grammar.rules.get(name) {
self.node_can_match_empty(rule_node, depth + 1)
} else {
false
}
}
GrammarNode::Sequence(items) => items
.iter()
.all(|n| self.node_can_match_empty(n, depth + 1)),
GrammarNode::Alternation(alts) => {
alts.iter().any(|n| self.node_can_match_empty(n, depth + 1))
}
GrammarNode::Repeat { min, .. } => *min == 0,
}
}
}
struct SimState<'g> {
grammar: &'g Grammar,
depth: usize,
}
impl<'g> SimState<'g> {
fn simulate_bytes(&mut self, cont: &[ContNode], bytes: &[u8]) -> bool {
if bytes.is_empty() {
return true;
}
self.try_consume_byte(cont, bytes[0], &bytes[1..])
}
fn try_consume_byte(&mut self, cont: &[ContNode], b: u8, rest: &[u8]) -> bool {
self.depth += 1;
if self.depth > MAX_DEPTH {
self.depth -= 1;
return true; }
let result = self.try_consume_byte_inner(cont, b, rest);
self.depth -= 1;
result
}
fn try_consume_byte_inner(&mut self, cont: &[ContNode], b: u8, rest: &[u8]) -> bool {
if cont.is_empty() {
return false; }
let Some((first, tail)) = cont.split_first() else {
return false;
};
match &first.node {
GrammarNode::Literal(bytes) => {
if bytes.is_empty() {
self.try_consume_byte(tail, b, rest)
} else if bytes[0] == b {
let remainder = &bytes[1..];
if remainder.is_empty() {
self.simulate_bytes(tail, rest)
} else {
let mut new_cont: Vec<ContNode> = Vec::with_capacity(tail.len() + 1);
new_cont.push(ContNode::new(GrammarNode::Literal(remainder.to_vec())));
new_cont.extend_from_slice(tail);
self.simulate_bytes(&new_cont, rest)
}
} else {
false
}
}
GrammarNode::CharClass { ranges, negated } => {
let in_class = ranges.iter().any(|r| r.contains(b));
let matches = if *negated { !in_class } else { in_class };
if matches {
self.simulate_bytes(tail, rest)
} else {
false
}
}
GrammarNode::RuleRef(name) => {
let rule_node = match self.grammar.rules.get(name) {
Some(n) => n.clone(),
None => return false,
};
let mut new_cont: Vec<ContNode> = Vec::with_capacity(tail.len() + 1);
new_cont.push(ContNode::new(rule_node));
new_cont.extend_from_slice(tail);
self.try_consume_byte(&new_cont, b, rest)
}
GrammarNode::Sequence(items) => {
if items.is_empty() {
self.try_consume_byte(tail, b, rest)
} else {
let mut new_cont: Vec<ContNode> = Vec::with_capacity(items.len() + tail.len());
for item in items {
new_cont.push(ContNode::new(item.clone()));
}
new_cont.extend_from_slice(tail);
self.try_consume_byte(&new_cont, b, rest)
}
}
GrammarNode::Alternation(alts) => {
for alt in alts {
let mut new_cont: Vec<ContNode> = Vec::with_capacity(tail.len() + 1);
new_cont.push(ContNode::new(alt.clone()));
new_cont.extend_from_slice(tail);
if self.try_consume_byte(&new_cont, b, rest) {
return true;
}
}
false
}
GrammarNode::Repeat { node, min, max } => {
let min = *min;
let max = *max;
if min == 0 && self.try_consume_byte(tail, b, rest) {
return true;
}
let can_take_more = max.is_none_or(|m| m > 0);
if can_take_more {
let new_min = min.saturating_sub(1);
let new_max = max.map(|m| m.saturating_sub(1));
let inner = node.as_ref().clone();
let repeat_rest = GrammarNode::Repeat {
node: Box::new(inner.clone()),
min: new_min,
max: new_max,
};
let mut new_cont: Vec<ContNode> = Vec::with_capacity(tail.len() + 2);
new_cont.push(ContNode::new(inner));
new_cont.push(ContNode::new(repeat_rest));
new_cont.extend_from_slice(tail);
if self.try_consume_byte(&new_cont, b, rest) {
return true;
}
}
false
}
}
}
fn advance_bytes(&mut self, cont: &[ContNode], bytes: &[u8]) -> GrammarResult<Vec<ContNode>> {
if bytes.is_empty() {
return Ok(cont.to_vec());
}
self.advance_one_byte(cont, bytes[0], &bytes[1..])
}
fn advance_one_byte(
&mut self,
cont: &[ContNode],
b: u8,
rest: &[u8],
) -> GrammarResult<Vec<ContNode>> {
self.depth += 1;
if self.depth > MAX_DEPTH {
self.depth -= 1;
return Err(GrammarError::RecursionLimit {
rule: "(advance)".to_string(),
});
}
let result = self.advance_one_byte_inner(cont, b, rest);
self.depth -= 1;
result
}
fn advance_one_byte_inner(
&mut self,
cont: &[ContNode],
b: u8,
rest: &[u8],
) -> GrammarResult<Vec<ContNode>> {
if cont.is_empty() {
return Err(GrammarError::Stuck);
}
let (first, tail) = cont.split_first().ok_or(GrammarError::Stuck)?;
match &first.node {
GrammarNode::Literal(bytes) => {
if bytes.is_empty() {
self.advance_one_byte(tail, b, rest)
} else if bytes[0] == b {
let remainder = &bytes[1..];
let mut new_cont: Vec<ContNode> = Vec::new();
if !remainder.is_empty() {
new_cont.push(ContNode::new(GrammarNode::Literal(remainder.to_vec())));
}
new_cont.extend_from_slice(tail);
self.advance_bytes(&new_cont, rest)
} else {
Err(GrammarError::Stuck)
}
}
GrammarNode::CharClass { ranges, negated } => {
let in_class = ranges.iter().any(|r| r.contains(b));
let matches = if *negated { !in_class } else { in_class };
if matches {
self.advance_bytes(tail, rest)
} else {
Err(GrammarError::Stuck)
}
}
GrammarNode::RuleRef(name) => {
let rule_node = self
.grammar
.rules
.get(name)
.ok_or_else(|| GrammarError::UnknownRule { rule: name.clone() })?
.clone();
let mut new_cont: Vec<ContNode> = Vec::with_capacity(tail.len() + 1);
new_cont.push(ContNode::new(rule_node));
new_cont.extend_from_slice(tail);
self.advance_one_byte(&new_cont, b, rest)
}
GrammarNode::Sequence(items) => {
if items.is_empty() {
self.advance_one_byte(tail, b, rest)
} else {
let mut new_cont: Vec<ContNode> = Vec::with_capacity(items.len() + tail.len());
for item in items {
new_cont.push(ContNode::new(item.clone()));
}
new_cont.extend_from_slice(tail);
self.advance_one_byte(&new_cont, b, rest)
}
}
GrammarNode::Alternation(alts) => {
for alt in alts {
let mut new_cont: Vec<ContNode> = Vec::with_capacity(tail.len() + 1);
new_cont.push(ContNode::new(alt.clone()));
new_cont.extend_from_slice(tail);
match self.advance_one_byte(&new_cont, b, rest) {
Ok(c) => return Ok(c),
Err(_) => continue,
}
}
Err(GrammarError::Stuck)
}
GrammarNode::Repeat { node, min, max } => {
let min = *min;
let max = *max;
if min == 0 {
if let Ok(c) = self.advance_one_byte(tail, b, rest) {
return Ok(c);
}
}
let can_take_more = max.is_none_or(|m| m > 0);
if can_take_more {
let new_min = min.saturating_sub(1);
let new_max = max.map(|m| m.saturating_sub(1));
let inner = node.as_ref().clone();
let repeat_rest = GrammarNode::Repeat {
node: Box::new(inner.clone()),
min: new_min,
max: new_max,
};
let mut new_cont: Vec<ContNode> = Vec::with_capacity(tail.len() + 2);
new_cont.push(ContNode::new(inner));
new_cont.push(ContNode::new(repeat_rest));
new_cont.extend_from_slice(tail);
if let Ok(c) = self.advance_one_byte(&new_cont, b, rest) {
return Ok(c);
}
}
Err(GrammarError::Stuck)
}
}
}
}
pub fn apply_grammar_mask(
logits: &mut [f32],
state: &GrammarState,
token_vocab: &[(u32, Vec<u8>)],
) {
for (token_id, token_bytes) in token_vocab {
let id = *token_id as usize;
if id < logits.len() && !state.allows_token(token_bytes) {
logits[id] = f32::NEG_INFINITY;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sampling::grammar::parser::Grammar;
fn make_state(grammar_str: &str) -> (Grammar, GrammarState) {
let g = Grammar::parse(grammar_str).unwrap();
let state = GrammarState::new(g.clone());
(g, state)
}
#[test]
fn test_allows_yes_no() {
let (_g, state) = make_state(r#"root ::= "yes" | "no""#);
assert!(state.allows_token(b"yes"));
assert!(state.allows_token(b"no"));
assert!(!state.allows_token(b"maybe"));
assert!(!state.allows_token(b"yes!"));
}
#[test]
fn test_initial_state_not_complete() {
let (_g, state) = make_state(r#"root ::= "hello""#);
assert!(!state.is_complete());
}
#[test]
fn test_complete_after_full_match() {
let (_, mut state) = make_state(r#"root ::= "hi""#);
state.advance(b"hi").unwrap();
assert!(state.is_complete());
}
#[test]
fn test_partial_literal() {
let (_g, state) = make_state(r#"root ::= "hi""#);
assert!(state.allows_token(b"h"));
assert!(!state.allows_token(b"x"));
}
#[test]
fn test_advance_stuck_returns_error() {
let (_, mut state) = make_state(r#"root ::= "yes""#);
let result = state.advance(b"no");
assert!(result.is_err());
}
#[test]
fn test_char_class() {
let (_g, state) = make_state("root ::= [a-z]+");
assert!(state.allows_token(b"hello"));
assert!(!state.allows_token(b"Hello")); assert!(!state.allows_token(b"123"));
}
#[test]
fn test_optional() {
let (_g, state) = make_state(r#"root ::= "a"? "b""#);
assert!(state.allows_token(b"ab"));
assert!(state.allows_token(b"b"));
assert!(!state.allows_token(b"c"));
}
#[test]
fn test_apply_grammar_mask() {
let (_, state) = make_state(r#"root ::= "yes" | "no""#);
let mut logits = vec![1.0f32, 2.0, 3.0, 4.0];
let vocab: Vec<(u32, Vec<u8>)> = vec![
(0, b"maybe".to_vec()),
(1, b"yes".to_vec()),
(2, b"no".to_vec()),
(3, b"nope".to_vec()),
];
apply_grammar_mask(&mut logits, &state, &vocab);
assert_eq!(logits[0], f32::NEG_INFINITY); assert!(logits[1].is_finite()); assert!(logits[2].is_finite()); assert_eq!(logits[3], f32::NEG_INFINITY); }
#[test]
fn test_empty_token_always_allowed() {
let (_g, state) = make_state(r#"root ::= "hello""#);
assert!(state.allows_token(b""));
}
#[test]
fn test_sequence_advance() {
let (_, mut state) = make_state(r#"root ::= "a" "b""#);
assert!(state.allows_token(b"a"));
state.advance(b"a").unwrap();
assert!(state.allows_token(b"b"));
assert!(!state.allows_token(b"a"));
}
#[test]
fn test_advance_through_rule_ref() {
let (_, mut state) = make_state("root ::= greeting\ngreeting ::= \"hi\"");
assert!(
state.allows_token(b"hi"),
"initial state should allow 'hi' via rule ref"
);
state
.advance(b"hi")
.expect("test: advancing 'hi' through rule ref should succeed");
assert!(
state.is_complete(),
"state should be complete after consuming all expected bytes"
);
}
#[test]
fn test_rule_ref_allows_correct_bytes() {
let (_g, state) = make_state("root ::= num\nnum ::= [0-9]+");
assert!(
state.allows_token(b"42"),
"rule ref should allow valid bytes"
);
assert!(
!state.allows_token(b"abc"),
"rule ref should reject invalid bytes"
);
}
#[test]
fn test_advance_negated_char_class() {
let (_, mut state) = make_state("root ::= [^0-9]");
assert!(
state.allows_token(b"a"),
"non-digit should be allowed by [^0-9]"
);
assert!(
!state.allows_token(b"5"),
"digit should not be allowed by [^0-9]"
);
state
.advance(b"a")
.expect("test: advancing a non-digit should succeed");
assert!(
state.is_complete(),
"should be complete after consuming one [^0-9] char"
);
}
#[test]
fn test_advance_negated_char_class_rejects_digit() {
let (_, mut state) = make_state("root ::= [^0-9]");
let result = state.advance(b"3");
assert!(
result.is_err(),
"advancing a digit into [^0-9] should return Stuck error"
);
}
#[test]
fn test_is_complete_on_optional_grammar() {
let (_g, state) = make_state(r#"root ::= "a"?"#);
assert!(
state.is_complete(),
"optional grammar should be complete in initial state"
);
}
#[test]
fn test_is_complete_on_star_grammar() {
let (_g, state) = make_state(r#"root ::= "a"*"#);
assert!(
state.is_complete(),
"star grammar should be complete in initial state"
);
}
#[test]
fn test_is_not_complete_on_plus_grammar() {
let (_g, state) = make_state(r#"root ::= "a"+"#);
assert!(
!state.is_complete(),
"plus grammar should NOT be complete in initial state"
);
}
#[test]
fn test_allows_very_long_token_conservatively() {
let (_g, state) = make_state(r#"root ::= "x""#);
let long_token: Vec<u8> = vec![b'z'; 65]; assert!(
state.allows_token(&long_token),
"tokens >64 bytes should be conservatively allowed"
);
}
#[test]
fn test_advance_empty_bytes_is_noop() {
let (_, mut state) = make_state(r#"root ::= "hello""#);
state
.advance(b"")
.expect("test: advancing empty bytes should succeed");
assert!(
!state.is_complete(),
"state should not be complete after empty advance"
);
assert!(
state.allows_token(b"hello"),
"should still allow 'hello' after empty advance"
);
}
#[test]
fn test_apply_grammar_mask_empty_vocab() {
let (_, state) = make_state(r#"root ::= "abc""#);
let mut logits = vec![1.0f32, 2.0, 3.0];
apply_grammar_mask(&mut logits, &state, &[]);
assert_eq!(logits, vec![1.0f32, 2.0, 3.0]);
}
#[test]
fn test_apply_grammar_mask_token_id_beyond_logit_len() {
let (_, state) = make_state(r#"root ::= "yes""#);
let mut logits = vec![1.0f32, 2.0]; let vocab: Vec<(u32, Vec<u8>)> = vec![
(0, b"yes".to_vec()),
(5, b"no".to_vec()), ];
apply_grammar_mask(&mut logits, &state, &vocab);
assert!(logits[0].is_finite(), "allowed token should not be masked");
assert!(logits[1].is_finite(), "untouched logit should stay finite");
}
#[test]
fn test_initial_state_via_grammar_method() {
let g = Grammar::parse(r#"root ::= "ok""#).expect("test: should parse");
let state = g.initial_state();
assert!(
state.allows_token(b"ok"),
"initial state should allow matching token"
);
assert!(
!state.allows_token(b"no"),
"initial state should reject non-matching token"
);
}
#[test]
fn test_advance_with_alternation() {
let (_, mut state) = make_state(r#"root ::= "yes" | "no""#);
state
.advance(b"yes")
.expect("test: advancing 'yes' should succeed");
assert!(
state.is_complete(),
"should be complete after consuming full 'yes' literal"
);
}
#[test]
fn test_advance_alternation_second_branch() {
let (_, mut state) = make_state(r#"root ::= "yes" | "no""#);
state
.advance(b"no")
.expect("test: advancing 'no' should succeed");
assert!(
state.is_complete(),
"should be complete after consuming full 'no' literal"
);
}
#[test]
fn test_advance_stuck_on_char_class_mismatch() {
let (_, mut state) = make_state("root ::= [a-z]");
let result = state.advance(b"3");
assert!(
result.is_err(),
"advancing a digit into [a-z] should return Stuck error"
);
}
}