use std::sync::LazyLock;
use bpe::byte_pair_encoding::BytePairEncoding;
use either::Either;
use regex_automata::{
meta::{BuildError, Regex},
util::captures::Captures,
Anchored, Input,
};
pub mod normalizer;
pub use bpe::*;
pub use normalizer::{Normalizable, NormalizedString};
static BPE_CL100K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$";
let pat2 = "\\s+\\s";
let pat3 = "\\s+";
Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)], false)
.expect("valid regex")
});
static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat1 = [
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
"\\p{N}{1,3}",
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
"\\s*[\\r\\n]+",
"\\s+$",
].join("|");
let pat2 = "\\s+\\s";
let pat3 = "\\s+";
Tokenizer::new_lookahead(bpe, &[(&pat1, false), (pat2, true), (pat3, false)], false)
.expect("valid regex")
});
static BPE_VOYAGE3_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_voyage3_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$";
let pat2 = "\\s+\\s";
let pat3 = "\\s+";
Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)], true)
.expect("valid regex")
});
pub struct Tokenizer {
pub bpe: BytePairEncoding,
pub pre: Option<Pretokenizer>,
nfc: bool,
}
pub struct Pretokenizer {
pat: Regex,
lookahead: Vec<bool>,
}
impl Tokenizer {
#[allow(clippy::result_large_err)]
pub fn new(bpe: BytePairEncoding, pat: Option<&str>, nfc: bool) -> Result<Self, BuildError> {
let pre = pat.map(Pretokenizer::new).transpose()?;
Ok(Self { nfc, bpe, pre })
}
#[allow(clippy::result_large_err)]
pub fn new_lookahead(
bpe: BytePairEncoding,
patterns: &[(&str, bool)],
nfc: bool,
) -> Result<Self, BuildError> {
let pre = Some(Pretokenizer::new_lookahead(patterns)?);
Ok(Self { nfc, bpe, pre })
}
pub fn count<'a, I: Normalizable<'a>>(&self, text: I) -> usize {
let text = self.normalize(text);
self.split(text.as_str())
.map(|piece| self.bpe.count(piece.as_bytes()))
.sum()
}
pub fn count_till_limit(&self, text: &NormalizedString, token_limit: usize) -> Option<usize> {
let res: Option<usize> = self.split(text.as_str()).try_fold(0, |consumed, piece| {
self.bpe
.count_till_limit(piece.as_bytes(), token_limit - consumed)
.map(|piece_count| consumed + piece_count)
});
res
}
pub fn encode<'a, I: Normalizable<'a>>(&self, text: I) -> Vec<u32> {
let text: NormalizedString<'_> = self.normalize(text);
self.split(text.as_str())
.flat_map(|piece| self.bpe.encode_via_backtracking(piece.as_bytes()))
.collect()
}
pub fn decode(&self, tokens: &[u32]) -> Option<String> {
String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
}
pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> {
match &self.pre {
Some(pre) => Either::Left(pre.split(text)),
None => Either::Right(std::iter::once(text)),
}
}
pub fn normalize<'a, I: Normalizable<'a>>(&self, text: I) -> NormalizedString<'a> {
text.normalize(self.nfc)
}
}
impl Pretokenizer {
#[allow(clippy::result_large_err)]
fn new(pat: &str) -> Result<Self, BuildError> {
let pat = Regex::new(pat)?;
Ok(Self {
pat,
lookahead: vec![false],
})
}
#[allow(clippy::result_large_err)]
fn new_lookahead(pats: &[(&str, bool)]) -> Result<Self, BuildError> {
let (pats, lookahead): (Vec<_>, _) = pats.iter().copied().unzip();
let pat = Regex::new_many(&pats)?;
Ok(Self { pat, lookahead })
}
pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &'a str> {
Splits {
pat: &self.pat,
lookahead: &self.lookahead,
text,
last: 0,
caps: Captures::matches(self.pat.group_info().clone()),
}
}
}
struct Splits<'a> {
pat: &'a Regex,
lookahead: &'a [bool],
text: &'a str,
last: usize,
caps: Captures,
}
impl<'a> Iterator for Splits<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
let input = Input::new(&self.text[self.last..]).anchored(Anchored::Yes);
self.caps.clear();
self.pat.captures(input, &mut self.caps);
let m = self.caps.get_match()?;
let start = self.last;
let mut end = self.last + m.range().end;
if self.lookahead[m.pattern().as_usize()] {
let last = self.text[start..end]
.chars()
.next_back()
.expect("Expected at least a look-ahead character!");
end -= last.len_utf8();
assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!");
}
self.last = end;
Some(&self.text[start..end])
}
}
pub fn cl100k_base() -> &'static Tokenizer {
&BPE_CL100K_BASE
}
pub fn o200k_base() -> &'static Tokenizer {
&BPE_O200K_BASE
}
pub fn voyage3_base() -> &'static Tokenizer {
&BPE_VOYAGE3_BASE
}
#[cfg(test)]
mod tests {
use bpe::byte_pair_encoding::{create_test_string, select_test_string};
use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton, CoreBPE};
use super::*;
#[test]
fn test_cl100k() {
test_equivalence(cl100k_base(), &cl100k_base_singleton().lock());
}
#[test]
fn test_o200k() {
test_equivalence(o200k_base(), &o200k_base_singleton().lock());
}
#[track_caller]
fn test_equivalence(tok: &Tokenizer, tiktoken: &CoreBPE) {
let text = create_test_string(&tok.bpe, 80_000);
for bytes in [10, 100, 1000, 10_000] {
for _ in 0..32 {
let text = select_test_string(&text, bytes);
let tokens = tok.encode(text);
let tiktokens = tiktoken.encode_ordinary(text).to_vec();
assert_eq!(tokens, tiktokens, "encoding mismatch for {text:?}");
}
}
}
#[test]
fn test_count_till_limit() {
assert_eq!(
cl100k_base().count_till_limit(&cl100k_base().normalize("abc"), 3),
Some(1)
);
assert_eq!(
cl100k_base().count_till_limit(&cl100k_base().normalize("abcabc"), 3),
Some(2)
);
assert_eq!(
cl100k_base().count_till_limit(&cl100k_base().normalize("abcabcabc"), 3),
Some(3)
);
assert_eq!(
cl100k_base().count_till_limit(&cl100k_base().normalize("abcabcabcabc"), 3),
None
);
}
}