1use std::{fmt::Display, hint::black_box, panic::AssertUnwindSafe, sync::Arc, time::Duration};
2
3use crate::{
4 api::{GrammarInit, ParserLimits, StopReason},
5 earley::{BiasComputer, Parser, ParserError, ParserStats},
6 infoln, panic_utils, warn, Instant, Logger, ParserFactory,
7};
8use anyhow::{ensure, Result};
9use toktrie::{InferenceCapabilities, SimpleVob, TokEnv, TokenId, INVALID_TOKEN};
10
11#[derive(Clone)]
12pub struct TokenParser {
13 pub token_env: TokEnv,
14 pub parser: Parser,
15 pub compute_mask_start_time: Instant,
16 pub last_bias_time: Duration,
17 pub inference_caps: InferenceCapabilities,
18 pub logger: Logger,
19 pub limits: ParserLimits,
20 pub bias_computer: Arc<dyn BiasComputer>,
21 pub dbg_grammar: String,
22 last_step_stats: ParserStats,
23 max_step_stats: ParserStats,
24 eos_token: TokenId,
25
26 had_rollback: bool,
27 had_backtrack: bool,
28
29 is_accepting_cache: Option<bool>,
30 ff_tokens_cache: Option<(Vec<TokenId>, Vec<u8>)>,
31 stop_reason: StopReason,
32 error_message: Option<String>,
33 max_tokens_total: usize,
34
35 llm_tokens: Vec<TokenId>,
37 llm_bytes: Vec<u8>,
38
39 grm_prefix: Vec<u8>,
40 is_fresh: bool,
41}
42
43impl TokenParser {
44 pub(crate) fn from_init(
46 factory: &ParserFactory,
47 grammar_init: GrammarInit,
48 logger: Logger,
49 inference_caps: InferenceCapabilities,
50 limits: ParserLimits,
51 ) -> Result<Self> {
52 panic_utils::catch_unwind(AssertUnwindSafe(|| {
53 Self::init_inner(factory, grammar_init, logger, inference_caps, limits)
54 }))
55 }
56
57 fn init_inner(
58 factory: &ParserFactory,
59 grammar_init: GrammarInit,
60 mut logger: Logger,
61 inference_caps: InferenceCapabilities,
62 limits: ParserLimits,
63 ) -> Result<Self> {
64 let token_env = factory.tok_env().clone();
65 ensure!(
66 token_env.tokenize_is_canonical() || !inference_caps.ff_tokens,
67 "ff_tokens requires canonical tokenization"
68 );
69 ensure!(
70 !inference_caps.backtrack || inference_caps.ff_tokens,
71 "backtrack requires ff_tokens"
72 );
73
74 let compute_mask_start_time = Instant::now();
75 let mut max_tokens = usize::MAX;
76 if let GrammarInit::Serialized(input) = &grammar_init {
77 if let Some(m) = input.max_tokens {
78 max_tokens = m;
79 }
80 }
81 let compiled_grammar = grammar_init.to_cgrammar(
82 Some(token_env.clone()),
83 &mut logger,
84 limits.clone(),
85 factory.extra_lexemes(),
86 )?;
87 let mut parser = Parser::new(
88 token_env.clone(),
89 compiled_grammar,
90 limits.clone(),
91 factory.perf_counters(),
92 )?;
93 parser.metrics_mut().rand = factory.next_rng();
94 let eos_token = token_env.tok_trie().eos_token();
95
96 Ok(TokenParser {
97 bias_computer: factory.slicer().clone(),
98 logger,
99 token_env,
100 inference_caps,
101 limits,
102 max_step_stats: ParserStats::default(),
103 last_step_stats: ParserStats::default(),
104 compute_mask_start_time,
105 is_accepting_cache: None,
106 ff_tokens_cache: None,
107 stop_reason: StopReason::NotStopped,
108 error_message: None,
109 parser,
110 dbg_grammar: String::new(),
111 eos_token,
112 llm_tokens: Vec::new(),
113 llm_bytes: Vec::new(),
114 grm_prefix: Vec::new(),
115 max_tokens_total: max_tokens,
116 last_bias_time: Duration::from_secs(0),
117 is_fresh: true,
118 had_backtrack: false,
119 had_rollback: false,
120 })
121 }
122
123 pub fn grammar_warnings(&mut self) -> Vec<String> {
124 self.parser.grammar_warnings()
125 }
126
127 pub fn get_capture(&self, name: &str) -> Option<&[u8]> {
128 self.parser.get_capture(name)
129 }
130
131 pub fn deep_clone(&self) -> Self {
133 let mut copy = self.clone();
134 copy.parser = self.parser.deep_clone();
135 copy
136 }
137
138 pub fn stop_reason(&self) -> StopReason {
139 self.stop_reason
140 }
141
142 pub fn stopped(&self) -> bool {
143 self.stop_reason != StopReason::NotStopped
144 }
145
146 pub fn is_fresh(&self) -> bool {
147 self.is_fresh
148 }
149
150 pub fn parser_stats(&self) -> &ParserStats {
151 self.parser.stats()
152 }
153
154 pub fn last_step_stats(&self) -> &ParserStats {
155 &self.last_step_stats
156 }
157
158 pub fn max_step_stats(&self) -> &ParserStats {
159 &self.max_step_stats
160 }
161
162 pub fn num_tokens(&self) -> usize {
163 self.llm_tokens.len()
164 }
165
166 pub fn final_bytes(&self) -> &[u8] {
167 &self.llm_bytes[self.grm_prefix.len()..]
168 }
169
170 pub fn is_accepting(&mut self) -> bool {
171 if let Some(acc) = self.is_accepting_cache {
172 acc
173 } else {
174 let r = !self.has_ff_bytes() && self.parser.is_accepting();
175 self.is_accepting_cache = Some(r);
176 r
177 }
178 }
179
180 pub fn bytes_since(&self, mut idx: usize) -> &[u8] {
181 idx += self.grm_prefix.len();
182 let endp = std::cmp::min(self.llm_bytes.len(), self.parser.hidden_start());
183 if idx >= self.llm_bytes.len() || idx >= endp {
184 return &[];
185 }
186 &self.llm_bytes[idx..endp]
187 }
188
189 pub fn start_without_prompt(&mut self) {
190 infoln!(
191 self,
192 "initial lexer cost: {} (no prompt)",
193 self.parser.lexer_stats()
194 );
195
196 assert!(self.is_fresh);
197 self.is_fresh = false;
198 }
199
200 fn tokenize_and_chop(
201 &mut self,
202 mut tokens: Vec<TokenId>,
203 num_fixed: usize,
204 ) -> (Vec<TokenId>, usize) {
205 let trie = self.token_env.tok_trie();
206 let (chop_tokens, chop_bytes) = self
207 .parser
208 .with_recognizer(|r| trie.chop_tokens(r, &tokens[num_fixed..]));
209 infoln!(
210 self,
211 "tokenize -> {}; chop: {} tokens, {} bytes",
212 trie.tokens_dbg(&tokens),
213 chop_tokens,
214 chop_bytes
215 );
216 tokens.truncate(tokens.len() - chop_tokens);
218 (tokens, chop_bytes)
219 }
220
221 pub fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
222 infoln!(self, "initial lexer cost: {}", self.parser.lexer_stats());
223
224 assert!(self.token_env.tokenize_is_canonical());
225 assert!(self.is_fresh);
226 self.is_fresh = false;
227
228 assert!(self.llm_tokens.is_empty());
229
230 let trie = self.token_env.tok_trie();
231 infoln!(self, "prompt: {}", trie.tokens_dbg(&prompt));
232 let mut prompt_bytes = trie.decode_raw(&prompt);
233 if self.can_force_bytes() {
234 self.parser.force_bytes();
235 }
236 let grm_bytes = self.parser.get_bytes().to_vec();
237 prompt_bytes.extend_from_slice(&grm_bytes);
238
239 let (tokens, num_fixed) = self.token_env.tokenize_bytes_marker(&prompt_bytes);
240 let (res_prompt, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
241
242 let trie = self.token_env.tok_trie();
243 infoln!(
244 self,
245 "prompt+grm: {} {}",
246 trie.tokens_dbg(&res_prompt),
247 self.parser.grammar().lexer_spec().no_forcing
248 );
249
250 if chop_bytes <= grm_bytes.len() {
252 self.llm_bytes = grm_bytes[0..grm_bytes.len() - chop_bytes].to_vec();
253 self.llm_tokens = self.token_env.tokenize_bytes_marker(&self.llm_bytes).0;
254 self.parser.apply_forced(self.llm_bytes.len());
255 let decoded = self.tok_trie().decode_raw(&self.llm_tokens);
256 if !self.llm_bytes.is_empty()
257 && !decoded.is_empty()
258 && decoded[1..] == self.llm_bytes
259 && decoded[0] == b' '
260 {
261 infoln!(self, "applying <s>space hack");
262 self.grm_prefix = decoded[0..1].to_vec();
263 self.llm_bytes = decoded;
264 }
265 infoln!(self, "ini_tokens: {}", trie.tokens_dbg(&self.llm_tokens));
266 } else {
267 self.grm_prefix = prompt_bytes
269 [prompt_bytes.len() - chop_bytes..prompt_bytes.len() - grm_bytes.len()]
270 .to_vec();
271 infoln!(
272 self,
273 "force_prefix: {:?}",
274 String::from_utf8_lossy(&self.grm_prefix)
275 );
276 }
277
278 infoln!(self, "res_prompt: {}", trie.tokens_dbg(&res_prompt));
279 res_prompt
280 }
281
282 pub fn augment_err(&self, e: impl Display) -> String {
283 if self.limits.verbose_errors {
284 format!(
285 "{e}\n<state>\n{}\n</state><grammar>\n{}\n</grammar>",
286 self.dump_state(),
287 self.dbg_grammar
288 )
289 } else {
290 format!("{e}\n<non-verbose/>")
291 }
292 }
293
294 pub fn dump_state(&self) -> String {
295 format!(
299 "Tokens: {}\n{} tokens, {} bytes; grm_prefix: {:?}\nFlags:{}{}\nParser: {}\nStop: {}\nError: {}",
300 self.tok_trie().tokens_dbg(&self.llm_tokens),
301 self.llm_tokens.len(),
302 self.llm_bytes.len(),
303 String::from_utf8_lossy(&self.grm_prefix),
304 if self.had_backtrack {
305 " had_backtrack"
306 } else {
307 ""
308 },
309 if self.had_rollback {
310 " had_rollback"
311 } else {
312 ""
313 },
314 self.parser.stats(),
315 self.stop_reason,
316 self.error_message.as_deref().unwrap_or("None"),
317 )
318 }
319
320 fn clear_caches(&mut self) {
321 self.is_accepting_cache = None;
322 self.ff_tokens_cache = None;
323 }
324
325 fn stop(&mut self, warn: &str, reason: StopReason) -> anyhow::Error {
326 if !warn.is_empty() {
327 self.error_message = Some(warn.to_string());
328 warn!(self, "{}; stopping", warn);
329 }
330 self.stop_reason = reason;
331 self.anyhow_error()
332 }
333
334 fn tok_trie(&self) -> &toktrie::TokTrie {
335 self.token_env.tok_trie()
336 }
337
338 pub fn error_message(&self) -> Option<String> {
339 self.error_message.clone()
340 }
341
342 fn check_initialized(&self, lbl: &str) -> Result<()> {
343 ensure!(!self.is_fresh, "process_prompt() not called in {}", lbl);
344 ensure!(
345 !self.stopped(),
346 "parser stopped in {}; {}",
347 lbl,
348 self.error_message()
349 .unwrap_or("no error message".to_string())
350 );
351 Ok(())
352 }
353
354 pub fn validate_token(&mut self, token: TokenId) -> Result<bool> {
355 if self.stopped() {
356 return Ok(false);
357 }
358 self.check_initialized("validate_token")?;
359 self.validate_tokens_raw(&[token]).map(|n| n > 0)
360 }
361
362 pub fn reset(&mut self) -> Result<()> {
363 self.rollback(self.llm_tokens.len())
364 }
365
366 pub fn rollback(&mut self, n_tokens: usize) -> Result<()> {
367 if n_tokens == 0 {
368 return Ok(());
369 }
370
371 ensure!(
372 n_tokens <= self.llm_tokens.len(),
373 "rollback: {} > {}",
374 n_tokens,
375 self.llm_tokens.len()
376 );
377
378 if self.stop_reason.is_ok() {
379 self.stop_reason = StopReason::NotStopped;
382 }
383
384 self.check_initialized("rollback")?;
386
387 self.had_rollback = true;
388
389 let new_len = self.llm_tokens.len() - n_tokens;
390 let mut bytes_to_drop = 0;
391 for tok in &self.llm_tokens[new_len..] {
392 if *tok == self.eos_token {
393 bytes_to_drop += 0;
395 } else {
396 bytes_to_drop += self.tok_trie().token_len(*tok);
397 }
398 }
399 ensure!(
400 bytes_to_drop <= self.llm_bytes.len(),
401 "rollback bytes: {} > {}",
402 bytes_to_drop,
403 self.llm_bytes.len()
404 );
405
406 self.parser.rollback(bytes_to_drop)?;
407
408 self.max_tokens_total = self.max_tokens_total.saturating_add(n_tokens);
409 self.llm_tokens.truncate(new_len);
410 self.llm_bytes
411 .truncate(self.llm_bytes.len() - bytes_to_drop);
412 self.clear_caches();
413
414 Ok(())
415 }
416
417 pub fn validate_tokens_raw(&mut self, tokens: &[TokenId]) -> Result<usize> {
421 if self.stopped() {
422 return Ok(0);
423 }
424 self.check_initialized("validate_tokens_raw")?;
425
426 if tokens.is_empty() {
427 return Ok(0);
428 }
429
430 let n_vocab = self.tok_trie().vocab_size();
431 for &t in tokens {
432 if t as usize >= n_vocab {
433 return Err(self.stop(
434 &format!("token id {t} out of range"),
435 StopReason::InternalError,
436 ));
437 }
438 }
439
440 let n_valid = self.parser.validate_tokens(tokens);
441 Ok(n_valid)
442 }
443
444 fn anyhow_error(&self) -> anyhow::Error {
445 anyhow::anyhow!(self
446 .error_message
447 .clone()
448 .unwrap_or(self.stop_reason.to_string()))
449 }
450
451 pub fn compute_mask(&mut self) -> Result<SimpleVob> {
454 self.compute_mask_start_time = Instant::now();
455 let r = self.compute_mask_inner();
456 self.parser
457 .perf_counters()
458 .compute_mask
459 .record(self.compute_mask_start_time.elapsed());
460 r
461 }
462
463 fn compute_mask_inner(&mut self) -> Result<SimpleVob> {
464 self.check_initialized("compute_mask")?;
465
466 infoln!(self, "compute_mask");
467
468 let prefix = if self.can_force_bytes() {
469 let (ff_tokens, token_prefix) = self
470 .ff_tokens_cache
471 .take()
472 .unwrap_or_else(|| self.ff_tokens());
473 if !ff_tokens.is_empty() {
474 let t = ff_tokens[0];
475 infoln!(self, "forcing ff_token by mask: {}", t);
476 let mask = self.tok_trie().singleton_token_set(t);
477 self.last_step_stats = ParserStats::default();
478 return Ok(mask);
479 } else {
480 token_prefix
482 }
483 } else {
484 let mut trg = Vec::new();
485 self.compute_ff_bytes_to(&mut trg);
486 trg
487 };
488
489 let mut allowed_tokens = self.compute_bias(&prefix);
490
491 if let Some(s) = self.parser.get_error() {
492 return Err(self.stop_for_parser_error("", s));
493 }
494
495 if self.eos_token != INVALID_TOKEN && self.is_accepting() {
496 allowed_tokens.allow_token(self.eos_token);
497 }
498
499 self.log_final(&prefix, &allowed_tokens);
500
501 if allowed_tokens.is_zero() {
502 infoln!(self, "no tokens allowed, stopping");
503 return Err(self.stop("", StopReason::NoExtensionBias));
504 }
505
506 Ok(allowed_tokens)
507 }
508
509 fn stop_for_parser_error(&mut self, pref: &str, err: ParserError) -> anyhow::Error {
510 self.stop(&format!("{}{}", pref, err.message()), err.stop_reason())
511 }
512
513 fn apply_token(&mut self, tok_id: TokenId) -> Result<usize> {
514 self.clear_caches();
515
516 let trie = self.token_env.tok_trie();
517
518 if (tok_id as usize) >= trie.vocab_size() {
519 return Err(self.stop(
520 &format!("token id {tok_id} out of range"),
521 StopReason::InternalError,
522 ));
523 }
524
525 self.llm_tokens.push(tok_id);
526
527 let tok_bytes = trie.decode_raw(&[tok_id]);
528
529 let prefix_len = self.grm_prefix.len().saturating_sub(self.llm_bytes.len());
531
532 infoln!(
533 self,
534 "consume_token: {} {} prefix={}",
535 tok_id,
536 trie.token_dbg(tok_id),
537 prefix_len
538 );
539
540 let tok_bytes = if prefix_len > 0 {
541 let to_apply = &tok_bytes[0..std::cmp::min(tok_bytes.len(), prefix_len)];
542 self.llm_bytes.extend_from_slice(to_apply);
543
544 if self.grm_prefix[0..self.llm_bytes.len()] != self.llm_bytes {
545 return Err(self.stop(
546 &format!(
547 "prefix mismatch: applying {:?}; {:?} vs {:?}",
548 String::from_utf8_lossy(to_apply),
549 String::from_utf8_lossy(&self.grm_prefix),
550 String::from_utf8_lossy(&self.llm_bytes)
551 ),
552 StopReason::InternalError,
553 ));
554 }
555
556 if prefix_len < tok_bytes.len() {
557 &tok_bytes[prefix_len..]
558 } else {
559 return Ok(0);
561 }
562 } else {
563 &tok_bytes
564 };
565
566 if let Some(err) = self.parser.get_error() {
567 return Err(self.stop_for_parser_error("", err));
568 }
569
570 match self.parser.apply_token(tok_bytes, tok_id) {
572 Err(e) => {
573 return Err(self.stop(
574 &format!("Parser Error: {e}"),
575 StopReason::ParserTooComplex, ));
577 }
578 Ok(backtrack_bytes0) => {
579 self.llm_bytes.extend_from_slice(tok_bytes);
580
581 if backtrack_bytes0 != 0 {
582 self.had_backtrack = true;
583 let mut backtrack_bytes: isize = backtrack_bytes0.try_into().unwrap();
584 let mut backtrack_tokens = 0;
585 while backtrack_bytes > 0 {
586 let tok_off = self.llm_tokens.len() - backtrack_tokens;
587 if tok_off == 0 {
588 break; }
590 let tok = self.llm_tokens[tok_off - 1];
591 backtrack_bytes -= trie.token_len(tok) as isize;
592 backtrack_tokens += 1;
593 }
594 assert!(backtrack_tokens > 0);
595 let additional_backtrack_bytes: usize = (-backtrack_bytes).try_into().unwrap();
596 let full_backtrack_bytes = backtrack_bytes0 + additional_backtrack_bytes;
597
598 let byte_ptr = self.llm_bytes.len() - full_backtrack_bytes;
599 infoln!(
600 self,
601 "backtrack: {} tokens / {}+{} bytes (deletes: {:?})",
602 backtrack_tokens,
603 backtrack_bytes0,
604 additional_backtrack_bytes,
605 String::from_utf8_lossy(&self.llm_bytes[byte_ptr..])
606 );
607 self.llm_bytes.truncate(byte_ptr);
608
609 let token_ptr = self.llm_tokens.len() - backtrack_tokens;
610 if !self.inference_caps.backtrack {
611 warn!(
612 self,
613 "can't backtrack over {}; this may confuse the model",
614 trie.tokens_dbg(&self.llm_tokens[token_ptr..])
615 );
616 backtrack_tokens = 0;
618 } else {
619 self.parser.additional_backtrack(additional_backtrack_bytes);
622 }
623 self.llm_tokens.truncate(token_ptr);
624 return Ok(backtrack_tokens);
625 }
626 }
627 }
628
629 Ok(0)
630 }
631
632 fn pending_grm_prefix(&self) -> &[u8] {
633 &self.grm_prefix[std::cmp::min(self.grm_prefix.len(), self.llm_bytes.len())..]
634 }
635
636 fn has_ff_bytes(&self) -> bool {
637 !self.pending_grm_prefix().is_empty() || !self.parser.currently_forced_bytes().is_empty()
638 }
639
640 fn can_force_bytes(&self) -> bool {
641 !self.parser.grammar().lexer_spec().no_forcing && self.token_env.tokenize_is_canonical()
642 }
643
644 pub fn force_bytes(&mut self) -> Vec<u8> {
645 self.parser.force_bytes();
646 let mut trg = Vec::new();
647 self.compute_ff_bytes_inner(&mut trg);
648 trg
649 }
650
651 fn compute_ff_bytes_to(&mut self, trg: &mut Vec<u8>) {
652 if self.can_force_bytes() {
654 self.parser.force_bytes();
655 }
656 self.compute_ff_bytes_inner(trg);
657 }
658
659 fn compute_ff_bytes_inner(&mut self, trg: &mut Vec<u8>) {
660 if self.llm_bytes.len() < self.grm_prefix.len() {
662 let inject = &self.grm_prefix[self.llm_bytes.len()..];
663 trg.extend_from_slice(inject);
664 infoln!(
665 self,
666 "injecting prefix: {:?}",
667 String::from_utf8_lossy(inject)
668 );
669 }
670
671 trg.extend_from_slice(self.parser.currently_forced_bytes());
672 }
673
674 fn ff_tokens(&mut self) -> (Vec<TokenId>, Vec<u8>) {
678 let mut forced_bytes = Vec::new();
679 let mut existing_tokens = if self.llm_tokens.is_empty() {
680 Vec::new()
681 } else {
682 let r = self.llm_tokens[self.llm_tokens.len() - 1..].to_vec();
683 let trie = self.token_env.tok_trie();
684 forced_bytes = trie.decode_raw(&r);
685 r
686 };
687 let num_existing_bytes = forced_bytes.len();
688
689 self.compute_ff_bytes_to(&mut forced_bytes);
690
691 let mut token_prefix = Vec::new();
692
693 let do_force =
694 forced_bytes.len() > num_existing_bytes && self.token_env.tokenize_is_canonical();
695 if do_force {
696 let t0 = Instant::now();
697 let (mut tokens, mut num_fixed) = self.token_env.tokenize_bytes_marker(&forced_bytes);
698 if !tokens.starts_with(&existing_tokens) {
699 let trie = self.token_env.tok_trie();
701 infoln!(
702 self,
703 "re-tokenizing without prefix: {}; because we got {}",
704 trie.tokens_dbg(&existing_tokens),
705 trie.tokens_dbg(&tokens),
706 );
707 (tokens, num_fixed) = self
708 .token_env
709 .tokenize_bytes_marker(&forced_bytes[num_existing_bytes..]);
710 infoln!(
711 self,
712 "re-tokenized: {} from: {:?}",
713 trie.tokens_dbg(&tokens),
714 &forced_bytes[num_existing_bytes..]
715 );
716 existing_tokens.clear();
717 } else {
718 num_fixed = std::cmp::max(existing_tokens.len(), num_fixed);
719 }
720
721 let (mut grm_tokens, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
722 assert!(grm_tokens.starts_with(&existing_tokens));
723 grm_tokens.drain(..existing_tokens.len());
724
725 let trie = self.token_env.tok_trie();
726 infoln!(
727 self,
728 "forced: {} bytes:{:?} tokens:{:?}",
729 trie.tokens_dbg(&grm_tokens),
730 &forced_bytes[num_existing_bytes..],
731 grm_tokens
732 );
733 token_prefix = forced_bytes[forced_bytes.len() - chop_bytes..].to_vec();
734
735 self.parser.perf_counters().tokenize_ff.record(t0.elapsed());
736
737 if !grm_tokens.is_empty() {
738 infoln!(
739 self,
740 "fixed_tokens: {}; prefix len {}",
741 trie.tokens_dbg(&grm_tokens),
742 token_prefix.len()
743 );
744 return (grm_tokens, token_prefix);
745 } else {
746 infoln!(self, "no fixed tokens; prefix len {}", token_prefix.len());
747 }
748 } else if forced_bytes.len() > num_existing_bytes {
749 infoln!(self, "not-forcing {} bytes", forced_bytes.len());
750 token_prefix = forced_bytes[num_existing_bytes..].to_vec();
751 }
752
753 (Vec::new(), token_prefix)
754 }
755
756 fn compute_bias(&mut self, token_prefix: &[u8]) -> SimpleVob {
757 let pre_stats = self.parser.stats().clone();
758 let set = self.parser.compute_bias(&*self.bias_computer, token_prefix);
759 let p_stats = self.parser.stats().delta(&pre_stats);
760 self.last_bias_time = Duration::from_micros(p_stats.compute_time_us);
761 self.last_step_stats = p_stats.clone();
762 self.max_step_stats = self.max_step_stats.max(&p_stats);
763 set
764 }
765
766 fn log_final(&mut self, token_prefix: &[u8], allowed_tokens: &SimpleVob) {
767 infoln!(
768 self,
769 "step-stats: {}us; {} lex fuel; {} items; {}",
770 self.compute_mask_start_time.elapsed().as_micros(),
771 self.last_step_stats.lexer_cost,
772 self.last_step_stats.all_items,
773 self.parser.lexer_stats(),
774 );
775
776 infoln!(
777 self,
778 "bias: (pref: {:?}; accpt: {}; temp: {:.3}) {}",
779 String::from_utf8_lossy(token_prefix),
780 self.is_accepting_cache.unwrap(),
781 self.parser.temperature().unwrap_or(0.0),
782 self.token_env.tok_trie().token_set_dbg(allowed_tokens)
783 );
784 }
785
786 pub fn temperature(&self) -> Option<f32> {
787 self.parser.temperature()
788 }
789
790 pub fn consume_token(&mut self, token: TokenId) -> Result<usize> {
793 self.check_initialized("consume_token")?;
794
795 if self.max_tokens_total == 0 {
796 return Err(self.stop("max_tokens_total reached", StopReason::MaxTokensTotal));
797 }
798 self.max_tokens_total -= 1;
799
800 if token == self.eos_token {
801 if self.parser.scan_eos() {
802 infoln!(self, "consume_token: scanned eos_token");
805 return Ok(0);
813 } else {
814 let accepting = self.is_accepting();
815 infoln!(
816 self,
817 "consume_token: eos_token not eaten by parser; accept={}",
818 accepting
819 );
820 if accepting {
821 self.llm_tokens.push(token);
822 return Ok(0);
823 }
824 }
825 }
826
827 let apply_res = self.apply_token(token);
828 self.parser.log_row_infos("post-apply");
829 match apply_res {
830 Err(_) => Err(self.anyhow_error()),
831 Ok(n) => Ok(n),
832 }
833 }
834
835 pub fn check_stop(&mut self) -> Result<bool> {
840 let empty_token_prefix = !self.has_ff_bytes();
841 let pending_eos = self.llm_tokens.last() == Some(&self.eos_token);
842 let lexer_bytes = self.parser.has_pending_lexeme_bytes();
843 let is_accepting = self.is_accepting();
844 let can_advance = self.parser.can_advance();
845 let parser_done = is_accepting && (!can_advance || pending_eos);
846 infoln!(
847 self,
848 "parser_done: {parser_done}; lexer_bytes: {lexer_bytes}; \
849 can_advance: {can_advance} (eos:{pending_eos}); \
850 accept: {is_accepting}; \
851 empty_token_prefix: {empty_token_prefix}"
852 );
853 assert!(!is_accepting || empty_token_prefix);
854
855 if parser_done {
856 infoln!(
857 self,
858 "only eos token allowed, stopping; accepting: {}",
859 is_accepting
860 );
861 let reason = if pending_eos {
862 StopReason::EndOfSentence
863 } else {
864 StopReason::NoExtension
865 };
866 self.stop("", reason);
867 Ok(true)
868 } else {
869 Ok(false)
870 }
871 }
872
873 pub fn compute_ff_tokens(&mut self) -> Vec<TokenId> {
876 let r = self.ff_tokens();
877 if self.can_force_bytes() {
878 self.ff_tokens_cache = Some(r.clone());
879 }
880 r.0
881 }
882
883 pub fn consume_ff_tokens(&mut self) -> Result<Vec<TokenId>> {
885 let ff_tokens = self.compute_ff_tokens();
886 for &t in &ff_tokens {
887 let num_backtrack = self.consume_token(t)?;
888 if num_backtrack > 0 {
889 return Err(self.stop(
890 &format!("backtrack required after ff_token: {t}"),
891 StopReason::InternalError,
892 ));
893 }
894 }
895 Ok(ff_tokens)
896 }
897
898 #[allow(dead_code)]
901 fn typical_use(&mut self, prompt: Vec<TokenId>) -> Result<()> {
902 let new_prompt = self.process_prompt(prompt);
906
907 black_box(new_prompt);
909
910 let mut tokens = vec![];
911
912 loop {
913 let temp = self.temperature();
914 let mask = self.compute_mask()?;
915
916 black_box((temp, mask));
920 let sampled_token = 42;
921
922 let num_backtrack = self.consume_token(sampled_token)?;
923
924 if num_backtrack == 0 {
925 tokens.push(sampled_token);
927 } else {
928 assert!(self.inference_caps.backtrack);
930 if num_backtrack == 1 {
931 } else if num_backtrack > 1 {
933 tokens.truncate(tokens.len() - num_backtrack - 1);
935 }
936 }
937
938 if self.check_stop()? {
943 break;
944 }
945
946 let forced = self.consume_ff_tokens()?;
949 tokens.extend_from_slice(&forced);
950 }
951
952 Ok(())
953 }
954}