bpe_openai/
lib.rs

1use std::sync::LazyLock;
2
3use bpe::byte_pair_encoding::BytePairEncoding;
4use either::Either;
5use regex_automata::{
6    meta::{BuildError, Regex},
7    util::captures::Captures,
8    Anchored, Input,
9};
10
11pub mod normalizer;
12
13pub use bpe::*;
14pub use normalizer::{Normalizable, NormalizedString};
15
16// Note: Below we rewrite the negative look-ahead with a positive pseudo look-ahead.
17// The look-ahead character is dropped from the match by the Pretokenizer iterator.
18// Note: The negative look-ahead `\\s+(?!\\S)` requires `\\s+\\s` but also `\\s+$` to handle end of file without dropping a character!
19
20static BPE_CL100K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
21    let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict"));
22    let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
23    let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$";
24    let pat2 = "\\s+\\s";
25    let pat3 = "\\s+";
26    Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)], false)
27        .expect("valid regex")
28});
29
30static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
31    let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict"));
32    let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
33    let pat1 = [
34        "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
35        "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
36        "\\p{N}{1,3}",
37        " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
38        "\\s*[\\r\\n]+",
39        "\\s+$",
40    ].join("|");
41    let pat2 = "\\s+\\s";
42    let pat3 = "\\s+";
43    Tokenizer::new_lookahead(bpe, &[(&pat1, false), (pat2, true), (pat3, false)], false)
44        .expect("valid regex")
45});
46
47static BPE_VOYAGE3_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
48    let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_voyage3_base.dict"));
49    let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
50    let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$";
51    let pat2 = "\\s+\\s";
52    let pat3 = "\\s+";
53    Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)], true)
54        .expect("valid regex")
55});
56
57/// A byte-pair encoding tokenizer that supports a pre-tokenization regex.
58/// The direct methods on this type pre-tokenize the input text and should
59/// produce the same output as the tiktoken tokenizers. The type gives access
60/// to the regex and underlying byte-pair encoding if needed. Note that using
61/// the byte-pair encoding directly does not take the regex into account and
62/// may result in output that differs from tiktoken.
63pub struct Tokenizer {
64    /// The byte-pair encoding for this tokenizer.
65    pub bpe: BytePairEncoding,
66    /// The pattern regex used to split the input.
67    pub pre: Option<Pretokenizer>,
68    /// Indicates whether the input should be normalized with NFC.
69    nfc: bool,
70}
71
72pub struct Pretokenizer {
73    /// The pattern regex used to split the input.
74    pat: Regex,
75    /// For each pattern in the regex a boolean whether the last character is a look-ahead.
76    lookahead: Vec<bool>,
77}
78
79impl Tokenizer {
80    /// Build a tokenizer with an optional pretokenization regex pattern.
81    #[allow(clippy::result_large_err)]
82    pub fn new(bpe: BytePairEncoding, pat: Option<&str>, nfc: bool) -> Result<Self, BuildError> {
83        let pre = pat.map(Pretokenizer::new).transpose()?;
84        Ok(Self { nfc, bpe, pre })
85    }
86
87    /// Build a tokenizer with pretokenization regex patterns. If the boolean for a pattern is true,
88    /// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character!
89    #[allow(clippy::result_large_err)]
90    pub fn new_lookahead(
91        bpe: BytePairEncoding,
92        patterns: &[(&str, bool)],
93        nfc: bool,
94    ) -> Result<Self, BuildError> {
95        let pre = Some(Pretokenizer::new_lookahead(patterns)?);
96        Ok(Self { nfc, bpe, pre })
97    }
98
99    /// Count the number of tokens produced when encoding the text. Applies pre-tokenization
100    /// before counting.
101    pub fn count<'a, I: Normalizable<'a>>(&self, text: I) -> usize {
102        let text = self.normalize(text);
103        self.split(text.as_str())
104            .map(|piece| self.bpe.count(piece.as_bytes()))
105            .sum()
106    }
107
108    /// Returns the token count iff the total token count stays below the specified token_limit.
109    /// Otherwise, it returns none. This function can be faster than [`Self::count`]` when the
110    /// token limit is much smaller than the provided text. Applies pre-tokenization before counting.
111    ///
112    /// Note: This function assumes that the text is already normalized, so that this function can run
113    /// in roughly O(token_limit) time.
114    pub fn count_till_limit(&self, text: &NormalizedString, token_limit: usize) -> Option<usize> {
115        let res: Option<usize> = self.split(text.as_str()).try_fold(0, |consumed, piece| {
116            self.bpe
117                .count_till_limit(piece.as_bytes(), token_limit - consumed)
118                .map(|piece_count| consumed + piece_count)
119        });
120        res
121    }
122
123    /// Returns the tokens for the encoding of the given text. Applies pre-tokenization before
124    /// encoding.
125    pub fn encode<'a, I: Normalizable<'a>>(&self, text: I) -> Vec<u32> {
126        let text: NormalizedString<'_> = self.normalize(text);
127        self.split(text.as_str())
128            .flat_map(|piece| self.bpe.encode_via_backtracking(piece.as_bytes()))
129            .collect()
130    }
131    /// Returns the text corresponding to the given encoding if it is valid UTF-8. Otherwise,
132    /// returns none.
133    pub fn decode(&self, tokens: &[u32]) -> Option<String> {
134        String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
135    }
136
137    /// Returns an iterator with the text pieces resulting from pre-tokenization. If this
138    /// tokenizer does not have pre-tokenization, the iterator returns the full text.
139    pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> {
140        match &self.pre {
141            Some(pre) => Either::Left(pre.split(text)),
142            None => Either::Right(std::iter::once(text)),
143        }
144    }
145
146    /// Returns the normalized text if the tokenizer requires normalization.
147    /// If the input was already normalized, this function is a noop.
148    pub fn normalize<'a, I: Normalizable<'a>>(&self, text: I) -> NormalizedString<'a> {
149        text.normalize(self.nfc)
150    }
151}
152
153impl Pretokenizer {
154    /// Build a pretokenizer from the given regex pattern.
155    #[allow(clippy::result_large_err)]
156    fn new(pat: &str) -> Result<Self, BuildError> {
157        let pat = Regex::new(pat)?;
158        Ok(Self {
159            pat,
160            lookahead: vec![false],
161        })
162    }
163
164    /// Build a pretokenizer from the given regex patterns. If the boolean for a pattern is true,
165    /// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character!
166    #[allow(clippy::result_large_err)]
167    fn new_lookahead(pats: &[(&str, bool)]) -> Result<Self, BuildError> {
168        let (pats, lookahead): (Vec<_>, _) = pats.iter().copied().unzip();
169        let pat = Regex::new_many(&pats)?;
170        Ok(Self { pat, lookahead })
171    }
172
173    /// Returns an iterator with the text pieces after splitting with the regular expression.
174    pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> {
175        Splits {
176            pat: &self.pat,
177            lookahead: &self.lookahead,
178            text,
179            last: 0,
180            caps: Captures::matches(self.pat.group_info().clone()),
181        }
182    }
183}
184
185/// This is a small wrapper around the regex which emulates the behaviour of look-ahead by
186/// dropping the look-ahead character from the match. The assumption here is that the
187/// second pattern is always a look-ahead pattern, and that just a single character needs
188/// to be dropped. With this little hack, we can keep most of the regex patterns as they are,
189/// but achieve a >3x speedup.
190///
191/// Alternatively, this could have been implemented with capture groups, but those were ~30%
192/// slower than this approach with multiple patterns.
193struct Splits<'a> {
194    pat: &'a Regex,
195    lookahead: &'a [bool],
196    text: &'a str,
197    last: usize,
198    caps: Captures,
199}
200
201impl<'a> Iterator for Splits<'a> {
202    type Item = &'a str;
203
204    fn next(&mut self) -> Option<Self::Item> {
205        let input = Input::new(&self.text[self.last..]).anchored(Anchored::Yes);
206        self.caps.clear();
207        self.pat.captures(input, &mut self.caps);
208        let m = self.caps.get_match()?;
209        let start = self.last;
210        let mut end = self.last + m.range().end;
211        if self.lookahead[m.pattern().as_usize()] {
212            let last = self.text[start..end]
213                .chars()
214                .next_back()
215                .expect("Expected at least a look-ahead character!");
216            end -= last.len_utf8();
217            assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!");
218        }
219        self.last = end;
220        Some(&self.text[start..end])
221    }
222}
223
224pub fn cl100k_base() -> &'static Tokenizer {
225    &BPE_CL100K_BASE
226}
227
228pub fn o200k_base() -> &'static Tokenizer {
229    &BPE_O200K_BASE
230}
231
232pub fn voyage3_base() -> &'static Tokenizer {
233    &BPE_VOYAGE3_BASE
234}
235
236#[cfg(test)]
237mod tests {
238    use bpe::byte_pair_encoding::{create_test_string, select_test_string};
239    use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton, CoreBPE};
240
241    use super::*;
242
243    #[test]
244    fn test_cl100k() {
245        test_equivalence(cl100k_base(), &cl100k_base_singleton().lock());
246    }
247
248    #[test]
249    fn test_o200k() {
250        test_equivalence(o200k_base(), &o200k_base_singleton().lock());
251    }
252
253    #[track_caller]
254    fn test_equivalence(tok: &Tokenizer, tiktoken: &CoreBPE) {
255        let text = create_test_string(&tok.bpe, 80_000);
256        for bytes in [10, 100, 1000, 10_000] {
257            for _ in 0..32 {
258                let text = select_test_string(&text, bytes);
259                let tokens = tok.encode(text);
260                let tiktokens = tiktoken.encode_ordinary(text).to_vec();
261                assert_eq!(tokens, tiktokens, "encoding mismatch for {text:?}");
262            }
263        }
264    }
265
266    #[test]
267    fn test_count_till_limit() {
268        assert_eq!(
269            cl100k_base().count_till_limit(&cl100k_base().normalize("abc"), 3),
270            Some(1)
271        );
272        assert_eq!(
273            cl100k_base().count_till_limit(&cl100k_base().normalize("abcabc"), 3),
274            Some(2)
275        );
276        assert_eq!(
277            cl100k_base().count_till_limit(&cl100k_base().normalize("abcabcabc"), 3),
278            Some(3)
279        );
280        assert_eq!(
281            cl100k_base().count_till_limit(&cl100k_base().normalize("abcabcabcabc"), 3),
282            None
283        );
284    }
285}