1use core::str;
5
6use bytemuck_derive::{Pod, Zeroable};
7
8use crate::{bytes::to_hex_string, tokenv::parse_numeric_token, SimpleVob};
9
10pub type TokenId = u32;
12
13#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)]
14#[repr(C)]
15pub struct BinTokRxInfo {
16 pub vocab_size: u32,
17 pub tok_eos: TokenId,
18}
19
20#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub struct TokRxInfo {
22 pub vocab_size: u32,
23 pub tok_eos: TokenId,
24 pub tok_bos: Option<TokenId>,
25 pub tok_pad: Option<TokenId>,
26 pub tok_unk: Option<TokenId>,
27 pub tok_end_of_turn: Option<TokenId>,
28}
29
30impl TokRxInfo {
31 pub fn new(vocab_size: u32, tok_eos: TokenId) -> Self {
32 TokRxInfo {
33 vocab_size,
34 tok_eos,
35 tok_bos: None,
36 tok_pad: None,
37 tok_unk: None,
38 tok_end_of_turn: None,
39 }
40 }
41
42 pub fn from_bin(info: &BinTokRxInfo) -> Self {
43 TokRxInfo {
44 vocab_size: info.vocab_size,
45 tok_eos: info.tok_eos,
46 tok_bos: None,
47 tok_pad: None,
48 tok_unk: None,
49 tok_end_of_turn: None,
50 }
51 }
52
53 pub fn to_bin(&self) -> BinTokRxInfo {
54 BinTokRxInfo {
55 vocab_size: self.vocab_size,
56 tok_eos: self.tok_eos,
57 }
58 }
59}
60
61pub trait Recognizer {
69 fn pop_bytes(&mut self, num: usize);
71 fn collapse(&mut self);
75 fn byte_allowed(&mut self, byte: u8) -> bool {
77 if self.try_push_byte(byte) {
78 self.pop_bytes(1);
79 true
80 } else {
81 false
82 }
83 }
84 fn trie_finished(&mut self);
88 fn trie_started(&mut self, _dbg_lbl: &str) {}
90 fn try_push_byte(&mut self, byte: u8) -> bool;
92 fn get_error(&mut self) -> Option<String> {
94 None
95 }
96 fn save_stats(&mut self, _nodes_walked: usize) {}
97}
98
99#[derive(Clone, Copy)]
100struct TokDesc {
101 len: u32,
102 off: u32,
103}
104
105#[derive(Clone)]
112pub struct TokTrie {
113 info: TokRxInfo,
114 token_offsets: Vec<TokDesc>,
115 token_data: Vec<u8>,
116 nodes: Vec<TrieNode>,
117 max_token_len: usize,
118 eos_tokens: Vec<TokenId>,
119}
120
121#[derive(Clone, Copy, Zeroable, Pod)]
122#[repr(C)]
123pub struct TrieNode {
124 bits: u32,
126 bits2: u32,
127}
128
129pub const INVALID_TOKEN: TokenId = 0xffff_ffff;
130
131const NO_TOKEN: u32 = 0xffffff;
132
133const PARENT_BITS: u32 = 10;
139const PARENT_MASK: u32 = (1 << PARENT_BITS) - 1;
140
141impl TrieNode {
142 fn new(byte: u8, token_id: u32, num_parents: usize) -> TrieNode {
143 assert!(num_parents > 0);
144 assert!(num_parents <= (1 << PARENT_BITS) as usize);
145 TrieNode {
146 bits: (token_id << 8) | byte as u32,
147 bits2: (num_parents - 1) as u32,
148 }
149 }
150
151 #[inline(always)]
152 pub fn byte(&self) -> u8 {
153 (self.bits & 0xff) as u8
154 }
155
156 #[inline(always)]
157 pub fn subtree_size(&self) -> usize {
158 (self.bits2 >> PARENT_BITS) as usize
159 }
160
161 fn set_subtree_size(&mut self, size: usize) {
162 assert!(size < (1 << (32 - PARENT_BITS)));
163 self.bits2 = (self.bits2 & PARENT_MASK) | ((size as u32) << PARENT_BITS);
164 }
165
166 #[inline(always)]
167 pub fn num_parents(&self) -> usize {
168 ((self.bits2 & PARENT_MASK) + 1) as usize
169 }
170
171 #[inline(always)]
172 pub fn token_id(&self) -> Option<u32> {
173 let r = self.bits >> 8;
174 if r == NO_TOKEN {
175 None
176 } else {
177 Some(r)
178 }
179 }
180}
181
182impl TokTrie {
183 pub const SPECIAL_TOKEN_MARKER: u8 = 0xff;
185
186 pub fn from(info: &TokRxInfo, words: &[Vec<u8>]) -> Self {
187 let mut trie = TrieHash::new(0xff);
188 let mut token_offsets = Vec::new();
189 let mut token_data = Vec::new();
190 assert!(info.vocab_size == words.len() as u32);
191 let mut max_token_len = 0;
192 for (idx, word) in words.iter().enumerate() {
193 if !word.is_empty() {
194 trie.insert(word, idx as u32);
195 max_token_len = std::cmp::max(max_token_len, word.len());
196 }
197 let desc = TokDesc {
198 len: word.len().try_into().unwrap(),
199 off: token_data.len().try_into().unwrap(),
200 };
201 token_offsets.push(desc);
202 token_data.extend_from_slice(word);
203 }
204 let mut nodes = Vec::new();
205 trie.serialize(&mut nodes, 0);
206 let r = TokTrie {
207 info: *info,
208 token_offsets,
209 token_data,
210 nodes,
211 max_token_len,
212 eos_tokens: vec![info.tok_eos],
213 };
214 r.validate();
215 r
216 }
217
218 pub fn filter(&self, filter: &SimpleVob) -> Self {
219 let mut words = vec![];
220 for n in 0..(self.vocab_size() as TokenId) {
221 let b = if filter.is_allowed(n) {
222 self.token(n)
223 } else {
224 &[]
225 };
226 words.push(b.to_vec());
227 }
228 let mut r = Self::from(self.info(), &words);
229 r.eos_tokens = self.eos_tokens.clone();
230 r
231 }
232
233 pub fn with_eos_token(&self, eos_token: TokenId) -> Self {
234 self.with_eos_tokens(&[eos_token])
235 }
236
237 pub fn with_eos_tokens(&self, eos_tokens: &[TokenId]) -> Self {
238 assert!(!eos_tokens.is_empty(), "eos_tokens must not be empty");
239 let vocab = self.vocab_size() as u32;
240 for &tok in eos_tokens {
241 assert!(
242 tok < vocab,
243 "EOS token ID {tok} is out of range (vocab_size={vocab})"
244 );
245 }
246 let mut r = self.clone();
247 r.info.tok_eos = eos_tokens[0];
248 r.eos_tokens = eos_tokens.to_vec();
249 r
250 }
251
252 pub fn with_info(&self, info: TokRxInfo) -> Self {
253 let mut r = self.clone();
254 r.info = info;
255 r.eos_tokens = vec![info.tok_eos];
256 r
257 }
258
259 pub fn build_chat_mode_trie(&self) -> Self {
260 self.with_eos_token(self.info.tok_end_of_turn.unwrap_or(self.info.tok_eos))
261 }
262
263 fn node_offset(&self, n: &TrieNode) -> usize {
264 let off = (n as *const _ as usize - self.root() as *const _ as usize)
265 / std::mem::size_of::<TrieNode>();
266 assert!(off < self.nodes.len());
267 off
268 }
269
270 fn next_node(&self, n: &TrieNode) -> usize {
271 self.node_offset(n) + n.subtree_size()
272 }
273
274 pub fn info(&self) -> &TokRxInfo {
275 &self.info
276 }
277
278 pub fn eos_token(&self) -> TokenId {
279 self.info.tok_eos
280 }
281
282 pub fn eos_tokens(&self) -> &[TokenId] {
283 &self.eos_tokens
284 }
285
286 pub fn vocab_size(&self) -> usize {
287 self.info.vocab_size as usize
288 }
289
290 pub fn alloc_token_set(&self) -> SimpleVob {
291 SimpleVob::alloc_with_capacity(self.vocab_size(), self.vocab_size() + 1)
292 }
293
294 pub fn singleton_token_set(&self, tok: TokenId) -> SimpleVob {
295 let mut r = self.alloc_token_set();
296 r.allow_token(tok);
297 r
298 }
299
300 pub fn eos_token_set(&self) -> SimpleVob {
302 let mut r = self.alloc_token_set();
303 let vocab = self.vocab_size() as u32;
304 for &eos in self.eos_tokens() {
305 if eos != INVALID_TOKEN && eos < vocab {
306 r.allow_token(eos);
307 }
308 }
309 r
310 }
311
312 pub fn token_set_dbg(&self, ts: &SimpleVob) -> String {
313 let max_examples = 50;
314
315 let ts_neg = ts.negated();
316 let use_neg = ts_neg.num_set() * 10 < ts.num_set();
317 let ts1 = if use_neg { &ts_neg } else { ts };
318 let num_set = ts1.num_set();
319 let max_tok = std::cmp::min(max_examples, num_set);
320 let mut token_names = Vec::new();
321 if self.info.tok_eos != INVALID_TOKEN && ts1.is_allowed(self.info.tok_eos) {
323 token_names.push("EOS".to_string());
324 }
325 for idx in 0..self.vocab_size() {
326 if idx as TokenId != self.info.tok_eos && ts1.is_allowed(idx as TokenId) {
327 token_names.push(self.token_dbg(idx as TokenId));
328 if token_names.len() >= max_tok {
329 break;
330 }
331 }
332 }
333 if token_names.len() < num_set {
334 token_names.push("...".to_string());
335 }
336 format!(
337 "TokenSet: {}/{}; {}{}",
338 ts.num_set(),
339 self.vocab_size(),
340 if use_neg { "ALL EXCEPT " } else { "" },
341 token_names.join(" ")
342 )
343 }
344
345 pub fn alloc_logits(&self) -> Vec<f32> {
346 vec![0.0; self.vocab_size() + 1]
347 }
348
349 pub fn test_trace_tokens(&self, toks: &[u32]) -> String {
350 self.tokens_dbg_ext(toks, false)
351 }
352
353 pub const MAX_DBG_TOKENS: usize = 200;
354
355 pub fn tokens_dbg(&self, toks: &[u32]) -> String {
356 self.tokens_dbg_ext(toks, true)
357 }
358
359 fn tokens_dbg_ext(&self, toks: &[u32], quote: bool) -> String {
360 let (limited, toks) = if toks.len() > Self::MAX_DBG_TOKENS {
362 ("…", &toks[toks.len() - Self::MAX_DBG_TOKENS..])
363 } else {
364 ("", toks)
365 };
366
367 let joined = toks
368 .iter()
369 .map(|t| self.token_dbg_ext(*t, false))
370 .collect::<Vec<_>>()
371 .join("‧");
372
373 if quote {
374 format!("⟦{limited}{joined}⟧")
375 } else if limited.is_empty() {
376 joined
377 } else {
378 format!("{limited}{joined}")
379 }
380 }
381
382 pub fn token_dbg(&self, idx: u32) -> String {
383 self.token_dbg_ext(idx, true)
384 }
385
386 fn token_dbg_ext(&self, idx: u32, quote: bool) -> String {
387 if idx == self.info.tok_eos {
388 "≺EOS≻".to_string()
389 } else if idx as usize >= self.vocab_size() {
390 format!("≺OOB[{idx}]≻")
391 } else {
392 let bytes = self.token(idx);
394 if bytes.len() > 1 && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER {
395 String::from_utf8_lossy(&bytes[1..]).to_string()
396 } else {
397 let s = String::from_utf8_lossy(bytes);
398 if s.is_empty() {
399 format!("≺EMPTY[{idx}]≻")
400 } else if !s.contains('\u{fffd}') {
401 let mut s = format!("{s:?}").replace("\\\"", "\"");
402 s.remove(0);
403 s.pop();
404 if quote {
405 format!("⟨{s}⟩")
406 } else {
407 s
408 }
409 } else {
410 let bytes = self.token(idx);
411 format!("≺HEX[{}]≻", to_hex_string(bytes))
412 }
413 }
414 }
415 }
416
417 pub fn token_str(&self, idx: u32) -> String {
418 String::from_utf8_lossy(self.token(idx)).to_string()
419 }
420
421 pub fn token_len(&self, idx: u32) -> usize {
422 let t = self.token(idx);
423 if t.is_empty() || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
424 let mut idx = idx;
425 let mut len = 1;
426 while idx >= 10 {
427 idx /= 10;
428 len += 1;
429 }
430 len + 3
432 } else {
433 t.len()
434 }
435 }
436
437 pub fn token(&self, idx: u32) -> &[u8] {
438 if idx >= self.token_offsets.len() as u32 {
439 return &[];
440 }
441 let desc = self.token_offsets[idx as usize];
442 let len = desc.len as usize;
443 let off = desc.off as usize;
444 &self.token_data[off..(off + len)]
445 }
446
447 pub fn decode(&self, tokens: &[TokenId]) -> Vec<u8> {
448 self.decode_ext(tokens, true)
449 }
450
451 pub fn decode_ext(&self, tokens: &[TokenId], include_special: bool) -> Vec<u8> {
452 let mut res = Vec::with_capacity(tokens.len() * 6 + 32); for &tok in tokens {
454 let t = self.token(tok);
455 if t.is_empty() {
456 if include_special {
457 res.extend_from_slice(format!("<[{tok}]>").as_bytes());
458 }
459 } else if t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
460 if include_special {
461 res.extend_from_slice(&t[1..]);
462 }
463 } else {
464 res.extend_from_slice(t);
465 }
466 }
467 res
468 }
469
470 pub fn decode_as_special(&self, tok: TokenId) -> Vec<u8> {
471 let mut res = Vec::with_capacity(9);
472 res.push(TokTrie::SPECIAL_TOKEN_MARKER);
473 res.extend_from_slice(format!("[{tok}]").as_bytes());
474 res
475 }
476
477 pub fn decode_raw(&self, tokens: &[TokenId]) -> Vec<u8> {
478 let mut res = Vec::with_capacity(tokens.len() * 6 + 32); for &tok in tokens {
480 let t = self.token(tok);
481 if t.is_empty() || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
482 res.push(TokTrie::SPECIAL_TOKEN_MARKER);
483 res.extend_from_slice(format!("[{tok}]").as_bytes());
484 } else {
485 res.extend_from_slice(t);
486 }
487 }
488 res
489 }
490
491 pub fn decode_str(&self, tokens: &[TokenId]) -> String {
492 String::from_utf8_lossy(&self.decode(tokens)).to_string()
493 }
494
495 pub fn decode_raw_to_decode(&self, bytes: &[u8]) -> Vec<u8> {
496 let mut res = Vec::new();
497 let mut idx = 0;
498 while idx < bytes.len() {
499 if bytes[idx] == TokTrie::SPECIAL_TOKEN_MARKER {
500 if let Some((len, tok)) = parse_numeric_token(&bytes[(idx + 1)..]) {
501 res.extend_from_slice(&self.decode(&[tok]));
502 idx += len + 1;
503 } else {
504 res.push(bytes[idx]);
505 idx += 1;
506 }
507 } else {
508 res.push(bytes[idx]);
509 idx += 1;
510 }
511 }
512 res
513 }
514
515 pub fn is_special_token(&self, tok: TokenId) -> bool {
516 let bytes = self.token(tok);
517 !bytes.is_empty() && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER
518 }
519
520 pub fn get_special_token(&self, name: &str) -> Option<TokenId> {
521 self.child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
522 .and_then(|n| {
523 self.child_at_bytes(n, name.as_bytes())
524 .and_then(|n| n.token_id())
525 })
526 }
527
528 pub fn get_special_tokens(&self) -> Vec<TokenId> {
529 let mut res = Vec::new();
530 let pref_node = self
531 .child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
532 .expect("missing special token prefix");
533 let mut stack = vec![pref_node];
534 while let Some(n) = stack.pop() {
535 for c in self.node_children(n) {
536 if let Some(tok) = c.token_id() {
537 res.push(tok);
538 if res.len() > Self::MAX_DBG_TOKENS + 1 {
539 break;
540 }
541 }
542 stack.push(c);
543 }
544 }
545 res.remove(0);
546 res
547 }
548
549 pub fn greedy_tokenize(&self, bytes: &[u8]) -> Vec<TokenId> {
550 let mut tokens = Vec::new();
551 let mut i = 0;
552 while i < bytes.len() {
553 let mut node = self.root();
554 let mut last_tok = None;
555 let mut last_idx = i;
556 #[allow(clippy::needless_range_loop)]
557 for j in i..bytes.len() {
558 if let Some(child) = self.child_at_byte(node, bytes[j]) {
559 node = child;
560 if let Some(tok) = node.token_id() {
561 last_tok = Some(tok);
562 last_idx = j;
563 }
564 } else {
565 break;
566 }
567 }
568 if let Some(t) = last_tok {
569 tokens.push(t);
570 } else {
571 }
575 i = last_idx + 1;
576 }
577 tokens
578 }
579
580 pub fn tokenize_with_special<F>(&self, s: &str, str_tokenize: F) -> Vec<TokenId>
582 where
583 F: Fn(&str) -> Vec<TokenId>,
584 {
585 let max_len = 100;
586
587 let bytes = s.as_bytes();
588 let mut out = Vec::new();
589 let mut last = 0; let mut i = 0; while i < bytes.len() {
593 if bytes[i] != b'<' {
594 i += 1;
595 continue;
596 }
597 let mut valid = true;
599 let mut j = i + 1;
600 let mut len_inside = 0;
601 while j < bytes.len() && len_inside < max_len {
603 match bytes[j] {
604 b'<' => {
605 valid = false;
606 break;
607 }
608 b'>' => break,
609 _ => {
610 len_inside += 1;
611 j += 1;
612 }
613 }
614 }
615 if !valid || j >= bytes.len() || bytes[j] != b'>' || len_inside == 0 {
616 i += 1;
618 continue;
619 }
620
621 let name = &s[i..=j];
622 if let Some(special_tok) = self.get_special_token(name) {
623 if last < i {
624 out.extend(str_tokenize(&s[last..i]));
625 }
626 out.push(special_tok);
627 } else {
628 out.extend(str_tokenize(&s[last..=j]));
630 }
631 i = j + 1;
633 last = i;
634 }
635 if last < bytes.len() {
637 out.extend(str_tokenize(&s[last..]));
638 }
639 out
640 }
641
642 pub fn tokenize_with_greedy_fallback(
643 &self,
644 bytes: &[u8],
645 str_tokenize: impl Fn(&str) -> Vec<TokenId>,
646 ) -> Vec<TokenId> {
647 match str::from_utf8(bytes) {
648 Ok(s) => {
649 str_tokenize(s)
651 }
652 Err(_) => {
653 let mut res = vec![];
654 for chunk in bytes.utf8_chunks() {
655 if !chunk.valid().is_empty() {
656 res.extend(str_tokenize(chunk.valid()));
657 }
658 if !chunk.invalid().is_empty() {
659 res.extend(self.greedy_tokenize(chunk.invalid()));
660 }
661 }
662 res
663 }
664 }
665 }
666
667 pub fn has_extensions(&self, bytes: &[u8]) -> bool {
668 match self.child_at_bytes(self.root(), bytes) {
669 None => false,
670 Some(n) => n.subtree_size() > 1,
671 }
672 }
673
674 pub fn token_id(&self, bytes: &[u8]) -> Option<TokenId> {
675 let (tok, len) = self.prefix_token_id(bytes);
676 if len == bytes.len() {
678 Some(tok)
679 } else {
680 None
681 }
682 }
683
684 pub fn prefix_token_id(&self, bytes: &[u8]) -> (TokenId, usize) {
685 assert!(!bytes.is_empty());
686 let mut last = (0, 0);
687 let mut n = self.root();
688 for (idx, byte) in bytes.iter().enumerate() {
689 n = match self.child_at_byte(n, *byte) {
690 Some(n) => n,
691 None => break,
692 };
693 if let Some(tok) = n.token_id() {
694 last = (tok, idx + 1);
695 }
696 }
697 last
698 }
699
700 pub fn max_token_len(&self) -> usize {
701 self.max_token_len
702 }
703
704 fn validate_node(&self, n: &TrieNode, ep: usize, used: &mut [bool]) {
705 if let Some(tok) = n.token_id() {
706 assert!(tok < self.info.vocab_size);
707 assert!(!used[tok as usize]);
708 used[tok as usize] = true;
709 }
710 let endp = self.next_node(n);
711 assert!(endp <= ep);
712 for child in self.node_children(n) {
713 self.validate_node(child, endp, used);
714 }
715 }
716
717 fn validate(&self) {
718 self.validate_node(
719 self.root(),
720 self.next_node(self.root()),
721 &mut vec![false; self.info.vocab_size as usize],
722 );
723 for idx in 0..self.info.vocab_size {
724 let _ = self.token(idx);
725 }
726 }
727
728 pub fn root(&self) -> &TrieNode {
729 &self.nodes[0]
730 }
731
732 pub fn check_against(&self, tokens: &[Vec<u8>]) {
733 for (idx, bytes) in tokens.iter().enumerate() {
734 let tid = idx as TokenId;
735 assert!(bytes == self.token(tid));
736 let root = self.root();
737 if !bytes.is_empty() {
738 let tid2 = self
739 .child_at_bytes(root, bytes)
740 .unwrap()
741 .token_id()
742 .unwrap();
743 if tid != tid2 {
744 let par = self
745 .child_at_bytes(root, &bytes[0..bytes.len() - 1])
746 .unwrap();
747 let has_it = self.node_children(par).any(|n| {
748 n.subtree_size() == 1
749 && n.byte() == bytes[bytes.len() - 1]
750 && n.token_id() == Some(tid)
751 });
752 assert!(has_it);
753 }
754 }
755 }
756 }
757
758 pub fn child_at_byte<'a>(&'a self, n: &'a TrieNode, byte: u8) -> Option<&'a TrieNode> {
759 self.node_children(n).find(|&child| child.byte() == byte)
760 }
761
762 pub fn all_subtokens(&self, bytes: &[u8]) -> Vec<TokenId> {
763 let mut r = Vec::new();
764 for i in 0..bytes.len() {
765 let mut n = self.root();
766 for &b in &bytes[i..] {
767 n = match self.child_at_byte(n, b) {
768 Some(n) => n,
769 None => break,
770 };
771 if let Some(tok) = n.token_id() {
772 r.push(tok);
773 }
774 }
775 }
776 r
777 }
778
779 pub fn node_children(&self, n: &TrieNode) -> NodeChildren<'_> {
780 let off = self.node_offset(n);
781 NodeChildren {
782 trie: self,
783 current_offset: off + 1,
784 end_offset: off + n.subtree_size(),
785 }
786 }
787
788 pub fn child_at_bytes<'a>(&'a self, mut n: &'a TrieNode, bytes: &[u8]) -> Option<&'a TrieNode> {
789 for &byte in bytes {
790 n = self.child_at_byte(n, byte)?
791 }
792 Some(n)
793 }
794
795 pub fn token_id_at_bytes(&self, bytes: &[u8]) -> Option<TokenId> {
796 self.child_at_bytes(self.root(), bytes)
797 .and_then(|n| n.token_id())
798 }
799
800 pub fn chop_tokens(&self, r: &mut impl Recognizer, tokens: &[TokenId]) -> (usize, usize) {
803 let max_token_lookback = 4;
804 let suff_bytes =
805 self.decode_raw(&tokens[tokens.len().saturating_sub(max_token_lookback)..]);
806 let suff_bytes = &suff_bytes[suff_bytes.len().saturating_sub(self.max_token_len())..];
807 for idx in 0..suff_bytes.len() {
813 let suff = &suff_bytes[idx..];
814 if self.has_valid_extensions(r, suff) {
815 let chop_bytes = suff.len();
816 assert!(chop_bytes > 0);
817 let mut curr_len = 0;
818 for chop_idx in 1..=tokens.len() {
819 curr_len += self.token_len(tokens[tokens.len() - chop_idx]);
820 if curr_len >= chop_bytes {
821 return (chop_idx, curr_len);
822 }
823 }
824 unreachable!();
825 }
826 }
827
828 (0, 0)
829 }
830
831 #[inline(never)]
833 pub fn has_valid_extensions(&self, r: &mut impl Recognizer, start: &[u8]) -> bool {
834 let n = self.child_at_bytes(self.root(), start);
835 if n.is_none() {
836 return false;
837 }
838 let n = n.unwrap();
839 r.trie_started("has_valid_extensions");
840 let off = self.node_offset(n);
841 let mut p = off + 1;
842 let endp = off + n.subtree_size();
843 let mut ok = false;
844 let mut next_pop = 0;
845 while p < endp {
846 r.pop_bytes(next_pop);
847 let n = &self.nodes[p];
848 let b = n.byte();
849 if r.try_push_byte(b) {
850 if n.token_id().is_some() {
851 ok = true;
852 break;
853 }
854 next_pop = if n.subtree_size() == 1 {
855 n.num_parents()
856 } else {
857 0
858 };
859 p += 1;
860 } else {
861 p += n.subtree_size();
862 next_pop = n.num_parents() - 1;
863 }
864 }
865 r.trie_finished();
866 ok
867 }
868
869 pub fn all_prefixes(&self, bytes: &[u8]) -> Vec<TokenId> {
870 let mut r = Vec::new();
871 let mut n = self.root();
872 for &b in bytes {
873 if let Some(c) = self.child_at_byte(n, b) {
874 n = c;
875 if let Some(tok) = n.token_id() {
876 r.push(tok);
877 }
878 } else {
879 break;
880 }
881 }
882 r
883 }
884
885 pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) {
886 if !start.is_empty() {
888 let mut fixed = FixedRecognizer::new(start);
889 self.add_bias(&mut fixed, toks, &[]);
890 }
891
892 let n = self.child_at_bytes(self.root(), start);
893 if n.is_none() {
894 return;
895 }
896 let n = n.unwrap();
897 r.trie_started("add_bias");
898 let (next_pop, nodes_walked) = self.add_bias_inner(r, toks, n);
899 if start.is_empty() {
900 r.pop_bytes(next_pop);
902 }
903 r.trie_finished();
904 r.save_stats(nodes_walked);
905 let defl_tok = self.vocab_size() as u32;
910 toks.disallow_token(defl_tok);
911 }
912
913 #[inline(never)]
914 fn add_bias_inner(
915 &self,
916 r: &mut impl Recognizer,
917 toks: &mut SimpleVob,
918 n: &TrieNode,
919 ) -> (usize, usize) {
920 let defl_tok = self.vocab_size() as u32;
924 let off = self.node_offset(n);
925 let total_nodes = n.subtree_size();
926 let mut p = off + 1;
927 let endp = off + total_nodes;
928 let nodes = &self.nodes[..endp];
929 let mut next_pop = 0;
930 let mut num_skip = 0;
931 while p < endp {
932 r.pop_bytes(next_pop);
933 let n = unsafe {
934 debug_assert!(
935 p < nodes.len(),
936 "node index {} out of bounds (len: {})",
937 p,
938 nodes.len()
939 );
940 nodes.get_unchecked(p)
941 };
942 let b = n.byte();
943 if r.try_push_byte(b) {
944 let tok = n.token_id().unwrap_or(defl_tok);
946 debug_assert!(
947 tok <= self.vocab_size() as u32,
948 "token {} out of valid range (vocab_size: {})",
949 tok,
950 self.vocab_size()
951 );
952 unsafe { toks.allow_token_unchecked(tok) };
953 next_pop = if n.subtree_size() == 1 {
954 n.num_parents()
955 } else {
956 0
957 };
958 p += 1;
959 } else {
960 let subtree_size = n.subtree_size();
961 p += subtree_size;
962 num_skip += subtree_size - 1;
964 next_pop = n.num_parents() - 1;
965 }
966 }
967 (next_pop, total_nodes - num_skip)
968 }
969
970 pub fn all_tokens(&self) -> Vec<Vec<u8>> {
971 (0..self.vocab_size())
972 .map(|idx| self.token(idx as u32).to_vec())
973 .collect()
974 }
975
976 pub fn sorted_tokens(&self) -> Vec<(u32, Vec<u8>)> {
977 let mut res = vec![];
978 let n = self.root();
979 let off = self.node_offset(n);
980 let mut p = off + 1;
981 let endp = off + n.subtree_size();
982 let mut next_pop = 0;
983 let mut bytes = vec![];
984 while p < endp {
985 bytes.drain(bytes.len() - next_pop..);
986 let n = &self.nodes[p];
987 let b = n.byte();
988 bytes.push(b);
989 if let Some(t) = n.token_id() {
990 res.push((t, bytes.clone()));
991 }
992 next_pop = if n.subtree_size() == 1 {
993 n.num_parents()
994 } else {
995 0
996 };
997 p += 1;
998 }
999 res
1000 }
1001
1002 fn count_until_depth(&self, depth: usize) -> (usize, usize) {
1003 let mut count = 0;
1004 let mut num_tokens = 0;
1005 let mut stack = vec![(self.root(), 0)];
1006 while let Some((n, d)) = stack.pop() {
1007 if d == depth {
1008 continue;
1009 } else {
1010 for c in self.node_children(n) {
1011 count += 1;
1012 if c.token_id().is_some() {
1013 num_tokens += 1;
1014 }
1015 stack.push((c, d + 1));
1016 }
1017 }
1018 }
1019 (count, num_tokens)
1020 }
1021
1022 pub fn trie_stats(&self) -> String {
1023 let mut nodes_histogram = vec![0; 256];
1024
1025 let mut token_nodes = 0;
1026
1027 let n = self.root();
1028 let off = self.node_offset(n);
1029 let mut p = off + 1;
1030 let endp = off + n.subtree_size();
1031 while p < endp {
1032 let n = &self.nodes[p];
1033
1034 if n.token_id().is_some() {
1035 token_nodes += 1;
1036 }
1037
1038 let last_ch = self.next_node(n);
1039 let mut ch_p = p + 1;
1040 let mut num_children = 0;
1041
1042 while ch_p < last_ch {
1043 let ch = &self.nodes[ch_p];
1044 ch_p += ch.subtree_size();
1045 num_children += 1;
1046 }
1047
1048 nodes_histogram[std::cmp::min(9, num_children)] += 1;
1049
1050 p += 1;
1051 }
1052
1053 let mut histogram = String::new();
1054
1055 if false {
1056 for (idx, num) in nodes_histogram.iter().enumerate() {
1057 if *num > 0 {
1058 if !histogram.is_empty() {
1059 histogram.push_str(", ");
1060 }
1061 histogram.push_str(&format!("{idx}:{num}"));
1062 }
1063 }
1064 }
1065
1066 if false {
1067 for n in self.node_children(self.root()) {
1068 histogram.push_str(&format!(
1069 "\n{} => {} {}",
1070 n.byte(),
1071 self.node_children(n).count(),
1072 n.subtree_size()
1073 ));
1074 }
1075 }
1076
1077 if false {
1078 for depth in 0..30 {
1079 let (count, num_tokens) = self.count_until_depth(depth);
1080 histogram.push_str(&format!(
1081 "\ndepth {depth}: {count} nodes {num_tokens} tokens"
1082 ));
1083 }
1084 }
1085
1086 if !histogram.is_empty() {
1087 histogram = format!("\n{histogram}");
1088 }
1089
1090 format!(
1091 "{}{} nodes, {} token nodes, {} token bytes, {} max len",
1092 histogram,
1093 self.nodes.len(),
1094 token_nodes,
1095 self.token_data.len(),
1096 self.max_token_len,
1097 )
1098 }
1099}
1100
1101pub struct NodeChildren<'a> {
1102 trie: &'a TokTrie,
1103 current_offset: usize,
1104 end_offset: usize,
1105}
1106
1107impl<'a> Iterator for NodeChildren<'a> {
1108 type Item = &'a TrieNode;
1109
1110 fn next(&mut self) -> Option<Self::Item> {
1111 if self.current_offset < self.end_offset {
1112 let node = &self.trie.nodes[self.current_offset];
1113 self.current_offset += node.subtree_size();
1114 Some(node)
1115 } else {
1116 None
1117 }
1118 }
1119}
1120
1121struct TrieHash {
1122 token_id: u32,
1123 byte: u8,
1124 children: Vec<TrieHash>,
1125}
1126
1127impl TrieHash {
1128 fn new(byte: u8) -> TrieHash {
1129 TrieHash {
1130 token_id: NO_TOKEN,
1131 byte,
1132 children: Vec::new(),
1133 }
1134 }
1135 fn insert(&mut self, word: &[u8], token_id: u32) {
1136 if word.is_empty() {
1137 assert!(self.token_id == NO_TOKEN);
1140 self.token_id = token_id;
1141 } else {
1142 for ch in &mut self.children {
1149 if ch.byte == word[0] {
1150 if word.len() == 1 && ch.token_id != NO_TOKEN {
1151 } else {
1153 ch.insert(&word[1..], token_id);
1154 return;
1155 }
1156 }
1157 }
1158
1159 let mut ch = TrieHash::new(word[0]);
1160 ch.insert(&word[1..], token_id);
1161 self.children.push(ch);
1162
1163 }
1176 }
1177
1178 fn serialize(&mut self, data: &mut Vec<TrieNode>, num_parents: usize) {
1179 let idx = data.len();
1180 let mut num_ch = self.children.len();
1181 data.push(TrieNode::new(
1182 self.byte,
1183 self.token_id,
1184 if num_parents == 0 { 1 } else { num_parents },
1185 ));
1186 self.children.sort_by_key(|e| e.byte);
1188 for entry in &mut self.children {
1189 num_ch -= 1;
1190 entry.serialize(data, if num_ch == 0 { num_parents + 1 } else { 1 });
1191 }
1192 let subtree_size = data.len() - idx;
1193 data[idx].set_subtree_size(subtree_size);
1194 }
1195}
1196
1197struct FixedRecognizer {
1198 bytes: Vec<u8>,
1199 bytes_ptr: usize,
1200}
1201
1202impl FixedRecognizer {
1203 fn new(bytes: &[u8]) -> FixedRecognizer {
1204 FixedRecognizer {
1205 bytes: bytes.to_vec(),
1206 bytes_ptr: 0,
1207 }
1208 }
1209}
1210
1211impl Recognizer for FixedRecognizer {
1212 fn collapse(&mut self) {}
1213 fn trie_finished(&mut self) {}
1214
1215 fn pop_bytes(&mut self, num: usize) {
1216 self.bytes_ptr -= num;
1217 }
1218
1219 fn try_push_byte(&mut self, byte: u8) -> bool {
1220 if self.bytes_ptr < self.bytes.len() && self.bytes[self.bytes_ptr] == byte {
1221 self.bytes_ptr += 1;
1222 true
1223 } else {
1224 false
1225 }
1226 }
1227}
1228
1229pub struct AnythingGoes;
1230
1231impl Recognizer for AnythingGoes {
1232 fn collapse(&mut self) {}
1233 fn trie_finished(&mut self) {}
1234 fn pop_bytes(&mut self, _num: usize) {}
1235 fn try_push_byte(&mut self, _byte: u8) -> bool {
1236 true
1237 }
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242 use super::*;
1243
1244 fn make_test_trie(eos: TokenId) -> TokTrie {
1245 let info = TokRxInfo::new(4, eos);
1246 let words = vec![b"a".to_vec(), b"b".to_vec(), b"c".to_vec(), b"d".to_vec()];
1247 TokTrie::from(&info, &words)
1248 }
1249
1250 #[test]
1251 fn test_default_single_eos() {
1252 let trie = make_test_trie(2);
1253 assert_eq!(trie.eos_token(), 2);
1254 assert_eq!(trie.eos_tokens(), &[2]);
1255 }
1256
1257 #[test]
1258 fn test_with_eos_tokens_multiple() {
1259 let trie = make_test_trie(0).with_eos_tokens(&[1, 3]);
1260 assert_eq!(trie.eos_token(), 1);
1261 assert_eq!(trie.eos_tokens(), &[1, 3]);
1262 assert_eq!(trie.info().tok_eos, 1);
1263 }
1264
1265 #[test]
1266 fn test_with_eos_token_backwards_compat() {
1267 let trie = make_test_trie(0).with_eos_token(2);
1268 assert_eq!(trie.eos_token(), 2);
1269 assert_eq!(trie.eos_tokens(), &[2]);
1270 }
1271
1272 #[test]
1273 fn test_with_info_resets_eos_tokens() {
1274 let trie = make_test_trie(0).with_eos_tokens(&[1, 2]);
1275 let trie2 = trie.with_info(TokRxInfo::new(4, 3));
1276 assert_eq!(trie2.eos_token(), 3);
1277 assert_eq!(trie2.eos_tokens(), &[3]);
1278 }
1279
1280 #[test]
1281 fn test_filter_preserves_eos_tokens() {
1282 let trie = make_test_trie(0).with_eos_tokens(&[1, 2]);
1283 let mut filter = trie.alloc_token_set();
1284 for i in 0..4 {
1285 filter.allow_token(i);
1286 }
1287 let filtered = trie.filter(&filter);
1288 assert_eq!(filtered.eos_tokens(), &[1, 2]);
1289 }
1290
1291 #[test]
1292 #[should_panic(expected = "eos_tokens must not be empty")]
1293 fn test_with_eos_tokens_empty_panics() {
1294 make_test_trie(0).with_eos_tokens(&[]);
1295 }
1296
1297 #[test]
1298 fn test_eos_token_set_single() {
1299 let trie = make_test_trie(2);
1300 let set = trie.eos_token_set();
1301 assert!(set.is_allowed(2));
1302 assert!(!set.is_allowed(0));
1303 assert!(!set.is_allowed(1));
1304 assert_eq!(set.num_set(), 1);
1305 }
1306
1307 #[test]
1308 fn test_eos_token_set_multiple() {
1309 let trie = make_test_trie(0).with_eos_tokens(&[1, 3]);
1310 let set = trie.eos_token_set();
1311 assert!(set.is_allowed(1));
1312 assert!(set.is_allowed(3));
1313 assert!(!set.is_allowed(0));
1314 assert!(!set.is_allowed(2));
1315 assert_eq!(set.num_set(), 2);
1316 }
1317}