Skip to main content

oxibonsai_runtime/grammar/
constraint.rs

1//! [`GrammarConstraint`] — implements [`TokenConstraint`] using the Earley
2//! chart-parser recognizer backed by a BNF context-free grammar.
3//!
4//! The `allowed_tokens` method speculatively feeds each token's byte sequence
5//! through a **clone** of the current recognizer state and marks the token
6//! allowed if and only if none of the bytes are rejected.
7//!
8//! **Phase 16B optimization:** Token byte sequences are precomputed once during
9//! construction via `tokenizer_decode_fn` and stored in `token_bytes: Vec<Vec<u8>>`.
10//! A `first_byte_index: Box<[Vec<u32>; 256]>` maps each first byte to the list of
11//! token IDs that start with that byte. During `allowed_tokens`, only tokens whose
12//! first byte is in `next_byte_set` are probed — all others are skipped without
13//! invoking the decode function or any recognizer cloning.
14//!
15//! This reduces the per-step work from O(vocab) decode calls + O(vocab) first-byte
16//! checks + O(filtered_vocab × token_len) recognizer probes to just
17//! O(|next_byte_set| × avg_matching_tokens × avg_token_len) recognizer probes.
18
19use std::sync::{Arc, Mutex};
20
21use super::ast::Grammar;
22use super::cache::AllowedTokensCache;
23use super::earley::EarleyRecognizer;
24use crate::constrained_decoding::TokenConstraint;
25
26// ─────────────────────────────────────────────────────────────────────────────
27// GrammarConstraint
28// ─────────────────────────────────────────────────────────────────────────────
29
30/// A [`TokenConstraint`] that enforces a context-free grammar on the generated
31/// byte stream, using the Earley chart-parser as the underlying recognizer.
32///
33/// # Construction
34///
35/// ```rust,no_run
36/// use oxibonsai_runtime::grammar::{arithmetic_grammar, GrammarConstraint};
37///
38/// let grammar = arithmetic_grammar();
39/// // Map each token id to its byte sequence; single-byte ASCII vocab here.
40/// let decode_fn = |token_id: u32| -> Vec<u8> {
41///     if token_id < 128 { vec![token_id as u8] } else { vec![] }
42/// };
43/// let constraint = GrammarConstraint::new(grammar, decode_fn, 128);
44/// ```
45///
46/// # Token decode function
47///
48/// The `tokenizer_decode_fn` maps a token id to the **byte sequence** it
49/// represents.  For an ASCII byte-level vocabulary it is simply
50/// `|id| vec![id as u8]`.  For a real LLM tokenizer it should call into
51/// `tokenizer.id_to_bytes(id)`.  Unknown / special tokens can return an empty
52/// `Vec<u8>`; they will be allowed iff the current recognizer state is
53/// accepting (which allows a graceful end-of-sequence).
54///
55/// # Phase 16B: Precomputed byte index
56///
57/// At construction time, `GrammarConstraint` eagerly calls `tokenizer_decode_fn`
58/// for every token ID in `0..vocab_size`, storing the results in `token_bytes`.
59/// Simultaneously, `first_byte_index[b]` accumulates the list of token IDs whose
60/// first byte is `b`, and `empty_token_ids` collects IDs with empty byte sequences
61/// (EOS, padding, special tokens).
62///
63/// This eliminates O(vocab) decode calls during each `allowed_tokens` call and
64/// allows the inner loop to skip entire byte classes not present in
65/// `next_byte_set` — often reducing the probed token count by 90–99 %.
66pub struct GrammarConstraint {
67    /// Original grammar (kept for potential future reset/inspection).
68    #[allow(dead_code)]
69    grammar: Arc<Grammar>,
70    /// Live Earley recognizer tracking the bytes generated so far.
71    recognizer: EarleyRecognizer,
72    /// Decodes a token id to its raw byte sequence.
73    ///
74    /// Retained for potential out-of-range token handling or future callers that
75    /// need to decode tokens not covered by the initial `0..vocab_size` range.
76    #[allow(dead_code)]
77    tokenizer_decode_fn: Arc<dyn Fn(u32) -> Vec<u8> + Send + Sync>,
78    /// Total vocabulary size used to allocate the precomputed index.
79    vocab_size: usize,
80    /// LRU memoization cache for `allowed_tokens` results keyed by Earley state hash.
81    ///
82    /// Wrapped in `Mutex` because `TokenConstraint::allowed_tokens` takes `&self`,
83    /// yet cache mutation requires `&mut`.  `Mutex::lock()` returning `PoisonError`
84    /// on panic is handled gracefully: cache misses are silent (never panics).
85    cache: Mutex<AllowedTokensCache>,
86
87    // ── Phase 16B: Precomputed token index ──────────────────────────────────
88    /// Precomputed byte sequences for every token in `0..vocab_size`.
89    ///
90    /// `token_bytes[id]` is the byte sequence for token `id`, precomputed once
91    /// at construction time.  This is the primary data consumed by `allowed_tokens`
92    /// and `advance`.
93    token_bytes: Vec<Vec<u8>>,
94
95    /// First-byte index: `first_byte_index[b]` is the list of token IDs
96    /// (in `0..vocab_size`) whose first byte equals `b`.
97    ///
98    /// Boxed to avoid stack-allocating 256 `Vec<u32>`s (which may trigger a
99    /// stack overflow for large vectors on some platforms).
100    first_byte_index: Box<[Vec<u32>; 256]>,
101
102    /// Token IDs (in `0..vocab_size`) whose byte sequence is empty.
103    ///
104    /// These represent EOS tokens, padding tokens, and other special tokens
105    /// that do not contribute bytes to the grammar stream.  They are allowed
106    /// only when the recognizer is in an accepting state.
107    empty_token_ids: Vec<u32>,
108}
109
110// ─────────────────────────────────────────────────────────────────────────────
111// Private construction helper
112// ─────────────────────────────────────────────────────────────────────────────
113
114/// Type alias for the first-byte index (avoids clippy::type_complexity).
115type FirstByteIndex = Box<[Vec<u32>; 256]>;
116
117/// Aggregate result of `build_token_index`.
118struct TokenIndex {
119    token_bytes: Vec<Vec<u8>>,
120    first_byte_index: FirstByteIndex,
121    empty_token_ids: Vec<u32>,
122}
123
124/// Build the three precomputed structures from a decode function and vocab size.
125fn build_token_index(decode_fn: &dyn Fn(u32) -> Vec<u8>, vocab_size: usize) -> TokenIndex {
126    let mut token_bytes: Vec<Vec<u8>> = Vec::with_capacity(vocab_size);
127
128    // Use a Vec<Vec<u32>> of length 256 to avoid constructing 256 Vecs on the
129    // stack before boxing — the std::array::from_fn approach would stack-allocate
130    // [Vec<u32>; 256] = ~3 KB, which is fine, but building it element-by-element
131    // via a Vec before converting avoids any platform-specific stack pressure.
132    let mut raw_index: Vec<Vec<u32>> = (0..256_usize).map(|_| Vec::new()).collect();
133    let mut empty_token_ids: Vec<u32> = Vec::new();
134
135    for id in 0..vocab_size as u32 {
136        let bytes = decode_fn(id);
137        match bytes.first() {
138            Some(&b) => raw_index[b as usize].push(id),
139            None => empty_token_ids.push(id),
140        }
141        token_bytes.push(bytes);
142    }
143
144    // Convert Vec<Vec<u32>> (length 256) into Box<[Vec<u32>; 256]>.
145    // We built `raw_index` with exactly 256 elements, so the try_into cannot fail.
146    let first_byte_index: FirstByteIndex = raw_index
147        .into_boxed_slice()
148        .try_into()
149        .expect("raw_index must have exactly 256 elements");
150
151    TokenIndex {
152        token_bytes,
153        first_byte_index,
154        empty_token_ids,
155    }
156}
157
158// ─────────────────────────────────────────────────────────────────────────────
159// Public API
160// ─────────────────────────────────────────────────────────────────────────────
161
162impl GrammarConstraint {
163    /// Create a new `GrammarConstraint`.
164    ///
165    /// The `grammar` is normalised (multi-byte terminals split into chains)
166    /// and wrapped in an `Arc` before being handed to the recognizer.
167    ///
168    /// **Phase 16B:** This eagerly calls `tokenizer_decode_fn(id)` for every
169    /// `id` in `0..vocab_size`, building `token_bytes` and `first_byte_index`.
170    /// Construction cost is O(vocab_size × avg_decode_cost); subsequent
171    /// `allowed_tokens` calls no longer call the decode function at all.
172    ///
173    /// # Parameters
174    ///
175    /// * `grammar`               — the context-free grammar to enforce
176    /// * `tokenizer_decode_fn`   — maps token id → byte sequence
177    /// * `vocab_size`            — total vocabulary size
178    pub fn new(
179        mut grammar: Grammar,
180        tokenizer_decode_fn: impl Fn(u32) -> Vec<u8> + Send + Sync + 'static,
181        vocab_size: usize,
182    ) -> Self {
183        grammar.normalise_terminals();
184        let grammar = Arc::new(grammar);
185        let recognizer = EarleyRecognizer::new(Arc::clone(&grammar));
186        let tokenizer_decode_fn: Arc<dyn Fn(u32) -> Vec<u8> + Send + Sync> =
187            Arc::new(tokenizer_decode_fn);
188
189        let idx = build_token_index(tokenizer_decode_fn.as_ref(), vocab_size);
190
191        Self {
192            grammar,
193            recognizer,
194            tokenizer_decode_fn,
195            vocab_size,
196            cache: Mutex::new(AllowedTokensCache::with_capacity(256)),
197            token_bytes: idx.token_bytes,
198            first_byte_index: idx.first_byte_index,
199            empty_token_ids: idx.empty_token_ids,
200        }
201    }
202
203    /// Create a new `GrammarConstraint` with a custom cache capacity.
204    ///
205    /// Identical to [`new`](Self::new) except that the LRU cache is initialised
206    /// with `capacity` entries rather than the default 256.  Use a larger value
207    /// when the grammar has many distinct parse states; use a smaller value to
208    /// bound memory at the cost of more cache misses.
209    ///
210    /// **Phase 16B:** Same eager precomputation as [`new`](Self::new).
211    ///
212    /// # Parameters
213    ///
214    /// * `grammar`               — the context-free grammar to enforce
215    /// * `tokenizer_decode_fn`   — maps token id → byte sequence
216    /// * `vocab_size`            — total vocabulary size
217    /// * `capacity`              — LRU cache capacity (clamped to ≥ 1)
218    pub fn with_cache_capacity(
219        mut grammar: Grammar,
220        tokenizer_decode_fn: impl Fn(u32) -> Vec<u8> + Send + Sync + 'static,
221        vocab_size: usize,
222        capacity: usize,
223    ) -> Self {
224        grammar.normalise_terminals();
225        let grammar = Arc::new(grammar);
226        let recognizer = EarleyRecognizer::new(Arc::clone(&grammar));
227        let tokenizer_decode_fn: Arc<dyn Fn(u32) -> Vec<u8> + Send + Sync> =
228            Arc::new(tokenizer_decode_fn);
229
230        let idx = build_token_index(tokenizer_decode_fn.as_ref(), vocab_size);
231
232        Self {
233            grammar,
234            recognizer,
235            tokenizer_decode_fn,
236            vocab_size,
237            cache: Mutex::new(AllowedTokensCache::with_capacity(capacity)),
238            token_bytes: idx.token_bytes,
239            first_byte_index: idx.first_byte_index,
240            empty_token_ids: idx.empty_token_ids,
241        }
242    }
243
244    /// Return cache hit/miss statistics as `(hits, misses)`.
245    ///
246    /// Useful for testing and for monitoring cache effectiveness in production.
247    /// Returns `(0, 0)` if the internal `Mutex` has been poisoned (never panics).
248    pub fn cache_stats(&self) -> (u64, u64) {
249        self.cache
250            .lock()
251            .map(|c| (c.hits(), c.misses()))
252            .unwrap_or((0, 0))
253    }
254
255    /// Return the current number of bytes consumed by the recognizer.
256    pub fn bytes_consumed(&self) -> usize {
257        self.recognizer.input_pos
258    }
259
260    /// Return `true` if the recognizer is still in a live (non-dead) state.
261    pub fn is_live(&self) -> bool {
262        self.recognizer.is_live()
263    }
264
265    /// Return the set of bytes valid as the next byte in the stream.
266    ///
267    /// This is a low-level utility; prefer `allowed_tokens` for normal use.
268    pub fn next_byte_set(&self) -> std::collections::HashSet<u8> {
269        self.recognizer.next_byte_set()
270    }
271
272    /// Return the vocabulary size passed to the constructor.
273    ///
274    /// This equals `self.token_bytes.len()`.
275    pub fn vocab_size(&self) -> usize {
276        self.vocab_size
277    }
278
279    /// Return an estimate of the heap memory (in bytes) occupied by the
280    /// precomputed token index built during construction.
281    ///
282    /// The estimate accounts for:
283    /// * `token_bytes`: 24-byte `Vec` header + inline byte storage per token.
284    /// * `first_byte_index`: 24-byte `Vec` header + 4-byte u32 per entry,
285    ///   for all 256 first-byte buckets.
286    /// * `empty_token_ids`: 4 bytes per entry.
287    ///
288    /// This is a lower bound (does not include allocator overhead or padding).
289    pub fn index_memory_bytes(&self) -> usize {
290        // 24 = size_of::<Vec<u8>>() on 64-bit platforms (ptr + len + cap).
291        let token_bytes_mem: usize = self.token_bytes.iter().map(|b| b.len() + 24).sum();
292        // 24 = size_of::<Vec<u32>>(); 4 = size_of::<u32>().
293        let index_mem: usize = self.first_byte_index.iter().map(|v| v.len() * 4 + 24).sum();
294        token_bytes_mem + index_mem + self.empty_token_ids.len() * 4
295    }
296}
297
298// ─────────────────────────────────────────────────────────────────────────────
299// TokenConstraint implementation
300// ─────────────────────────────────────────────────────────────────────────────
301
302impl TokenConstraint for GrammarConstraint {
303    /// Compute a per-token mask using the precomputed first-byte index.
304    ///
305    /// **Phase 16B algorithm:**
306    ///
307    /// 1. If the recognizer is dead, return all-false immediately.
308    /// 2. Compute `next_byte_set` (NBS) and `is_accepting`.
309    /// 3. If NBS is empty and not accepting, return all-false immediately.
310    /// 4. Check the LRU cache keyed by `state_hash()`.
311    /// 5. On cache miss: start with an all-false mask.
312    ///    * For each `first_byte` in NBS, iterate `first_byte_index[first_byte]`
313    ///      and probe only those tokens via `recognizer.clone_state()`.
314    ///    * For empty-byte tokens (EOS/special), allow them iff `is_accepting`.
315    /// 6. Insert the result into the LRU cache.
316    ///
317    /// The inner loop never calls `tokenizer_decode_fn` — it reads precomputed
318    /// `token_bytes` instead.  Tokens whose first byte is NOT in NBS are never
319    /// visited at all.
320    fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
321        // ── Early exits ─────────────────────────────────────────────────────
322        if !self.recognizer.is_live() {
323            return Some(vec![false; vocab_size]);
324        }
325
326        let nbs = self.recognizer.next_byte_set();
327        let currently_accepting = self.recognizer.is_accepting();
328
329        if nbs.is_empty() && !currently_accepting {
330            return Some(vec![false; vocab_size]);
331        }
332
333        // ── Cache lookup ─────────────────────────────────────────────────────
334        let state_hash = self.recognizer.state_hash();
335        if let Ok(mut cache) = self.cache.lock() {
336            if let Some(cached) = cache.get(state_hash) {
337                return Some(cached.to_vec());
338            }
339        }
340
341        // ── Cache miss: build mask using first-byte index ────────────────────
342        let mut mask = vec![false; vocab_size];
343
344        // Empty-byte tokens (EOS, special): allowed only when accepting.
345        if currently_accepting {
346            for &id in &self.empty_token_ids {
347                if (id as usize) < vocab_size {
348                    mask[id as usize] = true;
349                }
350            }
351        }
352
353        // Tokens grouped by first byte: iterate only over bytes that are in NBS.
354        for &first_byte in &nbs {
355            for &token_id in &self.first_byte_index[first_byte as usize] {
356                let token_idx = token_id as usize;
357                if token_idx >= vocab_size {
358                    continue;
359                }
360                let bytes = &self.token_bytes[token_idx];
361                if bytes.is_empty() {
362                    // Should not happen (empties are in empty_token_ids), but
363                    // handle defensively.
364                    if currently_accepting {
365                        mask[token_idx] = true;
366                    }
367                    continue;
368                }
369                // bytes[0] == first_byte by construction — no need to re-check.
370                // Probe the remaining bytes via a cloned recognizer state.
371                let mut probe = self.recognizer.clone_state();
372                let mut ok = true;
373                for &b in bytes {
374                    if !probe.feed_byte(b) {
375                        ok = false;
376                        break;
377                    }
378                }
379                if ok {
380                    mask[token_idx] = true;
381                }
382            }
383        }
384
385        // ── Store in cache ───────────────────────────────────────────────────
386        if let Ok(mut cache) = self.cache.lock() {
387            cache.insert(state_hash, mask.clone());
388        }
389
390        Some(mask)
391    }
392
393    /// Commit `token` to the recognizer by feeding its precomputed byte sequence.
394    ///
395    /// Uses the precomputed `token_bytes` slice instead of calling
396    /// `tokenizer_decode_fn`, avoiding one decode call per accepted token.
397    ///
398    /// Returns `false` if any byte in the token's sequence is rejected by the
399    /// grammar, or if the token ID is out of range for the precomputed index.
400    fn advance(&mut self, token: u32) -> bool {
401        let Some(bytes) = self.token_bytes.get(token as usize) else {
402            // Token ID is beyond the precomputed vocab range.
403            // Treat as empty → allowed only if currently accepting.
404            return self.recognizer.is_accepting();
405        };
406        if bytes.is_empty() {
407            return self.recognizer.is_accepting();
408        }
409        for &b in bytes {
410            if !self.recognizer.feed_byte(b) {
411                return false;
412            }
413        }
414        true
415    }
416
417    /// Returns `true` when the recognizer is in an accepting state.
418    fn is_complete(&self) -> bool {
419        self.recognizer.is_accepting()
420    }
421
422    /// Reset the recognizer to the initial state.
423    fn reset(&mut self) {
424        self.recognizer.reset();
425    }
426
427    fn name(&self) -> &str {
428        "GrammarConstraint"
429    }
430}
431
432// ─────────────────────────────────────────────────────────────────────────────
433// Unit tests
434// ─────────────────────────────────────────────────────────────────────────────
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use crate::constrained_decoding::TokenConstraint;
440    use crate::grammar::{arithmetic_grammar, csv_row_grammar, simple_ab_grammar};
441
442    // ── Minimal ASCII byte-level vocab helper ───────────────────────────────
443
444    /// Build a `GrammarConstraint` with a simple byte-level vocabulary
445    /// where token id == ASCII code point (0..128).
446    fn ascii_constraint(grammar: Grammar) -> GrammarConstraint {
447        GrammarConstraint::new(
448            grammar,
449            |id| {
450                if id < 128 {
451                    vec![id as u8]
452                } else {
453                    vec![]
454                }
455            },
456            128,
457        )
458    }
459
460    // ── Arithmetic grammar ──────────────────────────────────────────────────
461
462    #[test]
463    fn grammar_constraint_name() {
464        let c = ascii_constraint(arithmetic_grammar());
465        assert_eq!(c.name(), "GrammarConstraint");
466    }
467
468    #[test]
469    fn grammar_constraint_not_complete_initially() {
470        let c = ascii_constraint(arithmetic_grammar());
471        assert!(!c.is_complete());
472    }
473
474    #[test]
475    fn grammar_constraint_arithmetic_allows_digits_at_start() {
476        let c = ascii_constraint(arithmetic_grammar());
477        let mask = c.allowed_tokens(&[], 128).unwrap();
478        for d in b'0'..=b'9' {
479            assert!(mask[d as usize], "digit {d} should be allowed at start");
480        }
481        assert!(mask[b'(' as usize], "'(' should be allowed at start");
482        assert!(!mask[b'+' as usize], "'+' should not be allowed at start");
483    }
484
485    #[test]
486    fn grammar_constraint_advance_digit_and_operator() {
487        let mut c = ascii_constraint(arithmetic_grammar());
488        assert!(c.advance(b'1' as u32), "advancing '1' should succeed");
489        assert!(
490            c.advance(b'+' as u32),
491            "advancing '+' after '1' should succeed"
492        );
493    }
494
495    #[test]
496    fn grammar_constraint_advance_violation() {
497        let mut c = ascii_constraint(arithmetic_grammar());
498        let ok = c.advance(b'+' as u32);
499        assert!(!ok, "'+' at start should be rejected");
500    }
501
502    #[test]
503    fn grammar_constraint_complete_after_full_expression() {
504        let mut c = ascii_constraint(arithmetic_grammar());
505        c.advance(b'1' as u32);
506        assert!(c.is_complete(), "single digit is a complete expression");
507    }
508
509    #[test]
510    fn grammar_constraint_not_complete_after_operator() {
511        let mut c = ascii_constraint(arithmetic_grammar());
512        c.advance(b'1' as u32);
513        c.advance(b'+' as u32);
514        assert!(!c.is_complete(), "after '1+' the expression is incomplete");
515    }
516
517    #[test]
518    fn grammar_constraint_reset() {
519        let mut c = ascii_constraint(arithmetic_grammar());
520        c.advance(b'5' as u32);
521        assert!(c.is_complete());
522        c.reset();
523        assert!(!c.is_complete());
524        assert_eq!(c.bytes_consumed(), 0);
525    }
526
527    #[test]
528    fn grammar_constraint_full_sequence_1plus2() {
529        let mut c = ascii_constraint(arithmetic_grammar());
530        assert!(c.advance(b'1' as u32));
531        assert!(c.is_complete());
532        assert!(c.advance(b'+' as u32));
533        assert!(!c.is_complete());
534        assert!(c.advance(b'2' as u32));
535        assert!(c.is_complete());
536    }
537
538    #[test]
539    fn grammar_constraint_disallows_after_rejection() {
540        let mut c = ascii_constraint(arithmetic_grammar());
541        let ok = c.advance(b'+' as u32);
542        // After a rejection the recognizer is dead.
543        if !ok {
544            let mask = c.allowed_tokens(&[], 128).unwrap();
545            assert!(
546                mask.iter().all(|&b| !b),
547                "all tokens should be blocked after rejection"
548            );
549        }
550    }
551
552    #[test]
553    fn grammar_constraint_is_send_sync() {
554        fn assert_send_sync<T: Send + Sync>() {}
555        assert_send_sync::<GrammarConstraint>();
556    }
557
558    // ── Simple a^n b^n grammar ──────────────────────────────────────────────
559
560    #[test]
561    fn grammar_constraint_ab_sequence() {
562        let mut c = ascii_constraint(simple_ab_grammar());
563        // "ab" should be accepted.
564        assert!(c.advance(b'a' as u32));
565        assert!(!c.is_complete(), "after 'a' not yet complete");
566        assert!(c.advance(b'b' as u32));
567        assert!(c.is_complete(), "after 'ab' should be complete");
568    }
569
570    #[test]
571    fn grammar_constraint_ab_sequence_longer() {
572        let mut c = ascii_constraint(simple_ab_grammar());
573        // "aabb" should be accepted.
574        assert!(c.advance(b'a' as u32));
575        assert!(c.advance(b'a' as u32));
576        assert!(c.advance(b'b' as u32));
577        assert!(c.advance(b'b' as u32));
578        assert!(c.is_complete());
579    }
580
581    // ── CSV grammar ─────────────────────────────────────────────────────────
582
583    #[test]
584    fn grammar_constraint_csv_row() {
585        let mut c = ascii_constraint(csv_row_grammar());
586        // "a,b" is a valid two-field CSV row.
587        for b in b"a,b" {
588            assert!(c.advance(*b as u32), "byte {b} should be accepted");
589        }
590        assert!(c.is_complete());
591    }
592
593    #[test]
594    fn grammar_constraint_csv_row_single_field() {
595        let mut c = ascii_constraint(csv_row_grammar());
596        for b in b"hello" {
597            assert!(c.advance(*b as u32));
598        }
599        assert!(c.is_complete());
600    }
601
602    // ── Trait object safety ─────────────────────────────────────────────────
603
604    #[test]
605    fn grammar_constraint_implements_token_constraint_trait() {
606        let c: Box<dyn TokenConstraint> = Box::new(ascii_constraint(arithmetic_grammar()));
607        assert_eq!(c.name(), "GrammarConstraint");
608        assert!(!c.is_complete());
609    }
610
611    // ── Empty byte token ────────────────────────────────────────────────────
612
613    #[test]
614    fn grammar_constraint_empty_token_only_when_accepting() {
615        // Build a vocab where token 200 maps to empty bytes (special token).
616        let g = arithmetic_grammar();
617        let c = GrammarConstraint::new(
618            g,
619            |id| {
620                if id < 128 {
621                    vec![id as u8]
622                } else {
623                    vec![] // id == 200 is EOS; all non-ASCII ids map to empty
624                }
625            },
626            201,
627        );
628
629        // Initially not accepting, so token 200 should be blocked.
630        let mask = c.allowed_tokens(&[], 201).unwrap();
631        assert!(
632            !mask[200],
633            "EOS token should not be allowed when not accepting"
634        );
635    }
636
637    #[test]
638    fn grammar_constraint_empty_token_allowed_when_accepting() {
639        let g = arithmetic_grammar();
640        let mut c = GrammarConstraint::new(
641            g,
642            |id| {
643                if id < 128 {
644                    vec![id as u8]
645                } else {
646                    vec![] // id == 200 is EOS; all non-ASCII ids map to empty
647                }
648            },
649            201,
650        );
651
652        // After generating "9" (a complete expression) we are accepting.
653        c.advance(b'9' as u32);
654        assert!(c.is_complete());
655
656        let mask = c.allowed_tokens(&[], 201).unwrap();
657        assert!(mask[200], "EOS token should be allowed when accepting");
658    }
659
660    // ── Phase 16B: vocab_size accessor ──────────────────────────────────────
661
662    #[test]
663    fn grammar_constraint_vocab_size_accessor() {
664        let c = ascii_constraint(arithmetic_grammar());
665        assert_eq!(c.vocab_size(), 128);
666
667        let c2 = GrammarConstraint::new(arithmetic_grammar(), |id| vec![id as u8], 512);
668        assert_eq!(c2.vocab_size(), 512);
669    }
670
671    // ── Phase 16B: index_memory_bytes ───────────────────────────────────────
672
673    #[test]
674    fn grammar_constraint_index_memory_nonzero() {
675        let c = ascii_constraint(arithmetic_grammar());
676        assert!(
677            c.index_memory_bytes() > 0,
678            "index_memory_bytes must be > 0 for vocab_size > 0"
679        );
680    }
681
682    #[test]
683    fn grammar_constraint_index_memory_zero_vocab() {
684        // vocab_size == 0 → token_bytes is empty, but first_byte_index still
685        // holds 256 empty Vecs (each 24 bytes header).
686        let c = GrammarConstraint::new(arithmetic_grammar(), |_id| vec![], 0);
687        // 256 empty Vec<u32> × 24 bytes each = 6144 bytes minimum.
688        assert_eq!(c.index_memory_bytes(), 256 * 24);
689    }
690
691    // ── Phase 16B: first-byte index correctness ──────────────────────────────
692
693    #[test]
694    fn grammar_constraint_digits_allowed_at_start_via_index() {
695        // The arithmetic grammar starts with digits and '('.
696        // Verify that the index path produces the same mask as the old path.
697        let c = ascii_constraint(arithmetic_grammar());
698        let mask = c.allowed_tokens(&[], 128).unwrap();
699
700        for d in b'0'..=b'9' {
701            assert!(
702                mask[d as usize],
703                "digit token {} should be allowed at start",
704                d as char
705            );
706        }
707        assert!(mask[b'(' as usize], "'(' should be allowed at start");
708        // Non-first-byte tokens must be blocked.
709        assert!(!mask[b'+' as usize], "'+' not valid at start");
710        assert!(!mask[b' ' as usize], "space not valid at start");
711        assert!(!mask[b'z' as usize], "'z' not valid at start");
712    }
713
714    #[test]
715    fn grammar_constraint_advance_uses_cached_bytes() {
716        // Verify that advance() via cached bytes works identically to the
717        // old tokenizer_decode_fn path by checking recognizer state advancement.
718        let mut c = ascii_constraint(arithmetic_grammar());
719
720        // Feed "1+2" token by token.
721        assert!(c.advance(b'1' as u32), "'1' should advance");
722        assert!(c.is_complete(), "single digit is complete");
723        assert!(c.advance(b'+' as u32), "'+' should advance after digit");
724        assert!(!c.is_complete(), "incomplete after '+'");
725        assert!(c.advance(b'2' as u32), "'2' should advance");
726        assert!(c.is_complete(), "'1+2' is a complete expression");
727
728        // Verify bytes_consumed reflects all bytes fed.
729        assert_eq!(c.bytes_consumed(), 3, "3 bytes should have been consumed");
730    }
731
732    #[test]
733    fn grammar_constraint_advance_out_of_range_token() {
734        // Token ID beyond vocab_size (128) uses the "treat as accepting" fallback.
735        let c = ascii_constraint(arithmetic_grammar());
736        // At initial state, recognizer is NOT accepting → out-of-range token returns false.
737        let mut c_mut = ascii_constraint(arithmetic_grammar());
738        let ok = c_mut.advance(999); // well beyond vocab_size=128
739        assert!(
740            !ok,
741            "out-of-range token should return false when not accepting"
742        );
743
744        drop(c);
745
746        // After advancing to an accepting state, out-of-range token returns true.
747        let mut c2 = ascii_constraint(arithmetic_grammar());
748        c2.advance(b'5' as u32); // now accepting
749        assert!(c2.is_complete());
750        let ok2 = c2.advance(999);
751        assert!(ok2, "out-of-range token should return true when accepting");
752    }
753
754    // ── Phase 16B: precomputed bytes match decode fn ─────────────────────────
755
756    #[test]
757    fn grammar_constraint_precomputed_bytes_match_decode_fn() {
758        // Verify token_bytes[id] == direct decode for all ids 0..128.
759        let decode_fn = |id: u32| -> Vec<u8> {
760            if id < 128 {
761                vec![id as u8]
762            } else {
763                vec![]
764            }
765        };
766        let c = GrammarConstraint::new(arithmetic_grammar(), decode_fn, 128);
767
768        for id in 0u32..128 {
769            let precomputed = &c.token_bytes[id as usize];
770            let direct = if id < 128 { vec![id as u8] } else { vec![] };
771            assert_eq!(
772                precomputed, &direct,
773                "precomputed bytes for token {id} must match direct decode"
774            );
775        }
776    }
777}