Skip to main content

toktrie/
toktree.rs

1// use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node
2// special case num_ch=0xff -> num_ch=0x100
3
4use core::str;
5
6use bytemuck_derive::{Pod, Zeroable};
7
8use crate::{bytes::to_hex_string, tokenv::parse_numeric_token, SimpleVob};
9
10/// Numeric identifier for a single token in a tokenizer's vocabulary.
11pub type TokenId = u32;
12
13#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)]
14#[repr(C)]
15pub struct BinTokRxInfo {
16    pub vocab_size: u32,
17    pub tok_eos: TokenId,
18}
19
20#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub struct TokRxInfo {
22    pub vocab_size: u32,
23    pub tok_eos: TokenId,
24    pub tok_bos: Option<TokenId>,
25    pub tok_pad: Option<TokenId>,
26    pub tok_unk: Option<TokenId>,
27    pub tok_end_of_turn: Option<TokenId>,
28}
29
30impl TokRxInfo {
31    pub fn new(vocab_size: u32, tok_eos: TokenId) -> Self {
32        TokRxInfo {
33            vocab_size,
34            tok_eos,
35            tok_bos: None,
36            tok_pad: None,
37            tok_unk: None,
38            tok_end_of_turn: None,
39        }
40    }
41
42    pub fn from_bin(info: &BinTokRxInfo) -> Self {
43        TokRxInfo {
44            vocab_size: info.vocab_size,
45            tok_eos: info.tok_eos,
46            tok_bos: None,
47            tok_pad: None,
48            tok_unk: None,
49            tok_end_of_turn: None,
50        }
51    }
52
53    pub fn to_bin(&self) -> BinTokRxInfo {
54        BinTokRxInfo {
55            vocab_size: self.vocab_size,
56            tok_eos: self.tok_eos,
57        }
58    }
59}
60
61/// Byte-level constraint interface used for trie-based token filtering.
62///
63/// Implementations maintain a stack of states. The trie walker pushes bytes
64/// onto the stack as it descends, queries [`Recognizer::byte_allowed`] or
65/// [`Recognizer::try_push_byte`] to test transitions, and pops bytes when
66/// backtracking. This lets [`TokTrie`] efficiently compute the set of
67/// tokens that satisfy the constraint.
68pub trait Recognizer {
69    /// for _ in 0..num { stack.pop() }
70    fn pop_bytes(&mut self, num: usize);
71    /// "Collapse" the stack so that it consists only of its former
72    /// top element.
73    /// X = stack.top(); stack.empty(); stack.push(X)
74    fn collapse(&mut self);
75    /// check if stack.top() transitions via byte to a viable state
76    fn byte_allowed(&mut self, byte: u8) -> bool {
77        if self.try_push_byte(byte) {
78            self.pop_bytes(1);
79            true
80        } else {
81            false
82        }
83    }
84    /// Called when iteration over the trie is finished
85    /// Stack has exactly one element then, except when iteration started from non-root node.
86    /// In that case, the stack may have more than one element, and trie_finished() needs to pop the excessive elements.
87    fn trie_finished(&mut self);
88    /// Called when iteration over the trie is started
89    fn trie_started(&mut self, _dbg_lbl: &str) {}
90    /// This combines `push_byte` and `byte_allowed` into one function for performance.
91    fn try_push_byte(&mut self, byte: u8) -> bool;
92    /// Check if there are any errors to be reported to the user.
93    fn get_error(&mut self) -> Option<String> {
94        None
95    }
96    fn save_stats(&mut self, _nodes_walked: usize) {}
97}
98
99#[derive(Clone, Copy)]
100struct TokDesc {
101    len: u32,
102    off: u32,
103}
104
105/// A prefix tree (trie) of every token in a tokenizer's vocabulary.
106///
107/// The trie maps byte sequences to [`TokenId`]s and supports efficient
108/// constrained-decoding queries: given a [`Recognizer`] that accepts or
109/// rejects byte sequences, [`TokTrie::add_bias`] walks the trie and
110/// returns the set of tokens whose byte representations are accepted.
111#[derive(Clone)]
112pub struct TokTrie {
113    info: TokRxInfo,
114    token_offsets: Vec<TokDesc>,
115    token_data: Vec<u8>,
116    nodes: Vec<TrieNode>,
117    max_token_len: usize,
118    eos_tokens: Vec<TokenId>,
119}
120
121#[derive(Clone, Copy, Zeroable, Pod)]
122#[repr(C)]
123pub struct TrieNode {
124    // byte:token
125    bits: u32,
126    bits2: u32,
127}
128
129pub const INVALID_TOKEN: TokenId = 0xffff_ffff;
130
131const NO_TOKEN: u32 = 0xffffff;
132
133// PARENT_BITS=10 allows for up to 1024 parents, which is likely enough for tokens up to 2k bytes
134// this leaves 32-10 = 22 bits for subtree size, which allows for up to ~2M tokens
135// (4M trie nodes)
136// GLM4 tokenizer has a token with 1024 spaces - it requires PARENT_BITS >= 9
137// Note that because of the ~2M limit, we have ~3 bits left free in 'bits' field
138const PARENT_BITS: u32 = 10;
139const PARENT_MASK: u32 = (1 << PARENT_BITS) - 1;
140
141impl TrieNode {
142    fn new(byte: u8, token_id: u32, num_parents: usize) -> TrieNode {
143        assert!(num_parents > 0);
144        assert!(num_parents <= (1 << PARENT_BITS) as usize);
145        TrieNode {
146            bits: (token_id << 8) | byte as u32,
147            bits2: (num_parents - 1) as u32,
148        }
149    }
150
151    #[inline(always)]
152    pub fn byte(&self) -> u8 {
153        (self.bits & 0xff) as u8
154    }
155
156    #[inline(always)]
157    pub fn subtree_size(&self) -> usize {
158        (self.bits2 >> PARENT_BITS) as usize
159    }
160
161    fn set_subtree_size(&mut self, size: usize) {
162        assert!(size < (1 << (32 - PARENT_BITS)));
163        self.bits2 = (self.bits2 & PARENT_MASK) | ((size as u32) << PARENT_BITS);
164    }
165
166    #[inline(always)]
167    pub fn num_parents(&self) -> usize {
168        ((self.bits2 & PARENT_MASK) + 1) as usize
169    }
170
171    #[inline(always)]
172    pub fn token_id(&self) -> Option<u32> {
173        let r = self.bits >> 8;
174        if r == NO_TOKEN {
175            None
176        } else {
177            Some(r)
178        }
179    }
180}
181
182impl TokTrie {
183    // see https://github.com/microsoft/llguidance/blob/main/docs/special_tokens.md
184    pub const SPECIAL_TOKEN_MARKER: u8 = 0xff;
185
186    pub fn from(info: &TokRxInfo, words: &[Vec<u8>]) -> Self {
187        let mut trie = TrieHash::new(0xff);
188        let mut token_offsets = Vec::new();
189        let mut token_data = Vec::new();
190        assert!(info.vocab_size == words.len() as u32);
191        let mut max_token_len = 0;
192        for (idx, word) in words.iter().enumerate() {
193            if !word.is_empty() {
194                trie.insert(word, idx as u32);
195                max_token_len = std::cmp::max(max_token_len, word.len());
196            }
197            let desc = TokDesc {
198                len: word.len().try_into().unwrap(),
199                off: token_data.len().try_into().unwrap(),
200            };
201            token_offsets.push(desc);
202            token_data.extend_from_slice(word);
203        }
204        let mut nodes = Vec::new();
205        trie.serialize(&mut nodes, 0);
206        let r = TokTrie {
207            info: *info,
208            token_offsets,
209            token_data,
210            nodes,
211            max_token_len,
212            eos_tokens: vec![info.tok_eos],
213        };
214        r.validate();
215        r
216    }
217
218    pub fn filter(&self, filter: &SimpleVob) -> Self {
219        let mut words = vec![];
220        for n in 0..(self.vocab_size() as TokenId) {
221            let b = if filter.is_allowed(n) {
222                self.token(n)
223            } else {
224                &[]
225            };
226            words.push(b.to_vec());
227        }
228        let mut r = Self::from(self.info(), &words);
229        r.eos_tokens = self.eos_tokens.clone();
230        r
231    }
232
233    pub fn with_eos_token(&self, eos_token: TokenId) -> Self {
234        self.with_eos_tokens(&[eos_token])
235    }
236
237    pub fn with_eos_tokens(&self, eos_tokens: &[TokenId]) -> Self {
238        assert!(!eos_tokens.is_empty(), "eos_tokens must not be empty");
239        let vocab = self.vocab_size() as u32;
240        for &tok in eos_tokens {
241            assert!(
242                tok < vocab,
243                "EOS token ID {tok} is out of range (vocab_size={vocab})"
244            );
245        }
246        let mut r = self.clone();
247        r.info.tok_eos = eos_tokens[0];
248        r.eos_tokens = eos_tokens.to_vec();
249        r
250    }
251
252    pub fn with_info(&self, info: TokRxInfo) -> Self {
253        let mut r = self.clone();
254        r.info = info;
255        r.eos_tokens = vec![info.tok_eos];
256        r
257    }
258
259    pub fn build_chat_mode_trie(&self) -> Self {
260        self.with_eos_token(self.info.tok_end_of_turn.unwrap_or(self.info.tok_eos))
261    }
262
263    fn node_offset(&self, n: &TrieNode) -> usize {
264        let off = (n as *const _ as usize - self.root() as *const _ as usize)
265            / std::mem::size_of::<TrieNode>();
266        assert!(off < self.nodes.len());
267        off
268    }
269
270    fn next_node(&self, n: &TrieNode) -> usize {
271        self.node_offset(n) + n.subtree_size()
272    }
273
274    pub fn info(&self) -> &TokRxInfo {
275        &self.info
276    }
277
278    pub fn eos_token(&self) -> TokenId {
279        self.info.tok_eos
280    }
281
282    pub fn eos_tokens(&self) -> &[TokenId] {
283        &self.eos_tokens
284    }
285
286    pub fn vocab_size(&self) -> usize {
287        self.info.vocab_size as usize
288    }
289
290    pub fn alloc_token_set(&self) -> SimpleVob {
291        SimpleVob::alloc_with_capacity(self.vocab_size(), self.vocab_size() + 1)
292    }
293
294    pub fn singleton_token_set(&self, tok: TokenId) -> SimpleVob {
295        let mut r = self.alloc_token_set();
296        r.allow_token(tok);
297        r
298    }
299
300    /// Returns a token set containing all EOS tokens.
301    pub fn eos_token_set(&self) -> SimpleVob {
302        let mut r = self.alloc_token_set();
303        let vocab = self.vocab_size() as u32;
304        for &eos in self.eos_tokens() {
305            if eos != INVALID_TOKEN && eos < vocab {
306                r.allow_token(eos);
307            }
308        }
309        r
310    }
311
312    pub fn token_set_dbg(&self, ts: &SimpleVob) -> String {
313        let max_examples = 50;
314
315        let ts_neg = ts.negated();
316        let use_neg = ts_neg.num_set() * 10 < ts.num_set();
317        let ts1 = if use_neg { &ts_neg } else { ts };
318        let num_set = ts1.num_set();
319        let max_tok = std::cmp::min(max_examples, num_set);
320        let mut token_names = Vec::new();
321        // make sure we include EOS first if it's allowed
322        if self.info.tok_eos != INVALID_TOKEN && ts1.is_allowed(self.info.tok_eos) {
323            token_names.push("EOS".to_string());
324        }
325        for idx in 0..self.vocab_size() {
326            if idx as TokenId != self.info.tok_eos && ts1.is_allowed(idx as TokenId) {
327                token_names.push(self.token_dbg(idx as TokenId));
328                if token_names.len() >= max_tok {
329                    break;
330                }
331            }
332        }
333        if token_names.len() < num_set {
334            token_names.push("...".to_string());
335        }
336        format!(
337            "TokenSet: {}/{}; {}{}",
338            ts.num_set(),
339            self.vocab_size(),
340            if use_neg { "ALL EXCEPT " } else { "" },
341            token_names.join(" ")
342        )
343    }
344
345    pub fn alloc_logits(&self) -> Vec<f32> {
346        vec![0.0; self.vocab_size() + 1]
347    }
348
349    pub fn test_trace_tokens(&self, toks: &[u32]) -> String {
350        self.tokens_dbg_ext(toks, false)
351    }
352
353    pub const MAX_DBG_TOKENS: usize = 200;
354
355    pub fn tokens_dbg(&self, toks: &[u32]) -> String {
356        self.tokens_dbg_ext(toks, true)
357    }
358
359    fn tokens_dbg_ext(&self, toks: &[u32], quote: bool) -> String {
360        // if the token list is too long, we are typically interested in the most recent ones
361        let (limited, toks) = if toks.len() > Self::MAX_DBG_TOKENS {
362            ("…", &toks[toks.len() - Self::MAX_DBG_TOKENS..])
363        } else {
364            ("", toks)
365        };
366
367        let joined = toks
368            .iter()
369            .map(|t| self.token_dbg_ext(*t, false))
370            .collect::<Vec<_>>()
371            .join("‧");
372
373        if quote {
374            format!("⟦{limited}{joined}⟧")
375        } else if limited.is_empty() {
376            joined
377        } else {
378            format!("{limited}{joined}")
379        }
380    }
381
382    pub fn token_dbg(&self, idx: u32) -> String {
383        self.token_dbg_ext(idx, true)
384    }
385
386    fn token_dbg_ext(&self, idx: u32, quote: bool) -> String {
387        if idx == self.info.tok_eos {
388            "≺EOS≻".to_string()
389        } else if idx as usize >= self.vocab_size() {
390            format!("≺OOB[{idx}]≻")
391        } else {
392            // format!("{:?}[{}]", self.token_str(idx), idx)
393            let bytes = self.token(idx);
394            if bytes.len() > 1 && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER {
395                String::from_utf8_lossy(&bytes[1..]).to_string()
396            } else {
397                let s = String::from_utf8_lossy(bytes);
398                if s.is_empty() {
399                    format!("≺EMPTY[{idx}]≻")
400                } else if !s.contains('\u{fffd}') {
401                    let mut s = format!("{s:?}").replace("\\\"", "\"");
402                    s.remove(0);
403                    s.pop();
404                    if quote {
405                        format!("⟨{s}⟩")
406                    } else {
407                        s
408                    }
409                } else {
410                    let bytes = self.token(idx);
411                    format!("≺HEX[{}]≻", to_hex_string(bytes))
412                }
413            }
414        }
415    }
416
417    pub fn token_str(&self, idx: u32) -> String {
418        String::from_utf8_lossy(self.token(idx)).to_string()
419    }
420
421    pub fn token_len(&self, idx: u32) -> usize {
422        let t = self.token(idx);
423        if t.is_empty() || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
424            let mut idx = idx;
425            let mut len = 1;
426            while idx >= 10 {
427                idx /= 10;
428                len += 1;
429            }
430            // token 1234 -> \xff [ 1234 ]
431            len + 3
432        } else {
433            t.len()
434        }
435    }
436
437    pub fn token(&self, idx: u32) -> &[u8] {
438        if idx >= self.token_offsets.len() as u32 {
439            return &[];
440        }
441        let desc = self.token_offsets[idx as usize];
442        let len = desc.len as usize;
443        let off = desc.off as usize;
444        &self.token_data[off..(off + len)]
445    }
446
447    pub fn decode(&self, tokens: &[TokenId]) -> Vec<u8> {
448        self.decode_ext(tokens, true)
449    }
450
451    pub fn decode_ext(&self, tokens: &[TokenId], include_special: bool) -> Vec<u8> {
452        let mut res = Vec::with_capacity(tokens.len() * 6 + 32); // approximately
453        for &tok in tokens {
454            let t = self.token(tok);
455            if t.is_empty() {
456                if include_special {
457                    res.extend_from_slice(format!("<[{tok}]>").as_bytes());
458                }
459            } else if t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
460                if include_special {
461                    res.extend_from_slice(&t[1..]);
462                }
463            } else {
464                res.extend_from_slice(t);
465            }
466        }
467        res
468    }
469
470    pub fn decode_as_special(&self, tok: TokenId) -> Vec<u8> {
471        let mut res = Vec::with_capacity(9);
472        res.push(TokTrie::SPECIAL_TOKEN_MARKER);
473        res.extend_from_slice(format!("[{tok}]").as_bytes());
474        res
475    }
476
477    pub fn decode_raw(&self, tokens: &[TokenId]) -> Vec<u8> {
478        let mut res = Vec::with_capacity(tokens.len() * 6 + 32); // approximately
479        for &tok in tokens {
480            let t = self.token(tok);
481            if t.is_empty() || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
482                res.push(TokTrie::SPECIAL_TOKEN_MARKER);
483                res.extend_from_slice(format!("[{tok}]").as_bytes());
484            } else {
485                res.extend_from_slice(t);
486            }
487        }
488        res
489    }
490
491    pub fn decode_str(&self, tokens: &[TokenId]) -> String {
492        String::from_utf8_lossy(&self.decode(tokens)).to_string()
493    }
494
495    pub fn decode_raw_to_decode(&self, bytes: &[u8]) -> Vec<u8> {
496        let mut res = Vec::new();
497        let mut idx = 0;
498        while idx < bytes.len() {
499            if bytes[idx] == TokTrie::SPECIAL_TOKEN_MARKER {
500                if let Some((len, tok)) = parse_numeric_token(&bytes[(idx + 1)..]) {
501                    res.extend_from_slice(&self.decode(&[tok]));
502                    idx += len + 1;
503                } else {
504                    res.push(bytes[idx]);
505                    idx += 1;
506                }
507            } else {
508                res.push(bytes[idx]);
509                idx += 1;
510            }
511        }
512        res
513    }
514
515    pub fn is_special_token(&self, tok: TokenId) -> bool {
516        let bytes = self.token(tok);
517        !bytes.is_empty() && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER
518    }
519
520    pub fn get_special_token(&self, name: &str) -> Option<TokenId> {
521        self.child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
522            .and_then(|n| {
523                self.child_at_bytes(n, name.as_bytes())
524                    .and_then(|n| n.token_id())
525            })
526    }
527
528    pub fn get_special_tokens(&self) -> Vec<TokenId> {
529        let mut res = Vec::new();
530        let pref_node = self
531            .child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
532            .expect("missing special token prefix");
533        let mut stack = vec![pref_node];
534        while let Some(n) = stack.pop() {
535            for c in self.node_children(n) {
536                if let Some(tok) = c.token_id() {
537                    res.push(tok);
538                    if res.len() > Self::MAX_DBG_TOKENS + 1 {
539                        break;
540                    }
541                }
542                stack.push(c);
543            }
544        }
545        res.remove(0);
546        res
547    }
548
549    pub fn greedy_tokenize(&self, bytes: &[u8]) -> Vec<TokenId> {
550        let mut tokens = Vec::new();
551        let mut i = 0;
552        while i < bytes.len() {
553            let mut node = self.root();
554            let mut last_tok = None;
555            let mut last_idx = i;
556            #[allow(clippy::needless_range_loop)]
557            for j in i..bytes.len() {
558                if let Some(child) = self.child_at_byte(node, bytes[j]) {
559                    node = child;
560                    if let Some(tok) = node.token_id() {
561                        last_tok = Some(tok);
562                        last_idx = j;
563                    }
564                } else {
565                    break;
566                }
567            }
568            if let Some(t) = last_tok {
569                tokens.push(t);
570            } else {
571                // whoops, there is a byte missing from the tokenizer
572                // just carry on...
573                // https://github.com/guidance-ai/llguidance/issues/138
574            }
575            i = last_idx + 1;
576        }
577        tokens
578    }
579
580    /// Tokenize a string, interpreting `<name>` as special tokens.
581    pub fn tokenize_with_special<F>(&self, s: &str, str_tokenize: F) -> Vec<TokenId>
582    where
583        F: Fn(&str) -> Vec<TokenId>,
584    {
585        let max_len = 100;
586
587        let bytes = s.as_bytes();
588        let mut out = Vec::new();
589        let mut last = 0; // byte‐offset of the next “raw” segment
590        let mut i = 0; // current byte index
591
592        while i < bytes.len() {
593            if bytes[i] != b'<' {
594                i += 1;
595                continue;
596            }
597            // Potential start of `<...>`
598            let mut valid = true;
599            let mut j = i + 1;
600            let mut len_inside = 0;
601            // scan up to max_len chars or until we hit `>` or `<`
602            while j < bytes.len() && len_inside < max_len {
603                match bytes[j] {
604                    b'<' => {
605                        valid = false;
606                        break;
607                    }
608                    b'>' => break,
609                    _ => {
610                        len_inside += 1;
611                        j += 1;
612                    }
613                }
614            }
615            if !valid || j >= bytes.len() || bytes[j] != b'>' || len_inside == 0 {
616                // treat this `<` as literal
617                i += 1;
618                continue;
619            }
620
621            let name = &s[i..=j];
622            if let Some(special_tok) = self.get_special_token(name) {
623                if last < i {
624                    out.extend(str_tokenize(&s[last..i]));
625                }
626                out.push(special_tok);
627            } else {
628                // fallback: tokenize `<name>` literally
629                out.extend(str_tokenize(&s[last..=j]));
630            }
631            // advance past the `>`
632            i = j + 1;
633            last = i;
634        }
635        // any trailing text:
636        if last < bytes.len() {
637            out.extend(str_tokenize(&s[last..]));
638        }
639        out
640    }
641
642    pub fn tokenize_with_greedy_fallback(
643        &self,
644        bytes: &[u8],
645        str_tokenize: impl Fn(&str) -> Vec<TokenId>,
646    ) -> Vec<TokenId> {
647        match str::from_utf8(bytes) {
648            Ok(s) => {
649                // fast path
650                str_tokenize(s)
651            }
652            Err(_) => {
653                let mut res = vec![];
654                for chunk in bytes.utf8_chunks() {
655                    if !chunk.valid().is_empty() {
656                        res.extend(str_tokenize(chunk.valid()));
657                    }
658                    if !chunk.invalid().is_empty() {
659                        res.extend(self.greedy_tokenize(chunk.invalid()));
660                    }
661                }
662                res
663            }
664        }
665    }
666
667    pub fn has_extensions(&self, bytes: &[u8]) -> bool {
668        match self.child_at_bytes(self.root(), bytes) {
669            None => false,
670            Some(n) => n.subtree_size() > 1,
671        }
672    }
673
674    pub fn token_id(&self, bytes: &[u8]) -> Option<TokenId> {
675        let (tok, len) = self.prefix_token_id(bytes);
676        // println!("tok_id {:?} {:?} {:?} ", bytes, tok, len);
677        if len == bytes.len() {
678            Some(tok)
679        } else {
680            None
681        }
682    }
683
684    pub fn prefix_token_id(&self, bytes: &[u8]) -> (TokenId, usize) {
685        assert!(!bytes.is_empty());
686        let mut last = (0, 0);
687        let mut n = self.root();
688        for (idx, byte) in bytes.iter().enumerate() {
689            n = match self.child_at_byte(n, *byte) {
690                Some(n) => n,
691                None => break,
692            };
693            if let Some(tok) = n.token_id() {
694                last = (tok, idx + 1);
695            }
696        }
697        last
698    }
699
700    pub fn max_token_len(&self) -> usize {
701        self.max_token_len
702    }
703
704    fn validate_node(&self, n: &TrieNode, ep: usize, used: &mut [bool]) {
705        if let Some(tok) = n.token_id() {
706            assert!(tok < self.info.vocab_size);
707            assert!(!used[tok as usize]);
708            used[tok as usize] = true;
709        }
710        let endp = self.next_node(n);
711        assert!(endp <= ep);
712        for child in self.node_children(n) {
713            self.validate_node(child, endp, used);
714        }
715    }
716
717    fn validate(&self) {
718        self.validate_node(
719            self.root(),
720            self.next_node(self.root()),
721            &mut vec![false; self.info.vocab_size as usize],
722        );
723        for idx in 0..self.info.vocab_size {
724            let _ = self.token(idx);
725        }
726    }
727
728    pub fn root(&self) -> &TrieNode {
729        &self.nodes[0]
730    }
731
732    pub fn check_against(&self, tokens: &[Vec<u8>]) {
733        for (idx, bytes) in tokens.iter().enumerate() {
734            let tid = idx as TokenId;
735            assert!(bytes == self.token(tid));
736            let root = self.root();
737            if !bytes.is_empty() {
738                let tid2 = self
739                    .child_at_bytes(root, bytes)
740                    .unwrap()
741                    .token_id()
742                    .unwrap();
743                if tid != tid2 {
744                    let par = self
745                        .child_at_bytes(root, &bytes[0..bytes.len() - 1])
746                        .unwrap();
747                    let has_it = self.node_children(par).any(|n| {
748                        n.subtree_size() == 1
749                            && n.byte() == bytes[bytes.len() - 1]
750                            && n.token_id() == Some(tid)
751                    });
752                    assert!(has_it);
753                }
754            }
755        }
756    }
757
758    pub fn child_at_byte<'a>(&'a self, n: &'a TrieNode, byte: u8) -> Option<&'a TrieNode> {
759        self.node_children(n).find(|&child| child.byte() == byte)
760    }
761
762    pub fn all_subtokens(&self, bytes: &[u8]) -> Vec<TokenId> {
763        let mut r = Vec::new();
764        for i in 0..bytes.len() {
765            let mut n = self.root();
766            for &b in &bytes[i..] {
767                n = match self.child_at_byte(n, b) {
768                    Some(n) => n,
769                    None => break,
770                };
771                if let Some(tok) = n.token_id() {
772                    r.push(tok);
773                }
774            }
775        }
776        r
777    }
778
779    pub fn node_children(&self, n: &TrieNode) -> NodeChildren<'_> {
780        let off = self.node_offset(n);
781        NodeChildren {
782            trie: self,
783            current_offset: off + 1,
784            end_offset: off + n.subtree_size(),
785        }
786    }
787
788    pub fn child_at_bytes<'a>(&'a self, mut n: &'a TrieNode, bytes: &[u8]) -> Option<&'a TrieNode> {
789        for &byte in bytes {
790            n = self.child_at_byte(n, byte)?
791        }
792        Some(n)
793    }
794
795    pub fn token_id_at_bytes(&self, bytes: &[u8]) -> Option<TokenId> {
796        self.child_at_bytes(self.root(), bytes)
797            .and_then(|n| n.token_id())
798    }
799
800    /// Return how many tokens and bytes need to chopped off tokens,
801    /// so that we do not limit all possible future tokenizations matching the recognizer.
802    pub fn chop_tokens(&self, r: &mut impl Recognizer, tokens: &[TokenId]) -> (usize, usize) {
803        let max_token_lookback = 4;
804        let suff_bytes =
805            self.decode_raw(&tokens[tokens.len().saturating_sub(max_token_lookback)..]);
806        let suff_bytes = &suff_bytes[suff_bytes.len().saturating_sub(self.max_token_len())..];
807        // let suff_bytes = self.decode_raw(tokens);
808        // let suff_bytes = &suff_bytes[suff_bytes.len().saturating_sub(6)..];
809
810        // let mut anything_goes = StackRecognizer::from(AnythingGoes {});
811
812        for idx in 0..suff_bytes.len() {
813            let suff = &suff_bytes[idx..];
814            if self.has_valid_extensions(r, suff) {
815                let chop_bytes = suff.len();
816                assert!(chop_bytes > 0);
817                let mut curr_len = 0;
818                for chop_idx in 1..=tokens.len() {
819                    curr_len += self.token_len(tokens[tokens.len() - chop_idx]);
820                    if curr_len >= chop_bytes {
821                        return (chop_idx, curr_len);
822                    }
823                }
824                unreachable!();
825            }
826        }
827
828        (0, 0)
829    }
830
831    /// Check if add_bias() would have returned any tokens.
832    #[inline(never)]
833    pub fn has_valid_extensions(&self, r: &mut impl Recognizer, start: &[u8]) -> bool {
834        let n = self.child_at_bytes(self.root(), start);
835        if n.is_none() {
836            return false;
837        }
838        let n = n.unwrap();
839        r.trie_started("has_valid_extensions");
840        let off = self.node_offset(n);
841        let mut p = off + 1;
842        let endp = off + n.subtree_size();
843        let mut ok = false;
844        let mut next_pop = 0;
845        while p < endp {
846            r.pop_bytes(next_pop);
847            let n = &self.nodes[p];
848            let b = n.byte();
849            if r.try_push_byte(b) {
850                if n.token_id().is_some() {
851                    ok = true;
852                    break;
853                }
854                next_pop = if n.subtree_size() == 1 {
855                    n.num_parents()
856                } else {
857                    0
858                };
859                p += 1;
860            } else {
861                p += n.subtree_size();
862                next_pop = n.num_parents() - 1;
863            }
864        }
865        r.trie_finished();
866        ok
867    }
868
869    pub fn all_prefixes(&self, bytes: &[u8]) -> Vec<TokenId> {
870        let mut r = Vec::new();
871        let mut n = self.root();
872        for &b in bytes {
873            if let Some(c) = self.child_at_byte(n, b) {
874                n = c;
875                if let Some(tok) = n.token_id() {
876                    r.push(tok);
877                }
878            } else {
879                break;
880            }
881        }
882        r
883    }
884
885    pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) {
886        // all prefixes of 'start' are also allowed
887        if !start.is_empty() {
888            let mut fixed = FixedRecognizer::new(start);
889            self.add_bias(&mut fixed, toks, &[]);
890        }
891
892        let n = self.child_at_bytes(self.root(), start);
893        if n.is_none() {
894            return;
895        }
896        let n = n.unwrap();
897        r.trie_started("add_bias");
898        let (next_pop, nodes_walked) = self.add_bias_inner(r, toks, n);
899        if start.is_empty() {
900            // if start was non-empty, trie_finished() is supposed to clean this up
901            r.pop_bytes(next_pop);
902        }
903        r.trie_finished();
904        r.save_stats(nodes_walked);
905        // Clean up the fake token set by add_bias_inner for nodes without token_id.
906        // Note: If add_bias_inner panics, this cleanup won't run, leaving the fake token set.
907        // This is acceptable since panics indicate unrecoverable errors and the program state
908        // is likely already corrupted.
909        let defl_tok = self.vocab_size() as u32;
910        toks.disallow_token(defl_tok);
911    }
912
913    #[inline(never)]
914    fn add_bias_inner(
915        &self,
916        r: &mut impl Recognizer,
917        toks: &mut SimpleVob,
918        n: &TrieNode,
919    ) -> (usize, usize) {
920        // Use a fake token at vocab_size to avoid branching in the hot loop.
921        // This is safe because alloc_token_set() allocates capacity for vocab_size + 1 tokens.
922        // The fake token is cleaned up in add_bias() after the walk completes.
923        let defl_tok = self.vocab_size() as u32;
924        let off = self.node_offset(n);
925        let total_nodes = n.subtree_size();
926        let mut p = off + 1;
927        let endp = off + total_nodes;
928        let nodes = &self.nodes[..endp];
929        let mut next_pop = 0;
930        let mut num_skip = 0;
931        while p < endp {
932            r.pop_bytes(next_pop);
933            let n = unsafe {
934                debug_assert!(
935                    p < nodes.len(),
936                    "node index {} out of bounds (len: {})",
937                    p,
938                    nodes.len()
939                );
940                nodes.get_unchecked(p)
941            };
942            let b = n.byte();
943            if r.try_push_byte(b) {
944                // Avoid branching: always set a token (either real or fake at defl_tok)
945                let tok = n.token_id().unwrap_or(defl_tok);
946                debug_assert!(
947                    tok <= self.vocab_size() as u32,
948                    "token {} out of valid range (vocab_size: {})",
949                    tok,
950                    self.vocab_size()
951                );
952                unsafe { toks.allow_token_unchecked(tok) };
953                next_pop = if n.subtree_size() == 1 {
954                    n.num_parents()
955                } else {
956                    0
957                };
958                p += 1;
959            } else {
960                let subtree_size = n.subtree_size();
961                p += subtree_size;
962                // it's slightly faster to count skipped nodes, than walked nodes
963                num_skip += subtree_size - 1;
964                next_pop = n.num_parents() - 1;
965            }
966        }
967        (next_pop, total_nodes - num_skip)
968    }
969
970    pub fn all_tokens(&self) -> Vec<Vec<u8>> {
971        (0..self.vocab_size())
972            .map(|idx| self.token(idx as u32).to_vec())
973            .collect()
974    }
975
976    pub fn sorted_tokens(&self) -> Vec<(u32, Vec<u8>)> {
977        let mut res = vec![];
978        let n = self.root();
979        let off = self.node_offset(n);
980        let mut p = off + 1;
981        let endp = off + n.subtree_size();
982        let mut next_pop = 0;
983        let mut bytes = vec![];
984        while p < endp {
985            bytes.drain(bytes.len() - next_pop..);
986            let n = &self.nodes[p];
987            let b = n.byte();
988            bytes.push(b);
989            if let Some(t) = n.token_id() {
990                res.push((t, bytes.clone()));
991            }
992            next_pop = if n.subtree_size() == 1 {
993                n.num_parents()
994            } else {
995                0
996            };
997            p += 1;
998        }
999        res
1000    }
1001
1002    fn count_until_depth(&self, depth: usize) -> (usize, usize) {
1003        let mut count = 0;
1004        let mut num_tokens = 0;
1005        let mut stack = vec![(self.root(), 0)];
1006        while let Some((n, d)) = stack.pop() {
1007            if d == depth {
1008                continue;
1009            } else {
1010                for c in self.node_children(n) {
1011                    count += 1;
1012                    if c.token_id().is_some() {
1013                        num_tokens += 1;
1014                    }
1015                    stack.push((c, d + 1));
1016                }
1017            }
1018        }
1019        (count, num_tokens)
1020    }
1021
1022    pub fn trie_stats(&self) -> String {
1023        let mut nodes_histogram = vec![0; 256];
1024
1025        let mut token_nodes = 0;
1026
1027        let n = self.root();
1028        let off = self.node_offset(n);
1029        let mut p = off + 1;
1030        let endp = off + n.subtree_size();
1031        while p < endp {
1032            let n = &self.nodes[p];
1033
1034            if n.token_id().is_some() {
1035                token_nodes += 1;
1036            }
1037
1038            let last_ch = self.next_node(n);
1039            let mut ch_p = p + 1;
1040            let mut num_children = 0;
1041
1042            while ch_p < last_ch {
1043                let ch = &self.nodes[ch_p];
1044                ch_p += ch.subtree_size();
1045                num_children += 1;
1046            }
1047
1048            nodes_histogram[std::cmp::min(9, num_children)] += 1;
1049
1050            p += 1;
1051        }
1052
1053        let mut histogram = String::new();
1054
1055        if false {
1056            for (idx, num) in nodes_histogram.iter().enumerate() {
1057                if *num > 0 {
1058                    if !histogram.is_empty() {
1059                        histogram.push_str(", ");
1060                    }
1061                    histogram.push_str(&format!("{idx}:{num}"));
1062                }
1063            }
1064        }
1065
1066        if false {
1067            for n in self.node_children(self.root()) {
1068                histogram.push_str(&format!(
1069                    "\n{} => {} {}",
1070                    n.byte(),
1071                    self.node_children(n).count(),
1072                    n.subtree_size()
1073                ));
1074            }
1075        }
1076
1077        if false {
1078            for depth in 0..30 {
1079                let (count, num_tokens) = self.count_until_depth(depth);
1080                histogram.push_str(&format!(
1081                    "\ndepth {depth}: {count} nodes {num_tokens} tokens"
1082                ));
1083            }
1084        }
1085
1086        if !histogram.is_empty() {
1087            histogram = format!("\n{histogram}");
1088        }
1089
1090        format!(
1091            "{}{} nodes, {} token nodes, {} token bytes, {} max len",
1092            histogram,
1093            self.nodes.len(),
1094            token_nodes,
1095            self.token_data.len(),
1096            self.max_token_len,
1097        )
1098    }
1099}
1100
1101pub struct NodeChildren<'a> {
1102    trie: &'a TokTrie,
1103    current_offset: usize,
1104    end_offset: usize,
1105}
1106
1107impl<'a> Iterator for NodeChildren<'a> {
1108    type Item = &'a TrieNode;
1109
1110    fn next(&mut self) -> Option<Self::Item> {
1111        if self.current_offset < self.end_offset {
1112            let node = &self.trie.nodes[self.current_offset];
1113            self.current_offset += node.subtree_size();
1114            Some(node)
1115        } else {
1116            None
1117        }
1118    }
1119}
1120
1121struct TrieHash {
1122    token_id: u32,
1123    byte: u8,
1124    children: Vec<TrieHash>,
1125}
1126
1127impl TrieHash {
1128    fn new(byte: u8) -> TrieHash {
1129        TrieHash {
1130            token_id: NO_TOKEN,
1131            byte,
1132            children: Vec::new(),
1133        }
1134    }
1135    fn insert(&mut self, word: &[u8], token_id: u32) {
1136        if word.is_empty() {
1137            // Some tokenizers have duplicate tokens...
1138            // we just override
1139            assert!(self.token_id == NO_TOKEN);
1140            self.token_id = token_id;
1141        } else {
1142            // if self.children.len() == 0x100 {
1143            //     // assert!(self.children[word[0] as usize].byte == word[0]);
1144            //     self.children[word[0] as usize].insert(&word[1..], token_id);
1145            //     return;
1146            // }
1147
1148            for ch in &mut self.children {
1149                if ch.byte == word[0] {
1150                    if word.len() == 1 && ch.token_id != NO_TOKEN {
1151                        // this is duplicate token, proceed with adding a duplicate node
1152                    } else {
1153                        ch.insert(&word[1..], token_id);
1154                        return;
1155                    }
1156                }
1157            }
1158
1159            let mut ch = TrieHash::new(word[0]);
1160            ch.insert(&word[1..], token_id);
1161            self.children.push(ch);
1162
1163            // if it's getting dense, make it full
1164            // for cl100k threshold 60->15 nodes, 50->22, 40->45 30->94
1165            // for llama (32k) 50->5, 40->15
1166            // TODO remove this?
1167            // if self.children.len() > 250 {
1168            //     let mut v2 = (0..=255).map(TrieHash::new).collect::<Vec<_>>();
1169            //     for ch in self.children.drain(..) {
1170            //         let idx = ch.byte as usize;
1171            //         v2[idx] = ch;
1172            //     }
1173            //     self.children = v2;
1174            // }
1175        }
1176    }
1177
1178    fn serialize(&mut self, data: &mut Vec<TrieNode>, num_parents: usize) {
1179        let idx = data.len();
1180        let mut num_ch = self.children.len();
1181        data.push(TrieNode::new(
1182            self.byte,
1183            self.token_id,
1184            if num_parents == 0 { 1 } else { num_parents },
1185        ));
1186        //self.children.reverse();
1187        self.children.sort_by_key(|e| e.byte);
1188        for entry in &mut self.children {
1189            num_ch -= 1;
1190            entry.serialize(data, if num_ch == 0 { num_parents + 1 } else { 1 });
1191        }
1192        let subtree_size = data.len() - idx;
1193        data[idx].set_subtree_size(subtree_size);
1194    }
1195}
1196
1197struct FixedRecognizer {
1198    bytes: Vec<u8>,
1199    bytes_ptr: usize,
1200}
1201
1202impl FixedRecognizer {
1203    fn new(bytes: &[u8]) -> FixedRecognizer {
1204        FixedRecognizer {
1205            bytes: bytes.to_vec(),
1206            bytes_ptr: 0,
1207        }
1208    }
1209}
1210
1211impl Recognizer for FixedRecognizer {
1212    fn collapse(&mut self) {}
1213    fn trie_finished(&mut self) {}
1214
1215    fn pop_bytes(&mut self, num: usize) {
1216        self.bytes_ptr -= num;
1217    }
1218
1219    fn try_push_byte(&mut self, byte: u8) -> bool {
1220        if self.bytes_ptr < self.bytes.len() && self.bytes[self.bytes_ptr] == byte {
1221            self.bytes_ptr += 1;
1222            true
1223        } else {
1224            false
1225        }
1226    }
1227}
1228
1229pub struct AnythingGoes;
1230
1231impl Recognizer for AnythingGoes {
1232    fn collapse(&mut self) {}
1233    fn trie_finished(&mut self) {}
1234    fn pop_bytes(&mut self, _num: usize) {}
1235    fn try_push_byte(&mut self, _byte: u8) -> bool {
1236        true
1237    }
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242    use super::*;
1243
1244    fn make_test_trie(eos: TokenId) -> TokTrie {
1245        let info = TokRxInfo::new(4, eos);
1246        let words = vec![b"a".to_vec(), b"b".to_vec(), b"c".to_vec(), b"d".to_vec()];
1247        TokTrie::from(&info, &words)
1248    }
1249
1250    #[test]
1251    fn test_default_single_eos() {
1252        let trie = make_test_trie(2);
1253        assert_eq!(trie.eos_token(), 2);
1254        assert_eq!(trie.eos_tokens(), &[2]);
1255    }
1256
1257    #[test]
1258    fn test_with_eos_tokens_multiple() {
1259        let trie = make_test_trie(0).with_eos_tokens(&[1, 3]);
1260        assert_eq!(trie.eos_token(), 1);
1261        assert_eq!(trie.eos_tokens(), &[1, 3]);
1262        assert_eq!(trie.info().tok_eos, 1);
1263    }
1264
1265    #[test]
1266    fn test_with_eos_token_backwards_compat() {
1267        let trie = make_test_trie(0).with_eos_token(2);
1268        assert_eq!(trie.eos_token(), 2);
1269        assert_eq!(trie.eos_tokens(), &[2]);
1270    }
1271
1272    #[test]
1273    fn test_with_info_resets_eos_tokens() {
1274        let trie = make_test_trie(0).with_eos_tokens(&[1, 2]);
1275        let trie2 = trie.with_info(TokRxInfo::new(4, 3));
1276        assert_eq!(trie2.eos_token(), 3);
1277        assert_eq!(trie2.eos_tokens(), &[3]);
1278    }
1279
1280    #[test]
1281    fn test_filter_preserves_eos_tokens() {
1282        let trie = make_test_trie(0).with_eos_tokens(&[1, 2]);
1283        let mut filter = trie.alloc_token_set();
1284        for i in 0..4 {
1285            filter.allow_token(i);
1286        }
1287        let filtered = trie.filter(&filter);
1288        assert_eq!(filtered.eos_tokens(), &[1, 2]);
1289    }
1290
1291    #[test]
1292    #[should_panic(expected = "eos_tokens must not be empty")]
1293    fn test_with_eos_tokens_empty_panics() {
1294        make_test_trie(0).with_eos_tokens(&[]);
1295    }
1296
1297    #[test]
1298    fn test_eos_token_set_single() {
1299        let trie = make_test_trie(2);
1300        let set = trie.eos_token_set();
1301        assert!(set.is_allowed(2));
1302        assert!(!set.is_allowed(0));
1303        assert!(!set.is_allowed(1));
1304        assert_eq!(set.num_set(), 1);
1305    }
1306
1307    #[test]
1308    fn test_eos_token_set_multiple() {
1309        let trie = make_test_trie(0).with_eos_tokens(&[1, 3]);
1310        let set = trie.eos_token_set();
1311        assert!(set.is_allowed(1));
1312        assert!(set.is_allowed(3));
1313        assert!(!set.is_allowed(0));
1314        assert!(!set.is_allowed(2));
1315        assert_eq!(set.num_set(), 2);
1316    }
1317}