use std::sync::{Arc, Mutex};
use super::ast::Grammar;
use super::cache::AllowedTokensCache;
use super::earley::EarleyRecognizer;
use crate::constrained_decoding::TokenConstraint;
pub struct GrammarConstraint {
#[allow(dead_code)]
grammar: Arc<Grammar>,
recognizer: EarleyRecognizer,
#[allow(dead_code)]
tokenizer_decode_fn: Arc<dyn Fn(u32) -> Vec<u8> + Send + Sync>,
vocab_size: usize,
cache: Mutex<AllowedTokensCache>,
token_bytes: Vec<Vec<u8>>,
first_byte_index: Box<[Vec<u32>; 256]>,
empty_token_ids: Vec<u32>,
}
type FirstByteIndex = Box<[Vec<u32>; 256]>;
struct TokenIndex {
token_bytes: Vec<Vec<u8>>,
first_byte_index: FirstByteIndex,
empty_token_ids: Vec<u32>,
}
fn build_token_index(decode_fn: &dyn Fn(u32) -> Vec<u8>, vocab_size: usize) -> TokenIndex {
let mut token_bytes: Vec<Vec<u8>> = Vec::with_capacity(vocab_size);
let mut raw_index: Vec<Vec<u32>> = (0..256_usize).map(|_| Vec::new()).collect();
let mut empty_token_ids: Vec<u32> = Vec::new();
for id in 0..vocab_size as u32 {
let bytes = decode_fn(id);
match bytes.first() {
Some(&b) => raw_index[b as usize].push(id),
None => empty_token_ids.push(id),
}
token_bytes.push(bytes);
}
let first_byte_index: FirstByteIndex = raw_index
.into_boxed_slice()
.try_into()
.expect("raw_index must have exactly 256 elements");
TokenIndex {
token_bytes,
first_byte_index,
empty_token_ids,
}
}
impl GrammarConstraint {
pub fn new(
mut grammar: Grammar,
tokenizer_decode_fn: impl Fn(u32) -> Vec<u8> + Send + Sync + 'static,
vocab_size: usize,
) -> Self {
grammar.normalise_terminals();
let grammar = Arc::new(grammar);
let recognizer = EarleyRecognizer::new(Arc::clone(&grammar));
let tokenizer_decode_fn: Arc<dyn Fn(u32) -> Vec<u8> + Send + Sync> =
Arc::new(tokenizer_decode_fn);
let idx = build_token_index(tokenizer_decode_fn.as_ref(), vocab_size);
Self {
grammar,
recognizer,
tokenizer_decode_fn,
vocab_size,
cache: Mutex::new(AllowedTokensCache::with_capacity(256)),
token_bytes: idx.token_bytes,
first_byte_index: idx.first_byte_index,
empty_token_ids: idx.empty_token_ids,
}
}
pub fn with_cache_capacity(
mut grammar: Grammar,
tokenizer_decode_fn: impl Fn(u32) -> Vec<u8> + Send + Sync + 'static,
vocab_size: usize,
capacity: usize,
) -> Self {
grammar.normalise_terminals();
let grammar = Arc::new(grammar);
let recognizer = EarleyRecognizer::new(Arc::clone(&grammar));
let tokenizer_decode_fn: Arc<dyn Fn(u32) -> Vec<u8> + Send + Sync> =
Arc::new(tokenizer_decode_fn);
let idx = build_token_index(tokenizer_decode_fn.as_ref(), vocab_size);
Self {
grammar,
recognizer,
tokenizer_decode_fn,
vocab_size,
cache: Mutex::new(AllowedTokensCache::with_capacity(capacity)),
token_bytes: idx.token_bytes,
first_byte_index: idx.first_byte_index,
empty_token_ids: idx.empty_token_ids,
}
}
pub fn cache_stats(&self) -> (u64, u64) {
self.cache
.lock()
.map(|c| (c.hits(), c.misses()))
.unwrap_or((0, 0))
}
pub fn bytes_consumed(&self) -> usize {
self.recognizer.input_pos
}
pub fn is_live(&self) -> bool {
self.recognizer.is_live()
}
pub fn next_byte_set(&self) -> std::collections::HashSet<u8> {
self.recognizer.next_byte_set()
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn index_memory_bytes(&self) -> usize {
let token_bytes_mem: usize = self.token_bytes.iter().map(|b| b.len() + 24).sum();
let index_mem: usize = self.first_byte_index.iter().map(|v| v.len() * 4 + 24).sum();
token_bytes_mem + index_mem + self.empty_token_ids.len() * 4
}
}
impl TokenConstraint for GrammarConstraint {
fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
if !self.recognizer.is_live() {
return Some(vec![false; vocab_size]);
}
let nbs = self.recognizer.next_byte_set();
let currently_accepting = self.recognizer.is_accepting();
if nbs.is_empty() && !currently_accepting {
return Some(vec![false; vocab_size]);
}
let state_hash = self.recognizer.state_hash();
if let Ok(mut cache) = self.cache.lock() {
if let Some(cached) = cache.get(state_hash) {
return Some(cached.to_vec());
}
}
let mut mask = vec![false; vocab_size];
if currently_accepting {
for &id in &self.empty_token_ids {
if (id as usize) < vocab_size {
mask[id as usize] = true;
}
}
}
for &first_byte in &nbs {
for &token_id in &self.first_byte_index[first_byte as usize] {
let token_idx = token_id as usize;
if token_idx >= vocab_size {
continue;
}
let bytes = &self.token_bytes[token_idx];
if bytes.is_empty() {
if currently_accepting {
mask[token_idx] = true;
}
continue;
}
let mut probe = self.recognizer.clone_state();
let mut ok = true;
for &b in bytes {
if !probe.feed_byte(b) {
ok = false;
break;
}
}
if ok {
mask[token_idx] = true;
}
}
}
if let Ok(mut cache) = self.cache.lock() {
cache.insert(state_hash, mask.clone());
}
Some(mask)
}
fn advance(&mut self, token: u32) -> bool {
let Some(bytes) = self.token_bytes.get(token as usize) else {
return self.recognizer.is_accepting();
};
if bytes.is_empty() {
return self.recognizer.is_accepting();
}
for &b in bytes {
if !self.recognizer.feed_byte(b) {
return false;
}
}
true
}
fn is_complete(&self) -> bool {
self.recognizer.is_accepting()
}
fn reset(&mut self) {
self.recognizer.reset();
}
fn name(&self) -> &str {
"GrammarConstraint"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constrained_decoding::TokenConstraint;
use crate::grammar::{arithmetic_grammar, csv_row_grammar, simple_ab_grammar};
fn ascii_constraint(grammar: Grammar) -> GrammarConstraint {
GrammarConstraint::new(
grammar,
|id| {
if id < 128 {
vec![id as u8]
} else {
vec![]
}
},
128,
)
}
#[test]
fn grammar_constraint_name() {
let c = ascii_constraint(arithmetic_grammar());
assert_eq!(c.name(), "GrammarConstraint");
}
#[test]
fn grammar_constraint_not_complete_initially() {
let c = ascii_constraint(arithmetic_grammar());
assert!(!c.is_complete());
}
#[test]
fn grammar_constraint_arithmetic_allows_digits_at_start() {
let c = ascii_constraint(arithmetic_grammar());
let mask = c.allowed_tokens(&[], 128).unwrap();
for d in b'0'..=b'9' {
assert!(mask[d as usize], "digit {d} should be allowed at start");
}
assert!(mask[b'(' as usize], "'(' should be allowed at start");
assert!(!mask[b'+' as usize], "'+' should not be allowed at start");
}
#[test]
fn grammar_constraint_advance_digit_and_operator() {
let mut c = ascii_constraint(arithmetic_grammar());
assert!(c.advance(b'1' as u32), "advancing '1' should succeed");
assert!(
c.advance(b'+' as u32),
"advancing '+' after '1' should succeed"
);
}
#[test]
fn grammar_constraint_advance_violation() {
let mut c = ascii_constraint(arithmetic_grammar());
let ok = c.advance(b'+' as u32);
assert!(!ok, "'+' at start should be rejected");
}
#[test]
fn grammar_constraint_complete_after_full_expression() {
let mut c = ascii_constraint(arithmetic_grammar());
c.advance(b'1' as u32);
assert!(c.is_complete(), "single digit is a complete expression");
}
#[test]
fn grammar_constraint_not_complete_after_operator() {
let mut c = ascii_constraint(arithmetic_grammar());
c.advance(b'1' as u32);
c.advance(b'+' as u32);
assert!(!c.is_complete(), "after '1+' the expression is incomplete");
}
#[test]
fn grammar_constraint_reset() {
let mut c = ascii_constraint(arithmetic_grammar());
c.advance(b'5' as u32);
assert!(c.is_complete());
c.reset();
assert!(!c.is_complete());
assert_eq!(c.bytes_consumed(), 0);
}
#[test]
fn grammar_constraint_full_sequence_1plus2() {
let mut c = ascii_constraint(arithmetic_grammar());
assert!(c.advance(b'1' as u32));
assert!(c.is_complete());
assert!(c.advance(b'+' as u32));
assert!(!c.is_complete());
assert!(c.advance(b'2' as u32));
assert!(c.is_complete());
}
#[test]
fn grammar_constraint_disallows_after_rejection() {
let mut c = ascii_constraint(arithmetic_grammar());
let ok = c.advance(b'+' as u32);
if !ok {
let mask = c.allowed_tokens(&[], 128).unwrap();
assert!(
mask.iter().all(|&b| !b),
"all tokens should be blocked after rejection"
);
}
}
#[test]
fn grammar_constraint_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<GrammarConstraint>();
}
#[test]
fn grammar_constraint_ab_sequence() {
let mut c = ascii_constraint(simple_ab_grammar());
assert!(c.advance(b'a' as u32));
assert!(!c.is_complete(), "after 'a' not yet complete");
assert!(c.advance(b'b' as u32));
assert!(c.is_complete(), "after 'ab' should be complete");
}
#[test]
fn grammar_constraint_ab_sequence_longer() {
let mut c = ascii_constraint(simple_ab_grammar());
assert!(c.advance(b'a' as u32));
assert!(c.advance(b'a' as u32));
assert!(c.advance(b'b' as u32));
assert!(c.advance(b'b' as u32));
assert!(c.is_complete());
}
#[test]
fn grammar_constraint_csv_row() {
let mut c = ascii_constraint(csv_row_grammar());
for b in b"a,b" {
assert!(c.advance(*b as u32), "byte {b} should be accepted");
}
assert!(c.is_complete());
}
#[test]
fn grammar_constraint_csv_row_single_field() {
let mut c = ascii_constraint(csv_row_grammar());
for b in b"hello" {
assert!(c.advance(*b as u32));
}
assert!(c.is_complete());
}
#[test]
fn grammar_constraint_implements_token_constraint_trait() {
let c: Box<dyn TokenConstraint> = Box::new(ascii_constraint(arithmetic_grammar()));
assert_eq!(c.name(), "GrammarConstraint");
assert!(!c.is_complete());
}
#[test]
fn grammar_constraint_empty_token_only_when_accepting() {
let g = arithmetic_grammar();
let c = GrammarConstraint::new(
g,
|id| {
if id < 128 {
vec![id as u8]
} else {
vec![] }
},
201,
);
let mask = c.allowed_tokens(&[], 201).unwrap();
assert!(
!mask[200],
"EOS token should not be allowed when not accepting"
);
}
#[test]
fn grammar_constraint_empty_token_allowed_when_accepting() {
let g = arithmetic_grammar();
let mut c = GrammarConstraint::new(
g,
|id| {
if id < 128 {
vec![id as u8]
} else {
vec![] }
},
201,
);
c.advance(b'9' as u32);
assert!(c.is_complete());
let mask = c.allowed_tokens(&[], 201).unwrap();
assert!(mask[200], "EOS token should be allowed when accepting");
}
#[test]
fn grammar_constraint_vocab_size_accessor() {
let c = ascii_constraint(arithmetic_grammar());
assert_eq!(c.vocab_size(), 128);
let c2 = GrammarConstraint::new(arithmetic_grammar(), |id| vec![id as u8], 512);
assert_eq!(c2.vocab_size(), 512);
}
#[test]
fn grammar_constraint_index_memory_nonzero() {
let c = ascii_constraint(arithmetic_grammar());
assert!(
c.index_memory_bytes() > 0,
"index_memory_bytes must be > 0 for vocab_size > 0"
);
}
#[test]
fn grammar_constraint_index_memory_zero_vocab() {
let c = GrammarConstraint::new(arithmetic_grammar(), |_id| vec![], 0);
assert_eq!(c.index_memory_bytes(), 256 * 24);
}
#[test]
fn grammar_constraint_digits_allowed_at_start_via_index() {
let c = ascii_constraint(arithmetic_grammar());
let mask = c.allowed_tokens(&[], 128).unwrap();
for d in b'0'..=b'9' {
assert!(
mask[d as usize],
"digit token {} should be allowed at start",
d as char
);
}
assert!(mask[b'(' as usize], "'(' should be allowed at start");
assert!(!mask[b'+' as usize], "'+' not valid at start");
assert!(!mask[b' ' as usize], "space not valid at start");
assert!(!mask[b'z' as usize], "'z' not valid at start");
}
#[test]
fn grammar_constraint_advance_uses_cached_bytes() {
let mut c = ascii_constraint(arithmetic_grammar());
assert!(c.advance(b'1' as u32), "'1' should advance");
assert!(c.is_complete(), "single digit is complete");
assert!(c.advance(b'+' as u32), "'+' should advance after digit");
assert!(!c.is_complete(), "incomplete after '+'");
assert!(c.advance(b'2' as u32), "'2' should advance");
assert!(c.is_complete(), "'1+2' is a complete expression");
assert_eq!(c.bytes_consumed(), 3, "3 bytes should have been consumed");
}
#[test]
fn grammar_constraint_advance_out_of_range_token() {
let c = ascii_constraint(arithmetic_grammar());
let mut c_mut = ascii_constraint(arithmetic_grammar());
let ok = c_mut.advance(999); assert!(
!ok,
"out-of-range token should return false when not accepting"
);
drop(c);
let mut c2 = ascii_constraint(arithmetic_grammar());
c2.advance(b'5' as u32); assert!(c2.is_complete());
let ok2 = c2.advance(999);
assert!(ok2, "out-of-range token should return true when accepting");
}
#[test]
fn grammar_constraint_precomputed_bytes_match_decode_fn() {
let decode_fn = |id: u32| -> Vec<u8> {
if id < 128 {
vec![id as u8]
} else {
vec![]
}
};
let c = GrammarConstraint::new(arithmetic_grammar(), decode_fn, 128);
for id in 0u32..128 {
let precomputed = &c.token_bytes[id as usize];
let direct = if id < 128 { vec![id as u8] } else { vec![] };
assert_eq!(
precomputed, &direct,
"precomputed bytes for token {id} must match direct decode"
);
}
}
}