1use std::sync::LazyLock;
2
3use bpe::byte_pair_encoding::BytePairEncoding;
4use either::Either;
5use regex_automata::{
6 meta::{BuildError, Regex},
7 util::captures::Captures,
8 Anchored, Input,
9};
10
11pub mod normalizer;
12
13pub use bpe::*;
14pub use normalizer::{Normalizable, NormalizedString};
15
16static BPE_CL100K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
21 let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict"));
22 let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
23 let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$";
24 let pat2 = "\\s+\\s";
25 let pat3 = "\\s+";
26 Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)], false)
27 .expect("valid regex")
28});
29
30static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
31 let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict"));
32 let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
33 let pat1 = [
34 "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
35 "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
36 "\\p{N}{1,3}",
37 " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
38 "\\s*[\\r\\n]+",
39 "\\s+$",
40 ].join("|");
41 let pat2 = "\\s+\\s";
42 let pat3 = "\\s+";
43 Tokenizer::new_lookahead(bpe, &[(&pat1, false), (pat2, true), (pat3, false)], false)
44 .expect("valid regex")
45});
46
47static BPE_VOYAGE3_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
48 let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_voyage3_base.dict"));
49 let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
50 let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$";
51 let pat2 = "\\s+\\s";
52 let pat3 = "\\s+";
53 Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)], true)
54 .expect("valid regex")
55});
56
57pub struct Tokenizer {
64 pub bpe: BytePairEncoding,
66 pub pre: Option<Pretokenizer>,
68 nfc: bool,
70}
71
72pub struct Pretokenizer {
73 pat: Regex,
75 lookahead: Vec<bool>,
77}
78
79impl Tokenizer {
80 #[allow(clippy::result_large_err)]
82 pub fn new(bpe: BytePairEncoding, pat: Option<&str>, nfc: bool) -> Result<Self, BuildError> {
83 let pre = pat.map(Pretokenizer::new).transpose()?;
84 Ok(Self { nfc, bpe, pre })
85 }
86
87 #[allow(clippy::result_large_err)]
90 pub fn new_lookahead(
91 bpe: BytePairEncoding,
92 patterns: &[(&str, bool)],
93 nfc: bool,
94 ) -> Result<Self, BuildError> {
95 let pre = Some(Pretokenizer::new_lookahead(patterns)?);
96 Ok(Self { nfc, bpe, pre })
97 }
98
99 pub fn count<'a, I: Normalizable<'a>>(&self, text: I) -> usize {
102 let text = self.normalize(text);
103 self.split(text.as_str())
104 .map(|piece| self.bpe.count(piece.as_bytes()))
105 .sum()
106 }
107
108 pub fn count_till_limit(&self, text: &NormalizedString, token_limit: usize) -> Option<usize> {
115 let res: Option<usize> = self.split(text.as_str()).try_fold(0, |consumed, piece| {
116 self.bpe
117 .count_till_limit(piece.as_bytes(), token_limit - consumed)
118 .map(|piece_count| consumed + piece_count)
119 });
120 res
121 }
122
123 pub fn encode<'a, I: Normalizable<'a>>(&self, text: I) -> Vec<u32> {
126 let text: NormalizedString<'_> = self.normalize(text);
127 self.split(text.as_str())
128 .flat_map(|piece| self.bpe.encode_via_backtracking(piece.as_bytes()))
129 .collect()
130 }
131 pub fn decode(&self, tokens: &[u32]) -> Option<String> {
134 String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
135 }
136
137 pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> {
140 match &self.pre {
141 Some(pre) => Either::Left(pre.split(text)),
142 None => Either::Right(std::iter::once(text)),
143 }
144 }
145
146 pub fn normalize<'a, I: Normalizable<'a>>(&self, text: I) -> NormalizedString<'a> {
149 text.normalize(self.nfc)
150 }
151}
152
153impl Pretokenizer {
154 #[allow(clippy::result_large_err)]
156 fn new(pat: &str) -> Result<Self, BuildError> {
157 let pat = Regex::new(pat)?;
158 Ok(Self {
159 pat,
160 lookahead: vec![false],
161 })
162 }
163
164 #[allow(clippy::result_large_err)]
167 fn new_lookahead(pats: &[(&str, bool)]) -> Result<Self, BuildError> {
168 let (pats, lookahead): (Vec<_>, _) = pats.iter().copied().unzip();
169 let pat = Regex::new_many(&pats)?;
170 Ok(Self { pat, lookahead })
171 }
172
173 pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> {
175 Splits {
176 pat: &self.pat,
177 lookahead: &self.lookahead,
178 text,
179 last: 0,
180 caps: Captures::matches(self.pat.group_info().clone()),
181 }
182 }
183}
184
185struct Splits<'a> {
194 pat: &'a Regex,
195 lookahead: &'a [bool],
196 text: &'a str,
197 last: usize,
198 caps: Captures,
199}
200
201impl<'a> Iterator for Splits<'a> {
202 type Item = &'a str;
203
204 fn next(&mut self) -> Option<Self::Item> {
205 let input = Input::new(&self.text[self.last..]).anchored(Anchored::Yes);
206 self.caps.clear();
207 self.pat.captures(input, &mut self.caps);
208 let m = self.caps.get_match()?;
209 let start = self.last;
210 let mut end = self.last + m.range().end;
211 if self.lookahead[m.pattern().as_usize()] {
212 let last = self.text[start..end]
213 .chars()
214 .next_back()
215 .expect("Expected at least a look-ahead character!");
216 end -= last.len_utf8();
217 assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!");
218 }
219 self.last = end;
220 Some(&self.text[start..end])
221 }
222}
223
224pub fn cl100k_base() -> &'static Tokenizer {
225 &BPE_CL100K_BASE
226}
227
228pub fn o200k_base() -> &'static Tokenizer {
229 &BPE_O200K_BASE
230}
231
232pub fn voyage3_base() -> &'static Tokenizer {
233 &BPE_VOYAGE3_BASE
234}
235
236#[cfg(test)]
237mod tests {
238 use bpe::byte_pair_encoding::{create_test_string, select_test_string};
239 use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton, CoreBPE};
240
241 use super::*;
242
243 #[test]
244 fn test_cl100k() {
245 test_equivalence(cl100k_base(), &cl100k_base_singleton().lock());
246 }
247
248 #[test]
249 fn test_o200k() {
250 test_equivalence(o200k_base(), &o200k_base_singleton().lock());
251 }
252
253 #[track_caller]
254 fn test_equivalence(tok: &Tokenizer, tiktoken: &CoreBPE) {
255 let text = create_test_string(&tok.bpe, 80_000);
256 for bytes in [10, 100, 1000, 10_000] {
257 for _ in 0..32 {
258 let text = select_test_string(&text, bytes);
259 let tokens = tok.encode(text);
260 let tiktokens = tiktoken.encode_ordinary(text).to_vec();
261 assert_eq!(tokens, tiktokens, "encoding mismatch for {text:?}");
262 }
263 }
264 }
265
266 #[test]
267 fn test_count_till_limit() {
268 assert_eq!(
269 cl100k_base().count_till_limit(&cl100k_base().normalize("abc"), 3),
270 Some(1)
271 );
272 assert_eq!(
273 cl100k_base().count_till_limit(&cl100k_base().normalize("abcabc"), 3),
274 Some(2)
275 );
276 assert_eq!(
277 cl100k_base().count_till_limit(&cl100k_base().normalize("abcabcabc"), 3),
278 Some(3)
279 );
280 assert_eq!(
281 cl100k_base().count_till_limit(&cl100k_base().normalize("abcabcabcabc"), 3),
282 None
283 );
284 }
285}