use oxibonsai_runtime::{
constrained_decoding::TokenConstraint,
grammar::{
arithmetic_grammar, json_lite_grammar, simple_ab_grammar, AllowedTokensCache,
GrammarConstraint,
},
};
fn ascii_decode(id: u32) -> Vec<u8> {
if id < 128 {
vec![id as u8]
} else {
vec![]
}
}
fn ascii_constraint(grammar: oxibonsai_runtime::grammar::Grammar) -> GrammarConstraint {
GrammarConstraint::new(grammar, ascii_decode, 128)
}
fn ascii_constraint_with_cap(
grammar: oxibonsai_runtime::grammar::Grammar,
capacity: usize,
) -> GrammarConstraint {
GrammarConstraint::with_cache_capacity(grammar, ascii_decode, 128, capacity)
}
#[test]
fn cache_hit_on_repeated_state() {
let c = ascii_constraint(arithmetic_grammar());
let mask1 = c.allowed_tokens(&[], 128).unwrap();
let mask2 = c.allowed_tokens(&[], 128).unwrap();
assert_eq!(mask1, mask2, "repeated call must return identical mask");
let (hits, misses) = c.cache_stats();
assert_eq!(misses, 1, "should be exactly one miss (first call)");
assert_eq!(hits, 1, "should be exactly one hit (second call)");
}
#[test]
fn cache_miss_on_different_states() {
let mut c = ascii_constraint(arithmetic_grammar());
c.allowed_tokens(&[], 128).unwrap();
let (_, misses_before) = c.cache_stats();
assert_eq!(misses_before, 1);
c.advance(b'1' as u32);
c.allowed_tokens(&[b'1' as u32], 128).unwrap();
let (_, misses_after) = c.cache_stats();
assert_eq!(misses_after, 2, "new state must produce a second miss");
}
#[test]
fn lru_eviction_at_capacity() {
let mut c = ascii_constraint_with_cap(arithmetic_grammar(), 2);
c.allowed_tokens(&[], 128).unwrap(); c.advance(b'1' as u32);
c.allowed_tokens(&[], 128).unwrap(); c.advance(b'+' as u32);
c.allowed_tokens(&[], 128).unwrap();
let (_, misses) = c.cache_stats();
assert_eq!(misses, 3, "each of the 3 distinct states is a miss");
}
#[test]
fn mask_correctness_arithmetic() {
let cached = ascii_constraint(arithmetic_grammar());
let uncached = ascii_constraint(arithmetic_grammar());
let mask_cached = cached.allowed_tokens(&[], 128).unwrap();
let mask_uncached = uncached.allowed_tokens(&[], 128).unwrap();
assert_eq!(
mask_cached, mask_uncached,
"cached and uncached masks must be identical at initial state"
);
}
#[test]
fn mask_correctness_vs_uncached_multi_step() {
let mut cached = ascii_constraint(arithmetic_grammar());
let mut reference = ascii_constraint(arithmetic_grammar());
let tokens_to_advance: &[u8] = b"1+2";
for &b in tokens_to_advance {
let m_cached = cached.allowed_tokens(&[], 128).unwrap();
let m_ref = reference.allowed_tokens(&[], 128).unwrap();
assert_eq!(
m_cached, m_ref,
"masks differ before advancing '{}'",
b as char
);
cached.advance(b as u32);
reference.advance(b as u32);
}
let m_cached = cached.allowed_tokens(&[], 128).unwrap();
let m_ref = reference.allowed_tokens(&[], 128).unwrap();
assert_eq!(m_cached, m_ref, "masks differ at final state");
}
#[test]
fn grammar_cache_ab_grammar() {
let mut cached = ascii_constraint(simple_ab_grammar());
let mut reference = ascii_constraint(simple_ab_grammar());
for &b in b"aabb" {
let m1 = cached.allowed_tokens(&[], 128).unwrap();
let m2 = reference.allowed_tokens(&[], 128).unwrap();
assert_eq!(
m1, m2,
"ab grammar masks differ before advancing '{}'",
b as char
);
cached.advance(b as u32);
reference.advance(b as u32);
}
}
#[test]
fn grammar_cache_json_lite() {
let cached = ascii_constraint(json_lite_grammar());
let reference = ascii_constraint(json_lite_grammar());
let m1 = cached.allowed_tokens(&[], 128).unwrap();
let m2 = reference.allowed_tokens(&[], 128).unwrap();
assert_eq!(m1, m2, "json_lite initial masks must match");
assert!(
m1[b'{' as usize],
"{{}} must be allowed at start of json_lite"
);
assert!(
m1[b'[' as usize],
"[] must be allowed at start of json_lite"
);
assert!(
m1[b'"' as usize],
"quote must be allowed at start of json_lite"
);
}
#[test]
fn cache_does_not_affect_advance() {
let mut c = ascii_constraint(arithmetic_grammar());
let mask_before = c.allowed_tokens(&[], 128).unwrap();
assert!(mask_before[b'1' as usize]);
assert!(c.advance(b'1' as u32), "advancing '1' should succeed");
assert!(c.is_complete(), "single digit is a complete expression");
let mask_after = c.allowed_tokens(&[], 128).unwrap();
assert!(mask_after[b'+' as usize], "'+' should be allowed after '1'");
assert!(mask_after[b'-' as usize], "'-' should be allowed after '1'");
assert!(
c.advance(b'+' as u32),
"advancing '+' after '1' should succeed"
);
assert!(!c.is_complete(), "incomplete after '1+'");
}
#[test]
fn cache_stats_hit_count() {
let mut c = ascii_constraint(arithmetic_grammar());
c.allowed_tokens(&[], 128).unwrap();
let (h, m) = c.cache_stats();
assert_eq!(h, 0);
assert_eq!(m, 1);
c.allowed_tokens(&[], 128).unwrap();
let (h, m) = c.cache_stats();
assert_eq!(h, 1);
assert_eq!(m, 1);
c.allowed_tokens(&[], 128).unwrap();
let (h, m) = c.cache_stats();
assert_eq!(h, 2);
assert_eq!(m, 1);
c.advance(b'5' as u32);
c.allowed_tokens(&[], 128).unwrap();
let (h, m) = c.cache_stats();
assert_eq!(h, 2);
assert_eq!(m, 2);
c.allowed_tokens(&[], 128).unwrap();
let (h, m) = c.cache_stats();
assert_eq!(h, 3);
assert_eq!(m, 2);
}
#[test]
fn cache_thread_safety() {
use std::sync::{Arc, Mutex};
use std::thread;
let constraint = Arc::new(Mutex::new(ascii_constraint(arithmetic_grammar())));
let handle = {
let constraint = Arc::clone(&constraint);
thread::spawn(move || {
let c = constraint.lock().unwrap_or_else(|e| e.into_inner());
let mask = c.allowed_tokens(&[], 128).unwrap();
for d in b'0'..=b'9' {
assert!(mask[d as usize], "digit {d} should be allowed");
}
let (_, misses) = c.cache_stats();
misses
})
};
let misses_from_thread = handle.join().expect("thread should not panic");
assert_eq!(misses_from_thread, 1, "one miss in the spawned thread");
let c = constraint.lock().unwrap_or_else(|e| e.into_inner());
c.allowed_tokens(&[], 128).unwrap();
let (hits, _) = c.cache_stats();
assert_eq!(hits, 1, "main thread call is a cache hit");
}
#[test]
fn cache_direct_lru_stress() {
let cap = 8usize;
let mut cache = AllowedTokensCache::with_capacity(cap);
for i in 0u64..32 {
cache.insert(i, vec![true; 16]);
assert!(
cache.len() <= cap,
"cache len {} exceeded capacity {cap}",
cache.len()
);
}
assert_eq!(cache.len(), cap);
for i in 24u64..32 {
assert!(
cache.get(i).is_some(),
"key {i} should still be present in LRU"
);
}
}
#[test]
fn with_cache_capacity_constructor() {
let mut c = ascii_constraint_with_cap(arithmetic_grammar(), 1);
c.allowed_tokens(&[], 128).unwrap();
c.allowed_tokens(&[], 128).unwrap();
let (h, m) = c.cache_stats();
assert_eq!(h, 1, "second call at state 0 must be a hit");
assert_eq!(m, 1);
c.advance(b'3' as u32);
c.allowed_tokens(&[], 128).unwrap();
let (h, m) = c.cache_stats();
assert_eq!(m, 2);
assert_eq!(h, 1);
assert_eq!(h + m, 3, "total calls must equal hits + misses");
}
#[test]
fn test_byte_index_skips_non_matching_tokens() {
let c = ascii_constraint(arithmetic_grammar());
let mask = c.allowed_tokens(&[], 128).unwrap();
for d in b'0'..=b'9' {
assert!(
mask[d as usize],
"digit token '{}' (0x{:02x}) should be allowed at start",
d as char, d
);
}
assert!(mask[b'(' as usize], "'(' should be allowed at start");
let non_start_bytes: &[u8] = b"+-*/) az\n";
for &b in non_start_bytes {
assert!(
!mask[b as usize],
"byte '{}' (0x{:02x}) should not be allowed at start of arithmetic",
b as char, b
);
}
}
#[test]
fn test_precomputed_bytes_match_decode_fn() {
use oxibonsai_runtime::grammar::GrammarConstraint;
let decode_fn_direct = |id: u32| -> Vec<u8> {
if id < 128 {
vec![id as u8]
} else {
vec![]
}
};
let c = GrammarConstraint::new(arithmetic_grammar(), ascii_decode, 128);
assert_eq!(c.vocab_size(), 128);
for id in 0u32..128 {
let direct = decode_fn_direct(id);
let mask = c.allowed_tokens(&[], 128).unwrap();
let allowed = mask[id as usize];
let first_byte_opt = direct.first().copied();
let should_be_allowed = match first_byte_opt {
Some(b) => b.is_ascii_digit() || b == b'(',
None => false, };
assert_eq!(
allowed, should_be_allowed,
"token {id}: allowed={allowed} but expected={should_be_allowed} \
(first_byte={first_byte_opt:?})"
);
}
}
#[test]
fn test_index_memory_usage_nonzero() {
use oxibonsai_runtime::grammar::GrammarConstraint;
let c = GrammarConstraint::new(arithmetic_grammar(), ascii_decode, 128);
let mem = c.index_memory_bytes();
assert!(
mem > 0,
"index_memory_bytes() must be > 0 for vocab_size=128, got {mem}"
);
assert!(
mem >= 9344,
"index_memory_bytes() = {mem} is below the expected lower bound of 9344"
);
}
#[test]
fn test_vocab_size_accessor() {
use oxibonsai_runtime::grammar::GrammarConstraint;
let c1 = GrammarConstraint::new(arithmetic_grammar(), ascii_decode, 128);
assert_eq!(c1.vocab_size(), 128);
let c2 = GrammarConstraint::new(arithmetic_grammar(), ascii_decode, 4096);
assert_eq!(c2.vocab_size(), 4096);
let c3 = GrammarConstraint::with_cache_capacity(arithmetic_grammar(), ascii_decode, 512, 64);
assert_eq!(c3.vocab_size(), 512);
let c4 = GrammarConstraint::new(arithmetic_grammar(), ascii_decode, 0);
assert_eq!(c4.vocab_size(), 0);
}
#[test]
fn test_advance_uses_cached_bytes() {
let mut c1 = ascii_constraint(arithmetic_grammar());
let mut c2 = ascii_constraint(arithmetic_grammar());
let sequence: &[u8] = b"1+2*3";
for (i, &byte) in sequence.iter().enumerate() {
let ok1 = c1.advance(byte as u32);
let ok2 = c2.advance(byte as u32);
assert_eq!(
ok1, ok2,
"advance({}) disagreed at step {i}: c1={ok1} c2={ok2}",
byte as char
);
assert_eq!(
c1.is_complete(),
c2.is_complete(),
"is_complete() disagreed after step {i}"
);
assert_eq!(
c1.bytes_consumed(),
c2.bytes_consumed(),
"bytes_consumed() disagreed after step {i}"
);
assert_eq!(
c1.next_byte_set(),
c2.next_byte_set(),
"next_byte_set() disagreed after step {i}"
);
}
assert!(c1.is_complete(), "c1 should be complete after '1+2*3'");
assert!(c2.is_complete(), "c2 should be complete after '1+2*3'");
assert_eq!(c1.bytes_consumed(), 5);
assert_eq!(c2.bytes_consumed(), 5);
}