1use std::{hint::black_box, panic::AssertUnwindSafe, sync::Arc, time::Duration};
2
3use crate::{
4 api::{GrammarInit, ParserLimits, StopReason, TopLevelGrammar},
5 earley::{BiasComputer, DefaultBiasComputer, Parser, ParserError, ParserStats},
6 infoln, panic_utils, warn, Instant, Logger,
7};
8use anyhow::{ensure, Result};
9use toktrie::{InferenceCapabilities, SimpleVob, TokEnv, TokenId, INVALID_TOKEN};
10
11#[derive(Clone)]
12pub struct TokenParser {
13 pub token_env: TokEnv,
14 pub parser: Parser,
15 pub compute_mask_start_time: Instant,
16 pub last_bias_time: Duration,
17 pub inference_caps: InferenceCapabilities,
18 pub logger: Logger,
19 pub limits: ParserLimits,
20 pub bias_computer: Arc<dyn BiasComputer>,
21 last_step_stats: ParserStats,
22 max_step_stats: ParserStats,
23 eos_token: TokenId,
24
25 is_accepting_cache: Option<bool>,
26 ff_tokens_cache: Option<(Vec<TokenId>, Vec<u8>)>,
27 stop_reason: StopReason,
28 error_message: Option<String>,
29 max_tokens_total: usize,
30
31 llm_tokens: Vec<TokenId>,
33 llm_bytes: Vec<u8>,
34
35 grm_prefix: Vec<u8>,
36 is_fresh: bool,
37}
38
39impl TokenParser {
40 pub fn from_init(
41 token_env: TokEnv,
42 grammar_init: GrammarInit,
43 logger: Logger,
44 inference_caps: InferenceCapabilities,
45 limits: ParserLimits,
46 extra_lexemes: Vec<String>,
47 ) -> Result<Self> {
48 panic_utils::catch_unwind(AssertUnwindSafe(|| {
49 Self::init_inner(
50 grammar_init,
51 token_env,
52 logger,
53 inference_caps,
54 limits,
55 extra_lexemes,
56 )
57 }))
58 }
59
60 pub fn from_grammar(
61 token_env: TokEnv,
62 top_grammar: TopLevelGrammar,
63 logger: Logger,
64 inference_caps: InferenceCapabilities,
65 limits: ParserLimits,
66 extra_lexemes: Vec<String>,
67 ) -> Result<Self> {
68 Self::from_init(
69 token_env,
70 GrammarInit::Serialized(top_grammar),
71 logger,
72 inference_caps,
73 limits,
74 extra_lexemes,
75 )
76 }
77
78 fn init_inner(
79 grammar_init: GrammarInit,
80 token_env: TokEnv,
81 mut logger: Logger,
82 inference_caps: InferenceCapabilities,
83 limits: ParserLimits,
84 extra_lexemes: Vec<String>,
85 ) -> Result<Self> {
86 ensure!(
87 token_env.tokenize_is_canonical() || !inference_caps.ff_tokens,
88 "ff_tokens requires canonical tokenization"
89 );
90 ensure!(
91 !inference_caps.backtrack || inference_caps.ff_tokens,
92 "backtrack requires ff_tokens"
93 );
94
95 let compute_mask_start_time = Instant::now();
96 let mut max_tokens = usize::MAX;
97 if let GrammarInit::Serialized(input) = &grammar_init {
98 if let Some(m) = input.max_tokens {
99 max_tokens = m;
100 }
101 }
102 let compiled_grammar = grammar_init.to_cgrammar(
103 Some(token_env.clone()),
104 &mut logger,
105 limits.clone(),
106 extra_lexemes,
107 )?;
108 let parser = Parser::new(token_env.clone(), compiled_grammar, limits.clone())?;
109 let eos_token = token_env.tok_trie().eos_token();
110
111 Ok(TokenParser {
112 bias_computer: Arc::new(DefaultBiasComputer::new(token_env.clone())),
113 logger,
114 token_env,
115 inference_caps,
116 limits,
117 max_step_stats: ParserStats::default(),
118 last_step_stats: ParserStats::default(),
119 compute_mask_start_time,
120 is_accepting_cache: None,
121 ff_tokens_cache: None,
122 stop_reason: StopReason::NotStopped,
123 error_message: None,
124 parser,
125 eos_token,
126 llm_tokens: Vec::new(),
127 llm_bytes: Vec::new(),
128 grm_prefix: Vec::new(),
129 max_tokens_total: max_tokens,
130 last_bias_time: Duration::from_secs(0),
131 is_fresh: true,
132 })
133 }
134
135 pub fn get_capture(&self, name: &str) -> Option<&[u8]> {
136 self.parser.get_capture(name)
137 }
138
139 pub fn deep_clone(&self) -> Self {
141 let mut copy = self.clone();
142 copy.parser = self.parser.deep_clone();
143 copy
144 }
145
146 pub fn stop_reason(&self) -> StopReason {
147 self.stop_reason
148 }
149
150 pub fn is_fresh(&self) -> bool {
151 self.is_fresh
152 }
153
154 pub fn parser_stats(&self) -> &ParserStats {
155 self.parser.stats()
156 }
157
158 pub fn last_step_stats(&self) -> &ParserStats {
159 &self.last_step_stats
160 }
161
162 pub fn max_step_stats(&self) -> &ParserStats {
163 &self.max_step_stats
164 }
165
166 pub fn num_tokens(&self) -> usize {
167 self.llm_tokens.len()
168 }
169
170 pub fn final_bytes(&self) -> &[u8] {
171 &self.llm_bytes[self.grm_prefix.len()..]
172 }
173
174 pub fn is_accepting(&mut self) -> bool {
175 if let Some(acc) = self.is_accepting_cache {
176 acc
177 } else {
178 let r = !self.has_ff_bytes() && self.parser.is_accepting();
179 self.is_accepting_cache = Some(r);
180 r
181 }
182 }
183
184 pub fn bytes_since(&self, mut idx: usize) -> &[u8] {
185 idx += self.grm_prefix.len();
186 let endp = std::cmp::min(self.llm_bytes.len(), self.parser.hidden_start());
187 if idx >= self.llm_bytes.len() || idx >= endp {
188 return &[];
189 }
190 &self.llm_bytes[idx..endp]
191 }
192
193 pub fn start_without_prompt(&mut self) {
194 infoln!(
195 self,
196 "initial lexer cost: {} (no prompt)",
197 self.parser.lexer_stats()
198 );
199
200 assert!(self.is_fresh);
201 self.is_fresh = false;
202 }
203
204 fn tokenize_and_chop(
205 &mut self,
206 mut tokens: Vec<TokenId>,
207 num_fixed: usize,
208 ) -> (Vec<TokenId>, usize) {
209 let trie = self.token_env.tok_trie();
210 let (chop_tokens, chop_bytes) = self
211 .parser
212 .with_recognizer(|r| trie.chop_tokens(r, &tokens[num_fixed..]));
213 infoln!(
214 self,
215 "tokenize -> {}; chop: {} tokens, {} bytes",
216 trie.tokens_dbg(&tokens),
217 chop_tokens,
218 chop_bytes
219 );
220 tokens.truncate(tokens.len() - chop_tokens);
222 (tokens, chop_bytes)
223 }
224
225 pub fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
226 infoln!(self, "initial lexer cost: {}", self.parser.lexer_stats());
227
228 assert!(self.token_env.tokenize_is_canonical());
229 assert!(self.is_fresh);
230 self.is_fresh = false;
231
232 assert!(self.llm_tokens.is_empty());
233
234 let trie = self.token_env.tok_trie();
235 infoln!(self, "prompt: {}", trie.tokens_dbg(&prompt));
236 let mut prompt_bytes = trie.decode_raw(&prompt);
237 if self.can_force_bytes() {
238 self.parser.force_bytes();
239 }
240 let grm_bytes = self.parser.get_bytes().to_vec();
241 prompt_bytes.extend_from_slice(&grm_bytes);
242
243 let (tokens, num_fixed) = self.token_env.tokenize_bytes_marker(&prompt_bytes);
244 let (res_prompt, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
245
246 let trie = self.token_env.tok_trie();
247 infoln!(
248 self,
249 "prompt+grm: {} {}",
250 trie.tokens_dbg(&res_prompt),
251 self.parser.grammar().lexer_spec().no_forcing
252 );
253
254 if chop_bytes <= grm_bytes.len() {
256 self.llm_bytes = grm_bytes[0..grm_bytes.len() - chop_bytes].to_vec();
257 self.llm_tokens = self.token_env.tokenize_bytes_marker(&self.llm_bytes).0;
258 self.parser.apply_forced(self.llm_bytes.len());
259 let decoded = self.tok_trie().decode_raw(&self.llm_tokens);
260 if !self.llm_bytes.is_empty()
261 && !decoded.is_empty()
262 && decoded[1..] == self.llm_bytes
263 && decoded[0] == b' '
264 {
265 infoln!(self, "applying <s>space hack");
266 self.grm_prefix = decoded[0..1].to_vec();
267 self.llm_bytes = decoded;
268 }
269 infoln!(self, "ini_tokens: {}", trie.tokens_dbg(&self.llm_tokens));
270 } else {
271 self.grm_prefix = prompt_bytes
273 [prompt_bytes.len() - chop_bytes..prompt_bytes.len() - grm_bytes.len()]
274 .to_vec();
275 infoln!(
276 self,
277 "force_prefix: {:?}",
278 String::from_utf8_lossy(&self.grm_prefix)
279 );
280 }
281
282 infoln!(self, "res_prompt: {}", trie.tokens_dbg(&res_prompt));
283 res_prompt
284 }
285
286 fn clear_caches(&mut self) {
287 self.is_accepting_cache = None;
288 self.ff_tokens_cache = None;
289 }
290
291 fn stop(&mut self, warn: &str, reason: StopReason) -> anyhow::Error {
292 if !warn.is_empty() {
293 self.error_message = Some(warn.to_string());
294 warn!(self, "{}; stopping", warn);
295 }
296 self.stop_reason = reason;
297 self.anyhow_error()
298 }
299
300 fn tok_trie(&self) -> &toktrie::TokTrie {
301 self.token_env.tok_trie()
302 }
303
304 pub fn error_message(&self) -> Option<String> {
305 self.error_message.clone()
306 }
307
308 fn check_initialized(&self, lbl: &str) -> Result<()> {
309 ensure!(!self.is_fresh, "process_prompt() not called in {}", lbl);
310 ensure!(
311 self.stop_reason == StopReason::NotStopped,
312 "parser stopped in {}; {}",
313 lbl,
314 self.error_message()
315 .unwrap_or("no error message".to_string())
316 );
317 Ok(())
318 }
319
320 pub fn validate_token(&mut self, token: TokenId) -> Result<bool> {
321 self.check_initialized("validate_token")?;
322 self.validate_tokens_raw(&[token]).map(|n| n > 0)
323 }
324
325 pub fn reset(&mut self) -> Result<()> {
326 self.rollback(self.llm_tokens.len())
327 }
328
329 pub fn rollback(&mut self, n_tokens: usize) -> Result<()> {
330 if n_tokens == 0 {
331 return Ok(());
332 }
333
334 ensure!(
335 n_tokens <= self.llm_tokens.len(),
336 "rollback: {} > {}",
337 n_tokens,
338 self.llm_tokens.len()
339 );
340
341 if self.stop_reason.is_ok() {
342 self.stop_reason = StopReason::NotStopped;
345 }
346
347 self.check_initialized("rollback")?;
349
350 let new_len = self.llm_tokens.len() - n_tokens;
351 let mut bytes_to_drop = 0;
352 for tok in &self.llm_tokens[new_len..] {
353 if *tok == self.eos_token {
354 bytes_to_drop += 0;
356 } else {
357 bytes_to_drop += self.tok_trie().token_len(*tok);
358 }
359 }
360 ensure!(
361 bytes_to_drop <= self.llm_bytes.len(),
362 "rollback bytes: {} > {}",
363 bytes_to_drop,
364 self.llm_bytes.len()
365 );
366
367 self.parser.rollback(bytes_to_drop)?;
368
369 self.max_tokens_total = self.max_tokens_total.saturating_add(n_tokens);
370 self.llm_tokens.truncate(new_len);
371 self.llm_bytes
372 .truncate(self.llm_bytes.len() - bytes_to_drop);
373 self.clear_caches();
374
375 Ok(())
376 }
377
378 pub fn validate_tokens_raw(&mut self, tokens: &[TokenId]) -> Result<usize> {
382 self.check_initialized("validate_tokens_raw")?;
383
384 if tokens.is_empty() {
385 return Ok(0);
386 }
387
388 let n_vocab = self.tok_trie().vocab_size();
389 for &t in tokens {
390 if t as usize >= n_vocab {
391 return Err(self.stop(
392 &format!("token id {} out of range", t),
393 StopReason::InternalError,
394 ));
395 }
396 }
397
398 let n_valid = self.parser.validate_tokens(tokens);
399 Ok(n_valid)
400 }
401
402 fn anyhow_error(&self) -> anyhow::Error {
403 anyhow::anyhow!(self
404 .error_message
405 .clone()
406 .unwrap_or(self.stop_reason.to_string()))
407 }
408
409 pub fn compute_mask(&mut self) -> Result<SimpleVob> {
412 self.compute_mask_start_time = Instant::now();
413 let r = self.compute_mask_inner();
414 self.parser
415 .perf_counters()
416 .compute_mask
417 .record(self.compute_mask_start_time.elapsed());
418 r
419 }
420
421 fn compute_mask_inner(&mut self) -> Result<SimpleVob> {
422 self.check_initialized("compute_mask")?;
423
424 infoln!(self, "compute_mask");
425
426 let prefix = if self.can_force_bytes() {
427 let (ff_tokens, token_prefix) = self
428 .ff_tokens_cache
429 .take()
430 .unwrap_or_else(|| self.ff_tokens());
431 if !ff_tokens.is_empty() {
432 let t = ff_tokens[0];
433 infoln!(self, "forcing ff_token by mask: {}", t);
434 let mask = self.tok_trie().singleton_token_set(t);
435 self.last_step_stats = ParserStats::default();
436 return Ok(mask);
437 } else {
438 token_prefix
440 }
441 } else {
442 let mut trg = Vec::new();
443 self.compute_ff_bytes_to(&mut trg);
444 trg
445 };
446
447 let mut allowed_tokens = self.compute_bias(&prefix);
448
449 if let Some(s) = self.parser.get_error() {
450 return Err(self.stop_for_parser_error("", s));
451 }
452
453 if self.eos_token != INVALID_TOKEN && self.is_accepting() {
454 allowed_tokens.allow_token(self.eos_token);
455 }
456
457 self.log_final(&prefix, &allowed_tokens);
458
459 if allowed_tokens.is_zero() {
460 infoln!(self, "no tokens allowed, stopping");
461 return Err(self.stop("", StopReason::NoExtensionBias));
462 }
463
464 Ok(allowed_tokens)
465 }
466
467 fn stop_for_parser_error(&mut self, pref: &str, err: ParserError) -> anyhow::Error {
468 self.stop(&format!("{}{}", pref, err.message()), err.stop_reason())
469 }
470
471 fn apply_token(&mut self, tok_id: TokenId) -> Result<usize> {
472 self.clear_caches();
473
474 let trie = self.token_env.tok_trie();
475
476 if (tok_id as usize) >= trie.vocab_size() {
477 return Err(self.stop(
478 &format!("token id {} out of range", tok_id),
479 StopReason::InternalError,
480 ));
481 }
482
483 self.llm_tokens.push(tok_id);
484
485 let tok_bytes = trie.decode_raw(&[tok_id]);
486
487 let prefix_len = self.grm_prefix.len().saturating_sub(self.llm_bytes.len());
489
490 infoln!(
491 self,
492 "consume_token: {} {} prefix={}",
493 tok_id,
494 trie.token_dbg(tok_id),
495 prefix_len
496 );
497
498 let tok_bytes = if prefix_len > 0 {
499 let to_apply = &tok_bytes[0..std::cmp::min(tok_bytes.len(), prefix_len)];
500 self.llm_bytes.extend_from_slice(to_apply);
501
502 if self.grm_prefix[0..self.llm_bytes.len()] != self.llm_bytes {
503 return Err(self.stop(
504 &format!(
505 "prefix mismatch: applying {:?}; {:?} vs {:?}",
506 String::from_utf8_lossy(to_apply),
507 String::from_utf8_lossy(&self.grm_prefix),
508 String::from_utf8_lossy(&self.llm_bytes)
509 ),
510 StopReason::InternalError,
511 ));
512 }
513
514 if prefix_len < tok_bytes.len() {
515 &tok_bytes[prefix_len..]
516 } else {
517 return Ok(0);
519 }
520 } else {
521 &tok_bytes
522 };
523
524 if let Some(err) = self.parser.get_error() {
525 return Err(self.stop_for_parser_error("", err));
526 }
527
528 match self.parser.apply_token(tok_bytes) {
530 Err(e) => {
531 return Err(self.stop(
532 &format!("Parser Error: {}", e),
533 StopReason::ParserTooComplex, ));
535 }
536 Ok(backtrack_bytes0) => {
537 self.llm_bytes.extend_from_slice(tok_bytes);
538
539 if backtrack_bytes0 != 0 {
540 let mut backtrack_bytes: isize = backtrack_bytes0.try_into().unwrap();
541 let mut backtrack_tokens = 0;
542 while backtrack_bytes > 0 {
543 let tok_off = self.llm_tokens.len() - backtrack_tokens;
544 if tok_off == 0 {
545 break; }
547 let tok = self.llm_tokens[tok_off - 1];
548 backtrack_bytes -= trie.token_len(tok) as isize;
549 backtrack_tokens += 1;
550 }
551 assert!(backtrack_tokens > 0);
552 let additional_backtrack_bytes: usize = (-backtrack_bytes).try_into().unwrap();
553 let full_backtrack_bytes = backtrack_bytes0 + additional_backtrack_bytes;
554
555 let byte_ptr = self.llm_bytes.len() - full_backtrack_bytes;
556 infoln!(
557 self,
558 "backtrack: {} tokens / {}+{} bytes (deletes: {:?})",
559 backtrack_tokens,
560 backtrack_bytes0,
561 additional_backtrack_bytes,
562 String::from_utf8_lossy(&self.llm_bytes[byte_ptr..])
563 );
564 self.llm_bytes.truncate(byte_ptr);
565
566 let token_ptr = self.llm_tokens.len() - backtrack_tokens;
567 if !self.inference_caps.backtrack {
568 warn!(
569 self,
570 "can't backtrack over {}; this may confuse the model",
571 trie.tokens_dbg(&self.llm_tokens[token_ptr..])
572 );
573 backtrack_tokens = 0;
575 } else {
576 self.parser.additional_backtrack(additional_backtrack_bytes);
579 }
580 self.llm_tokens.truncate(token_ptr);
581 return Ok(backtrack_tokens);
582 }
583 }
584 }
585
586 Ok(0)
587 }
588
589 fn pending_grm_prefix(&self) -> &[u8] {
590 &self.grm_prefix[std::cmp::min(self.grm_prefix.len(), self.llm_bytes.len())..]
591 }
592
593 fn has_ff_bytes(&self) -> bool {
594 !self.pending_grm_prefix().is_empty() || !self.parser.currently_forced_bytes().is_empty()
595 }
596
597 fn can_force_bytes(&self) -> bool {
598 !self.parser.grammar().lexer_spec().no_forcing && self.token_env.tokenize_is_canonical()
599 }
600
601 pub fn force_bytes(&mut self) -> Vec<u8> {
602 self.parser.force_bytes();
603 let mut trg = Vec::new();
604 self.compute_ff_bytes_inner(&mut trg);
605 trg
606 }
607
608 fn compute_ff_bytes_to(&mut self, trg: &mut Vec<u8>) {
609 if self.can_force_bytes() {
611 self.parser.force_bytes();
612 }
613 self.compute_ff_bytes_inner(trg);
614 }
615
616 fn compute_ff_bytes_inner(&mut self, trg: &mut Vec<u8>) {
617 if self.llm_bytes.len() < self.grm_prefix.len() {
619 let inject = &self.grm_prefix[self.llm_bytes.len()..];
620 trg.extend_from_slice(inject);
621 infoln!(
622 self,
623 "injecting prefix: {:?}",
624 String::from_utf8_lossy(inject)
625 );
626 }
627
628 trg.extend_from_slice(self.parser.currently_forced_bytes());
629 }
630
631 fn ff_tokens(&mut self) -> (Vec<TokenId>, Vec<u8>) {
635 let mut forced_bytes = Vec::new();
636 let mut existing_tokens = if self.llm_tokens.is_empty() {
637 Vec::new()
638 } else {
639 let r = self.llm_tokens[self.llm_tokens.len() - 1..].to_vec();
640 let trie = self.token_env.tok_trie();
641 forced_bytes = trie.decode_raw(&r);
642 r
643 };
644 let num_existing_bytes = forced_bytes.len();
645
646 self.compute_ff_bytes_to(&mut forced_bytes);
647
648 let mut token_prefix = Vec::new();
649
650 let do_force =
651 forced_bytes.len() > num_existing_bytes && self.token_env.tokenize_is_canonical();
652 if do_force {
653 let t0 = Instant::now();
654 let (mut tokens, mut num_fixed) = self.token_env.tokenize_bytes_marker(&forced_bytes);
655 if !tokens.starts_with(&existing_tokens) {
656 let trie = self.token_env.tok_trie();
658 infoln!(
659 self,
660 "re-tokenizing without prefix: {}; because we got {}",
661 trie.tokens_dbg(&existing_tokens),
662 trie.tokens_dbg(&tokens),
663 );
664 (tokens, num_fixed) = self
665 .token_env
666 .tokenize_bytes_marker(&forced_bytes[num_existing_bytes..]);
667 infoln!(
668 self,
669 "re-tokenized: {} from: {:?}",
670 trie.tokens_dbg(&tokens),
671 &forced_bytes[num_existing_bytes..]
672 );
673 existing_tokens.clear();
674 } else {
675 num_fixed = std::cmp::max(existing_tokens.len(), num_fixed);
676 }
677
678 let (mut grm_tokens, chop_bytes) = self.tokenize_and_chop(tokens, num_fixed);
679 assert!(grm_tokens.starts_with(&existing_tokens));
680 grm_tokens.drain(..existing_tokens.len());
681
682 let trie = self.token_env.tok_trie();
683 infoln!(
684 self,
685 "forced: {} bytes:{:?} tokens:{:?}",
686 trie.tokens_dbg(&grm_tokens),
687 &forced_bytes[num_existing_bytes..],
688 grm_tokens
689 );
690 token_prefix = forced_bytes[forced_bytes.len() - chop_bytes..].to_vec();
691
692 self.parser.perf_counters().tokenize_ff.record(t0.elapsed());
693
694 if !grm_tokens.is_empty() {
695 infoln!(
696 self,
697 "fixed_tokens: {}; prefix len {}",
698 trie.tokens_dbg(&grm_tokens),
699 token_prefix.len()
700 );
701 return (grm_tokens, token_prefix);
702 } else {
703 infoln!(self, "no fixed tokens; prefix len {}", token_prefix.len());
704 }
705 } else if forced_bytes.len() > num_existing_bytes {
706 infoln!(self, "not-forcing {} bytes", forced_bytes.len());
707 token_prefix = forced_bytes[num_existing_bytes..].to_vec();
708 }
709
710 (Vec::new(), token_prefix)
711 }
712
713 fn compute_bias(&mut self, token_prefix: &[u8]) -> SimpleVob {
714 let pre_stats = self.parser.stats().clone();
715 let set = self.parser.compute_bias(&*self.bias_computer, token_prefix);
716 let p_stats = self.parser.stats().delta(&pre_stats);
717 self.last_bias_time = Duration::from_micros(p_stats.compute_time_us);
718 self.last_step_stats = p_stats.clone();
719 self.max_step_stats = self.max_step_stats.max(&p_stats);
720 set
721 }
722
723 fn log_final(&mut self, token_prefix: &[u8], allowed_tokens: &SimpleVob) {
724 infoln!(
725 self,
726 "step-stats: {}us; {} lex fuel; {} items; {}",
727 self.compute_mask_start_time.elapsed().as_micros(),
728 self.last_step_stats.lexer_cost,
729 self.last_step_stats.all_items,
730 self.parser.lexer_stats(),
731 );
732
733 infoln!(
734 self,
735 "bias: (pref: {:?}; accpt: {}; temp: {:.3}) {}",
736 String::from_utf8_lossy(token_prefix),
737 self.is_accepting_cache.unwrap(),
738 self.parser.temperature().unwrap_or(0.0),
739 self.token_env.tok_trie().token_set_dbg(allowed_tokens)
740 );
741 }
742
743 pub fn temperature(&self) -> Option<f32> {
744 self.parser.temperature()
745 }
746
747 pub fn consume_token(&mut self, token: TokenId) -> Result<usize> {
750 self.check_initialized("consume_token")?;
751
752 if self.max_tokens_total == 0 {
753 return Err(self.stop("max_tokens_total reached", StopReason::MaxTokensTotal));
754 }
755 self.max_tokens_total -= 1;
756
757 if token == self.eos_token {
758 if self.parser.scan_eos() {
759 infoln!(self, "consume_token: scanned eos_token");
762 return Ok(0);
770 } else {
771 let accepting = self.is_accepting();
772 infoln!(
773 self,
774 "consume_token: eos_token not eaten by parser; accept={}",
775 accepting
776 );
777 if accepting {
778 self.llm_tokens.push(token);
779 return Ok(0);
780 }
781 }
782 }
783
784 let apply_res = self.apply_token(token);
785 self.parser.log_row_infos("post-apply");
786 match apply_res {
787 Err(_) => Err(self.anyhow_error()),
788 Ok(n) => Ok(n),
789 }
790 }
791
792 pub fn check_stop(&mut self) -> Result<bool> {
797 let empty_token_prefix = !self.has_ff_bytes();
798 let pending_eos = self.llm_tokens.last() == Some(&self.eos_token);
799 let lexer_bytes = self.parser.has_pending_lexeme_bytes();
800 let is_accepting = self.is_accepting();
801 let can_advance = self.parser.can_advance();
802 let parser_done = is_accepting && (!can_advance || pending_eos);
803 infoln!(
804 self,
805 "parser_done: {parser_done}; lexer_bytes: {lexer_bytes}; \
806 can_advance: {can_advance} (eos:{pending_eos}); \
807 accept: {is_accepting}; \
808 empty_token_prefix: {empty_token_prefix}"
809 );
810 assert!(!is_accepting || empty_token_prefix);
811
812 if parser_done {
813 infoln!(
814 self,
815 "only eos token allowed, stopping; accepting: {}",
816 is_accepting
817 );
818 let reason = if pending_eos {
819 StopReason::EndOfSentence
820 } else {
821 StopReason::NoExtension
822 };
823 self.stop("", reason);
824 Ok(true)
825 } else {
826 Ok(false)
827 }
828 }
829
830 pub fn compute_ff_tokens(&mut self) -> Vec<TokenId> {
833 let r = self.ff_tokens();
834 if self.can_force_bytes() {
835 self.ff_tokens_cache = Some(r.clone());
836 }
837 r.0
838 }
839
840 pub fn consume_ff_tokens(&mut self) -> Result<Vec<TokenId>> {
842 let ff_tokens = self.compute_ff_tokens();
843 for &t in &ff_tokens {
844 let num_backtrack = self.consume_token(t)?;
845 if num_backtrack > 0 {
846 return Err(self.stop(
847 &format!("backtrack required after ff_token: {}", t),
848 StopReason::InternalError,
849 ));
850 }
851 }
852 Ok(ff_tokens)
853 }
854
855 #[allow(dead_code)]
858 fn typical_use(&mut self, prompt: Vec<TokenId>) -> Result<()> {
859 let new_prompt = self.process_prompt(prompt);
863
864 black_box(new_prompt);
866
867 let mut tokens = vec![];
868
869 loop {
870 let temp = self.temperature();
871 let mask = self.compute_mask()?;
872
873 black_box((temp, mask));
877 let sampled_token = 42;
878
879 let num_backtrack = self.consume_token(sampled_token)?;
880
881 if num_backtrack == 0 {
882 tokens.push(sampled_token);
884 } else {
885 assert!(self.inference_caps.backtrack);
887 if num_backtrack == 1 {
888 } else if num_backtrack > 1 {
890 tokens.truncate(tokens.len() - num_backtrack - 1);
892 }
893 }
894
895 if self.check_stop()? {
900 break;
901 }
902
903 let forced = self.consume_ff_tokens()?;
906 tokens.extend_from_slice(&forced);
907 }
908
909 Ok(())
910 }
911}