toktrie/
toktree.rs

1// use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node
2// special case num_ch=0xff -> num_ch=0x100
3
4use core::str;
5use std::sync::Arc;
6
7use bytemuck_derive::{Pod, Zeroable};
8
9use crate::{bytes::to_hex_string, SimpleVob};
10
11pub type TokenId = u32;
12
13#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)]
14#[repr(C)]
15pub struct BinTokRxInfo {
16    pub vocab_size: u32,
17    pub tok_eos: TokenId,
18}
19
20#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub struct TokRxInfo {
22    pub vocab_size: u32,
23    pub tok_eos: TokenId,
24    pub tok_bos: Option<TokenId>,
25    pub tok_pad: Option<TokenId>,
26    pub tok_unk: Option<TokenId>,
27    pub tok_end_of_turn: Option<TokenId>,
28}
29
30impl TokRxInfo {
31    pub fn new(vocab_size: u32, tok_eos: TokenId) -> Self {
32        TokRxInfo {
33            vocab_size,
34            tok_eos,
35            tok_bos: None,
36            tok_pad: None,
37            tok_unk: None,
38            tok_end_of_turn: None,
39        }
40    }
41
42    pub fn from_bin(info: &BinTokRxInfo) -> Self {
43        TokRxInfo {
44            vocab_size: info.vocab_size,
45            tok_eos: info.tok_eos,
46            tok_bos: None,
47            tok_pad: None,
48            tok_unk: None,
49            tok_end_of_turn: None,
50        }
51    }
52
53    pub fn to_bin(&self) -> BinTokRxInfo {
54        BinTokRxInfo {
55            vocab_size: self.vocab_size,
56            tok_eos: self.tok_eos,
57        }
58    }
59}
60
61pub trait Recognizer {
62    /// for _ in 0..num { stack.pop() }
63    fn pop_bytes(&mut self, num: usize);
64    /// "Collapse" the stack so that it consists only of its former
65    /// top element.
66    /// X = stack.top(); stack.empty(); stack.push(X)
67    fn collapse(&mut self);
68    /// check if stack.top() transitions via byte to a viable state
69    fn byte_allowed(&mut self, byte: u8) -> bool {
70        if self.try_push_byte(byte) {
71            self.pop_bytes(1);
72            true
73        } else {
74            false
75        }
76    }
77    /// Called when iteration over the trie is finished
78    /// Stack has exactly one element then, except when iteration started from non-root node.
79    /// In that case, the stack may have more than one element, and trie_finished() needs to pop the excessive elements.
80    fn trie_finished(&mut self);
81    /// Called when iteration over the trie is started
82    fn trie_started(&mut self, _dbg_lbl: &str) {}
83    /// This combines `push_byte` and `byte_allowed` into one function for performance.
84    fn try_push_byte(&mut self, byte: u8) -> bool;
85    /// Check if there are any errors to be reported to the user.
86    fn get_error(&mut self) -> Option<String> {
87        None
88    }
89    fn save_stats(&mut self, _nodes_walked: usize) {}
90}
91
92/// Parse a special token of the form \xFF [ 1 2 3 4 ]
93/// The initial \xFF is not included in the input.
94/// Returns the number of bytes consumed and the token id.
95pub fn parse_numeric_token(s: &[u8]) -> Option<(usize, TokenId)> {
96    let spec_len = s[0..std::cmp::min(s.len(), 20)]
97        .iter()
98        .position(|&x| x == ']' as u8);
99    if let Some(spec_len) = spec_len {
100        if s[0] != b'[' {
101            return None;
102        }
103        let inner_bytes = &s[1..spec_len];
104        if let Ok(inner_str) = std::str::from_utf8(inner_bytes) {
105            if let Ok(id) = u32::from_str_radix(inner_str, 10) {
106                return Some((spec_len + 1, id as TokenId));
107            }
108        }
109    }
110    None
111}
112
113pub trait TokenizerEnv: Send {
114    /// Associated trie.
115    fn tok_trie(&self) -> &TokTrie;
116
117    /// Tokenize a given byte sequence.
118    /// It may or may not interpret <|special_tokens|> as special.
119    fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId>;
120
121    /// Tokenize a given byte sequence.
122    /// It will interpret text starting with SPECIAL_TOKEN_MARKER as special tokens.
123    /// Returns tokens, and number of tokens are should never be re-tokenized
124    /// (because they were specified using the special token marker).
125    fn tokenize_bytes_marker(&self, s: &[u8]) -> (Vec<TokenId>, usize) {
126        let mut idx = 0;
127        let ff = TokTrie::SPECIAL_TOKEN_MARKER;
128        let mut result = Vec::new();
129        let trie = self.tok_trie();
130        let mut num_fixed_tokens = 0;
131        while idx < s.len() {
132            let normal_len = s[idx..]
133                .iter()
134                .position(|&x| x == ff)
135                .unwrap_or(s.len() - idx);
136            if normal_len != 0 {
137                result.extend_from_slice(&self.tokenize_bytes(&s[idx..idx + normal_len]));
138                idx += normal_len;
139            }
140            idx += 1; // skip ff
141            if idx + 2 < s.len() && s[idx] == '<' as u8 {
142                // tokenize \xff<foobar> as special token <foobar>
143                let spec_len = s[idx..std::cmp::min(s.len(), idx + 100)]
144                    .iter()
145                    .position(|&x| x == '>' as u8);
146                if let Some(mut spec_len) = spec_len {
147                    spec_len += 1;
148                    let spec_token = &s[idx - 1..idx + spec_len];
149                    if let Some(id) = trie.token_id_at_bytes(spec_token) {
150                        result.push(id);
151                        num_fixed_tokens = result.len();
152                        idx += spec_len;
153                    }
154                }
155            } else if idx < s.len() {
156                // tokenize \xff[1234] as token 1234
157                if let Some((n_bytes, tok_id)) = parse_numeric_token(&s[idx..]) {
158                    if tok_id < trie.vocab_size() as u32 {
159                        result.push(tok_id);
160                        num_fixed_tokens = result.len();
161                        idx += n_bytes;
162                    }
163                }
164            }
165        }
166
167        (result, num_fixed_tokens)
168    }
169
170    /// Tokenize a string coming from user. It may or may not interpret <|special_tokens|> as special.
171    fn tokenize(&self, s: &str) -> Vec<TokenId> {
172        self.tokenize_bytes(s.as_bytes())
173    }
174
175    /// Tokenize a string. It will interpret <|special_tokens|> as special.
176    fn tokenize_special(&self, s: &str) -> Vec<TokenId> {
177        self.tokenize(s)
178    }
179
180    /// End of sentence token
181    fn eos_token(&self) -> TokenId {
182        self.tok_trie().eos_token()
183    }
184
185    /// If this returns true, this tokenizer always returns canonical tokenizations
186    /// and can be used for forcing tokens.
187    /// Non-canonical tokenizers will typically just use TokTrie::greedy_tokenize().
188    fn tokenize_is_canonical(&self) -> bool {
189        true
190    }
191}
192
193pub type TokEnv = Arc<dyn TokenizerEnv + Sync + 'static>;
194
195pub struct TokEnvWithTrie {
196    base_env: TokEnv,
197    tok_trie: TokTrie,
198}
199
200impl TokEnvWithTrie {
201    pub fn new(base_env: TokEnv, tok_trie: TokTrie) -> Self {
202        Self { base_env, tok_trie }
203    }
204}
205
206impl TokenizerEnv for TokEnvWithTrie {
207    fn tok_trie(&self) -> &TokTrie {
208        &self.tok_trie
209    }
210
211    fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
212        self.base_env.tokenize_bytes(s)
213    }
214}
215
216#[derive(Clone)]
217pub struct TokTrie {
218    info: TokRxInfo,
219    token_offsets: Vec<u32>,
220    token_data: Vec<u8>,
221    nodes: Vec<TrieNode>,
222    max_token_len: usize,
223}
224
225#[derive(Clone, Copy, Zeroable, Pod)]
226#[repr(C)]
227pub struct TrieNode {
228    // byte:token
229    bits: u32,
230    bits2: u32,
231}
232
233pub const INVALID_TOKEN: TokenId = 0xffff_ffff;
234
235const NO_TOKEN: u32 = 0xffffff;
236
237impl TrieNode {
238    fn new(byte: u8, token_id: u32, num_parents: u8) -> TrieNode {
239        TrieNode {
240            bits: (token_id << 8) | byte as u32,
241            bits2: num_parents as u32,
242        }
243    }
244
245    #[inline(always)]
246    pub fn byte(&self) -> u8 {
247        (self.bits & 0xff) as u8
248    }
249
250    #[inline(always)]
251    pub fn subtree_size(&self) -> usize {
252        (self.bits2 >> 8) as usize
253    }
254
255    #[inline(always)]
256    pub fn num_parents(&self) -> usize {
257        (self.bits2 & 0xff) as usize
258    }
259
260    #[inline(always)]
261    pub fn token_id(&self) -> Option<u32> {
262        let r = self.bits >> 8;
263        if r == NO_TOKEN {
264            None
265        } else {
266            Some(r)
267        }
268    }
269}
270
271// max length of token is 255 bytes
272const LEN_BITS: u32 = 8;
273
274impl TokTrie {
275    // see https://github.com/microsoft/llguidance/blob/main/docs/special_tokens.md
276    pub const SPECIAL_TOKEN_MARKER: u8 = 0xff;
277
278    pub fn from(info: &TokRxInfo, words: &Vec<Vec<u8>>) -> Self {
279        let mut trie = TrieHash::new(0xff);
280        let mut token_offsets = Vec::new();
281        let mut token_data = Vec::new();
282        assert!(info.vocab_size == words.len() as u32);
283        let mut max_token_len = 0;
284        for (idx, word) in words.iter().enumerate() {
285            if word.len() > 0 {
286                trie.insert(word, idx as u32);
287                max_token_len = std::cmp::max(max_token_len, word.len());
288            }
289            assert!(word.len() < (1 << LEN_BITS));
290            assert!(token_data.len() < (1 << (32 - LEN_BITS)));
291            let desc = (word.len() as u32) | ((token_data.len() as u32) << LEN_BITS);
292            token_offsets.push(desc);
293            token_data.extend_from_slice(word);
294        }
295        let mut nodes = Vec::new();
296        trie.serialize(&mut nodes, 0);
297        let r = TokTrie {
298            info: info.clone(),
299            token_offsets,
300            token_data,
301            nodes,
302            max_token_len,
303        };
304        r.validate();
305        r
306    }
307
308    pub fn with_eos_token(&self, eos_token: TokenId) -> Self {
309        self.with_info(TokRxInfo {
310            tok_eos: eos_token,
311            ..self.info.clone()
312        })
313    }
314
315    pub fn with_info(&self, info: TokRxInfo) -> Self {
316        let mut r = self.clone();
317        r.info = info.clone();
318        r
319    }
320
321    pub fn build_chat_mode_trie(&self) -> Self {
322        self.with_eos_token(self.info.tok_end_of_turn.unwrap_or(self.info.tok_eos))
323    }
324
325    fn node_offset(&self, n: &TrieNode) -> usize {
326        let off = (n as *const _ as usize - self.root() as *const _ as usize)
327            / std::mem::size_of::<TrieNode>();
328        assert!(off < self.nodes.len());
329        off
330    }
331
332    fn next_node(&self, n: &TrieNode) -> usize {
333        return self.node_offset(n) + n.subtree_size();
334    }
335
336    pub fn info(&self) -> &TokRxInfo {
337        &self.info
338    }
339
340    pub fn eos_token(&self) -> TokenId {
341        self.info.tok_eos
342    }
343
344    pub fn vocab_size(&self) -> usize {
345        self.info.vocab_size as usize
346    }
347
348    pub fn alloc_token_set(&self) -> SimpleVob {
349        SimpleVob::alloc_with_capacity(self.vocab_size(), self.vocab_size() + 1)
350    }
351
352    pub fn singleton_token_set(&self, tok: TokenId) -> SimpleVob {
353        let mut r = self.alloc_token_set();
354        r.allow_token(tok);
355        r
356    }
357
358    pub fn token_set_dbg(&self, ts: &SimpleVob) -> String {
359        let max_examples = 50;
360
361        let ts_neg = ts.negated();
362        let use_neg = ts_neg.num_set() * 10 < ts.num_set();
363        let ts1 = if use_neg { &ts_neg } else { &ts };
364        let num_set = ts1.num_set();
365        let max_tok = std::cmp::min(max_examples, num_set);
366        let mut token_names = Vec::new();
367        // make sure we include EOS first if it's allowed
368        if self.info.tok_eos != INVALID_TOKEN && ts1.is_allowed(self.info.tok_eos) {
369            token_names.push("EOS".to_string());
370        }
371        for idx in 0..self.vocab_size() {
372            if idx as TokenId != self.info.tok_eos && ts1.is_allowed(idx as TokenId) {
373                token_names.push(self.token_dbg(idx as TokenId));
374                if token_names.len() >= max_tok {
375                    break;
376                }
377            }
378        }
379        if token_names.len() < num_set {
380            token_names.push("...".to_string());
381        }
382        format!(
383            "TokenSet: {}/{}; {}{}",
384            ts.num_set(),
385            self.vocab_size(),
386            if use_neg { "ALL EXCEPT " } else { "" },
387            token_names.join(" ")
388        )
389    }
390
391    pub fn alloc_logits(&self) -> Vec<f32> {
392        vec![0.0; self.vocab_size() + 1]
393    }
394
395    pub fn test_trace_tokens(&self, toks: &[u32]) -> String {
396        self.tokens_dbg_ext(toks, false)
397    }
398
399    pub const MAX_DBG_TOKENS: usize = 200;
400
401    pub fn tokens_dbg(&self, toks: &[u32]) -> String {
402        self.tokens_dbg_ext(toks, true)
403    }
404
405    fn tokens_dbg_ext(&self, toks: &[u32], quote: bool) -> String {
406        let (limited, toks) = if toks.len() > Self::MAX_DBG_TOKENS {
407            (true, &toks[0..Self::MAX_DBG_TOKENS])
408        } else {
409            (false, toks)
410        };
411
412        let mut joined = toks
413            .iter()
414            .map(|t| self.token_dbg_ext(*t, false))
415            .collect::<Vec<_>>()
416            .join("‧");
417
418        if limited {
419            joined.push_str("…");
420        }
421
422        if quote {
423            format!("⟦{}⟧", joined)
424        } else {
425            joined
426        }
427    }
428
429    pub fn token_dbg(&self, idx: u32) -> String {
430        self.token_dbg_ext(idx, true)
431    }
432
433    fn token_dbg_ext(&self, idx: u32, quote: bool) -> String {
434        if idx == self.info.tok_eos {
435            "≺EOS≻".to_string()
436        } else if idx as usize >= self.vocab_size() {
437            format!("≺OOB[{}]≻", idx)
438        } else {
439            // format!("{:?}[{}]", self.token_str(idx), idx)
440            let bytes = self.token(idx);
441            if bytes.len() > 1 && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER {
442                String::from_utf8_lossy(&bytes[1..]).to_string()
443            } else {
444                let s = String::from_utf8_lossy(bytes);
445                if s.len() == 0 {
446                    format!("≺EMPTY[{}]≻", idx)
447                } else if !s.contains('\u{fffd}') {
448                    let mut s = format!("{:?}", s).replace("\\\"", "\"");
449                    s.remove(0);
450                    s.pop();
451                    if quote {
452                        format!("⟨{}⟩", s)
453                    } else {
454                        s
455                    }
456                } else {
457                    let bytes = self.token(idx);
458                    format!("≺HEX[{}]≻", to_hex_string(bytes))
459                }
460            }
461        }
462    }
463
464    pub fn token_str(&self, idx: u32) -> String {
465        String::from_utf8_lossy(self.token(idx)).to_string()
466    }
467
468    pub fn token_len(&self, idx: u32) -> usize {
469        let t = self.token(idx);
470        if t.len() == 0 || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
471            let mut idx = idx;
472            let mut len = 1;
473            while idx >= 10 {
474                idx /= 10;
475                len += 1;
476            }
477            // token 1234 -> \xff [ 1234 ]
478            len + 3
479        } else {
480            t.len()
481        }
482    }
483
484    pub fn token(&self, idx: u32) -> &[u8] {
485        if idx >= self.token_offsets.len() as u32 {
486            return &[];
487        }
488        let off = self.token_offsets[idx as usize];
489        let len = off & ((1 << LEN_BITS) - 1);
490        let off = (off >> LEN_BITS) as usize;
491        &self.token_data[off..(off + len as usize)]
492    }
493
494    pub fn decode(&self, tokens: &[TokenId]) -> Vec<u8> {
495        let mut res = Vec::new();
496        res.reserve(tokens.len() * 6 + 32); // approximately
497        for &tok in tokens {
498            let t = self.token(tok);
499            if t.len() == 0 {
500                res.extend_from_slice(format!("<[{}]>", tok).as_bytes());
501            } else if t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
502                res.extend_from_slice(&t[1..]);
503            } else {
504                res.extend_from_slice(t);
505            }
506        }
507        res
508    }
509
510    pub fn decode_as_special(&self, tok: TokenId) -> Vec<u8> {
511        let mut res = Vec::new();
512        res.reserve(9);
513        res.push(TokTrie::SPECIAL_TOKEN_MARKER);
514        res.extend_from_slice(format!("[{}]", tok).as_bytes());
515        res
516    }
517
518    pub fn decode_raw(&self, tokens: &[TokenId]) -> Vec<u8> {
519        let mut res = Vec::new();
520        res.reserve(tokens.len() * 6 + 32); // approximately
521        for &tok in tokens {
522            let t = self.token(tok);
523            if t.len() == 0 || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
524                res.push(TokTrie::SPECIAL_TOKEN_MARKER);
525                res.extend_from_slice(format!("[{}]", tok).as_bytes());
526            } else {
527                res.extend_from_slice(t);
528            }
529        }
530        res
531    }
532
533    pub fn decode_str(&self, tokens: &[TokenId]) -> String {
534        String::from_utf8_lossy(&self.decode(tokens)).to_string()
535    }
536
537    pub fn decode_raw_to_decode(&self, bytes: &[u8]) -> Vec<u8> {
538        let mut res = Vec::new();
539        let mut idx = 0;
540        while idx < bytes.len() {
541            if bytes[idx] == TokTrie::SPECIAL_TOKEN_MARKER {
542                if let Some((len, tok)) = parse_numeric_token(&bytes[(idx + 1)..]) {
543                    res.extend_from_slice(&self.decode(&[tok]));
544                    idx += len + 1;
545                } else {
546                    res.push(bytes[idx]);
547                    idx += 1;
548                }
549            } else {
550                res.push(bytes[idx]);
551                idx += 1;
552            }
553        }
554        res
555    }
556
557    pub fn is_special_token(&self, tok: TokenId) -> bool {
558        let bytes = self.token(tok);
559        bytes.len() > 0 && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER
560    }
561
562    pub fn get_special_token(&self, name: &str) -> Option<TokenId> {
563        self.child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
564            .and_then(|n| {
565                self.child_at_bytes(n, name.as_bytes())
566                    .and_then(|n| n.token_id())
567            })
568    }
569
570    pub fn get_special_tokens(&self) -> Vec<TokenId> {
571        let mut res = Vec::new();
572        let pref_node = self
573            .child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
574            .expect("missing special token prefix");
575        let mut stack = vec![pref_node];
576        while let Some(n) = stack.pop() {
577            for c in self.node_children(n) {
578                if let Some(tok) = c.token_id() {
579                    res.push(tok);
580                    if res.len() > Self::MAX_DBG_TOKENS + 1 {
581                        break;
582                    }
583                }
584                stack.push(c);
585            }
586        }
587        res.remove(0);
588        res
589    }
590
591    pub fn greedy_tokenize(&self, bytes: &[u8]) -> Vec<TokenId> {
592        let mut r = Vec::new();
593        if bytes.len() == 0 {
594            return r;
595        }
596
597        let mut n = self.root();
598        let mut last_tok = None;
599        let mut last_idx = 0;
600        let mut idx = 0;
601        while idx < bytes.len() {
602            match self.child_at_byte(n, bytes[idx]) {
603                Some(c) => {
604                    if let Some(tok) = c.token_id() {
605                        last_tok = Some(tok);
606                        last_idx = idx;
607                    }
608                    n = c;
609                }
610                None => {
611                    r.push(last_tok.unwrap());
612                    idx = last_idx;
613                    n = self.root();
614                }
615            }
616            idx = idx + 1;
617        }
618        r.push(last_tok.unwrap());
619        r
620    }
621
622    pub fn tokenize_with_greedy_fallback(
623        &self,
624        bytes: &[u8],
625        str_tokenize: impl Fn(&str) -> Vec<TokenId>,
626    ) -> Vec<TokenId> {
627        match str::from_utf8(bytes) {
628            Ok(s) => {
629                // fast path
630                str_tokenize(s)
631            }
632            Err(_) => {
633                let mut res = vec![];
634                for chunk in bytes.utf8_chunks() {
635                    if !chunk.valid().is_empty() {
636                        res.extend(str_tokenize(chunk.valid()));
637                    }
638                    if !chunk.invalid().is_empty() {
639                        res.extend(self.greedy_tokenize(chunk.invalid()));
640                    }
641                }
642                res
643            }
644        }
645    }
646
647    pub fn has_extensions(&self, bytes: &[u8]) -> bool {
648        match self.child_at_bytes(self.root(), bytes) {
649            None => false,
650            Some(n) => n.subtree_size() > 1,
651        }
652    }
653
654    pub fn token_id(&self, bytes: &[u8]) -> Option<TokenId> {
655        let (tok, len) = self.prefix_token_id(bytes);
656        // println!("tok_id {:?} {:?} {:?} ", bytes, tok, len);
657        if len == bytes.len() {
658            Some(tok)
659        } else {
660            None
661        }
662    }
663
664    pub fn prefix_token_id(&self, bytes: &[u8]) -> (TokenId, usize) {
665        assert!(bytes.len() > 0);
666        let mut last = (0, 0);
667        let mut n = self.root();
668        for (idx, byte) in bytes.iter().enumerate() {
669            n = match self.child_at_byte(n, *byte) {
670                Some(n) => n,
671                None => break,
672            };
673            if let Some(tok) = n.token_id() {
674                last = (tok, idx + 1);
675            }
676        }
677        return last;
678    }
679
680    pub fn max_token_len(&self) -> usize {
681        self.max_token_len
682    }
683
684    fn validate_node(&self, n: &TrieNode, ep: usize, used: &mut [bool]) {
685        if let Some(tok) = n.token_id() {
686            assert!(tok < self.info.vocab_size);
687            assert!(!used[tok as usize]);
688            used[tok as usize] = true;
689        }
690        let endp = self.next_node(n);
691        assert!(endp <= ep);
692        for child in self.node_children(n) {
693            self.validate_node(child, endp, used);
694        }
695    }
696
697    fn validate(&self) {
698        self.validate_node(
699            self.root(),
700            self.next_node(self.root()),
701            &mut vec![false; self.info.vocab_size as usize],
702        );
703        for idx in 0..self.info.vocab_size {
704            let _ = self.token(idx);
705        }
706    }
707
708    pub fn root(&self) -> &TrieNode {
709        &self.nodes[0]
710    }
711
712    pub fn check_against(&self, tokens: &Vec<Vec<u8>>) {
713        let vocab_size = tokens.len();
714        for idx in 0..vocab_size {
715            let bytes = &tokens[idx];
716            let tid = idx as TokenId;
717            assert!(bytes == self.token(tid));
718            let root = self.root();
719            if bytes.len() > 0 {
720                let tid2 = self
721                    .child_at_bytes(root, &bytes)
722                    .unwrap()
723                    .token_id()
724                    .unwrap();
725                if tid != tid2 {
726                    let par = self
727                        .child_at_bytes(root, &bytes[0..bytes.len() - 1])
728                        .unwrap();
729                    let has_it = self.node_children(par).any(|n| {
730                        n.subtree_size() == 1
731                            && n.byte() == bytes[bytes.len() - 1]
732                            && n.token_id() == Some(tid)
733                    });
734                    assert!(has_it);
735                }
736            }
737        }
738    }
739
740    pub fn child_at_byte<'a>(&'a self, n: &'a TrieNode, byte: u8) -> Option<&'a TrieNode> {
741        for child in self.node_children(n) {
742            if child.byte() == byte {
743                return Some(child);
744            }
745        }
746        None
747    }
748
749    pub fn all_subtokens(&self, bytes: &[u8]) -> Vec<TokenId> {
750        let mut r = Vec::new();
751        for i in 0..bytes.len() {
752            let mut n = self.root();
753            for j in i..bytes.len() {
754                n = match self.child_at_byte(n, bytes[j]) {
755                    Some(n) => n,
756                    None => break,
757                };
758                if let Some(tok) = n.token_id() {
759                    r.push(tok);
760                }
761            }
762        }
763        r
764    }
765
766    pub fn node_children(&self, n: &TrieNode) -> NodeChildren {
767        let off = self.node_offset(n);
768        NodeChildren {
769            trie: self,
770            current_offset: off + 1,
771            end_offset: off + n.subtree_size(),
772        }
773    }
774
775    pub fn child_at_bytes<'a>(&'a self, mut n: &'a TrieNode, bytes: &[u8]) -> Option<&'a TrieNode> {
776        for &byte in bytes {
777            n = match self.child_at_byte(n, byte) {
778                Some(n) => n,
779                None => return None,
780            }
781        }
782        Some(n)
783    }
784
785    pub fn token_id_at_bytes(&self, bytes: &[u8]) -> Option<TokenId> {
786        self.child_at_bytes(self.root(), bytes)
787            .and_then(|n| n.token_id())
788    }
789
790    /// Return how many tokens and bytes need to chopped off tokens,
791    /// so that we do not limit all possible future tokenizations matching the recognizer.
792    pub fn chop_tokens(&self, r: &mut impl Recognizer, tokens: &[TokenId]) -> (usize, usize) {
793        let max_token_lookback = 4;
794        let suff_bytes =
795            self.decode_raw(&tokens[tokens.len().saturating_sub(max_token_lookback)..]);
796        let suff_bytes = &suff_bytes[suff_bytes.len().saturating_sub(self.max_token_len())..];
797
798        for idx in 0..suff_bytes.len() {
799            let suff = &suff_bytes[idx..];
800            if self.has_valid_extensions(r, suff) {
801                let chop_bytes = suff.len();
802                assert!(chop_bytes > 0);
803                let mut curr_len = 0;
804                for chop_idx in 1..=tokens.len() {
805                    curr_len += self.token_len(tokens[tokens.len() - chop_idx]);
806                    if curr_len >= chop_bytes {
807                        return (chop_idx, curr_len);
808                    }
809                }
810                unreachable!();
811            }
812        }
813
814        (0, 0)
815    }
816
817    /// Check if add_bias() would have returned any tokens.
818    #[inline(never)]
819    pub fn has_valid_extensions(&self, r: &mut impl Recognizer, start: &[u8]) -> bool {
820        let n = self.child_at_bytes(self.root(), start);
821        if n.is_none() {
822            return false;
823        }
824        let n = n.unwrap();
825        r.trie_started("has_valid_extensions");
826        let off = self.node_offset(n);
827        let mut p = off + 1;
828        let endp = off + n.subtree_size();
829        let mut ok = false;
830        let mut next_pop = 0;
831        while p < endp {
832            r.pop_bytes(next_pop);
833            let n = &self.nodes[p];
834            let b = n.byte();
835            if r.try_push_byte(b) {
836                if n.token_id().is_some() {
837                    ok = true;
838                    break;
839                }
840                next_pop = if n.subtree_size() == 1 {
841                    n.num_parents()
842                } else {
843                    0
844                };
845                p += 1;
846            } else {
847                p += n.subtree_size();
848                next_pop = n.num_parents() - 1;
849            }
850        }
851        r.trie_finished();
852        ok
853    }
854
855    pub fn all_prefixes(&self, bytes: &[u8]) -> Vec<TokenId> {
856        let mut r = Vec::new();
857        let mut n = self.root();
858        for &b in bytes {
859            if let Some(c) = self.child_at_byte(n, b) {
860                n = c;
861                if let Some(tok) = n.token_id() {
862                    r.push(tok);
863                }
864            } else {
865                break;
866            }
867        }
868        r
869    }
870
871    pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) {
872        // all prefixes of 'start' are also allowed
873        if start.len() > 0 {
874            let mut fixed = FixedRecognizer::new(start);
875            self.add_bias(&mut fixed, toks, &[]);
876        }
877
878        let n = self.child_at_bytes(self.root(), start);
879        if n.is_none() {
880            return;
881        }
882        let n = n.unwrap();
883        r.trie_started("add_bias");
884        let (next_pop, nodes_walked) = self.add_bias_inner(r, toks, n);
885        if start.len() == 0 {
886            // if start was non-empty, trie_finished() is supposed to clean this up
887            r.pop_bytes(next_pop);
888        }
889        r.trie_finished();
890        r.save_stats(nodes_walked);
891        // revert the fake token
892        let defl_tok = self.vocab_size() as u32;
893        toks.disallow_token(defl_tok);
894    }
895
896    #[inline(never)]
897    fn add_bias_inner(
898        &self,
899        r: &mut impl Recognizer,
900        toks: &mut SimpleVob,
901        n: &TrieNode,
902    ) -> (usize, usize) {
903        let defl_tok = self.vocab_size() as u32;
904        let off = self.node_offset(n);
905        let total_nodes = n.subtree_size();
906        let mut p = off + 1;
907        let endp = off + total_nodes;
908        let mut next_pop = 0;
909        let mut num_skip = 0;
910        while p < endp {
911            r.pop_bytes(next_pop);
912            let n = &self.nodes[p];
913            let b = n.byte();
914            if r.try_push_byte(b) {
915                toks.allow_token(n.token_id().unwrap_or(defl_tok));
916                next_pop = if n.subtree_size() == 1 {
917                    n.num_parents()
918                } else {
919                    0
920                };
921                p += 1;
922            } else {
923                let subtree_size = n.subtree_size();
924                p += subtree_size;
925                // it's slightly faster to count skipped nodes, than walked nodes
926                num_skip += subtree_size - 1;
927                next_pop = n.num_parents() - 1;
928            }
929        }
930        (next_pop, total_nodes - num_skip)
931    }
932
933    pub fn all_tokens(&self) -> Vec<Vec<u8>> {
934        (0..self.vocab_size())
935            .map(|idx| self.token(idx as u32).to_vec())
936            .collect()
937    }
938
939    pub fn sorted_tokens(&self) -> Vec<(u32, Vec<u8>)> {
940        let mut res = vec![];
941        let n = self.root();
942        let off = self.node_offset(n);
943        let mut p = off + 1;
944        let endp = off + n.subtree_size();
945        let mut next_pop = 0;
946        let mut bytes = vec![];
947        while p < endp {
948            bytes.drain(bytes.len() - next_pop..);
949            let n = &self.nodes[p];
950            let b = n.byte();
951            bytes.push(b);
952            if let Some(t) = n.token_id() {
953                res.push((t, bytes.clone()));
954            }
955            next_pop = if n.subtree_size() == 1 {
956                n.num_parents()
957            } else {
958                0
959            };
960            p += 1;
961        }
962        res
963    }
964
965    fn count_until_depth(&self, depth: usize) -> (usize, usize) {
966        let mut count = 0;
967        let mut num_tokens = 0;
968        let mut stack = vec![(self.root(), 0)];
969        while let Some((n, d)) = stack.pop() {
970            if d == depth {
971                continue;
972            } else {
973                for c in self.node_children(n) {
974                    count += 1;
975                    if c.token_id().is_some() {
976                        num_tokens += 1;
977                    }
978                    stack.push((c, d + 1));
979                }
980            }
981        }
982        (count, num_tokens)
983    }
984
985    pub fn trie_stats(&self) -> String {
986        let mut nodes_histogram = vec![0; 256];
987
988        let mut token_nodes = 0;
989
990        let n = self.root();
991        let off = self.node_offset(n);
992        let mut p = off + 1;
993        let endp = off + n.subtree_size();
994        while p < endp {
995            let n = &self.nodes[p];
996
997            if n.token_id().is_some() {
998                token_nodes += 1;
999            }
1000
1001            let last_ch = self.next_node(n);
1002            let mut ch_p = p + 1;
1003            let mut num_children = 0;
1004
1005            while ch_p < last_ch {
1006                let ch = &self.nodes[ch_p];
1007                ch_p += ch.subtree_size();
1008                num_children += 1;
1009            }
1010
1011            nodes_histogram[std::cmp::min(9, num_children)] += 1;
1012
1013            p += 1;
1014        }
1015
1016        let mut histogram = String::new();
1017
1018        if false {
1019            for (idx, num) in nodes_histogram.iter().enumerate() {
1020                if *num > 0 {
1021                    if !histogram.is_empty() {
1022                        histogram.push_str(", ");
1023                    }
1024                    histogram.push_str(&format!("{}:{}", idx, num));
1025                }
1026            }
1027        }
1028
1029        if false {
1030            for n in self.node_children(self.root()) {
1031                histogram.push_str(&format!(
1032                    "\n{} => {} {}",
1033                    n.byte(),
1034                    self.node_children(n).count(),
1035                    n.subtree_size()
1036                ));
1037            }
1038        }
1039
1040        if false {
1041            for depth in 0..30 {
1042                let (count, num_tokens) = self.count_until_depth(depth);
1043                histogram.push_str(&format!(
1044                    "\ndepth {}: {} nodes {} tokens",
1045                    depth, count, num_tokens
1046                ));
1047            }
1048        }
1049
1050        if histogram.len() > 0 {
1051            histogram = format!("\n{}", histogram);
1052        }
1053
1054        format!(
1055            "{}{} nodes, {} token nodes, {} token bytes, {} max len",
1056            histogram,
1057            self.nodes.len(),
1058            token_nodes,
1059            self.token_data.len(),
1060            self.max_token_len,
1061        )
1062    }
1063}
1064
1065pub struct NodeChildren<'a> {
1066    trie: &'a TokTrie,
1067    current_offset: usize,
1068    end_offset: usize,
1069}
1070
1071impl<'a> Iterator for NodeChildren<'a> {
1072    type Item = &'a TrieNode;
1073
1074    fn next(&mut self) -> Option<Self::Item> {
1075        if self.current_offset < self.end_offset {
1076            let node = &self.trie.nodes[self.current_offset];
1077            self.current_offset += node.subtree_size();
1078            Some(node)
1079        } else {
1080            None
1081        }
1082    }
1083}
1084
1085struct TrieHash {
1086    token_id: u32,
1087    byte: u8,
1088    children: Vec<TrieHash>,
1089}
1090
1091impl TrieHash {
1092    fn new(byte: u8) -> TrieHash {
1093        TrieHash {
1094            token_id: NO_TOKEN,
1095            byte,
1096            children: Vec::new(),
1097        }
1098    }
1099    fn insert(&mut self, word: &[u8], token_id: u32) {
1100        if word.len() == 0 {
1101            // Some tokenizers have duplicate tokens...
1102            // we just override
1103            assert!(self.token_id == NO_TOKEN);
1104            self.token_id = token_id;
1105        } else {
1106            // if self.children.len() == 0x100 {
1107            //     // assert!(self.children[word[0] as usize].byte == word[0]);
1108            //     self.children[word[0] as usize].insert(&word[1..], token_id);
1109            //     return;
1110            // }
1111
1112            for ch in &mut self.children {
1113                if ch.byte == word[0] {
1114                    if word.len() == 1 && ch.token_id != NO_TOKEN {
1115                        // this is duplicate token, proceed with adding a duplicate node
1116                    } else {
1117                        ch.insert(&word[1..], token_id);
1118                        return;
1119                    }
1120                }
1121            }
1122
1123            let mut ch = TrieHash::new(word[0]);
1124            ch.insert(&word[1..], token_id);
1125            self.children.push(ch);
1126
1127            // if it's getting dense, make it full
1128            // for cl100k threshold 60->15 nodes, 50->22, 40->45 30->94
1129            // for llama (32k) 50->5, 40->15
1130            // TODO remove this?
1131            // if self.children.len() > 250 {
1132            //     let mut v2 = (0..=255).map(TrieHash::new).collect::<Vec<_>>();
1133            //     for ch in self.children.drain(..) {
1134            //         let idx = ch.byte as usize;
1135            //         v2[idx] = ch;
1136            //     }
1137            //     self.children = v2;
1138            // }
1139        }
1140    }
1141    fn serialize(&mut self, data: &mut Vec<TrieNode>, num_parents: u8) {
1142        let idx = data.len();
1143        let mut num_ch = self.children.len();
1144        data.push(TrieNode::new(self.byte, self.token_id, num_parents));
1145        //self.children.reverse();
1146        self.children.sort_by_key(|e| e.byte);
1147        for entry in &mut self.children {
1148            num_ch -= 1;
1149            entry.serialize(data, if num_ch == 0 { num_parents + 1 } else { 1 });
1150        }
1151        data[idx].bits2 |= ((data.len() - idx) as u32) << 8;
1152    }
1153}
1154
1155struct FixedRecognizer {
1156    bytes: Vec<u8>,
1157    bytes_ptr: usize,
1158}
1159
1160impl FixedRecognizer {
1161    fn new(bytes: &[u8]) -> FixedRecognizer {
1162        FixedRecognizer {
1163            bytes: bytes.to_vec(),
1164            bytes_ptr: 0,
1165        }
1166    }
1167}
1168
1169impl Recognizer for FixedRecognizer {
1170    fn collapse(&mut self) {}
1171    fn trie_finished(&mut self) {}
1172
1173    fn pop_bytes(&mut self, num: usize) {
1174        self.bytes_ptr -= num;
1175    }
1176
1177    fn try_push_byte(&mut self, byte: u8) -> bool {
1178        if self.bytes_ptr < self.bytes.len() && self.bytes[self.bytes_ptr] == byte {
1179            self.bytes_ptr += 1;
1180            true
1181        } else {
1182            false
1183        }
1184    }
1185}
1186
1187pub struct ApproximateTokEnv {
1188    trie: TokTrie,
1189}
1190
1191impl ApproximateTokEnv {
1192    pub fn new(trie: TokTrie) -> Self {
1193        Self { trie }
1194    }
1195}
1196
1197impl TokenizerEnv for ApproximateTokEnv {
1198    fn tok_trie(&self) -> &TokTrie {
1199        &self.trie
1200    }
1201
1202    fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
1203        self.trie.greedy_tokenize(s)
1204    }
1205
1206    fn tokenize_is_canonical(&self) -> bool {
1207        false
1208    }
1209}