oxibonsai_runtime/grammar/constraint.rs
1//! [`GrammarConstraint`] — implements [`TokenConstraint`] using the Earley
2//! chart-parser recognizer backed by a BNF context-free grammar.
3//!
4//! The `allowed_tokens` method speculatively feeds each token's byte sequence
5//! through a **clone** of the current recognizer state and marks the token
6//! allowed if and only if none of the bytes are rejected.
7//!
8//! **Phase 16B optimization:** Token byte sequences are precomputed once during
9//! construction via `tokenizer_decode_fn` and stored in `token_bytes: Vec<Vec<u8>>`.
10//! A `first_byte_index: Box<[Vec<u32>; 256]>` maps each first byte to the list of
11//! token IDs that start with that byte. During `allowed_tokens`, only tokens whose
12//! first byte is in `next_byte_set` are probed — all others are skipped without
13//! invoking the decode function or any recognizer cloning.
14//!
15//! This reduces the per-step work from O(vocab) decode calls + O(vocab) first-byte
16//! checks + O(filtered_vocab × token_len) recognizer probes to just
17//! O(|next_byte_set| × avg_matching_tokens × avg_token_len) recognizer probes.
18
19use std::sync::{Arc, Mutex};
20
21use super::ast::Grammar;
22use super::cache::AllowedTokensCache;
23use super::earley::EarleyRecognizer;
24use crate::constrained_decoding::TokenConstraint;
25
26// ─────────────────────────────────────────────────────────────────────────────
27// GrammarConstraint
28// ─────────────────────────────────────────────────────────────────────────────
29
30/// A [`TokenConstraint`] that enforces a context-free grammar on the generated
31/// byte stream, using the Earley chart-parser as the underlying recognizer.
32///
33/// # Construction
34///
35/// ```rust,no_run
36/// use oxibonsai_runtime::grammar::{arithmetic_grammar, GrammarConstraint};
37///
38/// let grammar = arithmetic_grammar();
39/// // Map each token id to its byte sequence; single-byte ASCII vocab here.
40/// let decode_fn = |token_id: u32| -> Vec<u8> {
41/// if token_id < 128 { vec![token_id as u8] } else { vec![] }
42/// };
43/// let constraint = GrammarConstraint::new(grammar, decode_fn, 128);
44/// ```
45///
46/// # Token decode function
47///
48/// The `tokenizer_decode_fn` maps a token id to the **byte sequence** it
49/// represents. For an ASCII byte-level vocabulary it is simply
50/// `|id| vec![id as u8]`. For a real LLM tokenizer it should call into
51/// `tokenizer.id_to_bytes(id)`. Unknown / special tokens can return an empty
52/// `Vec<u8>`; they will be allowed iff the current recognizer state is
53/// accepting (which allows a graceful end-of-sequence).
54///
55/// # Phase 16B: Precomputed byte index
56///
57/// At construction time, `GrammarConstraint` eagerly calls `tokenizer_decode_fn`
58/// for every token ID in `0..vocab_size`, storing the results in `token_bytes`.
59/// Simultaneously, `first_byte_index[b]` accumulates the list of token IDs whose
60/// first byte is `b`, and `empty_token_ids` collects IDs with empty byte sequences
61/// (EOS, padding, special tokens).
62///
63/// This eliminates O(vocab) decode calls during each `allowed_tokens` call and
64/// allows the inner loop to skip entire byte classes not present in
65/// `next_byte_set` — often reducing the probed token count by 90–99 %.
66pub struct GrammarConstraint {
67 /// Original grammar (kept for potential future reset/inspection).
68 #[allow(dead_code)]
69 grammar: Arc<Grammar>,
70 /// Live Earley recognizer tracking the bytes generated so far.
71 recognizer: EarleyRecognizer,
72 /// Decodes a token id to its raw byte sequence.
73 ///
74 /// Retained for potential out-of-range token handling or future callers that
75 /// need to decode tokens not covered by the initial `0..vocab_size` range.
76 #[allow(dead_code)]
77 tokenizer_decode_fn: Arc<dyn Fn(u32) -> Vec<u8> + Send + Sync>,
78 /// Total vocabulary size used to allocate the precomputed index.
79 vocab_size: usize,
80 /// LRU memoization cache for `allowed_tokens` results keyed by Earley state hash.
81 ///
82 /// Wrapped in `Mutex` because `TokenConstraint::allowed_tokens` takes `&self`,
83 /// yet cache mutation requires `&mut`. `Mutex::lock()` returning `PoisonError`
84 /// on panic is handled gracefully: cache misses are silent (never panics).
85 cache: Mutex<AllowedTokensCache>,
86
87 // ── Phase 16B: Precomputed token index ──────────────────────────────────
88 /// Precomputed byte sequences for every token in `0..vocab_size`.
89 ///
90 /// `token_bytes[id]` is the byte sequence for token `id`, precomputed once
91 /// at construction time. This is the primary data consumed by `allowed_tokens`
92 /// and `advance`.
93 token_bytes: Vec<Vec<u8>>,
94
95 /// First-byte index: `first_byte_index[b]` is the list of token IDs
96 /// (in `0..vocab_size`) whose first byte equals `b`.
97 ///
98 /// Boxed to avoid stack-allocating 256 `Vec<u32>`s (which may trigger a
99 /// stack overflow for large vectors on some platforms).
100 first_byte_index: Box<[Vec<u32>; 256]>,
101
102 /// Token IDs (in `0..vocab_size`) whose byte sequence is empty.
103 ///
104 /// These represent EOS tokens, padding tokens, and other special tokens
105 /// that do not contribute bytes to the grammar stream. They are allowed
106 /// only when the recognizer is in an accepting state.
107 empty_token_ids: Vec<u32>,
108}
109
110// ─────────────────────────────────────────────────────────────────────────────
111// Private construction helper
112// ─────────────────────────────────────────────────────────────────────────────
113
114/// Type alias for the first-byte index (avoids clippy::type_complexity).
115type FirstByteIndex = Box<[Vec<u32>; 256]>;
116
117/// Aggregate result of `build_token_index`.
118struct TokenIndex {
119 token_bytes: Vec<Vec<u8>>,
120 first_byte_index: FirstByteIndex,
121 empty_token_ids: Vec<u32>,
122}
123
124/// Build the three precomputed structures from a decode function and vocab size.
125fn build_token_index(decode_fn: &dyn Fn(u32) -> Vec<u8>, vocab_size: usize) -> TokenIndex {
126 let mut token_bytes: Vec<Vec<u8>> = Vec::with_capacity(vocab_size);
127
128 // Use a Vec<Vec<u32>> of length 256 to avoid constructing 256 Vecs on the
129 // stack before boxing — the std::array::from_fn approach would stack-allocate
130 // [Vec<u32>; 256] = ~3 KB, which is fine, but building it element-by-element
131 // via a Vec before converting avoids any platform-specific stack pressure.
132 let mut raw_index: Vec<Vec<u32>> = (0..256_usize).map(|_| Vec::new()).collect();
133 let mut empty_token_ids: Vec<u32> = Vec::new();
134
135 for id in 0..vocab_size as u32 {
136 let bytes = decode_fn(id);
137 match bytes.first() {
138 Some(&b) => raw_index[b as usize].push(id),
139 None => empty_token_ids.push(id),
140 }
141 token_bytes.push(bytes);
142 }
143
144 // Convert Vec<Vec<u32>> (length 256) into Box<[Vec<u32>; 256]>.
145 // We built `raw_index` with exactly 256 elements, so the try_into cannot fail.
146 let first_byte_index: FirstByteIndex = raw_index
147 .into_boxed_slice()
148 .try_into()
149 .expect("raw_index must have exactly 256 elements");
150
151 TokenIndex {
152 token_bytes,
153 first_byte_index,
154 empty_token_ids,
155 }
156}
157
158// ─────────────────────────────────────────────────────────────────────────────
159// Public API
160// ─────────────────────────────────────────────────────────────────────────────
161
162impl GrammarConstraint {
163 /// Create a new `GrammarConstraint`.
164 ///
165 /// The `grammar` is normalised (multi-byte terminals split into chains)
166 /// and wrapped in an `Arc` before being handed to the recognizer.
167 ///
168 /// **Phase 16B:** This eagerly calls `tokenizer_decode_fn(id)` for every
169 /// `id` in `0..vocab_size`, building `token_bytes` and `first_byte_index`.
170 /// Construction cost is O(vocab_size × avg_decode_cost); subsequent
171 /// `allowed_tokens` calls no longer call the decode function at all.
172 ///
173 /// # Parameters
174 ///
175 /// * `grammar` — the context-free grammar to enforce
176 /// * `tokenizer_decode_fn` — maps token id → byte sequence
177 /// * `vocab_size` — total vocabulary size
178 pub fn new(
179 mut grammar: Grammar,
180 tokenizer_decode_fn: impl Fn(u32) -> Vec<u8> + Send + Sync + 'static,
181 vocab_size: usize,
182 ) -> Self {
183 grammar.normalise_terminals();
184 let grammar = Arc::new(grammar);
185 let recognizer = EarleyRecognizer::new(Arc::clone(&grammar));
186 let tokenizer_decode_fn: Arc<dyn Fn(u32) -> Vec<u8> + Send + Sync> =
187 Arc::new(tokenizer_decode_fn);
188
189 let idx = build_token_index(tokenizer_decode_fn.as_ref(), vocab_size);
190
191 Self {
192 grammar,
193 recognizer,
194 tokenizer_decode_fn,
195 vocab_size,
196 cache: Mutex::new(AllowedTokensCache::with_capacity(256)),
197 token_bytes: idx.token_bytes,
198 first_byte_index: idx.first_byte_index,
199 empty_token_ids: idx.empty_token_ids,
200 }
201 }
202
203 /// Create a new `GrammarConstraint` with a custom cache capacity.
204 ///
205 /// Identical to [`new`](Self::new) except that the LRU cache is initialised
206 /// with `capacity` entries rather than the default 256. Use a larger value
207 /// when the grammar has many distinct parse states; use a smaller value to
208 /// bound memory at the cost of more cache misses.
209 ///
210 /// **Phase 16B:** Same eager precomputation as [`new`](Self::new).
211 ///
212 /// # Parameters
213 ///
214 /// * `grammar` — the context-free grammar to enforce
215 /// * `tokenizer_decode_fn` — maps token id → byte sequence
216 /// * `vocab_size` — total vocabulary size
217 /// * `capacity` — LRU cache capacity (clamped to ≥ 1)
218 pub fn with_cache_capacity(
219 mut grammar: Grammar,
220 tokenizer_decode_fn: impl Fn(u32) -> Vec<u8> + Send + Sync + 'static,
221 vocab_size: usize,
222 capacity: usize,
223 ) -> Self {
224 grammar.normalise_terminals();
225 let grammar = Arc::new(grammar);
226 let recognizer = EarleyRecognizer::new(Arc::clone(&grammar));
227 let tokenizer_decode_fn: Arc<dyn Fn(u32) -> Vec<u8> + Send + Sync> =
228 Arc::new(tokenizer_decode_fn);
229
230 let idx = build_token_index(tokenizer_decode_fn.as_ref(), vocab_size);
231
232 Self {
233 grammar,
234 recognizer,
235 tokenizer_decode_fn,
236 vocab_size,
237 cache: Mutex::new(AllowedTokensCache::with_capacity(capacity)),
238 token_bytes: idx.token_bytes,
239 first_byte_index: idx.first_byte_index,
240 empty_token_ids: idx.empty_token_ids,
241 }
242 }
243
244 /// Return cache hit/miss statistics as `(hits, misses)`.
245 ///
246 /// Useful for testing and for monitoring cache effectiveness in production.
247 /// Returns `(0, 0)` if the internal `Mutex` has been poisoned (never panics).
248 pub fn cache_stats(&self) -> (u64, u64) {
249 self.cache
250 .lock()
251 .map(|c| (c.hits(), c.misses()))
252 .unwrap_or((0, 0))
253 }
254
255 /// Return the current number of bytes consumed by the recognizer.
256 pub fn bytes_consumed(&self) -> usize {
257 self.recognizer.input_pos
258 }
259
260 /// Return `true` if the recognizer is still in a live (non-dead) state.
261 pub fn is_live(&self) -> bool {
262 self.recognizer.is_live()
263 }
264
265 /// Return the set of bytes valid as the next byte in the stream.
266 ///
267 /// This is a low-level utility; prefer `allowed_tokens` for normal use.
268 pub fn next_byte_set(&self) -> std::collections::HashSet<u8> {
269 self.recognizer.next_byte_set()
270 }
271
272 /// Return the vocabulary size passed to the constructor.
273 ///
274 /// This equals `self.token_bytes.len()`.
275 pub fn vocab_size(&self) -> usize {
276 self.vocab_size
277 }
278
279 /// Return an estimate of the heap memory (in bytes) occupied by the
280 /// precomputed token index built during construction.
281 ///
282 /// The estimate accounts for:
283 /// * `token_bytes`: 24-byte `Vec` header + inline byte storage per token.
284 /// * `first_byte_index`: 24-byte `Vec` header + 4-byte u32 per entry,
285 /// for all 256 first-byte buckets.
286 /// * `empty_token_ids`: 4 bytes per entry.
287 ///
288 /// This is a lower bound (does not include allocator overhead or padding).
289 pub fn index_memory_bytes(&self) -> usize {
290 // 24 = size_of::<Vec<u8>>() on 64-bit platforms (ptr + len + cap).
291 let token_bytes_mem: usize = self.token_bytes.iter().map(|b| b.len() + 24).sum();
292 // 24 = size_of::<Vec<u32>>(); 4 = size_of::<u32>().
293 let index_mem: usize = self.first_byte_index.iter().map(|v| v.len() * 4 + 24).sum();
294 token_bytes_mem + index_mem + self.empty_token_ids.len() * 4
295 }
296}
297
298// ─────────────────────────────────────────────────────────────────────────────
299// TokenConstraint implementation
300// ─────────────────────────────────────────────────────────────────────────────
301
302impl TokenConstraint for GrammarConstraint {
303 /// Compute a per-token mask using the precomputed first-byte index.
304 ///
305 /// **Phase 16B algorithm:**
306 ///
307 /// 1. If the recognizer is dead, return all-false immediately.
308 /// 2. Compute `next_byte_set` (NBS) and `is_accepting`.
309 /// 3. If NBS is empty and not accepting, return all-false immediately.
310 /// 4. Check the LRU cache keyed by `state_hash()`.
311 /// 5. On cache miss: start with an all-false mask.
312 /// * For each `first_byte` in NBS, iterate `first_byte_index[first_byte]`
313 /// and probe only those tokens via `recognizer.clone_state()`.
314 /// * For empty-byte tokens (EOS/special), allow them iff `is_accepting`.
315 /// 6. Insert the result into the LRU cache.
316 ///
317 /// The inner loop never calls `tokenizer_decode_fn` — it reads precomputed
318 /// `token_bytes` instead. Tokens whose first byte is NOT in NBS are never
319 /// visited at all.
320 fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
321 // ── Early exits ─────────────────────────────────────────────────────
322 if !self.recognizer.is_live() {
323 return Some(vec![false; vocab_size]);
324 }
325
326 let nbs = self.recognizer.next_byte_set();
327 let currently_accepting = self.recognizer.is_accepting();
328
329 if nbs.is_empty() && !currently_accepting {
330 return Some(vec![false; vocab_size]);
331 }
332
333 // ── Cache lookup ─────────────────────────────────────────────────────
334 let state_hash = self.recognizer.state_hash();
335 if let Ok(mut cache) = self.cache.lock() {
336 if let Some(cached) = cache.get(state_hash) {
337 return Some(cached.to_vec());
338 }
339 }
340
341 // ── Cache miss: build mask using first-byte index ────────────────────
342 let mut mask = vec![false; vocab_size];
343
344 // Empty-byte tokens (EOS, special): allowed only when accepting.
345 if currently_accepting {
346 for &id in &self.empty_token_ids {
347 if (id as usize) < vocab_size {
348 mask[id as usize] = true;
349 }
350 }
351 }
352
353 // Tokens grouped by first byte: iterate only over bytes that are in NBS.
354 for &first_byte in &nbs {
355 for &token_id in &self.first_byte_index[first_byte as usize] {
356 let token_idx = token_id as usize;
357 if token_idx >= vocab_size {
358 continue;
359 }
360 let bytes = &self.token_bytes[token_idx];
361 if bytes.is_empty() {
362 // Should not happen (empties are in empty_token_ids), but
363 // handle defensively.
364 if currently_accepting {
365 mask[token_idx] = true;
366 }
367 continue;
368 }
369 // bytes[0] == first_byte by construction — no need to re-check.
370 // Probe the remaining bytes via a cloned recognizer state.
371 let mut probe = self.recognizer.clone_state();
372 let mut ok = true;
373 for &b in bytes {
374 if !probe.feed_byte(b) {
375 ok = false;
376 break;
377 }
378 }
379 if ok {
380 mask[token_idx] = true;
381 }
382 }
383 }
384
385 // ── Store in cache ───────────────────────────────────────────────────
386 if let Ok(mut cache) = self.cache.lock() {
387 cache.insert(state_hash, mask.clone());
388 }
389
390 Some(mask)
391 }
392
393 /// Commit `token` to the recognizer by feeding its precomputed byte sequence.
394 ///
395 /// Uses the precomputed `token_bytes` slice instead of calling
396 /// `tokenizer_decode_fn`, avoiding one decode call per accepted token.
397 ///
398 /// Returns `false` if any byte in the token's sequence is rejected by the
399 /// grammar, or if the token ID is out of range for the precomputed index.
400 fn advance(&mut self, token: u32) -> bool {
401 let Some(bytes) = self.token_bytes.get(token as usize) else {
402 // Token ID is beyond the precomputed vocab range.
403 // Treat as empty → allowed only if currently accepting.
404 return self.recognizer.is_accepting();
405 };
406 if bytes.is_empty() {
407 return self.recognizer.is_accepting();
408 }
409 for &b in bytes {
410 if !self.recognizer.feed_byte(b) {
411 return false;
412 }
413 }
414 true
415 }
416
417 /// Returns `true` when the recognizer is in an accepting state.
418 fn is_complete(&self) -> bool {
419 self.recognizer.is_accepting()
420 }
421
422 /// Reset the recognizer to the initial state.
423 fn reset(&mut self) {
424 self.recognizer.reset();
425 }
426
427 fn name(&self) -> &str {
428 "GrammarConstraint"
429 }
430}
431
432// ─────────────────────────────────────────────────────────────────────────────
433// Unit tests
434// ─────────────────────────────────────────────────────────────────────────────
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use crate::constrained_decoding::TokenConstraint;
440 use crate::grammar::{arithmetic_grammar, csv_row_grammar, simple_ab_grammar};
441
442 // ── Minimal ASCII byte-level vocab helper ───────────────────────────────
443
444 /// Build a `GrammarConstraint` with a simple byte-level vocabulary
445 /// where token id == ASCII code point (0..128).
446 fn ascii_constraint(grammar: Grammar) -> GrammarConstraint {
447 GrammarConstraint::new(
448 grammar,
449 |id| {
450 if id < 128 {
451 vec![id as u8]
452 } else {
453 vec![]
454 }
455 },
456 128,
457 )
458 }
459
460 // ── Arithmetic grammar ──────────────────────────────────────────────────
461
462 #[test]
463 fn grammar_constraint_name() {
464 let c = ascii_constraint(arithmetic_grammar());
465 assert_eq!(c.name(), "GrammarConstraint");
466 }
467
468 #[test]
469 fn grammar_constraint_not_complete_initially() {
470 let c = ascii_constraint(arithmetic_grammar());
471 assert!(!c.is_complete());
472 }
473
474 #[test]
475 fn grammar_constraint_arithmetic_allows_digits_at_start() {
476 let c = ascii_constraint(arithmetic_grammar());
477 let mask = c.allowed_tokens(&[], 128).unwrap();
478 for d in b'0'..=b'9' {
479 assert!(mask[d as usize], "digit {d} should be allowed at start");
480 }
481 assert!(mask[b'(' as usize], "'(' should be allowed at start");
482 assert!(!mask[b'+' as usize], "'+' should not be allowed at start");
483 }
484
485 #[test]
486 fn grammar_constraint_advance_digit_and_operator() {
487 let mut c = ascii_constraint(arithmetic_grammar());
488 assert!(c.advance(b'1' as u32), "advancing '1' should succeed");
489 assert!(
490 c.advance(b'+' as u32),
491 "advancing '+' after '1' should succeed"
492 );
493 }
494
495 #[test]
496 fn grammar_constraint_advance_violation() {
497 let mut c = ascii_constraint(arithmetic_grammar());
498 let ok = c.advance(b'+' as u32);
499 assert!(!ok, "'+' at start should be rejected");
500 }
501
502 #[test]
503 fn grammar_constraint_complete_after_full_expression() {
504 let mut c = ascii_constraint(arithmetic_grammar());
505 c.advance(b'1' as u32);
506 assert!(c.is_complete(), "single digit is a complete expression");
507 }
508
509 #[test]
510 fn grammar_constraint_not_complete_after_operator() {
511 let mut c = ascii_constraint(arithmetic_grammar());
512 c.advance(b'1' as u32);
513 c.advance(b'+' as u32);
514 assert!(!c.is_complete(), "after '1+' the expression is incomplete");
515 }
516
517 #[test]
518 fn grammar_constraint_reset() {
519 let mut c = ascii_constraint(arithmetic_grammar());
520 c.advance(b'5' as u32);
521 assert!(c.is_complete());
522 c.reset();
523 assert!(!c.is_complete());
524 assert_eq!(c.bytes_consumed(), 0);
525 }
526
527 #[test]
528 fn grammar_constraint_full_sequence_1plus2() {
529 let mut c = ascii_constraint(arithmetic_grammar());
530 assert!(c.advance(b'1' as u32));
531 assert!(c.is_complete());
532 assert!(c.advance(b'+' as u32));
533 assert!(!c.is_complete());
534 assert!(c.advance(b'2' as u32));
535 assert!(c.is_complete());
536 }
537
538 #[test]
539 fn grammar_constraint_disallows_after_rejection() {
540 let mut c = ascii_constraint(arithmetic_grammar());
541 let ok = c.advance(b'+' as u32);
542 // After a rejection the recognizer is dead.
543 if !ok {
544 let mask = c.allowed_tokens(&[], 128).unwrap();
545 assert!(
546 mask.iter().all(|&b| !b),
547 "all tokens should be blocked after rejection"
548 );
549 }
550 }
551
552 #[test]
553 fn grammar_constraint_is_send_sync() {
554 fn assert_send_sync<T: Send + Sync>() {}
555 assert_send_sync::<GrammarConstraint>();
556 }
557
558 // ── Simple a^n b^n grammar ──────────────────────────────────────────────
559
560 #[test]
561 fn grammar_constraint_ab_sequence() {
562 let mut c = ascii_constraint(simple_ab_grammar());
563 // "ab" should be accepted.
564 assert!(c.advance(b'a' as u32));
565 assert!(!c.is_complete(), "after 'a' not yet complete");
566 assert!(c.advance(b'b' as u32));
567 assert!(c.is_complete(), "after 'ab' should be complete");
568 }
569
570 #[test]
571 fn grammar_constraint_ab_sequence_longer() {
572 let mut c = ascii_constraint(simple_ab_grammar());
573 // "aabb" should be accepted.
574 assert!(c.advance(b'a' as u32));
575 assert!(c.advance(b'a' as u32));
576 assert!(c.advance(b'b' as u32));
577 assert!(c.advance(b'b' as u32));
578 assert!(c.is_complete());
579 }
580
581 // ── CSV grammar ─────────────────────────────────────────────────────────
582
583 #[test]
584 fn grammar_constraint_csv_row() {
585 let mut c = ascii_constraint(csv_row_grammar());
586 // "a,b" is a valid two-field CSV row.
587 for b in b"a,b" {
588 assert!(c.advance(*b as u32), "byte {b} should be accepted");
589 }
590 assert!(c.is_complete());
591 }
592
593 #[test]
594 fn grammar_constraint_csv_row_single_field() {
595 let mut c = ascii_constraint(csv_row_grammar());
596 for b in b"hello" {
597 assert!(c.advance(*b as u32));
598 }
599 assert!(c.is_complete());
600 }
601
602 // ── Trait object safety ─────────────────────────────────────────────────
603
604 #[test]
605 fn grammar_constraint_implements_token_constraint_trait() {
606 let c: Box<dyn TokenConstraint> = Box::new(ascii_constraint(arithmetic_grammar()));
607 assert_eq!(c.name(), "GrammarConstraint");
608 assert!(!c.is_complete());
609 }
610
611 // ── Empty byte token ────────────────────────────────────────────────────
612
613 #[test]
614 fn grammar_constraint_empty_token_only_when_accepting() {
615 // Build a vocab where token 200 maps to empty bytes (special token).
616 let g = arithmetic_grammar();
617 let c = GrammarConstraint::new(
618 g,
619 |id| {
620 if id < 128 {
621 vec![id as u8]
622 } else {
623 vec![] // id == 200 is EOS; all non-ASCII ids map to empty
624 }
625 },
626 201,
627 );
628
629 // Initially not accepting, so token 200 should be blocked.
630 let mask = c.allowed_tokens(&[], 201).unwrap();
631 assert!(
632 !mask[200],
633 "EOS token should not be allowed when not accepting"
634 );
635 }
636
637 #[test]
638 fn grammar_constraint_empty_token_allowed_when_accepting() {
639 let g = arithmetic_grammar();
640 let mut c = GrammarConstraint::new(
641 g,
642 |id| {
643 if id < 128 {
644 vec![id as u8]
645 } else {
646 vec![] // id == 200 is EOS; all non-ASCII ids map to empty
647 }
648 },
649 201,
650 );
651
652 // After generating "9" (a complete expression) we are accepting.
653 c.advance(b'9' as u32);
654 assert!(c.is_complete());
655
656 let mask = c.allowed_tokens(&[], 201).unwrap();
657 assert!(mask[200], "EOS token should be allowed when accepting");
658 }
659
660 // ── Phase 16B: vocab_size accessor ──────────────────────────────────────
661
662 #[test]
663 fn grammar_constraint_vocab_size_accessor() {
664 let c = ascii_constraint(arithmetic_grammar());
665 assert_eq!(c.vocab_size(), 128);
666
667 let c2 = GrammarConstraint::new(arithmetic_grammar(), |id| vec![id as u8], 512);
668 assert_eq!(c2.vocab_size(), 512);
669 }
670
671 // ── Phase 16B: index_memory_bytes ───────────────────────────────────────
672
673 #[test]
674 fn grammar_constraint_index_memory_nonzero() {
675 let c = ascii_constraint(arithmetic_grammar());
676 assert!(
677 c.index_memory_bytes() > 0,
678 "index_memory_bytes must be > 0 for vocab_size > 0"
679 );
680 }
681
682 #[test]
683 fn grammar_constraint_index_memory_zero_vocab() {
684 // vocab_size == 0 → token_bytes is empty, but first_byte_index still
685 // holds 256 empty Vecs (each 24 bytes header).
686 let c = GrammarConstraint::new(arithmetic_grammar(), |_id| vec![], 0);
687 // 256 empty Vec<u32> × 24 bytes each = 6144 bytes minimum.
688 assert_eq!(c.index_memory_bytes(), 256 * 24);
689 }
690
691 // ── Phase 16B: first-byte index correctness ──────────────────────────────
692
693 #[test]
694 fn grammar_constraint_digits_allowed_at_start_via_index() {
695 // The arithmetic grammar starts with digits and '('.
696 // Verify that the index path produces the same mask as the old path.
697 let c = ascii_constraint(arithmetic_grammar());
698 let mask = c.allowed_tokens(&[], 128).unwrap();
699
700 for d in b'0'..=b'9' {
701 assert!(
702 mask[d as usize],
703 "digit token {} should be allowed at start",
704 d as char
705 );
706 }
707 assert!(mask[b'(' as usize], "'(' should be allowed at start");
708 // Non-first-byte tokens must be blocked.
709 assert!(!mask[b'+' as usize], "'+' not valid at start");
710 assert!(!mask[b' ' as usize], "space not valid at start");
711 assert!(!mask[b'z' as usize], "'z' not valid at start");
712 }
713
714 #[test]
715 fn grammar_constraint_advance_uses_cached_bytes() {
716 // Verify that advance() via cached bytes works identically to the
717 // old tokenizer_decode_fn path by checking recognizer state advancement.
718 let mut c = ascii_constraint(arithmetic_grammar());
719
720 // Feed "1+2" token by token.
721 assert!(c.advance(b'1' as u32), "'1' should advance");
722 assert!(c.is_complete(), "single digit is complete");
723 assert!(c.advance(b'+' as u32), "'+' should advance after digit");
724 assert!(!c.is_complete(), "incomplete after '+'");
725 assert!(c.advance(b'2' as u32), "'2' should advance");
726 assert!(c.is_complete(), "'1+2' is a complete expression");
727
728 // Verify bytes_consumed reflects all bytes fed.
729 assert_eq!(c.bytes_consumed(), 3, "3 bytes should have been consumed");
730 }
731
732 #[test]
733 fn grammar_constraint_advance_out_of_range_token() {
734 // Token ID beyond vocab_size (128) uses the "treat as accepting" fallback.
735 let c = ascii_constraint(arithmetic_grammar());
736 // At initial state, recognizer is NOT accepting → out-of-range token returns false.
737 let mut c_mut = ascii_constraint(arithmetic_grammar());
738 let ok = c_mut.advance(999); // well beyond vocab_size=128
739 assert!(
740 !ok,
741 "out-of-range token should return false when not accepting"
742 );
743
744 drop(c);
745
746 // After advancing to an accepting state, out-of-range token returns true.
747 let mut c2 = ascii_constraint(arithmetic_grammar());
748 c2.advance(b'5' as u32); // now accepting
749 assert!(c2.is_complete());
750 let ok2 = c2.advance(999);
751 assert!(ok2, "out-of-range token should return true when accepting");
752 }
753
754 // ── Phase 16B: precomputed bytes match decode fn ─────────────────────────
755
756 #[test]
757 fn grammar_constraint_precomputed_bytes_match_decode_fn() {
758 // Verify token_bytes[id] == direct decode for all ids 0..128.
759 let decode_fn = |id: u32| -> Vec<u8> {
760 if id < 128 {
761 vec![id as u8]
762 } else {
763 vec![]
764 }
765 };
766 let c = GrammarConstraint::new(arithmetic_grammar(), decode_fn, 128);
767
768 for id in 0u32..128 {
769 let precomputed = &c.token_bytes[id as usize];
770 let direct = if id < 128 { vec![id as u8] } else { vec![] };
771 assert_eq!(
772 precomputed, &direct,
773 "precomputed bytes for token {id} must match direct decode"
774 );
775 }
776 }
777}