bpe_openai/
lib.rs

1use std::sync::LazyLock;
2
3use bpe::byte_pair_encoding::BytePairEncoding;
4use either::Either;
5use regex_automata::{
6    meta::{BuildError, Regex},
7    util::captures::Captures,
8    Anchored, Input,
9};
10
11// Note: Below we rewrite the negative look-ahead with a positive pseudo look-ahead.
12// The look-ahead character is dropped from the match by the Pretokenizer iterator.
13// Note: The negative look-ahead `\\s+(?!\\S)` requires `\\s+\\s` but also `\\s+$` to handle end of file without dropping a character!
14
15static BPE_CL100K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
16    let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict"));
17    let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
18    let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$";
19    let pat2 = "\\s+\\s";
20    let pat3 = "\\s+";
21    Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)])
22        .expect("valid regex")
23});
24
25static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
26    let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict"));
27    let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
28    let pat1 = [
29        "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
30        "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
31        "\\p{N}{1,3}",
32        " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
33        "\\s*[\\r\\n]+",
34        "\\s+$",
35    ].join("|");
36    let pat2 = "\\s+\\s";
37    let pat3 = "\\s+";
38    Tokenizer::new_lookahead(bpe, &[(&pat1, false), (pat2, true), (pat3, false)])
39        .expect("valid regex")
40});
41
42pub use bpe::*;
43
44/// A byte-pair encoding tokenizer that supports a pre-tokenization regex.
45/// The direct methods on this type pre-tokenize the input text and should
46/// produce the same output as the tiktoken tokenizers. The type gives access
47/// to the regex and underlying byte-pair encoding if needed. Note that using
48/// the byte-pair encoding directly does not take the regex into account and
49/// may result in output that differs from tiktoken.
50pub struct Tokenizer {
51    /// The byte-pair encoding for this tokenizer.
52    pub bpe: BytePairEncoding,
53    /// The pattern regex used to split the input.
54    pub pre: Option<Pretokenizer>,
55}
56
57pub struct Pretokenizer {
58    /// The pattern regex used to split the input.
59    pat: Regex,
60    /// For each pattern in the regex a boolean whether the last character is a look-ahead.
61    lookahead: Vec<bool>,
62}
63
64impl Tokenizer {
65    /// Build a tokenizer with an optional pretokenization regex pattern.
66    #[allow(clippy::result_large_err)]
67    pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> Result<Self, BuildError> {
68        let pre = pat.map(Pretokenizer::new).transpose()?;
69        Ok(Self { bpe, pre })
70    }
71
72    /// Build a tokenizer with pretokenization regex patterns. If the boolean for a pattern is true,
73    /// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character!
74    #[allow(clippy::result_large_err)]
75    pub fn new_lookahead(
76        bpe: BytePairEncoding,
77        patterns: &[(&str, bool)],
78    ) -> Result<Self, BuildError> {
79        let pre = Some(Pretokenizer::new_lookahead(patterns)?);
80        Ok(Self { bpe, pre })
81    }
82
83    /// Count the number of tokens produced when encoding the text. Applies pre-tokenization
84    /// before counting.
85    pub fn count(&self, text: &str) -> usize {
86        self.split(text)
87            .map(|piece| self.bpe.count(piece.as_bytes()))
88            .sum()
89    }
90
91    /// Returns the token count iff the total token count stays below the specified token_limit.
92    /// Otherwise, it returns none. This function can be faster than [`Self::count`]` when the
93    /// token limit is much smaller than the provided text. Applies pre-tokenization before counting.
94    pub fn count_till_limit(&self, text: &str, token_limit: usize) -> Option<usize> {
95        self.split(text).try_fold(0, |consumed, piece| {
96            self.bpe
97                .count_till_limit(piece.as_bytes(), token_limit - consumed)
98                .map(|piece_count| consumed + piece_count)
99        })
100    }
101
102    /// Returns the tokens for the encoding of the given text. Applies pre-tokenization before
103    /// encoding.
104    pub fn encode(&self, text: &str) -> Vec<u32> {
105        self.split(text)
106            .flat_map(|piece| self.bpe.encode_via_backtracking(piece.as_bytes()))
107            .collect()
108    }
109    /// Returns the text corresponding to the given encoding if it is valid UTF-8. Otherwise,
110    /// returns none.
111    pub fn decode(&self, tokens: &[u32]) -> Option<String> {
112        String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
113    }
114
115    /// Returns an iterator with the text pieces resulting from pre-tokenization. If this
116    /// tokenizer does not have pre-tokenization, the iterator returns the full text.
117    pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> + 'a {
118        match &self.pre {
119            Some(pre) => Either::Left(pre.split(text)),
120            None => Either::Right(std::iter::once(text)),
121        }
122    }
123}
124
125impl Pretokenizer {
126    /// Build a pretokenizer from the given regex pattern.
127    #[allow(clippy::result_large_err)]
128    fn new(pat: &str) -> Result<Self, BuildError> {
129        let pat = Regex::new(pat)?;
130        Ok(Self {
131            pat,
132            lookahead: vec![false],
133        })
134    }
135
136    /// Build a pretokenizer from the given regex patterns. If the boolean for a pattern is true,
137    /// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character!
138    #[allow(clippy::result_large_err)]
139    fn new_lookahead(pats: &[(&str, bool)]) -> Result<Self, BuildError> {
140        let (pats, lookahead): (Vec<_>, _) = pats.iter().copied().unzip();
141        let pat = Regex::new_many(&pats)?;
142        Ok(Self { pat, lookahead })
143    }
144
145    /// Returns an iterator with the text pieces after splitting with the regular expression.
146    pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> + 'a {
147        Splits {
148            pat: &self.pat,
149            lookahead: &self.lookahead,
150            text,
151            last: 0,
152            caps: Captures::matches(self.pat.group_info().clone()),
153        }
154    }
155}
156
157/// This is a small wrapper around the regex which emulates the behaviour of look-ahead by
158/// dropping the look-ahead character from the match. The assumption here is that the
159/// second pattern is always a look-ahead pattern, and that just a single character needs
160/// to be dropped. With this little hack, we can keep most of the regex patterns as they are,
161/// but achieve a >3x speedup.
162///
163/// Alternatively, this could have been implemented with capture groups, but those were ~30%
164/// slower than this approach with multiple patterns.
165struct Splits<'a> {
166    pat: &'a Regex,
167    lookahead: &'a [bool],
168    text: &'a str,
169    last: usize,
170    caps: Captures,
171}
172
173impl<'a> Iterator for Splits<'a> {
174    type Item = &'a str;
175
176    fn next(&mut self) -> Option<Self::Item> {
177        let input = Input::new(&self.text[self.last..]).anchored(Anchored::Yes);
178        self.caps.clear();
179        self.pat.captures(input, &mut self.caps);
180        let m = self.caps.get_match()?;
181        let start = self.last;
182        let mut end = self.last + m.range().end;
183        if self.lookahead[m.pattern().as_usize()] {
184            let last = self.text[start..end]
185                .chars()
186                .next_back()
187                .expect("Expected at least a look-ahead character!");
188            end -= last.len_utf8();
189            assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!");
190        }
191        self.last = end;
192        Some(&self.text[start..end])
193    }
194}
195
196pub fn cl100k_base() -> &'static Tokenizer {
197    &BPE_CL100K_BASE
198}
199
200pub fn o200k_base() -> &'static Tokenizer {
201    &BPE_O200K_BASE
202}
203
204#[cfg(test)]
205mod tests {
206    use bpe::byte_pair_encoding::{create_test_string, select_test_string};
207    use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton, CoreBPE};
208
209    use super::*;
210
211    #[test]
212    fn test_cl100k() {
213        test_equivalence(cl100k_base(), &cl100k_base_singleton().lock());
214    }
215
216    #[test]
217    fn test_o200k() {
218        test_equivalence(o200k_base(), &o200k_base_singleton().lock());
219    }
220
221    #[track_caller]
222    fn test_equivalence(tok: &Tokenizer, tiktoken: &CoreBPE) {
223        let text = create_test_string(&tok.bpe, 80_000);
224        for bytes in [10, 100, 1000, 10_000] {
225            for _ in 0..32 {
226                let text = select_test_string(&text, bytes);
227                let tokens = tok.encode(text);
228                let tiktokens = tiktoken.encode_ordinary(text).to_vec();
229                assert_eq!(tokens, tiktokens, "encoding mismatch for {text:?}");
230            }
231        }
232    }
233
234    #[test]
235    fn test_count_till_limit() {
236        assert_eq!(cl100k_base().count_till_limit("abc", 3), Some(1));
237        assert_eq!(cl100k_base().count_till_limit("abcabc", 3), Some(2));
238        assert_eq!(cl100k_base().count_till_limit("abcabcabc", 3), Some(3));
239        assert_eq!(cl100k_base().count_till_limit("abcabcabcabc", 3), None);
240    }
241}