1use std::sync::LazyLock;
2
3use bpe::byte_pair_encoding::BytePairEncoding;
4use either::Either;
5use regex_automata::{
6 meta::{BuildError, Regex},
7 util::captures::Captures,
8 Anchored, Input,
9};
10
11static BPE_CL100K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
16 let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict"));
17 let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
18 let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$";
19 let pat2 = "\\s+\\s";
20 let pat3 = "\\s+";
21 Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)])
22 .expect("valid regex")
23});
24
25static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
26 let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict"));
27 let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
28 let pat1 = [
29 "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
30 "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
31 "\\p{N}{1,3}",
32 " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
33 "\\s*[\\r\\n]+",
34 "\\s+$",
35 ].join("|");
36 let pat2 = "\\s+\\s";
37 let pat3 = "\\s+";
38 Tokenizer::new_lookahead(bpe, &[(&pat1, false), (pat2, true), (pat3, false)])
39 .expect("valid regex")
40});
41
42pub use bpe::*;
43
44pub struct Tokenizer {
51 pub bpe: BytePairEncoding,
53 pub pre: Option<Pretokenizer>,
55}
56
57pub struct Pretokenizer {
58 pat: Regex,
60 lookahead: Vec<bool>,
62}
63
64impl Tokenizer {
65 #[allow(clippy::result_large_err)]
67 pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> Result<Self, BuildError> {
68 let pre = pat.map(Pretokenizer::new).transpose()?;
69 Ok(Self { bpe, pre })
70 }
71
72 #[allow(clippy::result_large_err)]
75 pub fn new_lookahead(
76 bpe: BytePairEncoding,
77 patterns: &[(&str, bool)],
78 ) -> Result<Self, BuildError> {
79 let pre = Some(Pretokenizer::new_lookahead(patterns)?);
80 Ok(Self { bpe, pre })
81 }
82
83 pub fn count(&self, text: &str) -> usize {
86 self.split(text)
87 .map(|piece| self.bpe.count(piece.as_bytes()))
88 .sum()
89 }
90
91 pub fn count_till_limit(&self, text: &str, token_limit: usize) -> Option<usize> {
95 self.split(text).try_fold(0, |consumed, piece| {
96 self.bpe
97 .count_till_limit(piece.as_bytes(), token_limit - consumed)
98 .map(|piece_count| consumed + piece_count)
99 })
100 }
101
102 pub fn encode(&self, text: &str) -> Vec<u32> {
105 self.split(text)
106 .flat_map(|piece| self.bpe.encode_via_backtracking(piece.as_bytes()))
107 .collect()
108 }
109 pub fn decode(&self, tokens: &[u32]) -> Option<String> {
112 String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
113 }
114
115 pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> + 'a {
118 match &self.pre {
119 Some(pre) => Either::Left(pre.split(text)),
120 None => Either::Right(std::iter::once(text)),
121 }
122 }
123}
124
125impl Pretokenizer {
126 #[allow(clippy::result_large_err)]
128 fn new(pat: &str) -> Result<Self, BuildError> {
129 let pat = Regex::new(pat)?;
130 Ok(Self {
131 pat,
132 lookahead: vec![false],
133 })
134 }
135
136 #[allow(clippy::result_large_err)]
139 fn new_lookahead(pats: &[(&str, bool)]) -> Result<Self, BuildError> {
140 let (pats, lookahead): (Vec<_>, _) = pats.iter().copied().unzip();
141 let pat = Regex::new_many(&pats)?;
142 Ok(Self { pat, lookahead })
143 }
144
145 pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> + 'a {
147 Splits {
148 pat: &self.pat,
149 lookahead: &self.lookahead,
150 text,
151 last: 0,
152 caps: Captures::matches(self.pat.group_info().clone()),
153 }
154 }
155}
156
157struct Splits<'a> {
166 pat: &'a Regex,
167 lookahead: &'a [bool],
168 text: &'a str,
169 last: usize,
170 caps: Captures,
171}
172
173impl<'a> Iterator for Splits<'a> {
174 type Item = &'a str;
175
176 fn next(&mut self) -> Option<Self::Item> {
177 let input = Input::new(&self.text[self.last..]).anchored(Anchored::Yes);
178 self.caps.clear();
179 self.pat.captures(input, &mut self.caps);
180 let m = self.caps.get_match()?;
181 let start = self.last;
182 let mut end = self.last + m.range().end;
183 if self.lookahead[m.pattern().as_usize()] {
184 let last = self.text[start..end]
185 .chars()
186 .next_back()
187 .expect("Expected at least a look-ahead character!");
188 end -= last.len_utf8();
189 assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!");
190 }
191 self.last = end;
192 Some(&self.text[start..end])
193 }
194}
195
196pub fn cl100k_base() -> &'static Tokenizer {
197 &BPE_CL100K_BASE
198}
199
200pub fn o200k_base() -> &'static Tokenizer {
201 &BPE_O200K_BASE
202}
203
204#[cfg(test)]
205mod tests {
206 use bpe::byte_pair_encoding::{create_test_string, select_test_string};
207 use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton, CoreBPE};
208
209 use super::*;
210
211 #[test]
212 fn test_cl100k() {
213 test_equivalence(cl100k_base(), &cl100k_base_singleton().lock());
214 }
215
216 #[test]
217 fn test_o200k() {
218 test_equivalence(o200k_base(), &o200k_base_singleton().lock());
219 }
220
221 #[track_caller]
222 fn test_equivalence(tok: &Tokenizer, tiktoken: &CoreBPE) {
223 let text = create_test_string(&tok.bpe, 80_000);
224 for bytes in [10, 100, 1000, 10_000] {
225 for _ in 0..32 {
226 let text = select_test_string(&text, bytes);
227 let tokens = tok.encode(text);
228 let tiktokens = tiktoken.encode_ordinary(text).to_vec();
229 assert_eq!(tokens, tiktokens, "encoding mismatch for {text:?}");
230 }
231 }
232 }
233
234 #[test]
235 fn test_count_till_limit() {
236 assert_eq!(cl100k_base().count_till_limit("abc", 3), Some(1));
237 assert_eq!(cl100k_base().count_till_limit("abcabc", 3), Some(2));
238 assert_eq!(cl100k_base().count_till_limit("abcabcabc", 3), Some(3));
239 assert_eq!(cl100k_base().count_till_limit("abcabcabcabc", 3), None);
240 }
241}