llguidance/
tokenparser.rs

1use std::{hint::black_box, panic::AssertUnwindSafe, sync::Arc, time::Duration};
2
3use crate::{
4    api::{GrammarInit, ParserLimits, StopReason, TopLevelGrammar},
5    earley::{BiasComputer, DefaultBiasComputer, Parser, ParserError, ParserStats},
6    infoln, panic_utils, warn, Instant, Logger,
7};
8use anyhow::{ensure, Result};
9use toktrie::{InferenceCapabilities, SimpleVob, TokEnv, TokenId, INVALID_TOKEN};
10
11#[derive(Clone)]
12pub struct TokenParser {
13    pub token_env: TokEnv,
14    pub parser: Parser,
15    pub compute_mask_start_time: Instant,
16    pub last_bias_time: Duration,
17    pub inference_caps: InferenceCapabilities,
18    pub logger: Logger,
19    pub limits: ParserLimits,
20    pub bias_computer: Arc<dyn BiasComputer>,
21    last_step_stats: ParserStats,
22    max_step_stats: ParserStats,
23    eos_token: TokenId,
24
25    is_accepting_cache: Option<bool>,
26    ff_tokens_cache: Option<(Vec<TokenId>, Vec<u8>)>,
27    stop_reason: StopReason,
28    error_message: Option<String>,
29    max_tokens_total: usize,
30
31    // tokens currently in KV cache
32    llm_tokens: Vec<TokenId>,
33    llm_bytes: Vec<u8>,
34
35    grm_prefix: Vec<u8>,
36    is_fresh: bool,
37}
38
39impl TokenParser {
40    pub fn from_init(
41        token_env: TokEnv,
42        grammar_init: GrammarInit,
43        logger: Logger,
44        inference_caps: InferenceCapabilities,
45        limits: ParserLimits,
46        extra_lexemes: Vec<String>,
47    ) -> Result<Self> {
48        panic_utils::catch_unwind(AssertUnwindSafe(|| {
49            Self::init_inner(
50                grammar_init,
51                token_env,
52                logger,
53                inference_caps,
54                limits,
55                extra_lexemes,
56            )
57        }))
58    }
59
60    pub fn from_grammar(
61        token_env: TokEnv,
62        top_grammar: TopLevelGrammar,
63        logger: Logger,
64        inference_caps: InferenceCapabilities,
65        limits: ParserLimits,
66        extra_lexemes: Vec<String>,
67    ) -> Result<Self> {
68        Self::from_init(
69            token_env,
70            GrammarInit::Serialized(top_grammar),
71            logger,
72            inference_caps,
73            limits,
74            extra_lexemes,
75        )
76    }
77
78    fn init_inner(
79        grammar_init: GrammarInit,
80        token_env: TokEnv,
81        mut logger: Logger,
82        inference_caps: InferenceCapabilities,
83        limits: ParserLimits,
84        extra_lexemes: Vec<String>,
85    ) -> Result<Self> {
86        ensure!(
87            token_env.tokenize_is_canonical() || !inference_caps.ff_tokens,
88            "ff_tokens requires canonical tokenization"
89        );
90        ensure!(
91            !inference_caps.backtrack || inference_caps.ff_tokens,
92            "backtrack requires ff_tokens"
93        );
94
95        let compute_mask_start_time = Instant::now();
96        let mut max_tokens = usize::MAX;
97        if let GrammarInit::Serialized(input) = &grammar_init {
98            if let Some(m) = input.max_tokens {
99                max_tokens = m;
100            }
101        }
102        let compiled_grammar = grammar_init.to_cgrammar(
103            Some(token_env.clone()),
104            &mut logger,
105            limits.clone(),
106            extra_lexemes,
107        )?;
108        let parser = Parser::new(token_env.clone(), compiled_grammar, limits.clone())?;
109        let eos_token = token_env.tok_trie().eos_token();
110
111        Ok(TokenParser {
112            bias_computer: Arc::new(DefaultBiasComputer::new(token_env.clone())),
113            logger,
114            token_env,
115            inference_caps,
116            limits,
117            max_step_stats: ParserStats::default(),
118            last_step_stats: ParserStats::default(),
119            compute_mask_start_time,
120            is_accepting_cache: None,
121            ff_tokens_cache: None,
122            stop_reason: StopReason::NotStopped,
123            error_message: None,
124            parser,
125            eos_token,
126            llm_tokens: Vec::new(),
127            llm_bytes: Vec::new(),
128            grm_prefix: Vec::new(),
129            max_tokens_total: max_tokens,
130            last_bias_time: Duration::from_secs(0),
131            is_fresh: true,
132        })
133    }
134
135    pub fn get_capture(&self, name: &str) -> Option<&[u8]> {
136        self.parser.get_capture(name)
137    }
138
139    // regular .clone() uses a shared lexer state
140    pub fn deep_clone(&self) -> Self {
141        let mut copy = self.clone();
142        copy.parser = self.parser.deep_clone();
143        copy
144    }
145
146    pub fn stop_reason(&self) -> StopReason {
147        self.stop_reason
148    }
149
150    pub fn is_fresh(&self) -> bool {
151        self.is_fresh
152    }
153
154    pub fn parser_stats(&self) -> &ParserStats {
155        self.parser.stats()
156    }
157
158    pub fn last_step_stats(&self) -> &ParserStats {
159        &self.last_step_stats
160    }
161
162    pub fn max_step_stats(&self) -> &ParserStats {
163        &self.max_step_stats
164    }
165
166    pub fn num_tokens(&self) -> usize {
167        self.llm_tokens.len()
168    }
169
170    pub fn final_bytes(&self) -> &[u8] {
171        &self.llm_bytes[self.grm_prefix.len()..]
172    }
173
174    pub fn is_accepting(&mut self) -> bool {
175        if let Some(acc) = self.is_accepting_cache {
176            acc
177        } else {
178            let r = !self.has_ff_bytes() && self.parser.is_accepting();
179            self.is_accepting_cache = Some(r);
180            r
181        }
182    }
183
184    pub fn bytes_since(&self, mut idx: usize) -> &[u8] {
185        idx += self.grm_prefix.len();
186        let endp = std::cmp::min(self.llm_bytes.len(), self.parser.hidden_start());
187        if idx >= self.llm_bytes.len() || idx >= endp {
188            return &[];
189        }
190        &self.llm_bytes[idx..endp]
191    }
192
193    pub fn start_without_prompt(&mut self) {
194        infoln!(
195            self,
196            "initial lexer cost: {} (no prompt)",
197            self.parser.lexer_stats()
198        );
199
200        assert!(self.is_fresh);
201        self.is_fresh = false;
202    }
203
204    fn tokenize_and_chop(
205        &mut self,
206        mut tokens: Vec<TokenId>,
207        num_fixed: usize,
208    ) -> (Vec<TokenId>, usize) {
209        let trie = self.token_env.tok_trie();
210        let (chop_tokens, chop_bytes) = self
211            .parser
212            .with_recognizer(|r| trie.chop_tokens(r, &tokens[num_fixed..]));
213        infoln!(
214            self,
215            "tokenize -> {}; chop: {} tokens, {} bytes",
216            trie.tokens_dbg(&tokens),
217            chop_tokens,
218            chop_bytes
219        );
220        // here we remove a suffix from tokens that could be possibly tokenized differently
221        tokens.truncate(tokens.len() - chop_tokens);
222        (tokens, chop_bytes)
223    }
224
225    pub fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
226        infoln!(self, "initial lexer cost: {}", self.parser.lexer_stats());
227
228        assert!(self.token_env.tokenize_is_canonical());
229        assert!(self.is_fresh);
230        self.is_fresh = false;
231
232        assert!(self.llm_tokens.is_empty());
233
234        let trie = self.token_env.tok_trie();
235        infoln!(self, "prompt: {}", trie.tokens_dbg(&prompt));
236        let mut prompt_bytes = trie.decode_raw(&prompt);
237        if self.can_force_bytes() {
238            self.parser.force_bytes();
239        }
240        let grm_bytes = self.parser.get_bytes().to_vec();
241        prompt_bytes.extend_from_slice(&grm_bytes);
242
243        let (tokens, num_fixed) = self.token_env.tokenize_bytes_marker(&prompt_bytes);
244        let (res_prompt, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
245
246        let trie = self.token_env.tok_trie();
247        infoln!(
248            self,
249            "prompt+grm: {} {}",
250            trie.tokens_dbg(&res_prompt),
251            self.parser.grammar().lexer_spec().no_forcing
252        );
253
254        // if we moved a bunch of grammar to the prompt, update llm_tokens to reflect that
255        if chop_bytes <= grm_bytes.len() {
256            self.llm_bytes = grm_bytes[0..grm_bytes.len() - chop_bytes].to_vec();
257            self.llm_tokens = self.token_env.tokenize_bytes_marker(&self.llm_bytes).0;
258            self.parser.apply_forced(self.llm_bytes.len());
259            let decoded = self.tok_trie().decode_raw(&self.llm_tokens);
260            if !self.llm_bytes.is_empty()
261                && !decoded.is_empty()
262                && decoded[1..] == self.llm_bytes
263                && decoded[0] == b' '
264            {
265                infoln!(self, "applying <s>space hack");
266                self.grm_prefix = decoded[0..1].to_vec();
267                self.llm_bytes = decoded;
268            }
269            infoln!(self, "ini_tokens: {}", trie.tokens_dbg(&self.llm_tokens));
270        } else {
271            // pretend the final bit of prompt was the prefix of the grammar
272            self.grm_prefix = prompt_bytes
273                [prompt_bytes.len() - chop_bytes..prompt_bytes.len() - grm_bytes.len()]
274                .to_vec();
275            infoln!(
276                self,
277                "force_prefix: {:?}",
278                String::from_utf8_lossy(&self.grm_prefix)
279            );
280        }
281
282        infoln!(self, "res_prompt: {}", trie.tokens_dbg(&res_prompt));
283        res_prompt
284    }
285
286    fn clear_caches(&mut self) {
287        self.is_accepting_cache = None;
288        self.ff_tokens_cache = None;
289    }
290
291    fn stop(&mut self, warn: &str, reason: StopReason) -> anyhow::Error {
292        if !warn.is_empty() {
293            self.error_message = Some(warn.to_string());
294            warn!(self, "{}; stopping", warn);
295        }
296        self.stop_reason = reason;
297        self.anyhow_error()
298    }
299
300    fn tok_trie(&self) -> &toktrie::TokTrie {
301        self.token_env.tok_trie()
302    }
303
304    pub fn error_message(&self) -> Option<String> {
305        self.error_message.clone()
306    }
307
308    fn check_initialized(&self, lbl: &str) -> Result<()> {
309        ensure!(!self.is_fresh, "process_prompt() not called in {}", lbl);
310        ensure!(
311            self.stop_reason == StopReason::NotStopped,
312            "parser stopped in {}; {}",
313            lbl,
314            self.error_message()
315                .unwrap_or("no error message".to_string())
316        );
317        Ok(())
318    }
319
320    pub fn validate_token(&mut self, token: TokenId) -> Result<bool> {
321        self.check_initialized("validate_token")?;
322        self.validate_tokens_raw(&[token]).map(|n| n > 0)
323    }
324
325    pub fn reset(&mut self) -> Result<()> {
326        self.rollback(self.llm_tokens.len())
327    }
328
329    pub fn rollback(&mut self, n_tokens: usize) -> Result<()> {
330        if n_tokens == 0 {
331            return Ok(());
332        }
333
334        ensure!(
335            n_tokens <= self.llm_tokens.len(),
336            "rollback: {} > {}",
337            n_tokens,
338            self.llm_tokens.len()
339        );
340
341        if self.stop_reason.is_ok() {
342            // if we're stopped in "normal" way (e.g. end of grammar reached),
343            // pretend we're not stopped
344            self.stop_reason = StopReason::NotStopped;
345        }
346
347        // this will fail in case we're in error state or not initialized
348        self.check_initialized("rollback")?;
349
350        let new_len = self.llm_tokens.len() - n_tokens;
351        let mut bytes_to_drop = 0;
352        for tok in &self.llm_tokens[new_len..] {
353            if *tok == self.eos_token {
354                // doesn't count; we hope it's last though...
355                bytes_to_drop += 0;
356            } else {
357                bytes_to_drop += self.tok_trie().token_len(*tok);
358            }
359        }
360        ensure!(
361            bytes_to_drop <= self.llm_bytes.len(),
362            "rollback bytes: {} > {}",
363            bytes_to_drop,
364            self.llm_bytes.len()
365        );
366
367        self.parser.rollback(bytes_to_drop)?;
368
369        self.max_tokens_total = self.max_tokens_total.saturating_add(n_tokens);
370        self.llm_tokens.truncate(new_len);
371        self.llm_bytes
372            .truncate(self.llm_bytes.len() - bytes_to_drop);
373        self.clear_caches();
374
375        Ok(())
376    }
377
378    /// Returns how many of the passed tokens can be accepted by the parser.
379    /// It does not tokenize forced bytes, so will accept non-canonical tokenizations.
380    /// If called with more than one token, it may ignore max_tokens constraints.
381    pub fn validate_tokens_raw(&mut self, tokens: &[TokenId]) -> Result<usize> {
382        self.check_initialized("validate_tokens_raw")?;
383
384        if tokens.is_empty() {
385            return Ok(0);
386        }
387
388        let n_vocab = self.tok_trie().vocab_size();
389        for &t in tokens {
390            if t as usize >= n_vocab {
391                return Err(self.stop(
392                    &format!("token id {} out of range", t),
393                    StopReason::InternalError,
394                ));
395            }
396        }
397
398        let n_valid = self.parser.validate_tokens(tokens);
399        Ok(n_valid)
400    }
401
402    fn anyhow_error(&self) -> anyhow::Error {
403        anyhow::anyhow!(self
404            .error_message
405            .clone()
406            .unwrap_or(self.stop_reason.to_string()))
407    }
408
409    // compute_mask() is a top-level method in this file.
410    // compute_mask() is called by Constraint::compute_mask().
411    pub fn compute_mask(&mut self) -> Result<SimpleVob> {
412        self.compute_mask_start_time = Instant::now();
413        let r = self.compute_mask_inner();
414        self.parser
415            .perf_counters()
416            .compute_mask
417            .record(self.compute_mask_start_time.elapsed());
418        r
419    }
420
421    fn compute_mask_inner(&mut self) -> Result<SimpleVob> {
422        self.check_initialized("compute_mask")?;
423
424        infoln!(self, "compute_mask");
425
426        let prefix = if self.can_force_bytes() {
427            let (ff_tokens, token_prefix) = self
428                .ff_tokens_cache
429                .take()
430                .unwrap_or_else(|| self.ff_tokens());
431            if !ff_tokens.is_empty() {
432                let t = ff_tokens[0];
433                infoln!(self, "forcing ff_token by mask: {}", t);
434                let mask = self.tok_trie().singleton_token_set(t);
435                self.last_step_stats = ParserStats::default();
436                return Ok(mask);
437            } else {
438                // no tokens, so we got all our bytes back
439                token_prefix
440            }
441        } else {
442            let mut trg = Vec::new();
443            self.compute_ff_bytes_to(&mut trg);
444            trg
445        };
446
447        let mut allowed_tokens = self.compute_bias(&prefix);
448
449        if let Some(s) = self.parser.get_error() {
450            return Err(self.stop_for_parser_error("", s));
451        }
452
453        if self.eos_token != INVALID_TOKEN && self.is_accepting() {
454            allowed_tokens.allow_token(self.eos_token);
455        }
456
457        self.log_final(&prefix, &allowed_tokens);
458
459        if allowed_tokens.is_zero() {
460            infoln!(self, "no tokens allowed, stopping");
461            return Err(self.stop("", StopReason::NoExtensionBias));
462        }
463
464        Ok(allowed_tokens)
465    }
466
467    fn stop_for_parser_error(&mut self, pref: &str, err: ParserError) -> anyhow::Error {
468        self.stop(&format!("{}{}", pref, err.message()), err.stop_reason())
469    }
470
471    fn apply_token(&mut self, tok_id: TokenId) -> Result<usize> {
472        self.clear_caches();
473
474        let trie = self.token_env.tok_trie();
475
476        if (tok_id as usize) >= trie.vocab_size() {
477            return Err(self.stop(
478                &format!("token id {} out of range", tok_id),
479                StopReason::InternalError,
480            ));
481        }
482
483        self.llm_tokens.push(tok_id);
484
485        let tok_bytes = trie.decode_raw(&[tok_id]);
486
487        // first, check we're still in grm_prefix
488        let prefix_len = self.grm_prefix.len().saturating_sub(self.llm_bytes.len());
489
490        infoln!(
491            self,
492            "consume_token: {} {} prefix={}",
493            tok_id,
494            trie.token_dbg(tok_id),
495            prefix_len
496        );
497
498        let tok_bytes = if prefix_len > 0 {
499            let to_apply = &tok_bytes[0..std::cmp::min(tok_bytes.len(), prefix_len)];
500            self.llm_bytes.extend_from_slice(to_apply);
501
502            if self.grm_prefix[0..self.llm_bytes.len()] != self.llm_bytes {
503                return Err(self.stop(
504                    &format!(
505                        "prefix mismatch: applying {:?}; {:?} vs {:?}",
506                        String::from_utf8_lossy(to_apply),
507                        String::from_utf8_lossy(&self.grm_prefix),
508                        String::from_utf8_lossy(&self.llm_bytes)
509                    ),
510                    StopReason::InternalError,
511                ));
512            }
513
514            if prefix_len < tok_bytes.len() {
515                &tok_bytes[prefix_len..]
516            } else {
517                // still completely in prefix, nothing more to apply
518                return Ok(0);
519            }
520        } else {
521            &tok_bytes
522        };
523
524        if let Some(err) = self.parser.get_error() {
525            return Err(self.stop_for_parser_error("", err));
526        }
527
528        // now apply normally
529        match self.parser.apply_token(tok_bytes) {
530            Err(e) => {
531                return Err(self.stop(
532                    &format!("Parser Error: {}", e),
533                    StopReason::ParserTooComplex, // TODO - there are other reasons
534                ));
535            }
536            Ok(backtrack_bytes0) => {
537                self.llm_bytes.extend_from_slice(tok_bytes);
538
539                if backtrack_bytes0 != 0 {
540                    let mut backtrack_bytes: isize = backtrack_bytes0.try_into().unwrap();
541                    let mut backtrack_tokens = 0;
542                    while backtrack_bytes > 0 {
543                        let tok_off = self.llm_tokens.len() - backtrack_tokens;
544                        if tok_off == 0 {
545                            break; // we can't backtrack any further
546                        }
547                        let tok = self.llm_tokens[tok_off - 1];
548                        backtrack_bytes -= trie.token_len(tok) as isize;
549                        backtrack_tokens += 1;
550                    }
551                    assert!(backtrack_tokens > 0);
552                    let additional_backtrack_bytes: usize = (-backtrack_bytes).try_into().unwrap();
553                    let full_backtrack_bytes = backtrack_bytes0 + additional_backtrack_bytes;
554
555                    let byte_ptr = self.llm_bytes.len() - full_backtrack_bytes;
556                    infoln!(
557                        self,
558                        "backtrack: {} tokens / {}+{} bytes (deletes: {:?})",
559                        backtrack_tokens,
560                        backtrack_bytes0,
561                        additional_backtrack_bytes,
562                        String::from_utf8_lossy(&self.llm_bytes[byte_ptr..])
563                    );
564                    self.llm_bytes.truncate(byte_ptr);
565
566                    let token_ptr = self.llm_tokens.len() - backtrack_tokens;
567                    if !self.inference_caps.backtrack {
568                        warn!(
569                            self,
570                            "can't backtrack over {}; this may confuse the model",
571                            trie.tokens_dbg(&self.llm_tokens[token_ptr..])
572                        );
573                        // pretend there's no backtrack
574                        backtrack_tokens = 0;
575                    } else {
576                        // make sure the parser know we actually don't have
577                        // the non-backtracked bytes of backtracked token
578                        self.parser.additional_backtrack(additional_backtrack_bytes);
579                    }
580                    self.llm_tokens.truncate(token_ptr);
581                    return Ok(backtrack_tokens);
582                }
583            }
584        }
585
586        Ok(0)
587    }
588
589    fn pending_grm_prefix(&self) -> &[u8] {
590        &self.grm_prefix[std::cmp::min(self.grm_prefix.len(), self.llm_bytes.len())..]
591    }
592
593    fn has_ff_bytes(&self) -> bool {
594        !self.pending_grm_prefix().is_empty() || !self.parser.currently_forced_bytes().is_empty()
595    }
596
597    fn can_force_bytes(&self) -> bool {
598        !self.parser.grammar().lexer_spec().no_forcing && self.token_env.tokenize_is_canonical()
599    }
600
601    pub fn force_bytes(&mut self) -> Vec<u8> {
602        self.parser.force_bytes();
603        let mut trg = Vec::new();
604        self.compute_ff_bytes_inner(&mut trg);
605        trg
606    }
607
608    fn compute_ff_bytes_to(&mut self, trg: &mut Vec<u8>) {
609        // PERF: in some cases, this may be long
610        if self.can_force_bytes() {
611            self.parser.force_bytes();
612        }
613        self.compute_ff_bytes_inner(trg);
614    }
615
616    fn compute_ff_bytes_inner(&mut self, trg: &mut Vec<u8>) {
617        // handle grm_prefix we might have injected
618        if self.llm_bytes.len() < self.grm_prefix.len() {
619            let inject = &self.grm_prefix[self.llm_bytes.len()..];
620            trg.extend_from_slice(inject);
621            infoln!(
622                self,
623                "injecting prefix: {:?}",
624                String::from_utf8_lossy(inject)
625            );
626        }
627
628        trg.extend_from_slice(self.parser.currently_forced_bytes());
629    }
630
631    /// Converts forced bytes into tokens.
632    /// Also returns any bytes that need to be prefix of the
633    /// next sampled token (token healing).
634    fn ff_tokens(&mut self) -> (Vec<TokenId>, Vec<u8>) {
635        let mut forced_bytes = Vec::new();
636        let mut existing_tokens = if self.llm_tokens.is_empty() {
637            Vec::new()
638        } else {
639            let r = self.llm_tokens[self.llm_tokens.len() - 1..].to_vec();
640            let trie = self.token_env.tok_trie();
641            forced_bytes = trie.decode_raw(&r);
642            r
643        };
644        let num_existing_bytes = forced_bytes.len();
645
646        self.compute_ff_bytes_to(&mut forced_bytes);
647
648        let mut token_prefix = Vec::new();
649
650        let do_force =
651            forced_bytes.len() > num_existing_bytes && self.token_env.tokenize_is_canonical();
652        if do_force {
653            let t0 = Instant::now();
654            let (mut tokens, mut num_fixed) = self.token_env.tokenize_bytes_marker(&forced_bytes);
655            if !tokens.starts_with(&existing_tokens) {
656                // whoops, re-tokenize without the prefix
657                let trie = self.token_env.tok_trie();
658                infoln!(
659                    self,
660                    "re-tokenizing without prefix: {}; because we got {}",
661                    trie.tokens_dbg(&existing_tokens),
662                    trie.tokens_dbg(&tokens),
663                );
664                (tokens, num_fixed) = self
665                    .token_env
666                    .tokenize_bytes_marker(&forced_bytes[num_existing_bytes..]);
667                infoln!(
668                    self,
669                    "re-tokenized: {} from: {:?}",
670                    trie.tokens_dbg(&tokens),
671                    &forced_bytes[num_existing_bytes..]
672                );
673                existing_tokens.clear();
674            } else {
675                num_fixed = std::cmp::max(existing_tokens.len(), num_fixed);
676            }
677
678            let (mut grm_tokens, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
679            assert!(grm_tokens.starts_with(&existing_tokens));
680            grm_tokens.drain(..existing_tokens.len());
681
682            let trie = self.token_env.tok_trie();
683            infoln!(
684                self,
685                "forced: {} bytes:{:?} tokens:{:?}",
686                trie.tokens_dbg(&grm_tokens),
687                &forced_bytes[num_existing_bytes..],
688                grm_tokens
689            );
690            token_prefix = forced_bytes[forced_bytes.len() - chop_bytes..].to_vec();
691
692            self.parser.perf_counters().tokenize_ff.record(t0.elapsed());
693
694            if !grm_tokens.is_empty() {
695                infoln!(
696                    self,
697                    "fixed_tokens: {}; prefix len {}",
698                    trie.tokens_dbg(&grm_tokens),
699                    token_prefix.len()
700                );
701                return (grm_tokens, token_prefix);
702            } else {
703                infoln!(self, "no fixed tokens; prefix len {}", token_prefix.len());
704            }
705        } else if forced_bytes.len() > num_existing_bytes {
706            infoln!(self, "not-forcing {} bytes", forced_bytes.len());
707            token_prefix = forced_bytes[num_existing_bytes..].to_vec();
708        }
709
710        (Vec::new(), token_prefix)
711    }
712
713    fn compute_bias(&mut self, token_prefix: &[u8]) -> SimpleVob {
714        let pre_stats = self.parser.stats().clone();
715        let set = self.parser.compute_bias(&*self.bias_computer, token_prefix);
716        let p_stats = self.parser.stats().delta(&pre_stats);
717        self.last_bias_time = Duration::from_micros(p_stats.compute_time_us);
718        self.last_step_stats = p_stats.clone();
719        self.max_step_stats = self.max_step_stats.max(&p_stats);
720        set
721    }
722
723    fn log_final(&mut self, token_prefix: &[u8], allowed_tokens: &SimpleVob) {
724        infoln!(
725            self,
726            "step-stats: {}us; {} lex fuel; {} items; {}",
727            self.compute_mask_start_time.elapsed().as_micros(),
728            self.last_step_stats.lexer_cost,
729            self.last_step_stats.all_items,
730            self.parser.lexer_stats(),
731        );
732
733        infoln!(
734            self,
735            "bias: (pref: {:?}; accpt: {}; temp: {:.3}) {}",
736            String::from_utf8_lossy(token_prefix),
737            self.is_accepting_cache.unwrap(),
738            self.parser.temperature().unwrap_or(0.0),
739            self.token_env.tok_trie().token_set_dbg(allowed_tokens)
740        );
741    }
742
743    pub fn temperature(&self) -> Option<f32> {
744        self.parser.temperature()
745    }
746
747    /// Extend the current state of the parser with given token.
748    /// Returns number of tokens to backtrack if any.
749    pub fn consume_token(&mut self, token: TokenId) -> Result<usize> {
750        self.check_initialized("consume_token")?;
751
752        if self.max_tokens_total == 0 {
753            return Err(self.stop("max_tokens_total reached", StopReason::MaxTokensTotal));
754        }
755        self.max_tokens_total -= 1;
756
757        if token == self.eos_token {
758            if self.parser.scan_eos() {
759                // it got scanned correctly, so we remove it
760                // this only happens for gen() terminated by EOS
761                infoln!(self, "consume_token: scanned eos_token");
762                // if self.inference_caps.backtrack {
763                //     return Ok(1);
764                // } else {
765                //     warn!(self, "can't backtrack over eos_token");
766                //     return Ok(0);
767                // }
768                // don't backtrack it for now, fails tests
769                return Ok(0);
770            } else {
771                let accepting = self.is_accepting();
772                infoln!(
773                    self,
774                    "consume_token: eos_token not eaten by parser; accept={}",
775                    accepting
776                );
777                if accepting {
778                    self.llm_tokens.push(token);
779                    return Ok(0);
780                }
781            }
782        }
783
784        let apply_res = self.apply_token(token);
785        self.parser.log_row_infos("post-apply");
786        match apply_res {
787            Err(_) => Err(self.anyhow_error()),
788            Ok(n) => Ok(n),
789        }
790    }
791
792    /// Check whether the current parser state forces the sequence to stop.
793    /// If so, puts the parser in stop state and returns true.
794    /// Otherwise, returns false.
795    /// This generally should be called after consume_token().
796    pub fn check_stop(&mut self) -> Result<bool> {
797        let empty_token_prefix = !self.has_ff_bytes();
798        let pending_eos = self.llm_tokens.last() == Some(&self.eos_token);
799        let lexer_bytes = self.parser.has_pending_lexeme_bytes();
800        let is_accepting = self.is_accepting();
801        let can_advance = self.parser.can_advance();
802        let parser_done = is_accepting && (!can_advance || pending_eos);
803        infoln!(
804            self,
805            "parser_done: {parser_done}; lexer_bytes: {lexer_bytes}; \
806                can_advance: {can_advance} (eos:{pending_eos}); \
807                accept: {is_accepting}; \
808                empty_token_prefix: {empty_token_prefix}"
809        );
810        assert!(!is_accepting || empty_token_prefix);
811
812        if parser_done {
813            infoln!(
814                self,
815                "only eos token allowed, stopping; accepting: {}",
816                is_accepting
817            );
818            let reason = if pending_eos {
819                StopReason::EndOfSentence
820            } else {
821                StopReason::NoExtension
822            };
823            self.stop("", reason);
824            Ok(true)
825        } else {
826            Ok(false)
827        }
828    }
829
830    /// Check if there are any tokens to fast-forward, forced by the current
831    /// parser state.
832    pub fn compute_ff_tokens(&mut self) -> Vec<TokenId> {
833        let r = self.ff_tokens();
834        if self.can_force_bytes() {
835            self.ff_tokens_cache = Some(r.clone());
836        }
837        r.0
838    }
839
840    /// Compute and then consume fast-forward tokens.
841    pub fn consume_ff_tokens(&mut self) -> Result<Vec<TokenId>> {
842        let ff_tokens = self.compute_ff_tokens();
843        for &t in &ff_tokens {
844            let num_backtrack = self.consume_token(t)?;
845            if num_backtrack > 0 {
846                return Err(self.stop(
847                    &format!("backtrack required after ff_token: {}", t),
848                    StopReason::InternalError,
849                ));
850            }
851        }
852        Ok(ff_tokens)
853    }
854
855    /// This function documents typical use of this interface.
856    /// The `tokens` array simulates tokens being sampled.
857    #[allow(dead_code)]
858    fn typical_use(&mut self, prompt: Vec<TokenId>) -> Result<()> {
859        // First, check if we need to token-heal the prompt,
860        // and if there are some tokens forced by the beginning
861        // of the grammar.
862        let new_prompt = self.process_prompt(prompt);
863
864        // pass new prompt to inference engine
865        black_box(new_prompt);
866
867        let mut tokens = vec![];
868
869        loop {
870            let temp = self.temperature();
871            let mask = self.compute_mask()?;
872
873            // model forward pass in parallel with compute_mask() goes here
874
875            // simulate sampling a token with given mask
876            black_box((temp, mask));
877            let sampled_token = 42;
878
879            let num_backtrack = self.consume_token(sampled_token)?;
880
881            if num_backtrack == 0 {
882                // normal situation - the token was accepted
883                tokens.push(sampled_token);
884            } else {
885                // this will only happen if you enable backtrack
886                assert!(self.inference_caps.backtrack);
887                if num_backtrack == 1 {
888                    // don't add the token to the list
889                } else if num_backtrack > 1 {
890                    // backtrack
891                    tokens.truncate(tokens.len() - num_backtrack - 1);
892                }
893            }
894
895            // This is optional; if you don't check, compute_mask() will
896            // return an error when it cannot continue anymore.
897            // If you check here, you can distinguish between normal stop
898            // and an error.
899            if self.check_stop()? {
900                break;
901            }
902
903            // This is optional - call if you have the ability to append
904            // several tokens at once.
905            let forced = self.consume_ff_tokens()?;
906            tokens.extend_from_slice(&forced);
907        }
908
909        Ok(())
910    }
911}