llguidance/
tokenparser.rs

1use std::{fmt::Display, hint::black_box, panic::AssertUnwindSafe, sync::Arc, time::Duration};
2
3use crate::{
4    api::{GrammarInit, ParserLimits, StopReason},
5    earley::{BiasComputer, Parser, ParserError, ParserStats},
6    infoln, panic_utils, warn, Instant, Logger, ParserFactory,
7};
8use anyhow::{ensure, Result};
9use toktrie::{InferenceCapabilities, SimpleVob, TokEnv, TokenId, INVALID_TOKEN};
10
11#[derive(Clone)]
12pub struct TokenParser {
13    pub token_env: TokEnv,
14    pub parser: Parser,
15    pub compute_mask_start_time: Instant,
16    pub last_bias_time: Duration,
17    pub inference_caps: InferenceCapabilities,
18    pub logger: Logger,
19    pub limits: ParserLimits,
20    pub bias_computer: Arc<dyn BiasComputer>,
21    pub dbg_grammar: String,
22    last_step_stats: ParserStats,
23    max_step_stats: ParserStats,
24    eos_token: TokenId,
25
26    had_rollback: bool,
27    had_backtrack: bool,
28
29    is_accepting_cache: Option<bool>,
30    ff_tokens_cache: Option<(Vec<TokenId>, Vec<u8>)>,
31    stop_reason: StopReason,
32    error_message: Option<String>,
33    max_tokens_total: usize,
34
35    // tokens currently in KV cache
36    llm_tokens: Vec<TokenId>,
37    llm_bytes: Vec<u8>,
38
39    grm_prefix: Vec<u8>,
40    is_fresh: bool,
41}
42
43impl TokenParser {
44    // use ParserFactory externally
45    pub(crate) fn from_init(
46        factory: &ParserFactory,
47        grammar_init: GrammarInit,
48        logger: Logger,
49        inference_caps: InferenceCapabilities,
50        limits: ParserLimits,
51    ) -> Result<Self> {
52        panic_utils::catch_unwind(AssertUnwindSafe(|| {
53            Self::init_inner(factory, grammar_init, logger, inference_caps, limits)
54        }))
55    }
56
57    fn init_inner(
58        factory: &ParserFactory,
59        grammar_init: GrammarInit,
60        mut logger: Logger,
61        inference_caps: InferenceCapabilities,
62        limits: ParserLimits,
63    ) -> Result<Self> {
64        let token_env = factory.tok_env().clone();
65        ensure!(
66            token_env.tokenize_is_canonical() || !inference_caps.ff_tokens,
67            "ff_tokens requires canonical tokenization"
68        );
69        ensure!(
70            !inference_caps.backtrack || inference_caps.ff_tokens,
71            "backtrack requires ff_tokens"
72        );
73
74        let compute_mask_start_time = Instant::now();
75        let mut max_tokens = usize::MAX;
76        if let GrammarInit::Serialized(input) = &grammar_init {
77            if let Some(m) = input.max_tokens {
78                max_tokens = m;
79            }
80        }
81        let compiled_grammar = grammar_init.to_cgrammar(
82            Some(token_env.clone()),
83            &mut logger,
84            limits.clone(),
85            factory.extra_lexemes(),
86        )?;
87        let mut parser = Parser::new(
88            token_env.clone(),
89            compiled_grammar,
90            limits.clone(),
91            factory.perf_counters(),
92        )?;
93        parser.metrics_mut().rand = factory.next_rng();
94        let eos_token = token_env.tok_trie().eos_token();
95
96        Ok(TokenParser {
97            bias_computer: factory.slicer().clone(),
98            logger,
99            token_env,
100            inference_caps,
101            limits,
102            max_step_stats: ParserStats::default(),
103            last_step_stats: ParserStats::default(),
104            compute_mask_start_time,
105            is_accepting_cache: None,
106            ff_tokens_cache: None,
107            stop_reason: StopReason::NotStopped,
108            error_message: None,
109            parser,
110            dbg_grammar: String::new(),
111            eos_token,
112            llm_tokens: Vec::new(),
113            llm_bytes: Vec::new(),
114            grm_prefix: Vec::new(),
115            max_tokens_total: max_tokens,
116            last_bias_time: Duration::from_secs(0),
117            is_fresh: true,
118            had_backtrack: false,
119            had_rollback: false,
120        })
121    }
122
123    pub fn grammar_warnings(&mut self) -> Vec<String> {
124        self.parser.grammar_warnings()
125    }
126
127    pub fn get_capture(&self, name: &str) -> Option<&[u8]> {
128        self.parser.get_capture(name)
129    }
130
131    // regular .clone() uses a shared lexer state
132    pub fn deep_clone(&self) -> Self {
133        let mut copy = self.clone();
134        copy.parser = self.parser.deep_clone();
135        copy
136    }
137
138    pub fn stop_reason(&self) -> StopReason {
139        self.stop_reason
140    }
141
142    pub fn stopped(&self) -> bool {
143        self.stop_reason != StopReason::NotStopped
144    }
145
146    pub fn is_fresh(&self) -> bool {
147        self.is_fresh
148    }
149
150    pub fn parser_stats(&self) -> &ParserStats {
151        self.parser.stats()
152    }
153
154    pub fn last_step_stats(&self) -> &ParserStats {
155        &self.last_step_stats
156    }
157
158    pub fn max_step_stats(&self) -> &ParserStats {
159        &self.max_step_stats
160    }
161
162    pub fn num_tokens(&self) -> usize {
163        self.llm_tokens.len()
164    }
165
166    pub fn final_bytes(&self) -> &[u8] {
167        &self.llm_bytes[self.grm_prefix.len()..]
168    }
169
170    pub fn is_accepting(&mut self) -> bool {
171        if let Some(acc) = self.is_accepting_cache {
172            acc
173        } else {
174            let r = !self.has_ff_bytes() && self.parser.is_accepting();
175            self.is_accepting_cache = Some(r);
176            r
177        }
178    }
179
180    pub fn bytes_since(&self, mut idx: usize) -> &[u8] {
181        idx += self.grm_prefix.len();
182        let endp = std::cmp::min(self.llm_bytes.len(), self.parser.hidden_start());
183        if idx >= self.llm_bytes.len() || idx >= endp {
184            return &[];
185        }
186        &self.llm_bytes[idx..endp]
187    }
188
189    pub fn start_without_prompt(&mut self) {
190        infoln!(
191            self,
192            "initial lexer cost: {} (no prompt)",
193            self.parser.lexer_stats()
194        );
195
196        assert!(self.is_fresh);
197        self.is_fresh = false;
198    }
199
200    fn tokenize_and_chop(
201        &mut self,
202        mut tokens: Vec<TokenId>,
203        num_fixed: usize,
204    ) -> (Vec<TokenId>, usize) {
205        let trie = self.token_env.tok_trie();
206        let (chop_tokens, chop_bytes) = self
207            .parser
208            .with_recognizer(|r| trie.chop_tokens(r, &tokens[num_fixed..]));
209        infoln!(
210            self,
211            "tokenize -> {}; chop: {} tokens, {} bytes",
212            trie.tokens_dbg(&tokens),
213            chop_tokens,
214            chop_bytes
215        );
216        // here we remove a suffix from tokens that could be possibly tokenized differently
217        tokens.truncate(tokens.len() - chop_tokens);
218        (tokens, chop_bytes)
219    }
220
221    pub fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
222        infoln!(self, "initial lexer cost: {}", self.parser.lexer_stats());
223
224        assert!(self.token_env.tokenize_is_canonical());
225        assert!(self.is_fresh);
226        self.is_fresh = false;
227
228        assert!(self.llm_tokens.is_empty());
229
230        let trie = self.token_env.tok_trie();
231        infoln!(self, "prompt: {}", trie.tokens_dbg(&prompt));
232        let mut prompt_bytes = trie.decode_raw(&prompt);
233        if self.can_force_bytes() {
234            self.parser.force_bytes();
235        }
236        let grm_bytes = self.parser.get_bytes().to_vec();
237        prompt_bytes.extend_from_slice(&grm_bytes);
238
239        let (tokens, num_fixed) = self.token_env.tokenize_bytes_marker(&prompt_bytes);
240        let (res_prompt, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
241
242        let trie = self.token_env.tok_trie();
243        infoln!(
244            self,
245            "prompt+grm: {} {}",
246            trie.tokens_dbg(&res_prompt),
247            self.parser.grammar().lexer_spec().no_forcing
248        );
249
250        // if we moved a bunch of grammar to the prompt, update llm_tokens to reflect that
251        if chop_bytes <= grm_bytes.len() {
252            self.llm_bytes = grm_bytes[0..grm_bytes.len() - chop_bytes].to_vec();
253            self.llm_tokens = self.token_env.tokenize_bytes_marker(&self.llm_bytes).0;
254            self.parser.apply_forced(self.llm_bytes.len());
255            let decoded = self.tok_trie().decode_raw(&self.llm_tokens);
256            if !self.llm_bytes.is_empty()
257                && !decoded.is_empty()
258                && decoded[1..] == self.llm_bytes
259                && decoded[0] == b' '
260            {
261                infoln!(self, "applying <s>space hack");
262                self.grm_prefix = decoded[0..1].to_vec();
263                self.llm_bytes = decoded;
264            }
265            infoln!(self, "ini_tokens: {}", trie.tokens_dbg(&self.llm_tokens));
266        } else {
267            // pretend the final bit of prompt was the prefix of the grammar
268            self.grm_prefix = prompt_bytes
269                [prompt_bytes.len() - chop_bytes..prompt_bytes.len() - grm_bytes.len()]
270                .to_vec();
271            infoln!(
272                self,
273                "force_prefix: {:?}",
274                String::from_utf8_lossy(&self.grm_prefix)
275            );
276        }
277
278        infoln!(self, "res_prompt: {}", trie.tokens_dbg(&res_prompt));
279        res_prompt
280    }
281
282    pub fn augment_err(&self, e: impl Display) -> String {
283        if self.limits.verbose_errors {
284            format!(
285                "{e}\n<state>\n{}\n</state><grammar>\n{}\n</grammar>",
286                self.dump_state(),
287                self.dbg_grammar
288            )
289        } else {
290            format!("{e}\n<non-verbose/>")
291        }
292    }
293
294    pub fn dump_state(&self) -> String {
295        // make sure not take self.parser.shared lock
296        // for example, self.parser.lexer_stats() takes it
297        // if we take it after panic, it will be poisoned
298        format!(
299            "Tokens: {}\n{} tokens, {} bytes; grm_prefix: {:?}\nFlags:{}{}\nParser: {}\nStop: {}\nError: {}",
300            self.tok_trie().tokens_dbg(&self.llm_tokens),
301            self.llm_tokens.len(),
302            self.llm_bytes.len(),
303            String::from_utf8_lossy(&self.grm_prefix),
304            if self.had_backtrack {
305                " had_backtrack"
306            } else {
307                ""
308            },
309            if self.had_rollback {
310                " had_rollback"
311            } else {
312                ""
313            },
314            self.parser.stats(),
315            self.stop_reason,
316            self.error_message.as_deref().unwrap_or("None"),
317        )
318    }
319
320    fn clear_caches(&mut self) {
321        self.is_accepting_cache = None;
322        self.ff_tokens_cache = None;
323    }
324
325    fn stop(&mut self, warn: &str, reason: StopReason) -> anyhow::Error {
326        if !warn.is_empty() {
327            self.error_message = Some(warn.to_string());
328            warn!(self, "{}; stopping", warn);
329        }
330        self.stop_reason = reason;
331        self.anyhow_error()
332    }
333
334    fn tok_trie(&self) -> &toktrie::TokTrie {
335        self.token_env.tok_trie()
336    }
337
338    pub fn error_message(&self) -> Option<String> {
339        self.error_message.clone()
340    }
341
342    fn check_initialized(&self, lbl: &str) -> Result<()> {
343        ensure!(!self.is_fresh, "process_prompt() not called in {}", lbl);
344        ensure!(
345            !self.stopped(),
346            "parser stopped in {}; {}",
347            lbl,
348            self.error_message()
349                .unwrap_or("no error message".to_string())
350        );
351        Ok(())
352    }
353
354    pub fn validate_token(&mut self, token: TokenId) -> Result<bool> {
355        if self.stopped() {
356            return Ok(false);
357        }
358        self.check_initialized("validate_token")?;
359        self.validate_tokens_raw(&[token]).map(|n| n > 0)
360    }
361
362    pub fn reset(&mut self) -> Result<()> {
363        self.rollback(self.llm_tokens.len())
364    }
365
366    pub fn rollback(&mut self, n_tokens: usize) -> Result<()> {
367        if n_tokens == 0 {
368            return Ok(());
369        }
370
371        ensure!(
372            n_tokens <= self.llm_tokens.len(),
373            "rollback: {} > {}",
374            n_tokens,
375            self.llm_tokens.len()
376        );
377
378        if self.stop_reason.is_ok() {
379            // if we're stopped in "normal" way (e.g. end of grammar reached),
380            // pretend we're not stopped
381            self.stop_reason = StopReason::NotStopped;
382        }
383
384        // this will fail in case we're in error state or not initialized
385        self.check_initialized("rollback")?;
386
387        self.had_rollback = true;
388
389        let new_len = self.llm_tokens.len() - n_tokens;
390        let mut bytes_to_drop = 0;
391        for tok in &self.llm_tokens[new_len..] {
392            if *tok == self.eos_token {
393                // doesn't count; we hope it's last though...
394                bytes_to_drop += 0;
395            } else {
396                bytes_to_drop += self.tok_trie().token_len(*tok);
397            }
398        }
399        ensure!(
400            bytes_to_drop <= self.llm_bytes.len(),
401            "rollback bytes: {} > {}",
402            bytes_to_drop,
403            self.llm_bytes.len()
404        );
405
406        self.parser.rollback(bytes_to_drop)?;
407
408        self.max_tokens_total = self.max_tokens_total.saturating_add(n_tokens);
409        self.llm_tokens.truncate(new_len);
410        self.llm_bytes
411            .truncate(self.llm_bytes.len() - bytes_to_drop);
412        self.clear_caches();
413
414        Ok(())
415    }
416
417    /// Returns how many of the passed tokens can be accepted by the parser.
418    /// It does not tokenize forced bytes, so will accept non-canonical tokenizations.
419    /// If called with more than one token, it may ignore max_tokens constraints.
420    pub fn validate_tokens_raw(&mut self, tokens: &[TokenId]) -> Result<usize> {
421        if self.stopped() {
422            return Ok(0);
423        }
424        self.check_initialized("validate_tokens_raw")?;
425
426        if tokens.is_empty() {
427            return Ok(0);
428        }
429
430        let n_vocab = self.tok_trie().vocab_size();
431        for &t in tokens {
432            if t as usize >= n_vocab {
433                return Err(self.stop(
434                    &format!("token id {t} out of range"),
435                    StopReason::InternalError,
436                ));
437            }
438        }
439
440        let n_valid = self.parser.validate_tokens(tokens);
441        Ok(n_valid)
442    }
443
444    fn anyhow_error(&self) -> anyhow::Error {
445        anyhow::anyhow!(self
446            .error_message
447            .clone()
448            .unwrap_or(self.stop_reason.to_string()))
449    }
450
451    // compute_mask() is a top-level method in this file.
452    // compute_mask() is called by Constraint::compute_mask().
453    pub fn compute_mask(&mut self) -> Result<SimpleVob> {
454        self.compute_mask_start_time = Instant::now();
455        let r = self.compute_mask_inner();
456        self.parser
457            .perf_counters()
458            .compute_mask
459            .record(self.compute_mask_start_time.elapsed());
460        r
461    }
462
463    fn compute_mask_inner(&mut self) -> Result<SimpleVob> {
464        self.check_initialized("compute_mask")?;
465
466        infoln!(self, "compute_mask");
467
468        let prefix = if self.can_force_bytes() {
469            let (ff_tokens, token_prefix) = self
470                .ff_tokens_cache
471                .take()
472                .unwrap_or_else(|| self.ff_tokens());
473            if !ff_tokens.is_empty() {
474                let t = ff_tokens[0];
475                infoln!(self, "forcing ff_token by mask: {}", t);
476                let mask = self.tok_trie().singleton_token_set(t);
477                self.last_step_stats = ParserStats::default();
478                return Ok(mask);
479            } else {
480                // no tokens, so we got all our bytes back
481                token_prefix
482            }
483        } else {
484            let mut trg = Vec::new();
485            self.compute_ff_bytes_to(&mut trg);
486            trg
487        };
488
489        let mut allowed_tokens = self.compute_bias(&prefix);
490
491        if let Some(s) = self.parser.get_error() {
492            return Err(self.stop_for_parser_error("", s));
493        }
494
495        if self.eos_token != INVALID_TOKEN && self.is_accepting() {
496            allowed_tokens.allow_token(self.eos_token);
497        }
498
499        self.log_final(&prefix, &allowed_tokens);
500
501        if allowed_tokens.is_zero() {
502            infoln!(self, "no tokens allowed, stopping");
503            return Err(self.stop("", StopReason::NoExtensionBias));
504        }
505
506        Ok(allowed_tokens)
507    }
508
509    fn stop_for_parser_error(&mut self, pref: &str, err: ParserError) -> anyhow::Error {
510        self.stop(&format!("{}{}", pref, err.message()), err.stop_reason())
511    }
512
513    fn apply_token(&mut self, tok_id: TokenId) -> Result<usize> {
514        self.clear_caches();
515
516        let trie = self.token_env.tok_trie();
517
518        if (tok_id as usize) >= trie.vocab_size() {
519            return Err(self.stop(
520                &format!("token id {tok_id} out of range"),
521                StopReason::InternalError,
522            ));
523        }
524
525        self.llm_tokens.push(tok_id);
526
527        let tok_bytes = trie.decode_raw(&[tok_id]);
528
529        // first, check we're still in grm_prefix
530        let prefix_len = self.grm_prefix.len().saturating_sub(self.llm_bytes.len());
531
532        infoln!(
533            self,
534            "consume_token: {} {} prefix={}",
535            tok_id,
536            trie.token_dbg(tok_id),
537            prefix_len
538        );
539
540        let tok_bytes = if prefix_len > 0 {
541            let to_apply = &tok_bytes[0..std::cmp::min(tok_bytes.len(), prefix_len)];
542            self.llm_bytes.extend_from_slice(to_apply);
543
544            if self.grm_prefix[0..self.llm_bytes.len()] != self.llm_bytes {
545                return Err(self.stop(
546                    &format!(
547                        "prefix mismatch: applying {:?}; {:?} vs {:?}",
548                        String::from_utf8_lossy(to_apply),
549                        String::from_utf8_lossy(&self.grm_prefix),
550                        String::from_utf8_lossy(&self.llm_bytes)
551                    ),
552                    StopReason::InternalError,
553                ));
554            }
555
556            if prefix_len < tok_bytes.len() {
557                &tok_bytes[prefix_len..]
558            } else {
559                // still completely in prefix, nothing more to apply
560                return Ok(0);
561            }
562        } else {
563            &tok_bytes
564        };
565
566        if let Some(err) = self.parser.get_error() {
567            return Err(self.stop_for_parser_error("", err));
568        }
569
570        // now apply normally
571        match self.parser.apply_token(tok_bytes, tok_id) {
572            Err(e) => {
573                return Err(self.stop(
574                    &format!("Parser Error: {e}"),
575                    StopReason::ParserTooComplex, // TODO - there are other reasons
576                ));
577            }
578            Ok(backtrack_bytes0) => {
579                self.llm_bytes.extend_from_slice(tok_bytes);
580
581                if backtrack_bytes0 != 0 {
582                    self.had_backtrack = true;
583                    let mut backtrack_bytes: isize = backtrack_bytes0.try_into().unwrap();
584                    let mut backtrack_tokens = 0;
585                    while backtrack_bytes > 0 {
586                        let tok_off = self.llm_tokens.len() - backtrack_tokens;
587                        if tok_off == 0 {
588                            break; // we can't backtrack any further
589                        }
590                        let tok = self.llm_tokens[tok_off - 1];
591                        backtrack_bytes -= trie.token_len(tok) as isize;
592                        backtrack_tokens += 1;
593                    }
594                    assert!(backtrack_tokens > 0);
595                    let additional_backtrack_bytes: usize = (-backtrack_bytes).try_into().unwrap();
596                    let full_backtrack_bytes = backtrack_bytes0 + additional_backtrack_bytes;
597
598                    let byte_ptr = self.llm_bytes.len() - full_backtrack_bytes;
599                    infoln!(
600                        self,
601                        "backtrack: {} tokens / {}+{} bytes (deletes: {:?})",
602                        backtrack_tokens,
603                        backtrack_bytes0,
604                        additional_backtrack_bytes,
605                        String::from_utf8_lossy(&self.llm_bytes[byte_ptr..])
606                    );
607                    self.llm_bytes.truncate(byte_ptr);
608
609                    let token_ptr = self.llm_tokens.len() - backtrack_tokens;
610                    if !self.inference_caps.backtrack {
611                        warn!(
612                            self,
613                            "can't backtrack over {}; this may confuse the model",
614                            trie.tokens_dbg(&self.llm_tokens[token_ptr..])
615                        );
616                        // pretend there's no backtrack
617                        backtrack_tokens = 0;
618                    } else {
619                        // make sure the parser know we actually don't have
620                        // the non-backtracked bytes of backtracked token
621                        self.parser.additional_backtrack(additional_backtrack_bytes);
622                    }
623                    self.llm_tokens.truncate(token_ptr);
624                    return Ok(backtrack_tokens);
625                }
626            }
627        }
628
629        Ok(0)
630    }
631
632    fn pending_grm_prefix(&self) -> &[u8] {
633        &self.grm_prefix[std::cmp::min(self.grm_prefix.len(), self.llm_bytes.len())..]
634    }
635
636    fn has_ff_bytes(&self) -> bool {
637        !self.pending_grm_prefix().is_empty() || !self.parser.currently_forced_bytes().is_empty()
638    }
639
640    fn can_force_bytes(&self) -> bool {
641        !self.parser.grammar().lexer_spec().no_forcing && self.token_env.tokenize_is_canonical()
642    }
643
644    pub fn force_bytes(&mut self) -> Vec<u8> {
645        self.parser.force_bytes();
646        let mut trg = Vec::new();
647        self.compute_ff_bytes_inner(&mut trg);
648        trg
649    }
650
651    fn compute_ff_bytes_to(&mut self, trg: &mut Vec<u8>) {
652        // PERF: in some cases, this may be long
653        if self.can_force_bytes() {
654            self.parser.force_bytes();
655        }
656        self.compute_ff_bytes_inner(trg);
657    }
658
659    fn compute_ff_bytes_inner(&mut self, trg: &mut Vec<u8>) {
660        // handle grm_prefix we might have injected
661        if self.llm_bytes.len() < self.grm_prefix.len() {
662            let inject = &self.grm_prefix[self.llm_bytes.len()..];
663            trg.extend_from_slice(inject);
664            infoln!(
665                self,
666                "injecting prefix: {:?}",
667                String::from_utf8_lossy(inject)
668            );
669        }
670
671        trg.extend_from_slice(self.parser.currently_forced_bytes());
672    }
673
674    /// Converts forced bytes into tokens.
675    /// Also returns any bytes that need to be prefix of the
676    /// next sampled token (token healing).
677    fn ff_tokens(&mut self) -> (Vec<TokenId>, Vec<u8>) {
678        let mut forced_bytes = Vec::new();
679        let mut existing_tokens = if self.llm_tokens.is_empty() {
680            Vec::new()
681        } else {
682            let r = self.llm_tokens[self.llm_tokens.len() - 1..].to_vec();
683            let trie = self.token_env.tok_trie();
684            forced_bytes = trie.decode_raw(&r);
685            r
686        };
687        let num_existing_bytes = forced_bytes.len();
688
689        self.compute_ff_bytes_to(&mut forced_bytes);
690
691        let mut token_prefix = Vec::new();
692
693        let do_force =
694            forced_bytes.len() > num_existing_bytes && self.token_env.tokenize_is_canonical();
695        if do_force {
696            let t0 = Instant::now();
697            let (mut tokens, mut num_fixed) = self.token_env.tokenize_bytes_marker(&forced_bytes);
698            if !tokens.starts_with(&existing_tokens) {
699                // whoops, re-tokenize without the prefix
700                let trie = self.token_env.tok_trie();
701                infoln!(
702                    self,
703                    "re-tokenizing without prefix: {}; because we got {}",
704                    trie.tokens_dbg(&existing_tokens),
705                    trie.tokens_dbg(&tokens),
706                );
707                (tokens, num_fixed) = self
708                    .token_env
709                    .tokenize_bytes_marker(&forced_bytes[num_existing_bytes..]);
710                infoln!(
711                    self,
712                    "re-tokenized: {} from: {:?}",
713                    trie.tokens_dbg(&tokens),
714                    &forced_bytes[num_existing_bytes..]
715                );
716                existing_tokens.clear();
717            } else {
718                num_fixed = std::cmp::max(existing_tokens.len(), num_fixed);
719            }
720
721            let (mut grm_tokens, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
722            assert!(grm_tokens.starts_with(&existing_tokens));
723            grm_tokens.drain(..existing_tokens.len());
724
725            let trie = self.token_env.tok_trie();
726            infoln!(
727                self,
728                "forced: {} bytes:{:?} tokens:{:?}",
729                trie.tokens_dbg(&grm_tokens),
730                &forced_bytes[num_existing_bytes..],
731                grm_tokens
732            );
733            token_prefix = forced_bytes[forced_bytes.len() - chop_bytes..].to_vec();
734
735            self.parser.perf_counters().tokenize_ff.record(t0.elapsed());
736
737            if !grm_tokens.is_empty() {
738                infoln!(
739                    self,
740                    "fixed_tokens: {}; prefix len {}",
741                    trie.tokens_dbg(&grm_tokens),
742                    token_prefix.len()
743                );
744                return (grm_tokens, token_prefix);
745            } else {
746                infoln!(self, "no fixed tokens; prefix len {}", token_prefix.len());
747            }
748        } else if forced_bytes.len() > num_existing_bytes {
749            infoln!(self, "not-forcing {} bytes", forced_bytes.len());
750            token_prefix = forced_bytes[num_existing_bytes..].to_vec();
751        }
752
753        (Vec::new(), token_prefix)
754    }
755
756    fn compute_bias(&mut self, token_prefix: &[u8]) -> SimpleVob {
757        let pre_stats = self.parser.stats().clone();
758        let set = self.parser.compute_bias(&*self.bias_computer, token_prefix);
759        let p_stats = self.parser.stats().delta(&pre_stats);
760        self.last_bias_time = Duration::from_micros(p_stats.compute_time_us);
761        self.last_step_stats = p_stats.clone();
762        self.max_step_stats = self.max_step_stats.max(&p_stats);
763        set
764    }
765
766    fn log_final(&mut self, token_prefix: &[u8], allowed_tokens: &SimpleVob) {
767        infoln!(
768            self,
769            "step-stats: {}us; {} lex fuel; {} items; {}",
770            self.compute_mask_start_time.elapsed().as_micros(),
771            self.last_step_stats.lexer_cost,
772            self.last_step_stats.all_items,
773            self.parser.lexer_stats(),
774        );
775
776        infoln!(
777            self,
778            "bias: (pref: {:?}; accpt: {}; temp: {:.3}) {}",
779            String::from_utf8_lossy(token_prefix),
780            self.is_accepting_cache.unwrap(),
781            self.parser.temperature().unwrap_or(0.0),
782            self.token_env.tok_trie().token_set_dbg(allowed_tokens)
783        );
784    }
785
786    pub fn temperature(&self) -> Option<f32> {
787        self.parser.temperature()
788    }
789
790    /// Extend the current state of the parser with given token.
791    /// Returns number of tokens to backtrack if any.
792    pub fn consume_token(&mut self, token: TokenId) -> Result<usize> {
793        self.check_initialized("consume_token")?;
794
795        if self.max_tokens_total == 0 {
796            return Err(self.stop("max_tokens_total reached", StopReason::MaxTokensTotal));
797        }
798        self.max_tokens_total -= 1;
799
800        if token == self.eos_token {
801            if self.parser.scan_eos() {
802                // it got scanned correctly, so we remove it
803                // this only happens for gen() terminated by EOS
804                infoln!(self, "consume_token: scanned eos_token");
805                // if self.inference_caps.backtrack {
806                //     return Ok(1);
807                // } else {
808                //     warn!(self, "can't backtrack over eos_token");
809                //     return Ok(0);
810                // }
811                // don't backtrack it for now, fails tests
812                return Ok(0);
813            } else {
814                let accepting = self.is_accepting();
815                infoln!(
816                    self,
817                    "consume_token: eos_token not eaten by parser; accept={}",
818                    accepting
819                );
820                if accepting {
821                    self.llm_tokens.push(token);
822                    return Ok(0);
823                }
824            }
825        }
826
827        let apply_res = self.apply_token(token);
828        self.parser.log_row_infos("post-apply");
829        match apply_res {
830            Err(_) => Err(self.anyhow_error()),
831            Ok(n) => Ok(n),
832        }
833    }
834
835    /// Check whether the current parser state forces the sequence to stop.
836    /// If so, puts the parser in stop state and returns true.
837    /// Otherwise, returns false.
838    /// This generally should be called after consume_token().
839    pub fn check_stop(&mut self) -> Result<bool> {
840        let empty_token_prefix = !self.has_ff_bytes();
841        let pending_eos = self.llm_tokens.last() == Some(&self.eos_token);
842        let lexer_bytes = self.parser.has_pending_lexeme_bytes();
843        let is_accepting = self.is_accepting();
844        let can_advance = self.parser.can_advance();
845        let parser_done = is_accepting && (!can_advance || pending_eos);
846        infoln!(
847            self,
848            "parser_done: {parser_done}; lexer_bytes: {lexer_bytes}; \
849                can_advance: {can_advance} (eos:{pending_eos}); \
850                accept: {is_accepting}; \
851                empty_token_prefix: {empty_token_prefix}"
852        );
853        assert!(!is_accepting || empty_token_prefix);
854
855        if parser_done {
856            infoln!(
857                self,
858                "only eos token allowed, stopping; accepting: {}",
859                is_accepting
860            );
861            let reason = if pending_eos {
862                StopReason::EndOfSentence
863            } else {
864                StopReason::NoExtension
865            };
866            self.stop("", reason);
867            Ok(true)
868        } else {
869            Ok(false)
870        }
871    }
872
873    /// Check if there are any tokens to fast-forward, forced by the current
874    /// parser state.
875    pub fn compute_ff_tokens(&mut self) -> Vec<TokenId> {
876        let r = self.ff_tokens();
877        if self.can_force_bytes() {
878            self.ff_tokens_cache = Some(r.clone());
879        }
880        r.0
881    }
882
883    /// Compute and then consume fast-forward tokens.
884    pub fn consume_ff_tokens(&mut self) -> Result<Vec<TokenId>> {
885        let ff_tokens = self.compute_ff_tokens();
886        for &t in &ff_tokens {
887            let num_backtrack = self.consume_token(t)?;
888            if num_backtrack > 0 {
889                return Err(self.stop(
890                    &format!("backtrack required after ff_token: {t}"),
891                    StopReason::InternalError,
892                ));
893            }
894        }
895        Ok(ff_tokens)
896    }
897
898    /// This function documents typical use of this interface.
899    /// The `tokens` array simulates tokens being sampled.
900    #[allow(dead_code)]
901    fn typical_use(&mut self, prompt: Vec<TokenId>) -> Result<()> {
902        // First, check if we need to token-heal the prompt,
903        // and if there are some tokens forced by the beginning
904        // of the grammar.
905        let new_prompt = self.process_prompt(prompt);
906
907        // pass new prompt to inference engine
908        black_box(new_prompt);
909
910        let mut tokens = vec![];
911
912        loop {
913            let temp = self.temperature();
914            let mask = self.compute_mask()?;
915
916            // model forward pass in parallel with compute_mask() goes here
917
918            // simulate sampling a token with given mask
919            black_box((temp, mask));
920            let sampled_token = 42;
921
922            let num_backtrack = self.consume_token(sampled_token)?;
923
924            if num_backtrack == 0 {
925                // normal situation - the token was accepted
926                tokens.push(sampled_token);
927            } else {
928                // this will only happen if you enable backtrack
929                assert!(self.inference_caps.backtrack);
930                if num_backtrack == 1 {
931                    // don't add the token to the list
932                } else if num_backtrack > 1 {
933                    // backtrack
934                    tokens.truncate(tokens.len() - num_backtrack - 1);
935                }
936            }
937
938            // This is optional; if you don't check, compute_mask() will
939            // return an error when it cannot continue anymore.
940            // If you check here, you can distinguish between normal stop
941            // and an error.
942            if self.check_stop()? {
943                break;
944            }
945
946            // This is optional - call if you have the ability to append
947            // several tokens at once.
948            let forced = self.consume_ff_tokens()?;
949            tokens.extend_from_slice(&forced);
950        }
951
952        Ok(())
953    }
954}