1use std::{fmt::Display, hint::black_box, panic::AssertUnwindSafe, sync::Arc, time::Duration};
2
3use crate::{
4 api::{GrammarInit, ParserLimits, StopReason},
5 earley::{BiasComputer, Parser, ParserError, ParserStats},
6 infoln, panic_utils, warn, Instant, Logger, ParserFactory,
7};
8use anyhow::{ensure, Result};
9use toktrie::{InferenceCapabilities, SimpleVob, TokEnv, TokenId, INVALID_TOKEN};
10
11#[derive(Clone)]
17pub struct TokenParser {
18 pub token_env: TokEnv,
19 pub parser: Parser,
20 pub compute_mask_start_time: Instant,
21 pub last_bias_time: Duration,
22 pub inference_caps: InferenceCapabilities,
23 pub logger: Logger,
24 pub limits: ParserLimits,
25 pub bias_computer: Arc<dyn BiasComputer>,
26 pub dbg_grammar: String,
27 last_step_stats: ParserStats,
28 max_step_stats: ParserStats,
29 eos_tokens: Vec<TokenId>,
30
31 had_rollback: bool,
32 had_backtrack: bool,
33
34 is_accepting_cache: Option<bool>,
35 ff_tokens_cache: Option<(Vec<TokenId>, Vec<u8>)>,
36 stop_reason: StopReason,
37 error_message: Option<String>,
38 max_tokens_total: usize,
39
40 llm_tokens: Vec<TokenId>,
42 llm_bytes: Vec<u8>,
43
44 grm_prefix: Vec<u8>,
45 is_fresh: bool,
46}
47
48impl TokenParser {
49 pub(crate) fn from_init(
51 factory: &ParserFactory,
52 grammar_init: GrammarInit,
53 logger: Logger,
54 inference_caps: InferenceCapabilities,
55 limits: ParserLimits,
56 ) -> Result<Self> {
57 panic_utils::catch_unwind(AssertUnwindSafe(|| {
58 Self::init_inner(factory, grammar_init, logger, inference_caps, limits)
59 }))
60 }
61
62 fn init_inner(
63 factory: &ParserFactory,
64 grammar_init: GrammarInit,
65 mut logger: Logger,
66 inference_caps: InferenceCapabilities,
67 limits: ParserLimits,
68 ) -> Result<Self> {
69 let token_env = factory.tok_env().clone();
70 ensure!(
71 token_env.tokenize_is_canonical() || !inference_caps.ff_tokens,
72 "ff_tokens requires canonical tokenization"
73 );
74 ensure!(
75 !inference_caps.backtrack || inference_caps.ff_tokens,
76 "backtrack requires ff_tokens"
77 );
78
79 let compute_mask_start_time = Instant::now();
80 let mut max_tokens = usize::MAX;
81 if let GrammarInit::Serialized(input) = &grammar_init {
82 if let Some(m) = input.max_tokens {
83 max_tokens = m;
84 }
85 }
86 let compiled_grammar = grammar_init.to_cgrammar(
87 Some(token_env.clone()),
88 &mut logger,
89 limits.clone(),
90 factory.extra_lexemes(),
91 )?;
92 let mut parser = Parser::new(
93 token_env.clone(),
94 compiled_grammar,
95 limits.clone(),
96 factory.perf_counters(),
97 )?;
98 parser.metrics_mut().rand = factory.next_rng();
99 let eos_tokens = token_env.tok_trie().eos_tokens().to_vec();
100
101 Ok(TokenParser {
102 bias_computer: factory.slicer().clone(),
103 logger,
104 token_env,
105 inference_caps,
106 limits,
107 max_step_stats: ParserStats::default(),
108 last_step_stats: ParserStats::default(),
109 compute_mask_start_time,
110 is_accepting_cache: None,
111 ff_tokens_cache: None,
112 stop_reason: StopReason::NotStopped,
113 error_message: None,
114 parser,
115 dbg_grammar: String::new(),
116 eos_tokens,
117 llm_tokens: Vec::new(),
118 llm_bytes: Vec::new(),
119 grm_prefix: Vec::new(),
120 max_tokens_total: max_tokens,
121 last_bias_time: Duration::from_secs(0),
122 is_fresh: true,
123 had_backtrack: false,
124 had_rollback: false,
125 })
126 }
127
128 pub fn grammar_warnings(&mut self) -> Vec<String> {
129 self.parser.grammar_warnings()
130 }
131
132 pub fn get_capture(&self, name: &str) -> Option<&[u8]> {
133 self.parser.get_capture(name)
134 }
135
136 pub fn captures(&self) -> &[(String, Vec<u8>)] {
137 self.parser.captures()
138 }
139
140 pub fn deep_clone(&self) -> Self {
142 let mut copy = self.clone();
143 copy.parser = self.parser.deep_clone();
144 copy
145 }
146
147 pub fn stop_reason(&self) -> StopReason {
148 self.stop_reason
149 }
150
151 pub fn stopped(&self) -> bool {
152 self.stop_reason != StopReason::NotStopped
153 }
154
155 pub fn is_fresh(&self) -> bool {
156 self.is_fresh
157 }
158
159 pub fn parser_stats(&self) -> &ParserStats {
160 self.parser.stats()
161 }
162
163 pub fn last_step_stats(&self) -> &ParserStats {
164 &self.last_step_stats
165 }
166
167 pub fn max_step_stats(&self) -> &ParserStats {
168 &self.max_step_stats
169 }
170
171 pub fn num_tokens(&self) -> usize {
172 self.llm_tokens.len()
173 }
174
175 pub fn final_bytes(&self) -> &[u8] {
176 &self.llm_bytes[self.grm_prefix.len()..]
177 }
178
179 pub fn is_accepting(&mut self) -> bool {
180 if let Some(acc) = self.is_accepting_cache {
181 acc
182 } else {
183 let r = !self.has_ff_bytes() && self.parser.is_accepting();
184 self.is_accepting_cache = Some(r);
185 r
186 }
187 }
188
189 pub fn bytes_since(&self, mut idx: usize) -> &[u8] {
190 idx += self.grm_prefix.len();
191 let endp = std::cmp::min(self.llm_bytes.len(), self.parser.hidden_start());
192 if idx >= self.llm_bytes.len() || idx >= endp {
193 return &[];
194 }
195 &self.llm_bytes[idx..endp]
196 }
197
198 pub fn start_without_prompt(&mut self) {
199 infoln!(
200 self,
201 "initial lexer cost: {} (no prompt)",
202 self.parser.lexer_stats()
203 );
204
205 assert!(self.is_fresh);
206 self.is_fresh = false;
207 }
208
209 fn tokenize_and_chop(
210 &mut self,
211 mut tokens: Vec<TokenId>,
212 num_fixed: usize,
213 ) -> (Vec<TokenId>, usize) {
214 let trie = self.token_env.tok_trie();
215 let (chop_tokens, chop_bytes) = self
216 .parser
217 .with_recognizer(|r| trie.chop_tokens(r, &tokens[num_fixed..]));
218 infoln!(
219 self,
220 "tokenize -> {}; chop: {} tokens, {} bytes",
221 trie.tokens_dbg(&tokens),
222 chop_tokens,
223 chop_bytes
224 );
225 tokens.truncate(tokens.len() - chop_tokens);
227 (tokens, chop_bytes)
228 }
229
230 pub fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
231 infoln!(self, "initial lexer cost: {}", self.parser.lexer_stats());
232
233 assert!(self.token_env.tokenize_is_canonical());
234 assert!(self.is_fresh);
235 self.is_fresh = false;
236
237 assert!(self.llm_tokens.is_empty());
238
239 let trie = self.token_env.tok_trie();
240 infoln!(self, "prompt: {}", trie.tokens_dbg(&prompt));
241 let mut prompt_bytes = trie.decode_raw(&prompt);
242 if self.can_force_bytes() {
243 self.parser.force_bytes();
244 }
245 let grm_bytes = self.parser.get_bytes().to_vec();
246 prompt_bytes.extend_from_slice(&grm_bytes);
247
248 let (tokens, num_fixed) = self.token_env.tokenize_bytes_marker(&prompt_bytes);
249 let (res_prompt, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
250
251 let trie = self.token_env.tok_trie();
252 infoln!(
253 self,
254 "prompt+grm: {} {}",
255 trie.tokens_dbg(&res_prompt),
256 self.parser.grammar().lexer_spec().no_forcing
257 );
258
259 if chop_bytes <= grm_bytes.len() {
261 self.llm_bytes = grm_bytes[0..grm_bytes.len() - chop_bytes].to_vec();
262 self.llm_tokens = self.token_env.tokenize_bytes_marker(&self.llm_bytes).0;
263 self.parser.apply_forced(self.llm_bytes.len());
264 let decoded = self.tok_trie().decode_raw(&self.llm_tokens);
265 if !self.llm_bytes.is_empty()
266 && !decoded.is_empty()
267 && decoded[1..] == self.llm_bytes
268 && decoded[0] == b' '
269 {
270 infoln!(self, "applying <s>space hack");
271 self.grm_prefix = decoded[0..1].to_vec();
272 self.llm_bytes = decoded;
273 }
274 infoln!(self, "ini_tokens: {}", trie.tokens_dbg(&self.llm_tokens));
275 } else {
276 self.grm_prefix = prompt_bytes
278 [prompt_bytes.len() - chop_bytes..prompt_bytes.len() - grm_bytes.len()]
279 .to_vec();
280 infoln!(
281 self,
282 "force_prefix: {:?}",
283 String::from_utf8_lossy(&self.grm_prefix)
284 );
285 }
286
287 infoln!(self, "res_prompt: {}", trie.tokens_dbg(&res_prompt));
288 res_prompt
289 }
290
291 pub fn augment_err(&self, e: impl Display) -> String {
292 if self.limits.verbose_errors {
293 format!(
294 "{e}\n<state>\n{}\n</state><grammar>\n{}\n</grammar>",
295 self.dump_state(),
296 self.dbg_grammar
297 )
298 } else {
299 format!("{e}\n<non-verbose/>")
300 }
301 }
302
303 pub fn dump_state(&self) -> String {
304 format!(
308 "Tokens: {}\n{} tokens, {} bytes; grm_prefix: {:?}\nFlags:{}{}\nParser: {}\nStop: {}\nError: {}",
309 self.tok_trie().tokens_dbg(&self.llm_tokens),
310 self.llm_tokens.len(),
311 self.llm_bytes.len(),
312 String::from_utf8_lossy(&self.grm_prefix),
313 if self.had_backtrack {
314 " had_backtrack"
315 } else {
316 ""
317 },
318 if self.had_rollback {
319 " had_rollback"
320 } else {
321 ""
322 },
323 self.parser.stats(),
324 self.stop_reason,
325 self.error_message.as_deref().unwrap_or("None"),
326 )
327 }
328
329 fn clear_caches(&mut self) {
330 self.is_accepting_cache = None;
331 self.ff_tokens_cache = None;
332 }
333
334 fn stop(&mut self, warn: &str, reason: StopReason) -> anyhow::Error {
335 if !warn.is_empty() {
336 self.error_message = Some(warn.to_string());
337 warn!(self, "{}; stopping", warn);
338 }
339 self.stop_reason = reason;
340 self.anyhow_error()
341 }
342
343 fn tok_trie(&self) -> &toktrie::TokTrie {
344 self.token_env.tok_trie()
345 }
346
347 pub fn error_message(&self) -> Option<String> {
348 self.error_message.clone()
349 }
350
351 fn check_initialized(&self, lbl: &str) -> Result<()> {
352 ensure!(!self.is_fresh, "process_prompt() not called in {}", lbl);
353 ensure!(
354 !self.stopped(),
355 "parser stopped in {}; {}",
356 lbl,
357 self.error_message()
358 .unwrap_or("no error message".to_string())
359 );
360 Ok(())
361 }
362
363 pub fn validate_token(&mut self, token: TokenId) -> Result<bool> {
364 if self.stopped() {
365 return Ok(false);
366 }
367 self.check_initialized("validate_token")?;
368 self.validate_tokens_raw(&[token]).map(|n| n > 0)
369 }
370
371 pub fn reset(&mut self) -> Result<()> {
372 self.rollback(self.llm_tokens.len())
373 }
374
375 pub fn rollback(&mut self, n_tokens: usize) -> Result<()> {
376 if n_tokens == 0 {
377 return Ok(());
378 }
379
380 ensure!(
381 n_tokens <= self.llm_tokens.len(),
382 "rollback: {} > {}",
383 n_tokens,
384 self.llm_tokens.len()
385 );
386
387 if self.stop_reason.is_ok() {
388 self.stop_reason = StopReason::NotStopped;
391 }
392
393 self.check_initialized("rollback")?;
395
396 self.had_rollback = true;
397
398 let new_len = self.llm_tokens.len() - n_tokens;
399 let mut bytes_to_drop = 0;
400 for tok in &self.llm_tokens[new_len..] {
401 if self.eos_tokens.contains(tok) {
402 bytes_to_drop += 0;
404 } else {
405 bytes_to_drop += self.tok_trie().token_len(*tok);
406 }
407 }
408 ensure!(
409 bytes_to_drop <= self.llm_bytes.len(),
410 "rollback bytes: {} > {}",
411 bytes_to_drop,
412 self.llm_bytes.len()
413 );
414
415 self.parser.rollback(bytes_to_drop)?;
416
417 self.max_tokens_total = self.max_tokens_total.saturating_add(n_tokens);
418 self.llm_tokens.truncate(new_len);
419 self.llm_bytes
420 .truncate(self.llm_bytes.len() - bytes_to_drop);
421 self.clear_caches();
422
423 Ok(())
424 }
425
426 pub fn validate_tokens_raw(&mut self, tokens: &[TokenId]) -> Result<usize> {
430 if self.stopped() {
431 return Ok(0);
432 }
433 self.check_initialized("validate_tokens_raw")?;
434
435 if tokens.is_empty() {
436 return Ok(0);
437 }
438
439 let n_vocab = self.tok_trie().vocab_size();
440 for &t in tokens {
441 if t as usize >= n_vocab {
442 return Err(self.stop(
443 &format!("token id {t} out of range"),
444 StopReason::InternalError,
445 ));
446 }
447 }
448
449 let n_valid = self.parser.validate_tokens(tokens);
450 Ok(n_valid)
451 }
452
453 fn anyhow_error(&self) -> anyhow::Error {
454 anyhow::anyhow!(self
455 .error_message
456 .clone()
457 .unwrap_or(self.stop_reason.to_string()))
458 }
459
460 pub fn compute_mask(&mut self) -> Result<SimpleVob> {
463 self.compute_mask_start_time = Instant::now();
464 let r = self.compute_mask_inner();
465 self.parser
466 .perf_counters()
467 .compute_mask
468 .record(self.compute_mask_start_time.elapsed());
469 r
470 }
471
472 fn compute_mask_inner(&mut self) -> Result<SimpleVob> {
473 self.check_initialized("compute_mask")?;
474
475 infoln!(self, "compute_mask");
476
477 let prefix = if self.can_force_bytes() {
478 let (ff_tokens, token_prefix) = self
479 .ff_tokens_cache
480 .take()
481 .unwrap_or_else(|| self.ff_tokens());
482 if !ff_tokens.is_empty() {
483 let t = ff_tokens[0];
484 infoln!(self, "forcing ff_token by mask: {}", t);
485 let mask = self.tok_trie().singleton_token_set(t);
486 self.last_step_stats = ParserStats::default();
487 return Ok(mask);
488 } else {
489 token_prefix
491 }
492 } else {
493 let mut trg = Vec::new();
494 self.compute_ff_bytes_to(&mut trg);
495 trg
496 };
497
498 let mut allowed_tokens = self.compute_bias(&prefix);
499
500 if let Some(s) = self.parser.get_error() {
501 return Err(self.stop_for_parser_error("", s));
502 }
503
504 if self.is_accepting() {
505 for &eos in &self.eos_tokens {
506 if eos != INVALID_TOKEN {
507 allowed_tokens.allow_token(eos);
508 }
509 }
510 }
511
512 self.log_final(&prefix, &allowed_tokens);
513
514 if allowed_tokens.is_zero() {
515 infoln!(self, "no tokens allowed, stopping");
516 return Err(self.stop("", StopReason::NoExtensionBias));
517 }
518
519 Ok(allowed_tokens)
520 }
521
522 fn stop_for_parser_error(&mut self, pref: &str, err: ParserError) -> anyhow::Error {
523 self.stop(&format!("{}{}", pref, err.message()), err.stop_reason())
524 }
525
526 fn apply_token(&mut self, tok_id: TokenId) -> Result<usize> {
527 self.clear_caches();
528
529 let trie = self.token_env.tok_trie();
530
531 if (tok_id as usize) >= trie.vocab_size() {
532 return Err(self.stop(
533 &format!("token id {tok_id} out of range"),
534 StopReason::InternalError,
535 ));
536 }
537
538 self.llm_tokens.push(tok_id);
539
540 let tok_bytes = trie.decode_raw(&[tok_id]);
541
542 let prefix_len = self.grm_prefix.len().saturating_sub(self.llm_bytes.len());
544
545 infoln!(
546 self,
547 "consume_token: {} {} prefix={}",
548 tok_id,
549 trie.token_dbg(tok_id),
550 prefix_len
551 );
552
553 let tok_bytes = if prefix_len > 0 {
554 let to_apply = &tok_bytes[0..std::cmp::min(tok_bytes.len(), prefix_len)];
555 self.llm_bytes.extend_from_slice(to_apply);
556
557 if self.grm_prefix[0..self.llm_bytes.len()] != self.llm_bytes {
558 return Err(self.stop(
559 &format!(
560 "prefix mismatch: applying {:?}; {:?} vs {:?}",
561 String::from_utf8_lossy(to_apply),
562 String::from_utf8_lossy(&self.grm_prefix),
563 String::from_utf8_lossy(&self.llm_bytes)
564 ),
565 StopReason::InternalError,
566 ));
567 }
568
569 if prefix_len < tok_bytes.len() {
570 &tok_bytes[prefix_len..]
571 } else {
572 return Ok(0);
574 }
575 } else {
576 &tok_bytes
577 };
578
579 if let Some(err) = self.parser.get_error() {
580 return Err(self.stop_for_parser_error("", err));
581 }
582
583 match self.parser.apply_token(tok_bytes, tok_id) {
585 Err(e) => {
586 return Err(self.stop(
587 &format!("Parser Error: {e}"),
588 StopReason::ParserTooComplex, ));
590 }
591 Ok(backtrack_bytes0) => {
592 self.llm_bytes.extend_from_slice(tok_bytes);
593
594 if backtrack_bytes0 != 0 {
595 self.had_backtrack = true;
596 let mut backtrack_bytes: isize = backtrack_bytes0.try_into().unwrap();
597 let mut backtrack_tokens = 0;
598 while backtrack_bytes > 0 {
599 let tok_off = self.llm_tokens.len() - backtrack_tokens;
600 if tok_off == 0 {
601 break; }
603 let tok = self.llm_tokens[tok_off - 1];
604 backtrack_bytes -= trie.token_len(tok) as isize;
605 backtrack_tokens += 1;
606 }
607 assert!(backtrack_tokens > 0);
608 let additional_backtrack_bytes: usize = (-backtrack_bytes).try_into().unwrap();
609 let full_backtrack_bytes = backtrack_bytes0 + additional_backtrack_bytes;
610
611 let byte_ptr = self.llm_bytes.len() - full_backtrack_bytes;
612 infoln!(
613 self,
614 "backtrack: {} tokens / {}+{} bytes (deletes: {:?})",
615 backtrack_tokens,
616 backtrack_bytes0,
617 additional_backtrack_bytes,
618 String::from_utf8_lossy(&self.llm_bytes[byte_ptr..])
619 );
620 self.llm_bytes.truncate(byte_ptr);
621
622 let token_ptr = self.llm_tokens.len() - backtrack_tokens;
623 if !self.inference_caps.backtrack {
624 warn!(
625 self,
626 "can't backtrack over {}; this may confuse the model",
627 trie.tokens_dbg(&self.llm_tokens[token_ptr..])
628 );
629 backtrack_tokens = 0;
631 } else {
632 self.parser.additional_backtrack(additional_backtrack_bytes);
635 }
636 self.llm_tokens.truncate(token_ptr);
637 return Ok(backtrack_tokens);
638 }
639 }
640 }
641
642 Ok(0)
643 }
644
645 fn pending_grm_prefix(&self) -> &[u8] {
646 &self.grm_prefix[std::cmp::min(self.grm_prefix.len(), self.llm_bytes.len())..]
647 }
648
649 fn has_ff_bytes(&self) -> bool {
650 !self.pending_grm_prefix().is_empty() || !self.parser.currently_forced_bytes().is_empty()
651 }
652
653 fn can_force_bytes(&self) -> bool {
654 !self.parser.grammar().lexer_spec().no_forcing && self.token_env.tokenize_is_canonical()
655 }
656
657 pub fn force_bytes(&mut self) -> Vec<u8> {
658 self.parser.force_bytes();
659 let mut trg = Vec::new();
660 self.compute_ff_bytes_inner(&mut trg);
661 trg
662 }
663
664 fn compute_ff_bytes_to(&mut self, trg: &mut Vec<u8>) {
665 if self.can_force_bytes() {
667 self.parser.force_bytes();
668 }
669 self.compute_ff_bytes_inner(trg);
670 }
671
672 fn compute_ff_bytes_inner(&mut self, trg: &mut Vec<u8>) {
673 if self.llm_bytes.len() < self.grm_prefix.len() {
675 let inject = &self.grm_prefix[self.llm_bytes.len()..];
676 trg.extend_from_slice(inject);
677 infoln!(
678 self,
679 "injecting prefix: {:?}",
680 String::from_utf8_lossy(inject)
681 );
682 }
683
684 trg.extend_from_slice(self.parser.currently_forced_bytes());
685 }
686
687 fn ff_tokens(&mut self) -> (Vec<TokenId>, Vec<u8>) {
691 let mut forced_bytes = Vec::new();
692 let mut existing_tokens = if self.llm_tokens.is_empty() {
693 Vec::new()
694 } else {
695 let r = self.llm_tokens[self.llm_tokens.len() - 1..].to_vec();
696 let trie = self.token_env.tok_trie();
697 forced_bytes = trie.decode_raw(&r);
698 r
699 };
700 let num_existing_bytes = forced_bytes.len();
701
702 self.compute_ff_bytes_to(&mut forced_bytes);
703
704 let mut token_prefix = Vec::new();
705
706 let do_force =
707 forced_bytes.len() > num_existing_bytes && self.token_env.tokenize_is_canonical();
708 if do_force {
709 let t0 = Instant::now();
710 let (mut tokens, mut num_fixed) = self.token_env.tokenize_bytes_marker(&forced_bytes);
711 if !tokens.starts_with(&existing_tokens) {
712 let trie = self.token_env.tok_trie();
714 infoln!(
715 self,
716 "re-tokenizing without prefix: {}; because we got {}",
717 trie.tokens_dbg(&existing_tokens),
718 trie.tokens_dbg(&tokens),
719 );
720 (tokens, num_fixed) = self
721 .token_env
722 .tokenize_bytes_marker(&forced_bytes[num_existing_bytes..]);
723 infoln!(
724 self,
725 "re-tokenized: {} from: {:?}",
726 trie.tokens_dbg(&tokens),
727 &forced_bytes[num_existing_bytes..]
728 );
729 existing_tokens.clear();
730 } else {
731 num_fixed = std::cmp::max(existing_tokens.len(), num_fixed);
732 }
733
734 let (mut grm_tokens, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
735 assert!(grm_tokens.starts_with(&existing_tokens));
736 grm_tokens.drain(..existing_tokens.len());
737
738 let trie = self.token_env.tok_trie();
739 infoln!(
740 self,
741 "forced: {} bytes:{:?} tokens:{:?}",
742 trie.tokens_dbg(&grm_tokens),
743 &forced_bytes[num_existing_bytes..],
744 grm_tokens
745 );
746 token_prefix = forced_bytes[forced_bytes.len() - chop_bytes..].to_vec();
747
748 self.parser.perf_counters().tokenize_ff.record(t0.elapsed());
749
750 if !grm_tokens.is_empty() {
751 infoln!(
752 self,
753 "fixed_tokens: {}; prefix len {}",
754 trie.tokens_dbg(&grm_tokens),
755 token_prefix.len()
756 );
757 return (grm_tokens, token_prefix);
758 } else {
759 infoln!(self, "no fixed tokens; prefix len {}", token_prefix.len());
760 }
761 } else if forced_bytes.len() > num_existing_bytes {
762 infoln!(self, "not-forcing {} bytes", forced_bytes.len());
763 token_prefix = forced_bytes[num_existing_bytes..].to_vec();
764 }
765
766 (Vec::new(), token_prefix)
767 }
768
769 fn compute_bias(&mut self, token_prefix: &[u8]) -> SimpleVob {
770 let pre_stats = self.parser.stats().clone();
771 let set = self.parser.compute_bias(&*self.bias_computer, token_prefix);
772 let p_stats = self.parser.stats().delta(&pre_stats);
773 self.last_bias_time = Duration::from_micros(p_stats.compute_time_us);
774 self.last_step_stats = p_stats.clone();
775 self.max_step_stats = self.max_step_stats.max(&p_stats);
776 set
777 }
778
779 fn log_final(&mut self, token_prefix: &[u8], allowed_tokens: &SimpleVob) {
780 infoln!(
781 self,
782 "step-stats: {}us; {} lex fuel; {} items; {}",
783 self.compute_mask_start_time.elapsed().as_micros(),
784 self.last_step_stats.lexer_cost,
785 self.last_step_stats.all_items,
786 self.parser.lexer_stats(),
787 );
788
789 infoln!(
790 self,
791 "bias: (pref: {:?}; accpt: {}; temp: {:.3}) {}",
792 String::from_utf8_lossy(token_prefix),
793 self.is_accepting_cache.unwrap(),
794 self.parser.temperature().unwrap_or(0.0),
795 self.token_env.tok_trie().token_set_dbg(allowed_tokens)
796 );
797 }
798
799 pub fn temperature(&self) -> Option<f32> {
800 self.parser.temperature()
801 }
802
803 pub fn consume_token(&mut self, token: TokenId) -> Result<usize> {
806 self.check_initialized("consume_token")?;
807
808 if self.max_tokens_total == 0 {
809 return Err(self.stop("max_tokens_total reached", StopReason::MaxTokensTotal));
810 }
811 self.max_tokens_total -= 1;
812
813 if self.eos_tokens.contains(&token) {
814 if self.parser.scan_eos() {
815 infoln!(self, "consume_token: scanned eos_token");
818 return Ok(0);
826 } else {
827 let accepting = self.is_accepting();
828 infoln!(
829 self,
830 "consume_token: eos_token not eaten by parser; accept={}",
831 accepting
832 );
833 if accepting {
834 self.llm_tokens.push(token);
835 return Ok(0);
836 }
837 }
838 }
839
840 let apply_res = self.apply_token(token);
841 self.parser.log_row_infos("post-apply");
842 match apply_res {
843 Err(_) => Err(self.anyhow_error()),
844 Ok(n) => Ok(n),
845 }
846 }
847
848 pub fn check_stop(&mut self) -> Result<bool> {
853 let empty_token_prefix = !self.has_ff_bytes();
854 let pending_eos = self
855 .llm_tokens
856 .last()
857 .is_some_and(|t| self.eos_tokens.contains(t));
858 let lexer_bytes = self.parser.has_pending_lexeme_bytes();
859 let is_accepting = self.is_accepting();
860 let can_advance = self.parser.can_advance();
861 let parser_done = is_accepting && (!can_advance || pending_eos);
862 infoln!(
863 self,
864 "parser_done: {parser_done}; lexer_bytes: {lexer_bytes}; \
865 can_advance: {can_advance} (eos:{pending_eos}); \
866 accept: {is_accepting}; \
867 empty_token_prefix: {empty_token_prefix}"
868 );
869 assert!(!is_accepting || empty_token_prefix);
870
871 if parser_done {
872 infoln!(
873 self,
874 "only eos token allowed, stopping; accepting: {}",
875 is_accepting
876 );
877 let reason = if pending_eos {
878 StopReason::EndOfSentence
879 } else {
880 StopReason::NoExtension
881 };
882 self.stop("", reason);
883 Ok(true)
884 } else {
885 Ok(false)
886 }
887 }
888
889 pub fn compute_ff_tokens(&mut self) -> Vec<TokenId> {
892 let r = self.ff_tokens();
893 if self.can_force_bytes() {
894 self.ff_tokens_cache = Some(r.clone());
895 }
896 r.0
897 }
898
899 pub fn consume_ff_tokens(&mut self) -> Result<Vec<TokenId>> {
901 let ff_tokens = self.compute_ff_tokens();
902 for &t in &ff_tokens {
903 let num_backtrack = self.consume_token(t)?;
904 if num_backtrack > 0 {
905 return Err(self.stop(
906 &format!("backtrack required after ff_token: {t}"),
907 StopReason::InternalError,
908 ));
909 }
910 }
911 Ok(ff_tokens)
912 }
913
914 #[allow(dead_code)]
917 fn typical_use(&mut self, prompt: Vec<TokenId>) -> Result<()> {
918 let new_prompt = self.process_prompt(prompt);
922
923 black_box(new_prompt);
925
926 let mut tokens = vec![];
927
928 loop {
929 let temp = self.temperature();
930 let mask = self.compute_mask()?;
931
932 black_box((temp, mask));
936 let sampled_token = 42;
937
938 let num_backtrack = self.consume_token(sampled_token)?;
939
940 if num_backtrack == 0 {
941 tokens.push(sampled_token);
943 } else {
944 assert!(self.inference_caps.backtrack);
946 if num_backtrack == 1 {
947 } else if num_backtrack > 1 {
949 tokens.truncate(tokens.len() - num_backtrack - 1);
951 }
952 }
953
954 if self.check_stop()? {
959 break;
960 }
961
962 let forced = self.consume_ff_tokens()?;
965 tokens.extend_from_slice(&forced);
966 }
967
968 Ok(())
969 }
970
971 pub fn invalidate_bias_cache(&mut self) {
972 self.parser.invalidate_bias_cache();
973 }
974}