1use core::str;
5use std::sync::Arc;
6
7use bytemuck_derive::{Pod, Zeroable};
8
9use crate::{bytes::to_hex_string, SimpleVob};
10
11pub type TokenId = u32;
12
13#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)]
14#[repr(C)]
15pub struct BinTokRxInfo {
16 pub vocab_size: u32,
17 pub tok_eos: TokenId,
18}
19
20#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub struct TokRxInfo {
22 pub vocab_size: u32,
23 pub tok_eos: TokenId,
24 pub tok_bos: Option<TokenId>,
25 pub tok_pad: Option<TokenId>,
26 pub tok_unk: Option<TokenId>,
27 pub tok_end_of_turn: Option<TokenId>,
28}
29
30impl TokRxInfo {
31 pub fn new(vocab_size: u32, tok_eos: TokenId) -> Self {
32 TokRxInfo {
33 vocab_size,
34 tok_eos,
35 tok_bos: None,
36 tok_pad: None,
37 tok_unk: None,
38 tok_end_of_turn: None,
39 }
40 }
41
42 pub fn from_bin(info: &BinTokRxInfo) -> Self {
43 TokRxInfo {
44 vocab_size: info.vocab_size,
45 tok_eos: info.tok_eos,
46 tok_bos: None,
47 tok_pad: None,
48 tok_unk: None,
49 tok_end_of_turn: None,
50 }
51 }
52
53 pub fn to_bin(&self) -> BinTokRxInfo {
54 BinTokRxInfo {
55 vocab_size: self.vocab_size,
56 tok_eos: self.tok_eos,
57 }
58 }
59}
60
61pub trait Recognizer {
62 fn pop_bytes(&mut self, num: usize);
64 fn collapse(&mut self);
68 fn byte_allowed(&mut self, byte: u8) -> bool {
70 if self.try_push_byte(byte) {
71 self.pop_bytes(1);
72 true
73 } else {
74 false
75 }
76 }
77 fn trie_finished(&mut self);
81 fn trie_started(&mut self, _dbg_lbl: &str) {}
83 fn try_push_byte(&mut self, byte: u8) -> bool;
85 fn get_error(&mut self) -> Option<String> {
87 None
88 }
89 fn save_stats(&mut self, _nodes_walked: usize) {}
90}
91
92pub fn parse_numeric_token(s: &[u8]) -> Option<(usize, TokenId)> {
96 let spec_len = s[0..std::cmp::min(s.len(), 20)]
97 .iter()
98 .position(|&x| x == ']' as u8);
99 if let Some(spec_len) = spec_len {
100 if s[0] != b'[' {
101 return None;
102 }
103 let inner_bytes = &s[1..spec_len];
104 if let Ok(inner_str) = std::str::from_utf8(inner_bytes) {
105 if let Ok(id) = u32::from_str_radix(inner_str, 10) {
106 return Some((spec_len + 1, id as TokenId));
107 }
108 }
109 }
110 None
111}
112
113pub trait TokenizerEnv: Send {
114 fn tok_trie(&self) -> &TokTrie;
116
117 fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId>;
120
121 fn tokenize_bytes_marker(&self, s: &[u8]) -> (Vec<TokenId>, usize) {
126 let mut idx = 0;
127 let ff = TokTrie::SPECIAL_TOKEN_MARKER;
128 let mut result = Vec::new();
129 let trie = self.tok_trie();
130 let mut num_fixed_tokens = 0;
131 while idx < s.len() {
132 let normal_len = s[idx..]
133 .iter()
134 .position(|&x| x == ff)
135 .unwrap_or(s.len() - idx);
136 if normal_len != 0 {
137 result.extend_from_slice(&self.tokenize_bytes(&s[idx..idx + normal_len]));
138 idx += normal_len;
139 }
140 idx += 1; if idx + 2 < s.len() && s[idx] == '<' as u8 {
142 let spec_len = s[idx..std::cmp::min(s.len(), idx + 100)]
144 .iter()
145 .position(|&x| x == '>' as u8);
146 if let Some(mut spec_len) = spec_len {
147 spec_len += 1;
148 let spec_token = &s[idx - 1..idx + spec_len];
149 if let Some(id) = trie.token_id_at_bytes(spec_token) {
150 result.push(id);
151 num_fixed_tokens = result.len();
152 idx += spec_len;
153 }
154 }
155 } else if idx < s.len() {
156 if let Some((n_bytes, tok_id)) = parse_numeric_token(&s[idx..]) {
158 if tok_id < trie.vocab_size() as u32 {
159 result.push(tok_id);
160 num_fixed_tokens = result.len();
161 idx += n_bytes;
162 }
163 }
164 }
165 }
166
167 (result, num_fixed_tokens)
168 }
169
170 fn tokenize(&self, s: &str) -> Vec<TokenId> {
172 self.tokenize_bytes(s.as_bytes())
173 }
174
175 fn tokenize_special(&self, s: &str) -> Vec<TokenId> {
177 self.tokenize(s)
178 }
179
180 fn eos_token(&self) -> TokenId {
182 self.tok_trie().eos_token()
183 }
184
185 fn tokenize_is_canonical(&self) -> bool {
189 true
190 }
191}
192
193pub type TokEnv = Arc<dyn TokenizerEnv + Sync + 'static>;
194
195pub struct TokEnvWithTrie {
196 base_env: TokEnv,
197 tok_trie: TokTrie,
198}
199
200impl TokEnvWithTrie {
201 pub fn new(base_env: TokEnv, tok_trie: TokTrie) -> Self {
202 Self { base_env, tok_trie }
203 }
204}
205
206impl TokenizerEnv for TokEnvWithTrie {
207 fn tok_trie(&self) -> &TokTrie {
208 &self.tok_trie
209 }
210
211 fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
212 self.base_env.tokenize_bytes(s)
213 }
214}
215
216#[derive(Clone)]
217pub struct TokTrie {
218 info: TokRxInfo,
219 token_offsets: Vec<u32>,
220 token_data: Vec<u8>,
221 nodes: Vec<TrieNode>,
222 max_token_len: usize,
223}
224
225#[derive(Clone, Copy, Zeroable, Pod)]
226#[repr(C)]
227pub struct TrieNode {
228 bits: u32,
230 bits2: u32,
231}
232
233pub const INVALID_TOKEN: TokenId = 0xffff_ffff;
234
235const NO_TOKEN: u32 = 0xffffff;
236
237impl TrieNode {
238 fn new(byte: u8, token_id: u32, num_parents: u8) -> TrieNode {
239 TrieNode {
240 bits: (token_id << 8) | byte as u32,
241 bits2: num_parents as u32,
242 }
243 }
244
245 #[inline(always)]
246 pub fn byte(&self) -> u8 {
247 (self.bits & 0xff) as u8
248 }
249
250 #[inline(always)]
251 pub fn subtree_size(&self) -> usize {
252 (self.bits2 >> 8) as usize
253 }
254
255 #[inline(always)]
256 pub fn num_parents(&self) -> usize {
257 (self.bits2 & 0xff) as usize
258 }
259
260 #[inline(always)]
261 pub fn token_id(&self) -> Option<u32> {
262 let r = self.bits >> 8;
263 if r == NO_TOKEN {
264 None
265 } else {
266 Some(r)
267 }
268 }
269}
270
271const LEN_BITS: u32 = 8;
273
274impl TokTrie {
275 pub const SPECIAL_TOKEN_MARKER: u8 = 0xff;
277
278 pub fn from(info: &TokRxInfo, words: &Vec<Vec<u8>>) -> Self {
279 let mut trie = TrieHash::new(0xff);
280 let mut token_offsets = Vec::new();
281 let mut token_data = Vec::new();
282 assert!(info.vocab_size == words.len() as u32);
283 let mut max_token_len = 0;
284 for (idx, word) in words.iter().enumerate() {
285 if word.len() > 0 {
286 trie.insert(word, idx as u32);
287 max_token_len = std::cmp::max(max_token_len, word.len());
288 }
289 assert!(word.len() < (1 << LEN_BITS));
290 assert!(token_data.len() < (1 << (32 - LEN_BITS)));
291 let desc = (word.len() as u32) | ((token_data.len() as u32) << LEN_BITS);
292 token_offsets.push(desc);
293 token_data.extend_from_slice(word);
294 }
295 let mut nodes = Vec::new();
296 trie.serialize(&mut nodes, 0);
297 let r = TokTrie {
298 info: info.clone(),
299 token_offsets,
300 token_data,
301 nodes,
302 max_token_len,
303 };
304 r.validate();
305 r
306 }
307
308 pub fn with_eos_token(&self, eos_token: TokenId) -> Self {
309 self.with_info(TokRxInfo {
310 tok_eos: eos_token,
311 ..self.info.clone()
312 })
313 }
314
315 pub fn with_info(&self, info: TokRxInfo) -> Self {
316 let mut r = self.clone();
317 r.info = info.clone();
318 r
319 }
320
321 pub fn build_chat_mode_trie(&self) -> Self {
322 self.with_eos_token(self.info.tok_end_of_turn.unwrap_or(self.info.tok_eos))
323 }
324
325 fn node_offset(&self, n: &TrieNode) -> usize {
326 let off = (n as *const _ as usize - self.root() as *const _ as usize)
327 / std::mem::size_of::<TrieNode>();
328 assert!(off < self.nodes.len());
329 off
330 }
331
332 fn next_node(&self, n: &TrieNode) -> usize {
333 return self.node_offset(n) + n.subtree_size();
334 }
335
336 pub fn info(&self) -> &TokRxInfo {
337 &self.info
338 }
339
340 pub fn eos_token(&self) -> TokenId {
341 self.info.tok_eos
342 }
343
344 pub fn vocab_size(&self) -> usize {
345 self.info.vocab_size as usize
346 }
347
348 pub fn alloc_token_set(&self) -> SimpleVob {
349 SimpleVob::alloc_with_capacity(self.vocab_size(), self.vocab_size() + 1)
350 }
351
352 pub fn singleton_token_set(&self, tok: TokenId) -> SimpleVob {
353 let mut r = self.alloc_token_set();
354 r.allow_token(tok);
355 r
356 }
357
358 pub fn token_set_dbg(&self, ts: &SimpleVob) -> String {
359 let max_examples = 50;
360
361 let ts_neg = ts.negated();
362 let use_neg = ts_neg.num_set() * 10 < ts.num_set();
363 let ts1 = if use_neg { &ts_neg } else { &ts };
364 let num_set = ts1.num_set();
365 let max_tok = std::cmp::min(max_examples, num_set);
366 let mut token_names = Vec::new();
367 if self.info.tok_eos != INVALID_TOKEN && ts1.is_allowed(self.info.tok_eos) {
369 token_names.push("EOS".to_string());
370 }
371 for idx in 0..self.vocab_size() {
372 if idx as TokenId != self.info.tok_eos && ts1.is_allowed(idx as TokenId) {
373 token_names.push(self.token_dbg(idx as TokenId));
374 if token_names.len() >= max_tok {
375 break;
376 }
377 }
378 }
379 if token_names.len() < num_set {
380 token_names.push("...".to_string());
381 }
382 format!(
383 "TokenSet: {}/{}; {}{}",
384 ts.num_set(),
385 self.vocab_size(),
386 if use_neg { "ALL EXCEPT " } else { "" },
387 token_names.join(" ")
388 )
389 }
390
391 pub fn alloc_logits(&self) -> Vec<f32> {
392 vec![0.0; self.vocab_size() + 1]
393 }
394
395 pub fn test_trace_tokens(&self, toks: &[u32]) -> String {
396 self.tokens_dbg_ext(toks, false)
397 }
398
399 pub const MAX_DBG_TOKENS: usize = 200;
400
401 pub fn tokens_dbg(&self, toks: &[u32]) -> String {
402 self.tokens_dbg_ext(toks, true)
403 }
404
405 fn tokens_dbg_ext(&self, toks: &[u32], quote: bool) -> String {
406 let (limited, toks) = if toks.len() > Self::MAX_DBG_TOKENS {
407 (true, &toks[0..Self::MAX_DBG_TOKENS])
408 } else {
409 (false, toks)
410 };
411
412 let mut joined = toks
413 .iter()
414 .map(|t| self.token_dbg_ext(*t, false))
415 .collect::<Vec<_>>()
416 .join("‧");
417
418 if limited {
419 joined.push_str("…");
420 }
421
422 if quote {
423 format!("⟦{}⟧", joined)
424 } else {
425 joined
426 }
427 }
428
429 pub fn token_dbg(&self, idx: u32) -> String {
430 self.token_dbg_ext(idx, true)
431 }
432
433 fn token_dbg_ext(&self, idx: u32, quote: bool) -> String {
434 if idx == self.info.tok_eos {
435 "≺EOS≻".to_string()
436 } else if idx as usize >= self.vocab_size() {
437 format!("≺OOB[{}]≻", idx)
438 } else {
439 let bytes = self.token(idx);
441 if bytes.len() > 1 && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER {
442 String::from_utf8_lossy(&bytes[1..]).to_string()
443 } else {
444 let s = String::from_utf8_lossy(bytes);
445 if s.len() == 0 {
446 format!("≺EMPTY[{}]≻", idx)
447 } else if !s.contains('\u{fffd}') {
448 let mut s = format!("{:?}", s).replace("\\\"", "\"");
449 s.remove(0);
450 s.pop();
451 if quote {
452 format!("⟨{}⟩", s)
453 } else {
454 s
455 }
456 } else {
457 let bytes = self.token(idx);
458 format!("≺HEX[{}]≻", to_hex_string(bytes))
459 }
460 }
461 }
462 }
463
464 pub fn token_str(&self, idx: u32) -> String {
465 String::from_utf8_lossy(self.token(idx)).to_string()
466 }
467
468 pub fn token_len(&self, idx: u32) -> usize {
469 let t = self.token(idx);
470 if t.len() == 0 || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
471 let mut idx = idx;
472 let mut len = 1;
473 while idx >= 10 {
474 idx /= 10;
475 len += 1;
476 }
477 len + 3
479 } else {
480 t.len()
481 }
482 }
483
484 pub fn token(&self, idx: u32) -> &[u8] {
485 if idx >= self.token_offsets.len() as u32 {
486 return &[];
487 }
488 let off = self.token_offsets[idx as usize];
489 let len = off & ((1 << LEN_BITS) - 1);
490 let off = (off >> LEN_BITS) as usize;
491 &self.token_data[off..(off + len as usize)]
492 }
493
494 pub fn decode(&self, tokens: &[TokenId]) -> Vec<u8> {
495 let mut res = Vec::new();
496 res.reserve(tokens.len() * 6 + 32); for &tok in tokens {
498 let t = self.token(tok);
499 if t.len() == 0 {
500 res.extend_from_slice(format!("<[{}]>", tok).as_bytes());
501 } else if t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
502 res.extend_from_slice(&t[1..]);
503 } else {
504 res.extend_from_slice(t);
505 }
506 }
507 res
508 }
509
510 pub fn decode_as_special(&self, tok: TokenId) -> Vec<u8> {
511 let mut res = Vec::new();
512 res.reserve(9);
513 res.push(TokTrie::SPECIAL_TOKEN_MARKER);
514 res.extend_from_slice(format!("[{}]", tok).as_bytes());
515 res
516 }
517
518 pub fn decode_raw(&self, tokens: &[TokenId]) -> Vec<u8> {
519 let mut res = Vec::new();
520 res.reserve(tokens.len() * 6 + 32); for &tok in tokens {
522 let t = self.token(tok);
523 if t.len() == 0 || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
524 res.push(TokTrie::SPECIAL_TOKEN_MARKER);
525 res.extend_from_slice(format!("[{}]", tok).as_bytes());
526 } else {
527 res.extend_from_slice(t);
528 }
529 }
530 res
531 }
532
533 pub fn decode_str(&self, tokens: &[TokenId]) -> String {
534 String::from_utf8_lossy(&self.decode(tokens)).to_string()
535 }
536
537 pub fn decode_raw_to_decode(&self, bytes: &[u8]) -> Vec<u8> {
538 let mut res = Vec::new();
539 let mut idx = 0;
540 while idx < bytes.len() {
541 if bytes[idx] == TokTrie::SPECIAL_TOKEN_MARKER {
542 if let Some((len, tok)) = parse_numeric_token(&bytes[(idx + 1)..]) {
543 res.extend_from_slice(&self.decode(&[tok]));
544 idx += len + 1;
545 } else {
546 res.push(bytes[idx]);
547 idx += 1;
548 }
549 } else {
550 res.push(bytes[idx]);
551 idx += 1;
552 }
553 }
554 res
555 }
556
557 pub fn is_special_token(&self, tok: TokenId) -> bool {
558 let bytes = self.token(tok);
559 bytes.len() > 0 && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER
560 }
561
562 pub fn get_special_token(&self, name: &str) -> Option<TokenId> {
563 self.child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
564 .and_then(|n| {
565 self.child_at_bytes(n, name.as_bytes())
566 .and_then(|n| n.token_id())
567 })
568 }
569
570 pub fn get_special_tokens(&self) -> Vec<TokenId> {
571 let mut res = Vec::new();
572 let pref_node = self
573 .child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
574 .expect("missing special token prefix");
575 let mut stack = vec![pref_node];
576 while let Some(n) = stack.pop() {
577 for c in self.node_children(n) {
578 if let Some(tok) = c.token_id() {
579 res.push(tok);
580 if res.len() > Self::MAX_DBG_TOKENS + 1 {
581 break;
582 }
583 }
584 stack.push(c);
585 }
586 }
587 res.remove(0);
588 res
589 }
590
591 pub fn greedy_tokenize(&self, bytes: &[u8]) -> Vec<TokenId> {
592 let mut r = Vec::new();
593 if bytes.len() == 0 {
594 return r;
595 }
596
597 let mut n = self.root();
598 let mut last_tok = None;
599 let mut last_idx = 0;
600 let mut idx = 0;
601 while idx < bytes.len() {
602 match self.child_at_byte(n, bytes[idx]) {
603 Some(c) => {
604 if let Some(tok) = c.token_id() {
605 last_tok = Some(tok);
606 last_idx = idx;
607 }
608 n = c;
609 }
610 None => {
611 r.push(last_tok.unwrap());
612 idx = last_idx;
613 n = self.root();
614 }
615 }
616 idx = idx + 1;
617 }
618 r.push(last_tok.unwrap());
619 r
620 }
621
622 pub fn tokenize_with_greedy_fallback(
623 &self,
624 bytes: &[u8],
625 str_tokenize: impl Fn(&str) -> Vec<TokenId>,
626 ) -> Vec<TokenId> {
627 match str::from_utf8(bytes) {
628 Ok(s) => {
629 str_tokenize(s)
631 }
632 Err(_) => {
633 let mut res = vec![];
634 for chunk in bytes.utf8_chunks() {
635 if !chunk.valid().is_empty() {
636 res.extend(str_tokenize(chunk.valid()));
637 }
638 if !chunk.invalid().is_empty() {
639 res.extend(self.greedy_tokenize(chunk.invalid()));
640 }
641 }
642 res
643 }
644 }
645 }
646
647 pub fn has_extensions(&self, bytes: &[u8]) -> bool {
648 match self.child_at_bytes(self.root(), bytes) {
649 None => false,
650 Some(n) => n.subtree_size() > 1,
651 }
652 }
653
654 pub fn token_id(&self, bytes: &[u8]) -> Option<TokenId> {
655 let (tok, len) = self.prefix_token_id(bytes);
656 if len == bytes.len() {
658 Some(tok)
659 } else {
660 None
661 }
662 }
663
664 pub fn prefix_token_id(&self, bytes: &[u8]) -> (TokenId, usize) {
665 assert!(bytes.len() > 0);
666 let mut last = (0, 0);
667 let mut n = self.root();
668 for (idx, byte) in bytes.iter().enumerate() {
669 n = match self.child_at_byte(n, *byte) {
670 Some(n) => n,
671 None => break,
672 };
673 if let Some(tok) = n.token_id() {
674 last = (tok, idx + 1);
675 }
676 }
677 return last;
678 }
679
680 pub fn max_token_len(&self) -> usize {
681 self.max_token_len
682 }
683
684 fn validate_node(&self, n: &TrieNode, ep: usize, used: &mut [bool]) {
685 if let Some(tok) = n.token_id() {
686 assert!(tok < self.info.vocab_size);
687 assert!(!used[tok as usize]);
688 used[tok as usize] = true;
689 }
690 let endp = self.next_node(n);
691 assert!(endp <= ep);
692 for child in self.node_children(n) {
693 self.validate_node(child, endp, used);
694 }
695 }
696
697 fn validate(&self) {
698 self.validate_node(
699 self.root(),
700 self.next_node(self.root()),
701 &mut vec![false; self.info.vocab_size as usize],
702 );
703 for idx in 0..self.info.vocab_size {
704 let _ = self.token(idx);
705 }
706 }
707
708 pub fn root(&self) -> &TrieNode {
709 &self.nodes[0]
710 }
711
712 pub fn check_against(&self, tokens: &Vec<Vec<u8>>) {
713 let vocab_size = tokens.len();
714 for idx in 0..vocab_size {
715 let bytes = &tokens[idx];
716 let tid = idx as TokenId;
717 assert!(bytes == self.token(tid));
718 let root = self.root();
719 if bytes.len() > 0 {
720 let tid2 = self
721 .child_at_bytes(root, &bytes)
722 .unwrap()
723 .token_id()
724 .unwrap();
725 if tid != tid2 {
726 let par = self
727 .child_at_bytes(root, &bytes[0..bytes.len() - 1])
728 .unwrap();
729 let has_it = self.node_children(par).any(|n| {
730 n.subtree_size() == 1
731 && n.byte() == bytes[bytes.len() - 1]
732 && n.token_id() == Some(tid)
733 });
734 assert!(has_it);
735 }
736 }
737 }
738 }
739
740 pub fn child_at_byte<'a>(&'a self, n: &'a TrieNode, byte: u8) -> Option<&'a TrieNode> {
741 for child in self.node_children(n) {
742 if child.byte() == byte {
743 return Some(child);
744 }
745 }
746 None
747 }
748
749 pub fn all_subtokens(&self, bytes: &[u8]) -> Vec<TokenId> {
750 let mut r = Vec::new();
751 for i in 0..bytes.len() {
752 let mut n = self.root();
753 for j in i..bytes.len() {
754 n = match self.child_at_byte(n, bytes[j]) {
755 Some(n) => n,
756 None => break,
757 };
758 if let Some(tok) = n.token_id() {
759 r.push(tok);
760 }
761 }
762 }
763 r
764 }
765
766 pub fn node_children(&self, n: &TrieNode) -> NodeChildren {
767 let off = self.node_offset(n);
768 NodeChildren {
769 trie: self,
770 current_offset: off + 1,
771 end_offset: off + n.subtree_size(),
772 }
773 }
774
775 pub fn child_at_bytes<'a>(&'a self, mut n: &'a TrieNode, bytes: &[u8]) -> Option<&'a TrieNode> {
776 for &byte in bytes {
777 n = match self.child_at_byte(n, byte) {
778 Some(n) => n,
779 None => return None,
780 }
781 }
782 Some(n)
783 }
784
785 pub fn token_id_at_bytes(&self, bytes: &[u8]) -> Option<TokenId> {
786 self.child_at_bytes(self.root(), bytes)
787 .and_then(|n| n.token_id())
788 }
789
790 pub fn chop_tokens(&self, r: &mut impl Recognizer, tokens: &[TokenId]) -> (usize, usize) {
793 let max_token_lookback = 4;
794 let suff_bytes =
795 self.decode_raw(&tokens[tokens.len().saturating_sub(max_token_lookback)..]);
796 let suff_bytes = &suff_bytes[suff_bytes.len().saturating_sub(self.max_token_len())..];
797
798 for idx in 0..suff_bytes.len() {
799 let suff = &suff_bytes[idx..];
800 if self.has_valid_extensions(r, suff) {
801 let chop_bytes = suff.len();
802 assert!(chop_bytes > 0);
803 let mut curr_len = 0;
804 for chop_idx in 1..=tokens.len() {
805 curr_len += self.token_len(tokens[tokens.len() - chop_idx]);
806 if curr_len >= chop_bytes {
807 return (chop_idx, curr_len);
808 }
809 }
810 unreachable!();
811 }
812 }
813
814 (0, 0)
815 }
816
817 #[inline(never)]
819 pub fn has_valid_extensions(&self, r: &mut impl Recognizer, start: &[u8]) -> bool {
820 let n = self.child_at_bytes(self.root(), start);
821 if n.is_none() {
822 return false;
823 }
824 let n = n.unwrap();
825 r.trie_started("has_valid_extensions");
826 let off = self.node_offset(n);
827 let mut p = off + 1;
828 let endp = off + n.subtree_size();
829 let mut ok = false;
830 let mut next_pop = 0;
831 while p < endp {
832 r.pop_bytes(next_pop);
833 let n = &self.nodes[p];
834 let b = n.byte();
835 if r.try_push_byte(b) {
836 if n.token_id().is_some() {
837 ok = true;
838 break;
839 }
840 next_pop = if n.subtree_size() == 1 {
841 n.num_parents()
842 } else {
843 0
844 };
845 p += 1;
846 } else {
847 p += n.subtree_size();
848 next_pop = n.num_parents() - 1;
849 }
850 }
851 r.trie_finished();
852 ok
853 }
854
855 pub fn all_prefixes(&self, bytes: &[u8]) -> Vec<TokenId> {
856 let mut r = Vec::new();
857 let mut n = self.root();
858 for &b in bytes {
859 if let Some(c) = self.child_at_byte(n, b) {
860 n = c;
861 if let Some(tok) = n.token_id() {
862 r.push(tok);
863 }
864 } else {
865 break;
866 }
867 }
868 r
869 }
870
871 pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) {
872 if start.len() > 0 {
874 let mut fixed = FixedRecognizer::new(start);
875 self.add_bias(&mut fixed, toks, &[]);
876 }
877
878 let n = self.child_at_bytes(self.root(), start);
879 if n.is_none() {
880 return;
881 }
882 let n = n.unwrap();
883 r.trie_started("add_bias");
884 let (next_pop, nodes_walked) = self.add_bias_inner(r, toks, n);
885 if start.len() == 0 {
886 r.pop_bytes(next_pop);
888 }
889 r.trie_finished();
890 r.save_stats(nodes_walked);
891 let defl_tok = self.vocab_size() as u32;
893 toks.disallow_token(defl_tok);
894 }
895
896 #[inline(never)]
897 fn add_bias_inner(
898 &self,
899 r: &mut impl Recognizer,
900 toks: &mut SimpleVob,
901 n: &TrieNode,
902 ) -> (usize, usize) {
903 let defl_tok = self.vocab_size() as u32;
904 let off = self.node_offset(n);
905 let total_nodes = n.subtree_size();
906 let mut p = off + 1;
907 let endp = off + total_nodes;
908 let mut next_pop = 0;
909 let mut num_skip = 0;
910 while p < endp {
911 r.pop_bytes(next_pop);
912 let n = &self.nodes[p];
913 let b = n.byte();
914 if r.try_push_byte(b) {
915 toks.allow_token(n.token_id().unwrap_or(defl_tok));
916 next_pop = if n.subtree_size() == 1 {
917 n.num_parents()
918 } else {
919 0
920 };
921 p += 1;
922 } else {
923 let subtree_size = n.subtree_size();
924 p += subtree_size;
925 num_skip += subtree_size - 1;
927 next_pop = n.num_parents() - 1;
928 }
929 }
930 (next_pop, total_nodes - num_skip)
931 }
932
933 pub fn all_tokens(&self) -> Vec<Vec<u8>> {
934 (0..self.vocab_size())
935 .map(|idx| self.token(idx as u32).to_vec())
936 .collect()
937 }
938
939 pub fn sorted_tokens(&self) -> Vec<(u32, Vec<u8>)> {
940 let mut res = vec![];
941 let n = self.root();
942 let off = self.node_offset(n);
943 let mut p = off + 1;
944 let endp = off + n.subtree_size();
945 let mut next_pop = 0;
946 let mut bytes = vec![];
947 while p < endp {
948 bytes.drain(bytes.len() - next_pop..);
949 let n = &self.nodes[p];
950 let b = n.byte();
951 bytes.push(b);
952 if let Some(t) = n.token_id() {
953 res.push((t, bytes.clone()));
954 }
955 next_pop = if n.subtree_size() == 1 {
956 n.num_parents()
957 } else {
958 0
959 };
960 p += 1;
961 }
962 res
963 }
964
965 fn count_until_depth(&self, depth: usize) -> (usize, usize) {
966 let mut count = 0;
967 let mut num_tokens = 0;
968 let mut stack = vec![(self.root(), 0)];
969 while let Some((n, d)) = stack.pop() {
970 if d == depth {
971 continue;
972 } else {
973 for c in self.node_children(n) {
974 count += 1;
975 if c.token_id().is_some() {
976 num_tokens += 1;
977 }
978 stack.push((c, d + 1));
979 }
980 }
981 }
982 (count, num_tokens)
983 }
984
985 pub fn trie_stats(&self) -> String {
986 let mut nodes_histogram = vec![0; 256];
987
988 let mut token_nodes = 0;
989
990 let n = self.root();
991 let off = self.node_offset(n);
992 let mut p = off + 1;
993 let endp = off + n.subtree_size();
994 while p < endp {
995 let n = &self.nodes[p];
996
997 if n.token_id().is_some() {
998 token_nodes += 1;
999 }
1000
1001 let last_ch = self.next_node(n);
1002 let mut ch_p = p + 1;
1003 let mut num_children = 0;
1004
1005 while ch_p < last_ch {
1006 let ch = &self.nodes[ch_p];
1007 ch_p += ch.subtree_size();
1008 num_children += 1;
1009 }
1010
1011 nodes_histogram[std::cmp::min(9, num_children)] += 1;
1012
1013 p += 1;
1014 }
1015
1016 let mut histogram = String::new();
1017
1018 if false {
1019 for (idx, num) in nodes_histogram.iter().enumerate() {
1020 if *num > 0 {
1021 if !histogram.is_empty() {
1022 histogram.push_str(", ");
1023 }
1024 histogram.push_str(&format!("{}:{}", idx, num));
1025 }
1026 }
1027 }
1028
1029 if false {
1030 for n in self.node_children(self.root()) {
1031 histogram.push_str(&format!(
1032 "\n{} => {} {}",
1033 n.byte(),
1034 self.node_children(n).count(),
1035 n.subtree_size()
1036 ));
1037 }
1038 }
1039
1040 if false {
1041 for depth in 0..30 {
1042 let (count, num_tokens) = self.count_until_depth(depth);
1043 histogram.push_str(&format!(
1044 "\ndepth {}: {} nodes {} tokens",
1045 depth, count, num_tokens
1046 ));
1047 }
1048 }
1049
1050 if histogram.len() > 0 {
1051 histogram = format!("\n{}", histogram);
1052 }
1053
1054 format!(
1055 "{}{} nodes, {} token nodes, {} token bytes, {} max len",
1056 histogram,
1057 self.nodes.len(),
1058 token_nodes,
1059 self.token_data.len(),
1060 self.max_token_len,
1061 )
1062 }
1063}
1064
1065pub struct NodeChildren<'a> {
1066 trie: &'a TokTrie,
1067 current_offset: usize,
1068 end_offset: usize,
1069}
1070
1071impl<'a> Iterator for NodeChildren<'a> {
1072 type Item = &'a TrieNode;
1073
1074 fn next(&mut self) -> Option<Self::Item> {
1075 if self.current_offset < self.end_offset {
1076 let node = &self.trie.nodes[self.current_offset];
1077 self.current_offset += node.subtree_size();
1078 Some(node)
1079 } else {
1080 None
1081 }
1082 }
1083}
1084
1085struct TrieHash {
1086 token_id: u32,
1087 byte: u8,
1088 children: Vec<TrieHash>,
1089}
1090
1091impl TrieHash {
1092 fn new(byte: u8) -> TrieHash {
1093 TrieHash {
1094 token_id: NO_TOKEN,
1095 byte,
1096 children: Vec::new(),
1097 }
1098 }
1099 fn insert(&mut self, word: &[u8], token_id: u32) {
1100 if word.len() == 0 {
1101 assert!(self.token_id == NO_TOKEN);
1104 self.token_id = token_id;
1105 } else {
1106 for ch in &mut self.children {
1113 if ch.byte == word[0] {
1114 if word.len() == 1 && ch.token_id != NO_TOKEN {
1115 } else {
1117 ch.insert(&word[1..], token_id);
1118 return;
1119 }
1120 }
1121 }
1122
1123 let mut ch = TrieHash::new(word[0]);
1124 ch.insert(&word[1..], token_id);
1125 self.children.push(ch);
1126
1127 }
1140 }
1141 fn serialize(&mut self, data: &mut Vec<TrieNode>, num_parents: u8) {
1142 let idx = data.len();
1143 let mut num_ch = self.children.len();
1144 data.push(TrieNode::new(self.byte, self.token_id, num_parents));
1145 self.children.sort_by_key(|e| e.byte);
1147 for entry in &mut self.children {
1148 num_ch -= 1;
1149 entry.serialize(data, if num_ch == 0 { num_parents + 1 } else { 1 });
1150 }
1151 data[idx].bits2 |= ((data.len() - idx) as u32) << 8;
1152 }
1153}
1154
1155struct FixedRecognizer {
1156 bytes: Vec<u8>,
1157 bytes_ptr: usize,
1158}
1159
1160impl FixedRecognizer {
1161 fn new(bytes: &[u8]) -> FixedRecognizer {
1162 FixedRecognizer {
1163 bytes: bytes.to_vec(),
1164 bytes_ptr: 0,
1165 }
1166 }
1167}
1168
1169impl Recognizer for FixedRecognizer {
1170 fn collapse(&mut self) {}
1171 fn trie_finished(&mut self) {}
1172
1173 fn pop_bytes(&mut self, num: usize) {
1174 self.bytes_ptr -= num;
1175 }
1176
1177 fn try_push_byte(&mut self, byte: u8) -> bool {
1178 if self.bytes_ptr < self.bytes.len() && self.bytes[self.bytes_ptr] == byte {
1179 self.bytes_ptr += 1;
1180 true
1181 } else {
1182 false
1183 }
1184 }
1185}
1186
1187pub struct ApproximateTokEnv {
1188 trie: TokTrie,
1189}
1190
1191impl ApproximateTokEnv {
1192 pub fn new(trie: TokTrie) -> Self {
1193 Self { trie }
1194 }
1195}
1196
1197impl TokenizerEnv for ApproximateTokEnv {
1198 fn tok_trie(&self) -> &TokTrie {
1199 &self.trie
1200 }
1201
1202 fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
1203 self.trie.greedy_tokenize(s)
1204 }
1205
1206 fn tokenize_is_canonical(&self) -> bool {
1207 false
1208 }
1209}