Skip to main content

llguidance/
tokenparser.rs

1use std::{fmt::Display, hint::black_box, panic::AssertUnwindSafe, sync::Arc, time::Duration};
2
3use crate::{
4    api::{GrammarInit, ParserLimits, StopReason},
5    earley::{BiasComputer, Parser, ParserError, ParserStats},
6    infoln, panic_utils, warn, Instant, Logger, ParserFactory,
7};
8use anyhow::{ensure, Result};
9use toktrie::{InferenceCapabilities, SimpleVob, TokEnv, TokenId, INVALID_TOKEN};
10
11/// Token-level parser that drives a single constrained-generation session.
12///
13/// Created by [`ParserFactory::create_parser()`] and typically wrapped in a
14/// [`crate::Constraint`] for the sampling loop.  Maintains the grammar state,
15/// computes token masks, and processes sampled tokens.
16#[derive(Clone)]
17pub struct TokenParser {
18    pub token_env: TokEnv,
19    pub parser: Parser,
20    pub compute_mask_start_time: Instant,
21    pub last_bias_time: Duration,
22    pub inference_caps: InferenceCapabilities,
23    pub logger: Logger,
24    pub limits: ParserLimits,
25    pub bias_computer: Arc<dyn BiasComputer>,
26    pub dbg_grammar: String,
27    last_step_stats: ParserStats,
28    max_step_stats: ParserStats,
29    eos_tokens: Vec<TokenId>,
30
31    had_rollback: bool,
32    had_backtrack: bool,
33
34    is_accepting_cache: Option<bool>,
35    ff_tokens_cache: Option<(Vec<TokenId>, Vec<u8>)>,
36    stop_reason: StopReason,
37    error_message: Option<String>,
38    max_tokens_total: usize,
39
40    // tokens currently in KV cache
41    llm_tokens: Vec<TokenId>,
42    llm_bytes: Vec<u8>,
43
44    grm_prefix: Vec<u8>,
45    is_fresh: bool,
46}
47
48impl TokenParser {
49    // use ParserFactory externally
50    pub(crate) fn from_init(
51        factory: &ParserFactory,
52        grammar_init: GrammarInit,
53        logger: Logger,
54        inference_caps: InferenceCapabilities,
55        limits: ParserLimits,
56    ) -> Result<Self> {
57        panic_utils::catch_unwind(AssertUnwindSafe(|| {
58            Self::init_inner(factory, grammar_init, logger, inference_caps, limits)
59        }))
60    }
61
62    fn init_inner(
63        factory: &ParserFactory,
64        grammar_init: GrammarInit,
65        mut logger: Logger,
66        inference_caps: InferenceCapabilities,
67        limits: ParserLimits,
68    ) -> Result<Self> {
69        let token_env = factory.tok_env().clone();
70        ensure!(
71            token_env.tokenize_is_canonical() || !inference_caps.ff_tokens,
72            "ff_tokens requires canonical tokenization"
73        );
74        ensure!(
75            !inference_caps.backtrack || inference_caps.ff_tokens,
76            "backtrack requires ff_tokens"
77        );
78
79        let compute_mask_start_time = Instant::now();
80        let mut max_tokens = usize::MAX;
81        if let GrammarInit::Serialized(input) = &grammar_init {
82            if let Some(m) = input.max_tokens {
83                max_tokens = m;
84            }
85        }
86        let compiled_grammar = grammar_init.to_cgrammar(
87            Some(token_env.clone()),
88            &mut logger,
89            limits.clone(),
90            factory.extra_lexemes(),
91        )?;
92        let mut parser = Parser::new(
93            token_env.clone(),
94            compiled_grammar,
95            limits.clone(),
96            factory.perf_counters(),
97        )?;
98        parser.metrics_mut().rand = factory.next_rng();
99        let eos_tokens = token_env.tok_trie().eos_tokens().to_vec();
100
101        Ok(TokenParser {
102            bias_computer: factory.slicer().clone(),
103            logger,
104            token_env,
105            inference_caps,
106            limits,
107            max_step_stats: ParserStats::default(),
108            last_step_stats: ParserStats::default(),
109            compute_mask_start_time,
110            is_accepting_cache: None,
111            ff_tokens_cache: None,
112            stop_reason: StopReason::NotStopped,
113            error_message: None,
114            parser,
115            dbg_grammar: String::new(),
116            eos_tokens,
117            llm_tokens: Vec::new(),
118            llm_bytes: Vec::new(),
119            grm_prefix: Vec::new(),
120            max_tokens_total: max_tokens,
121            last_bias_time: Duration::from_secs(0),
122            is_fresh: true,
123            had_backtrack: false,
124            had_rollback: false,
125        })
126    }
127
128    pub fn grammar_warnings(&mut self) -> Vec<String> {
129        self.parser.grammar_warnings()
130    }
131
132    pub fn get_capture(&self, name: &str) -> Option<&[u8]> {
133        self.parser.get_capture(name)
134    }
135
136    pub fn captures(&self) -> &[(String, Vec<u8>)] {
137        self.parser.captures()
138    }
139
140    // regular .clone() uses a shared lexer state
141    pub fn deep_clone(&self) -> Self {
142        let mut copy = self.clone();
143        copy.parser = self.parser.deep_clone();
144        copy
145    }
146
147    pub fn stop_reason(&self) -> StopReason {
148        self.stop_reason
149    }
150
151    pub fn stopped(&self) -> bool {
152        self.stop_reason != StopReason::NotStopped
153    }
154
155    pub fn is_fresh(&self) -> bool {
156        self.is_fresh
157    }
158
159    pub fn parser_stats(&self) -> &ParserStats {
160        self.parser.stats()
161    }
162
163    pub fn last_step_stats(&self) -> &ParserStats {
164        &self.last_step_stats
165    }
166
167    pub fn max_step_stats(&self) -> &ParserStats {
168        &self.max_step_stats
169    }
170
171    pub fn num_tokens(&self) -> usize {
172        self.llm_tokens.len()
173    }
174
175    pub fn final_bytes(&self) -> &[u8] {
176        &self.llm_bytes[self.grm_prefix.len()..]
177    }
178
179    pub fn is_accepting(&mut self) -> bool {
180        if let Some(acc) = self.is_accepting_cache {
181            acc
182        } else {
183            let r = !self.has_ff_bytes() && self.parser.is_accepting();
184            self.is_accepting_cache = Some(r);
185            r
186        }
187    }
188
189    pub fn bytes_since(&self, mut idx: usize) -> &[u8] {
190        idx += self.grm_prefix.len();
191        let endp = std::cmp::min(self.llm_bytes.len(), self.parser.hidden_start());
192        if idx >= self.llm_bytes.len() || idx >= endp {
193            return &[];
194        }
195        &self.llm_bytes[idx..endp]
196    }
197
198    pub fn start_without_prompt(&mut self) {
199        infoln!(
200            self,
201            "initial lexer cost: {} (no prompt)",
202            self.parser.lexer_stats()
203        );
204
205        assert!(self.is_fresh);
206        self.is_fresh = false;
207    }
208
209    fn tokenize_and_chop(
210        &mut self,
211        mut tokens: Vec<TokenId>,
212        num_fixed: usize,
213    ) -> (Vec<TokenId>, usize) {
214        let trie = self.token_env.tok_trie();
215        let (chop_tokens, chop_bytes) = self
216            .parser
217            .with_recognizer(|r| trie.chop_tokens(r, &tokens[num_fixed..]));
218        infoln!(
219            self,
220            "tokenize -> {}; chop: {} tokens, {} bytes",
221            trie.tokens_dbg(&tokens),
222            chop_tokens,
223            chop_bytes
224        );
225        // here we remove a suffix from tokens that could be possibly tokenized differently
226        tokens.truncate(tokens.len() - chop_tokens);
227        (tokens, chop_bytes)
228    }
229
230    pub fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
231        infoln!(self, "initial lexer cost: {}", self.parser.lexer_stats());
232
233        assert!(self.token_env.tokenize_is_canonical());
234        assert!(self.is_fresh);
235        self.is_fresh = false;
236
237        assert!(self.llm_tokens.is_empty());
238
239        let trie = self.token_env.tok_trie();
240        infoln!(self, "prompt: {}", trie.tokens_dbg(&prompt));
241        let mut prompt_bytes = trie.decode_raw(&prompt);
242        if self.can_force_bytes() {
243            self.parser.force_bytes();
244        }
245        let grm_bytes = self.parser.get_bytes().to_vec();
246        prompt_bytes.extend_from_slice(&grm_bytes);
247
248        let (tokens, num_fixed) = self.token_env.tokenize_bytes_marker(&prompt_bytes);
249        let (res_prompt, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
250
251        let trie = self.token_env.tok_trie();
252        infoln!(
253            self,
254            "prompt+grm: {} {}",
255            trie.tokens_dbg(&res_prompt),
256            self.parser.grammar().lexer_spec().no_forcing
257        );
258
259        // if we moved a bunch of grammar to the prompt, update llm_tokens to reflect that
260        if chop_bytes <= grm_bytes.len() {
261            self.llm_bytes = grm_bytes[0..grm_bytes.len() - chop_bytes].to_vec();
262            self.llm_tokens = self.token_env.tokenize_bytes_marker(&self.llm_bytes).0;
263            self.parser.apply_forced(self.llm_bytes.len());
264            let decoded = self.tok_trie().decode_raw(&self.llm_tokens);
265            if !self.llm_bytes.is_empty()
266                && !decoded.is_empty()
267                && decoded[1..] == self.llm_bytes
268                && decoded[0] == b' '
269            {
270                infoln!(self, "applying <s>space hack");
271                self.grm_prefix = decoded[0..1].to_vec();
272                self.llm_bytes = decoded;
273            }
274            infoln!(self, "ini_tokens: {}", trie.tokens_dbg(&self.llm_tokens));
275        } else {
276            // pretend the final bit of prompt was the prefix of the grammar
277            self.grm_prefix = prompt_bytes
278                [prompt_bytes.len() - chop_bytes..prompt_bytes.len() - grm_bytes.len()]
279                .to_vec();
280            infoln!(
281                self,
282                "force_prefix: {:?}",
283                String::from_utf8_lossy(&self.grm_prefix)
284            );
285        }
286
287        infoln!(self, "res_prompt: {}", trie.tokens_dbg(&res_prompt));
288        res_prompt
289    }
290
291    pub fn augment_err(&self, e: impl Display) -> String {
292        if self.limits.verbose_errors {
293            format!(
294                "{e}\n<state>\n{}\n</state><grammar>\n{}\n</grammar>",
295                self.dump_state(),
296                self.dbg_grammar
297            )
298        } else {
299            format!("{e}\n<non-verbose/>")
300        }
301    }
302
303    pub fn dump_state(&self) -> String {
304        // make sure not take self.parser.shared lock
305        // for example, self.parser.lexer_stats() takes it
306        // if we take it after panic, it will be poisoned
307        format!(
308            "Tokens: {}\n{} tokens, {} bytes; grm_prefix: {:?}\nFlags:{}{}\nParser: {}\nStop: {}\nError: {}",
309            self.tok_trie().tokens_dbg(&self.llm_tokens),
310            self.llm_tokens.len(),
311            self.llm_bytes.len(),
312            String::from_utf8_lossy(&self.grm_prefix),
313            if self.had_backtrack {
314                " had_backtrack"
315            } else {
316                ""
317            },
318            if self.had_rollback {
319                " had_rollback"
320            } else {
321                ""
322            },
323            self.parser.stats(),
324            self.stop_reason,
325            self.error_message.as_deref().unwrap_or("None"),
326        )
327    }
328
329    fn clear_caches(&mut self) {
330        self.is_accepting_cache = None;
331        self.ff_tokens_cache = None;
332    }
333
334    fn stop(&mut self, warn: &str, reason: StopReason) -> anyhow::Error {
335        if !warn.is_empty() {
336            self.error_message = Some(warn.to_string());
337            warn!(self, "{}; stopping", warn);
338        }
339        self.stop_reason = reason;
340        self.anyhow_error()
341    }
342
343    fn tok_trie(&self) -> &toktrie::TokTrie {
344        self.token_env.tok_trie()
345    }
346
347    pub fn error_message(&self) -> Option<String> {
348        self.error_message.clone()
349    }
350
351    fn check_initialized(&self, lbl: &str) -> Result<()> {
352        ensure!(!self.is_fresh, "process_prompt() not called in {}", lbl);
353        ensure!(
354            !self.stopped(),
355            "parser stopped in {}; {}",
356            lbl,
357            self.error_message()
358                .unwrap_or("no error message".to_string())
359        );
360        Ok(())
361    }
362
363    pub fn validate_token(&mut self, token: TokenId) -> Result<bool> {
364        if self.stopped() {
365            return Ok(false);
366        }
367        self.check_initialized("validate_token")?;
368        self.validate_tokens_raw(&[token]).map(|n| n > 0)
369    }
370
371    pub fn reset(&mut self) -> Result<()> {
372        self.rollback(self.llm_tokens.len())
373    }
374
375    pub fn rollback(&mut self, n_tokens: usize) -> Result<()> {
376        if n_tokens == 0 {
377            return Ok(());
378        }
379
380        ensure!(
381            n_tokens <= self.llm_tokens.len(),
382            "rollback: {} > {}",
383            n_tokens,
384            self.llm_tokens.len()
385        );
386
387        if self.stop_reason.is_ok() {
388            // if we're stopped in "normal" way (e.g. end of grammar reached),
389            // pretend we're not stopped
390            self.stop_reason = StopReason::NotStopped;
391        }
392
393        // this will fail in case we're in error state or not initialized
394        self.check_initialized("rollback")?;
395
396        self.had_rollback = true;
397
398        let new_len = self.llm_tokens.len() - n_tokens;
399        let mut bytes_to_drop = 0;
400        for tok in &self.llm_tokens[new_len..] {
401            if self.eos_tokens.contains(tok) {
402                // doesn't count; we hope it's last though...
403                bytes_to_drop += 0;
404            } else {
405                bytes_to_drop += self.tok_trie().token_len(*tok);
406            }
407        }
408        ensure!(
409            bytes_to_drop <= self.llm_bytes.len(),
410            "rollback bytes: {} > {}",
411            bytes_to_drop,
412            self.llm_bytes.len()
413        );
414
415        self.parser.rollback(bytes_to_drop)?;
416
417        self.max_tokens_total = self.max_tokens_total.saturating_add(n_tokens);
418        self.llm_tokens.truncate(new_len);
419        self.llm_bytes
420            .truncate(self.llm_bytes.len() - bytes_to_drop);
421        self.clear_caches();
422
423        Ok(())
424    }
425
426    /// Returns how many of the passed tokens can be accepted by the parser.
427    /// It does not tokenize forced bytes, so will accept non-canonical tokenizations.
428    /// If called with more than one token, it may ignore max_tokens constraints.
429    pub fn validate_tokens_raw(&mut self, tokens: &[TokenId]) -> Result<usize> {
430        if self.stopped() {
431            return Ok(0);
432        }
433        self.check_initialized("validate_tokens_raw")?;
434
435        if tokens.is_empty() {
436            return Ok(0);
437        }
438
439        let n_vocab = self.tok_trie().vocab_size();
440        for &t in tokens {
441            if t as usize >= n_vocab {
442                return Err(self.stop(
443                    &format!("token id {t} out of range"),
444                    StopReason::InternalError,
445                ));
446            }
447        }
448
449        let n_valid = self.parser.validate_tokens(tokens);
450        Ok(n_valid)
451    }
452
453    fn anyhow_error(&self) -> anyhow::Error {
454        anyhow::anyhow!(self
455            .error_message
456            .clone()
457            .unwrap_or(self.stop_reason.to_string()))
458    }
459
460    // compute_mask() is a top-level method in this file.
461    // compute_mask() is called by Constraint::compute_mask().
462    pub fn compute_mask(&mut self) -> Result<SimpleVob> {
463        self.compute_mask_start_time = Instant::now();
464        let r = self.compute_mask_inner();
465        self.parser
466            .perf_counters()
467            .compute_mask
468            .record(self.compute_mask_start_time.elapsed());
469        r
470    }
471
472    fn compute_mask_inner(&mut self) -> Result<SimpleVob> {
473        self.check_initialized("compute_mask")?;
474
475        infoln!(self, "compute_mask");
476
477        let prefix = if self.can_force_bytes() {
478            let (ff_tokens, token_prefix) = self
479                .ff_tokens_cache
480                .take()
481                .unwrap_or_else(|| self.ff_tokens());
482            if !ff_tokens.is_empty() {
483                let t = ff_tokens[0];
484                infoln!(self, "forcing ff_token by mask: {}", t);
485                let mask = self.tok_trie().singleton_token_set(t);
486                self.last_step_stats = ParserStats::default();
487                return Ok(mask);
488            } else {
489                // no tokens, so we got all our bytes back
490                token_prefix
491            }
492        } else {
493            let mut trg = Vec::new();
494            self.compute_ff_bytes_to(&mut trg);
495            trg
496        };
497
498        let mut allowed_tokens = self.compute_bias(&prefix);
499
500        if let Some(s) = self.parser.get_error() {
501            return Err(self.stop_for_parser_error("", s));
502        }
503
504        if self.is_accepting() {
505            for &eos in &self.eos_tokens {
506                if eos != INVALID_TOKEN {
507                    allowed_tokens.allow_token(eos);
508                }
509            }
510        }
511
512        self.log_final(&prefix, &allowed_tokens);
513
514        if allowed_tokens.is_zero() {
515            infoln!(self, "no tokens allowed, stopping");
516            return Err(self.stop("", StopReason::NoExtensionBias));
517        }
518
519        Ok(allowed_tokens)
520    }
521
522    fn stop_for_parser_error(&mut self, pref: &str, err: ParserError) -> anyhow::Error {
523        self.stop(&format!("{}{}", pref, err.message()), err.stop_reason())
524    }
525
526    fn apply_token(&mut self, tok_id: TokenId) -> Result<usize> {
527        self.clear_caches();
528
529        let trie = self.token_env.tok_trie();
530
531        if (tok_id as usize) >= trie.vocab_size() {
532            return Err(self.stop(
533                &format!("token id {tok_id} out of range"),
534                StopReason::InternalError,
535            ));
536        }
537
538        self.llm_tokens.push(tok_id);
539
540        let tok_bytes = trie.decode_raw(&[tok_id]);
541
542        // first, check we're still in grm_prefix
543        let prefix_len = self.grm_prefix.len().saturating_sub(self.llm_bytes.len());
544
545        infoln!(
546            self,
547            "consume_token: {} {} prefix={}",
548            tok_id,
549            trie.token_dbg(tok_id),
550            prefix_len
551        );
552
553        let tok_bytes = if prefix_len > 0 {
554            let to_apply = &tok_bytes[0..std::cmp::min(tok_bytes.len(), prefix_len)];
555            self.llm_bytes.extend_from_slice(to_apply);
556
557            if self.grm_prefix[0..self.llm_bytes.len()] != self.llm_bytes {
558                return Err(self.stop(
559                    &format!(
560                        "prefix mismatch: applying {:?}; {:?} vs {:?}",
561                        String::from_utf8_lossy(to_apply),
562                        String::from_utf8_lossy(&self.grm_prefix),
563                        String::from_utf8_lossy(&self.llm_bytes)
564                    ),
565                    StopReason::InternalError,
566                ));
567            }
568
569            if prefix_len < tok_bytes.len() {
570                &tok_bytes[prefix_len..]
571            } else {
572                // still completely in prefix, nothing more to apply
573                return Ok(0);
574            }
575        } else {
576            &tok_bytes
577        };
578
579        if let Some(err) = self.parser.get_error() {
580            return Err(self.stop_for_parser_error("", err));
581        }
582
583        // now apply normally
584        match self.parser.apply_token(tok_bytes, tok_id) {
585            Err(e) => {
586                return Err(self.stop(
587                    &format!("Parser Error: {e}"),
588                    StopReason::ParserTooComplex, // TODO - there are other reasons
589                ));
590            }
591            Ok(backtrack_bytes0) => {
592                self.llm_bytes.extend_from_slice(tok_bytes);
593
594                if backtrack_bytes0 != 0 {
595                    self.had_backtrack = true;
596                    let mut backtrack_bytes: isize = backtrack_bytes0.try_into().unwrap();
597                    let mut backtrack_tokens = 0;
598                    while backtrack_bytes > 0 {
599                        let tok_off = self.llm_tokens.len() - backtrack_tokens;
600                        if tok_off == 0 {
601                            break; // we can't backtrack any further
602                        }
603                        let tok = self.llm_tokens[tok_off - 1];
604                        backtrack_bytes -= trie.token_len(tok) as isize;
605                        backtrack_tokens += 1;
606                    }
607                    assert!(backtrack_tokens > 0);
608                    let additional_backtrack_bytes: usize = (-backtrack_bytes).try_into().unwrap();
609                    let full_backtrack_bytes = backtrack_bytes0 + additional_backtrack_bytes;
610
611                    let byte_ptr = self.llm_bytes.len() - full_backtrack_bytes;
612                    infoln!(
613                        self,
614                        "backtrack: {} tokens / {}+{} bytes (deletes: {:?})",
615                        backtrack_tokens,
616                        backtrack_bytes0,
617                        additional_backtrack_bytes,
618                        String::from_utf8_lossy(&self.llm_bytes[byte_ptr..])
619                    );
620                    self.llm_bytes.truncate(byte_ptr);
621
622                    let token_ptr = self.llm_tokens.len() - backtrack_tokens;
623                    if !self.inference_caps.backtrack {
624                        warn!(
625                            self,
626                            "can't backtrack over {}; this may confuse the model",
627                            trie.tokens_dbg(&self.llm_tokens[token_ptr..])
628                        );
629                        // pretend there's no backtrack
630                        backtrack_tokens = 0;
631                    } else {
632                        // make sure the parser know we actually don't have
633                        // the non-backtracked bytes of backtracked token
634                        self.parser.additional_backtrack(additional_backtrack_bytes);
635                    }
636                    self.llm_tokens.truncate(token_ptr);
637                    return Ok(backtrack_tokens);
638                }
639            }
640        }
641
642        Ok(0)
643    }
644
645    fn pending_grm_prefix(&self) -> &[u8] {
646        &self.grm_prefix[std::cmp::min(self.grm_prefix.len(), self.llm_bytes.len())..]
647    }
648
649    fn has_ff_bytes(&self) -> bool {
650        !self.pending_grm_prefix().is_empty() || !self.parser.currently_forced_bytes().is_empty()
651    }
652
653    fn can_force_bytes(&self) -> bool {
654        !self.parser.grammar().lexer_spec().no_forcing && self.token_env.tokenize_is_canonical()
655    }
656
657    pub fn force_bytes(&mut self) -> Vec<u8> {
658        self.parser.force_bytes();
659        let mut trg = Vec::new();
660        self.compute_ff_bytes_inner(&mut trg);
661        trg
662    }
663
664    fn compute_ff_bytes_to(&mut self, trg: &mut Vec<u8>) {
665        // PERF: in some cases, this may be long
666        if self.can_force_bytes() {
667            self.parser.force_bytes();
668        }
669        self.compute_ff_bytes_inner(trg);
670    }
671
672    fn compute_ff_bytes_inner(&mut self, trg: &mut Vec<u8>) {
673        // handle grm_prefix we might have injected
674        if self.llm_bytes.len() < self.grm_prefix.len() {
675            let inject = &self.grm_prefix[self.llm_bytes.len()..];
676            trg.extend_from_slice(inject);
677            infoln!(
678                self,
679                "injecting prefix: {:?}",
680                String::from_utf8_lossy(inject)
681            );
682        }
683
684        trg.extend_from_slice(self.parser.currently_forced_bytes());
685    }
686
687    /// Converts forced bytes into tokens.
688    /// Also returns any bytes that need to be prefix of the
689    /// next sampled token (token healing).
690    fn ff_tokens(&mut self) -> (Vec<TokenId>, Vec<u8>) {
691        let mut forced_bytes = Vec::new();
692        let mut existing_tokens = if self.llm_tokens.is_empty() {
693            Vec::new()
694        } else {
695            let r = self.llm_tokens[self.llm_tokens.len() - 1..].to_vec();
696            let trie = self.token_env.tok_trie();
697            forced_bytes = trie.decode_raw(&r);
698            r
699        };
700        let num_existing_bytes = forced_bytes.len();
701
702        self.compute_ff_bytes_to(&mut forced_bytes);
703
704        let mut token_prefix = Vec::new();
705
706        let do_force =
707            forced_bytes.len() > num_existing_bytes && self.token_env.tokenize_is_canonical();
708        if do_force {
709            let t0 = Instant::now();
710            let (mut tokens, mut num_fixed) = self.token_env.tokenize_bytes_marker(&forced_bytes);
711            if !tokens.starts_with(&existing_tokens) {
712                // whoops, re-tokenize without the prefix
713                let trie = self.token_env.tok_trie();
714                infoln!(
715                    self,
716                    "re-tokenizing without prefix: {}; because we got {}",
717                    trie.tokens_dbg(&existing_tokens),
718                    trie.tokens_dbg(&tokens),
719                );
720                (tokens, num_fixed) = self
721                    .token_env
722                    .tokenize_bytes_marker(&forced_bytes[num_existing_bytes..]);
723                infoln!(
724                    self,
725                    "re-tokenized: {} from: {:?}",
726                    trie.tokens_dbg(&tokens),
727                    &forced_bytes[num_existing_bytes..]
728                );
729                existing_tokens.clear();
730            } else {
731                num_fixed = std::cmp::max(existing_tokens.len(), num_fixed);
732            }
733
734            let (mut grm_tokens, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
735            assert!(grm_tokens.starts_with(&existing_tokens));
736            grm_tokens.drain(..existing_tokens.len());
737
738            let trie = self.token_env.tok_trie();
739            infoln!(
740                self,
741                "forced: {} bytes:{:?} tokens:{:?}",
742                trie.tokens_dbg(&grm_tokens),
743                &forced_bytes[num_existing_bytes..],
744                grm_tokens
745            );
746            token_prefix = forced_bytes[forced_bytes.len() - chop_bytes..].to_vec();
747
748            self.parser.perf_counters().tokenize_ff.record(t0.elapsed());
749
750            if !grm_tokens.is_empty() {
751                infoln!(
752                    self,
753                    "fixed_tokens: {}; prefix len {}",
754                    trie.tokens_dbg(&grm_tokens),
755                    token_prefix.len()
756                );
757                return (grm_tokens, token_prefix);
758            } else {
759                infoln!(self, "no fixed tokens; prefix len {}", token_prefix.len());
760            }
761        } else if forced_bytes.len() > num_existing_bytes {
762            infoln!(self, "not-forcing {} bytes", forced_bytes.len());
763            token_prefix = forced_bytes[num_existing_bytes..].to_vec();
764        }
765
766        (Vec::new(), token_prefix)
767    }
768
769    fn compute_bias(&mut self, token_prefix: &[u8]) -> SimpleVob {
770        let pre_stats = self.parser.stats().clone();
771        let set = self.parser.compute_bias(&*self.bias_computer, token_prefix);
772        let p_stats = self.parser.stats().delta(&pre_stats);
773        self.last_bias_time = Duration::from_micros(p_stats.compute_time_us);
774        self.last_step_stats = p_stats.clone();
775        self.max_step_stats = self.max_step_stats.max(&p_stats);
776        set
777    }
778
779    fn log_final(&mut self, token_prefix: &[u8], allowed_tokens: &SimpleVob) {
780        infoln!(
781            self,
782            "step-stats: {}us; {} lex fuel; {} items; {}",
783            self.compute_mask_start_time.elapsed().as_micros(),
784            self.last_step_stats.lexer_cost,
785            self.last_step_stats.all_items,
786            self.parser.lexer_stats(),
787        );
788
789        infoln!(
790            self,
791            "bias: (pref: {:?}; accpt: {}; temp: {:.3}) {}",
792            String::from_utf8_lossy(token_prefix),
793            self.is_accepting_cache.unwrap(),
794            self.parser.temperature().unwrap_or(0.0),
795            self.token_env.tok_trie().token_set_dbg(allowed_tokens)
796        );
797    }
798
799    pub fn temperature(&self) -> Option<f32> {
800        self.parser.temperature()
801    }
802
803    /// Extend the current state of the parser with given token.
804    /// Returns number of tokens to backtrack if any.
805    pub fn consume_token(&mut self, token: TokenId) -> Result<usize> {
806        self.check_initialized("consume_token")?;
807
808        if self.max_tokens_total == 0 {
809            return Err(self.stop("max_tokens_total reached", StopReason::MaxTokensTotal));
810        }
811        self.max_tokens_total -= 1;
812
813        if self.eos_tokens.contains(&token) {
814            if self.parser.scan_eos() {
815                // it got scanned correctly, so we remove it
816                // this only happens for gen() terminated by EOS
817                infoln!(self, "consume_token: scanned eos_token");
818                // if self.inference_caps.backtrack {
819                //     return Ok(1);
820                // } else {
821                //     warn!(self, "can't backtrack over eos_token");
822                //     return Ok(0);
823                // }
824                // don't backtrack it for now, fails tests
825                return Ok(0);
826            } else {
827                let accepting = self.is_accepting();
828                infoln!(
829                    self,
830                    "consume_token: eos_token not eaten by parser; accept={}",
831                    accepting
832                );
833                if accepting {
834                    self.llm_tokens.push(token);
835                    return Ok(0);
836                }
837            }
838        }
839
840        let apply_res = self.apply_token(token);
841        self.parser.log_row_infos("post-apply");
842        match apply_res {
843            Err(_) => Err(self.anyhow_error()),
844            Ok(n) => Ok(n),
845        }
846    }
847
848    /// Check whether the current parser state forces the sequence to stop.
849    /// If so, puts the parser in stop state and returns true.
850    /// Otherwise, returns false.
851    /// This generally should be called after consume_token().
852    pub fn check_stop(&mut self) -> Result<bool> {
853        let empty_token_prefix = !self.has_ff_bytes();
854        let pending_eos = self
855            .llm_tokens
856            .last()
857            .is_some_and(|t| self.eos_tokens.contains(t));
858        let lexer_bytes = self.parser.has_pending_lexeme_bytes();
859        let is_accepting = self.is_accepting();
860        let can_advance = self.parser.can_advance();
861        let parser_done = is_accepting && (!can_advance || pending_eos);
862        infoln!(
863            self,
864            "parser_done: {parser_done}; lexer_bytes: {lexer_bytes}; \
865                can_advance: {can_advance} (eos:{pending_eos}); \
866                accept: {is_accepting}; \
867                empty_token_prefix: {empty_token_prefix}"
868        );
869        assert!(!is_accepting || empty_token_prefix);
870
871        if parser_done {
872            infoln!(
873                self,
874                "only eos token allowed, stopping; accepting: {}",
875                is_accepting
876            );
877            let reason = if pending_eos {
878                StopReason::EndOfSentence
879            } else {
880                StopReason::NoExtension
881            };
882            self.stop("", reason);
883            Ok(true)
884        } else {
885            Ok(false)
886        }
887    }
888
889    /// Check if there are any tokens to fast-forward, forced by the current
890    /// parser state.
891    pub fn compute_ff_tokens(&mut self) -> Vec<TokenId> {
892        let r = self.ff_tokens();
893        if self.can_force_bytes() {
894            self.ff_tokens_cache = Some(r.clone());
895        }
896        r.0
897    }
898
899    /// Compute and then consume fast-forward tokens.
900    pub fn consume_ff_tokens(&mut self) -> Result<Vec<TokenId>> {
901        let ff_tokens = self.compute_ff_tokens();
902        for &t in &ff_tokens {
903            let num_backtrack = self.consume_token(t)?;
904            if num_backtrack > 0 {
905                return Err(self.stop(
906                    &format!("backtrack required after ff_token: {t}"),
907                    StopReason::InternalError,
908                ));
909            }
910        }
911        Ok(ff_tokens)
912    }
913
914    /// This function documents typical use of this interface.
915    /// The `tokens` array simulates tokens being sampled.
916    #[allow(dead_code)]
917    fn typical_use(&mut self, prompt: Vec<TokenId>) -> Result<()> {
918        // First, check if we need to token-heal the prompt,
919        // and if there are some tokens forced by the beginning
920        // of the grammar.
921        let new_prompt = self.process_prompt(prompt);
922
923        // pass new prompt to inference engine
924        black_box(new_prompt);
925
926        let mut tokens = vec![];
927
928        loop {
929            let temp = self.temperature();
930            let mask = self.compute_mask()?;
931
932            // model forward pass in parallel with compute_mask() goes here
933
934            // simulate sampling a token with given mask
935            black_box((temp, mask));
936            let sampled_token = 42;
937
938            let num_backtrack = self.consume_token(sampled_token)?;
939
940            if num_backtrack == 0 {
941                // normal situation - the token was accepted
942                tokens.push(sampled_token);
943            } else {
944                // this will only happen if you enable backtrack
945                assert!(self.inference_caps.backtrack);
946                if num_backtrack == 1 {
947                    // don't add the token to the list
948                } else if num_backtrack > 1 {
949                    // backtrack
950                    tokens.truncate(tokens.len() - num_backtrack - 1);
951                }
952            }
953
954            // This is optional; if you don't check, compute_mask() will
955            // return an error when it cannot continue anymore.
956            // If you check here, you can distinguish between normal stop
957            // and an error.
958            if self.check_stop()? {
959                break;
960            }
961
962            // This is optional - call if you have the ability to append
963            // several tokens at once.
964            let forced = self.consume_ff_tokens()?;
965            tokens.extend_from_slice(&forced);
966        }
967
968        Ok(())
969    }
970
971    pub fn invalidate_bias_cache(&mut self) {
972        self.parser.invalidate_bias_cache();
973    }
974}