Skip to main content

codec_rs/
longest_match.rs

1// SPDX-License-Identifier: MIT
2//! Vocab-only longest-prefix-match tokenizer.
3//!
4//! Walks input left-to-right, emitting the ID of the longest vocab
5//! fragment that matches at each position. Suitable for canonical-IR /
6//! synthetic test maps. NOT BPE-correct for real model vocabs — use
7//! [`crate::BPETokenizer`] for those.
8
9use std::collections::HashMap;
10
11use crate::map::TokenizerMap;
12use crate::tokenize::{BPETokenizer, ITokenizer};
13
14/// Vocab-only fallback tokenizer.
15pub struct LongestMatchTokenizer {
16    id: String,
17    fragment_to_id: HashMap<String, u32>,
18    max_fragment_length: usize,
19    special_fragment_to_id: Vec<(String, u32)>,
20}
21
22impl LongestMatchTokenizer {
23    pub fn new(map: &TokenizerMap) -> Self {
24        let id = map.id.clone();
25        let mut max_len = 1usize;
26        let mut fragment_to_id: HashMap<String, u32> = HashMap::new();
27
28        if let Some(vocab) = &map.vocab {
29            for (frag, &fid) in vocab {
30                if frag.is_empty() {
31                    continue;
32                }
33                fragment_to_id.insert(frag.clone(), fid);
34                if frag.len() > max_len {
35                    max_len = frag.len();
36                }
37            }
38        }
39        if let Some(tokens) = &map.tokens {
40            for (id_str, frag) in tokens {
41                if frag.is_empty() {
42                    continue;
43                }
44                let Ok(fid) = id_str.parse::<u32>() else {
45                    continue;
46                };
47                fragment_to_id.insert(frag.clone(), fid);
48                if frag.len() > max_len {
49                    max_len = frag.len();
50                }
51            }
52        }
53
54        let mut special_fragment_to_id: Vec<(String, u32)> = Vec::new();
55        if let Some(specials) = &map.special_tokens {
56            for (name, &sid) in specials {
57                special_fragment_to_id.push((name.clone(), sid));
58                if !name.starts_with('<') {
59                    special_fragment_to_id.push((format!("<|{name}|>"), sid));
60                }
61            }
62        }
63
64        Self {
65            id,
66            fragment_to_id,
67            max_fragment_length: max_len,
68            special_fragment_to_id,
69        }
70    }
71
72    pub fn encode(&self, text: &str) -> Vec<u32> {
73        // Operate on bytes to mirror .NET String.Substring/CompareOrdinal
74        // semantics in a way that's safe for arbitrary content. We only
75        // emit boundaries on UTF-8 char boundaries, but advance by raw
76        // bytes for the longest-prefix walk.
77        let bytes = text.as_bytes();
78        let mut output: Vec<u32> = Vec::new();
79        let mut pos = 0usize;
80        let n = bytes.len();
81
82        while pos < n {
83            // Specials win.
84            let mut consumed = false;
85            for (frag, sid) in &self.special_fragment_to_id {
86                let fb = frag.as_bytes();
87                if pos + fb.len() <= n && &bytes[pos..pos + fb.len()] == fb {
88                    output.push(*sid);
89                    pos += fb.len();
90                    consumed = true;
91                    break;
92                }
93            }
94            if consumed {
95                continue;
96            }
97
98            let remaining = n - pos;
99            let try_up_to = self.max_fragment_length.min(remaining);
100            let mut matched_id: Option<u32> = None;
101            let mut matched_len = 0usize;
102            for len in (1..=try_up_to).rev() {
103                // Only attempt if the slice is on a char boundary
104                // (otherwise it can't equal any real fragment string).
105                if !text.is_char_boundary(pos + len) || !text.is_char_boundary(pos) {
106                    continue;
107                }
108                let candidate = &text[pos..pos + len];
109                if let Some(&fid) = self.fragment_to_id.get(candidate) {
110                    matched_id = Some(fid);
111                    matched_len = len;
112                    break;
113                }
114            }
115
116            match matched_id {
117                None => {
118                    output.push(0); // UNK
119                    // Advance by one char (or one byte if mid-codepoint, defensive).
120                    let advance = next_char_boundary(text, pos).max(1);
121                    pos += advance;
122                }
123                Some(fid) => {
124                    output.push(fid);
125                    pos += matched_len;
126                }
127            }
128        }
129        output
130    }
131}
132
133fn next_char_boundary(s: &str, pos: usize) -> usize {
134    let bytes = s.as_bytes();
135    let n = bytes.len();
136    let mut i = pos + 1;
137    while i < n && !s.is_char_boundary(i) {
138        i += 1;
139    }
140    i - pos
141}
142
143impl ITokenizer for LongestMatchTokenizer {
144    fn id(&self) -> &str {
145        &self.id
146    }
147    fn encode(&self, text: &str) -> Vec<u32> {
148        Self::encode(self, text)
149    }
150}
151
152/// Top-level tokenizer factory.
153pub struct Tokenize;
154
155impl Tokenize {
156    /// Build the right tokenizer for the map. [`BPETokenizer`] when the
157    /// map has BPE data; otherwise [`LongestMatchTokenizer`].
158    pub fn pick(map: &TokenizerMap) -> Box<dyn ITokenizer> {
159        if BPETokenizer::supports(map) {
160            // We've already verified support; unwrap is safe.
161            Box::new(BPETokenizer::new(map).expect("supports() succeeded"))
162        } else {
163            Box::new(LongestMatchTokenizer::new(map))
164        }
165    }
166
167    /// One-shot encode using [`Tokenize::pick`].
168    pub fn encode(map: &TokenizerMap, text: &str) -> Vec<u32> {
169        Self::pick(map).encode(text)
170    }
171}