use regex::Regex;
pub trait PreTokenizer: Send + Sync {
fn next_match(&self, text: &str, pos: usize) -> Option<(usize, usize)>;
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum FastPath {
None,
Cl100k,
O200k,
Qwen2,
Deepseek,
}
pub struct RegexPreTokenizer {
regex: Regex,
fast: FastPath,
}
impl RegexPreTokenizer {
pub(crate) fn new(pattern: &str, fast: FastPath) -> Self {
Self {
regex: Regex::new(pattern).expect("invalid regex pattern"),
fast,
}
}
}
impl PreTokenizer for RegexPreTokenizer {
#[inline]
fn next_match(&self, text: &str, pos: usize) -> Option<(usize, usize)> {
let bytes = text.as_bytes();
let fast = match self.fast {
FastPath::Cl100k => cl100k_ascii_next(bytes, pos, 3),
FastPath::Qwen2 => cl100k_ascii_next(bytes, pos, 1),
FastPath::O200k => o200k_ascii_next(bytes, pos),
FastPath::Deepseek => deepseek_ascii_next(bytes, pos),
FastPath::None => None,
};
if let Some(r) = fast {
return Some(r);
}
let mat = self.regex.find_at(text, pos)?;
let start = mat.start();
let end = adjust_whitespace_end(bytes, start, mat.end());
Some((start, end))
}
}
#[inline]
fn take_crlf(b: &[u8], mut k: usize) -> usize {
while k < b.len() && (b[k] == b'\r' || b[k] == b'\n') {
k += 1;
}
k
}
#[inline]
fn ascii_num_punct(b: &[u8], i: usize, max_digits: usize) -> Option<(usize, usize)> {
let n = b.len();
let c0 = b[i];
if c0.is_ascii_digit() {
let mut j = i;
let mut k = 0;
while j < n && k < max_digits && b[j] < 0x80 && b[j].is_ascii_digit() {
j += 1;
k += 1;
}
if k < max_digits && j < n && b[j] >= 0x80 {
return None;
}
return Some((i, j));
}
let mut j = i;
if c0 == b' ' {
match b.get(i + 1) {
Some(&c1)
if c1 < 0x80
&& !is_ascii_ws(c1)
&& !c1.is_ascii_alphabetic()
&& !c1.is_ascii_digit() =>
{
j = i + 1;
}
_ => return None,
}
}
let cj = b[j];
if cj < 0x80 && !is_ascii_ws(cj) && !cj.is_ascii_alphabetic() && !cj.is_ascii_digit() {
let mut k = j;
while k < n
&& b[k] < 0x80
&& !is_ascii_ws(b[k])
&& !b[k].is_ascii_alphabetic()
&& !b[k].is_ascii_digit()
{
k += 1;
}
if k < n && b[k] >= 0x80 {
return None;
}
k = take_crlf(b, k);
return Some((i, k));
}
None
}
#[inline]
fn cl100k_ascii_next(b: &[u8], i: usize, max_digits: usize) -> Option<(usize, usize)> {
let n = b.len();
if i >= n {
return None;
}
let c0 = b[i];
if c0 >= 0x80 {
return None;
}
if c0 == b'\''
&& let Some(len) = match_contraction(b, i)
{
return Some((i, i + len));
}
if c0 != b'\r'
&& c0 != b'\n'
&& !c0.is_ascii_alphabetic()
&& !c0.is_ascii_digit()
&& let Some(&c1) = b.get(i + 1)
&& c1 < 0x80
&& c1.is_ascii_alphabetic()
{
let mut j = i + 1;
while j < n && b[j] < 0x80 && b[j].is_ascii_alphabetic() {
j += 1;
}
if j < n && b[j] >= 0x80 {
return None;
}
return Some((i, j));
}
if c0.is_ascii_alphabetic() {
let mut j = i;
while j < n && b[j] < 0x80 && b[j].is_ascii_alphabetic() {
j += 1;
}
if j < n && b[j] >= 0x80 {
return None;
}
return Some((i, j));
}
ascii_num_punct(b, i, max_digits)
}
#[inline]
fn o200k_ascii_next(b: &[u8], i: usize) -> Option<(usize, usize)> {
let n = b.len();
if i >= n {
return None;
}
let c0 = b[i];
if c0 >= 0x80 {
return None;
}
let p = if c0.is_ascii_alphabetic() {
i
} else if c0 != b'\r' && c0 != b'\n' && !c0.is_ascii_digit() {
match b.get(i + 1) {
Some(&c1) if c1 < 0x80 && c1.is_ascii_alphabetic() => i + 1,
_ => return ascii_num_punct(b, i, 3),
}
} else {
return ascii_num_punct(b, i, 3);
};
let mut q = p;
while q < n && b[q] < 0x80 && b[q].is_ascii_uppercase() {
q += 1;
}
if q < n && b[q] >= 0x80 {
return None;
}
let letters_end = if q > p {
if q < n && b[q].is_ascii_lowercase() {
let mut r = q;
while r < n && b[r] < 0x80 && b[r].is_ascii_lowercase() {
r += 1;
}
if r < n && b[r] >= 0x80 {
return None;
}
r
} else {
q
}
} else {
let mut r = p;
while r < n && b[r] < 0x80 && b[r].is_ascii_lowercase() {
r += 1;
}
if r < n && b[r] >= 0x80 {
return None;
}
r
};
let mut end = letters_end;
if end < n
&& b[end] == b'\''
&& let Some(len) = match_contraction(b, end)
{
end += len;
}
Some((i, end))
}
#[inline]
fn deepseek_ascii_next(b: &[u8], i: usize) -> Option<(usize, usize)> {
let n = b.len();
if i >= n {
return None;
}
let c0 = b[i];
if c0 >= 0x80 {
return None; }
if c0.is_ascii_digit() {
let mut j = i;
let mut k = 0;
while j < n && k < 3 && b[j].is_ascii_digit() {
j += 1;
k += 1;
}
if k < 3 && j < n && b[j] >= 0x80 {
return None; }
return Some((i, j));
}
if c0.is_ascii_punctuation() {
if let Some(&c1) = b.get(i + 1)
&& c1 < 0x80
&& c1.is_ascii_alphabetic()
{
let mut j = i + 1;
while j < n && b[j].is_ascii_alphabetic() {
j += 1;
}
return Some((i, j));
}
let mut k = i;
while k < n && b[k] < 0x80 && b[k].is_ascii_punctuation() {
k += 1;
}
if k < n && b[k] >= 0x80 {
return None; }
k = take_crlf(b, k);
return Some((i, k));
}
if c0.is_ascii_alphabetic() {
let mut j = i;
while j < n && b[j].is_ascii_alphabetic() {
j += 1;
}
if j < n && b[j] >= 0x80 {
return None; }
return Some((i, j));
}
if c0 == b' ' {
match b.get(i + 1) {
Some(&c1) if c1 >= 0x80 => return None, Some(&c1) if c1.is_ascii_alphabetic() => {
let mut j = i + 1;
while j < n && b[j].is_ascii_alphabetic() {
j += 1;
}
if j < n && b[j] >= 0x80 {
return None;
}
return Some((i, j));
}
Some(&c1) if c1.is_ascii_punctuation() => {
let mut k = i + 1;
while k < n && b[k] < 0x80 && b[k].is_ascii_punctuation() {
k += 1;
}
if k < n && b[k] >= 0x80 {
return None;
}
k = take_crlf(b, k);
return Some((i, k));
}
_ => return None,
}
}
None
}
#[inline]
fn match_contraction(b: &[u8], i: usize) -> Option<usize> {
let c1 = b.get(i + 1).copied()?.to_ascii_lowercase();
match c1 {
b's' | b't' | b'm' | b'd' => Some(2),
b'r' if b.get(i + 2).map(|c| c.to_ascii_lowercase()) == Some(b'e') => Some(3),
b'v' if b.get(i + 2).map(|c| c.to_ascii_lowercase()) == Some(b'e') => Some(3),
b'l' if b.get(i + 2).map(|c| c.to_ascii_lowercase()) == Some(b'l') => Some(3),
_ => None,
}
}
#[inline]
fn adjust_whitespace_end(bytes: &[u8], start: usize, end: usize) -> usize {
if end - start <= 1 || end >= bytes.len() {
return end;
}
let first = bytes[start];
if first > 0x20 && first < 0x7F {
return end;
}
let piece = &bytes[start..end];
if piece.iter().all(|&b| is_ascii_ws(b)) {
let next = bytes[end];
if is_ascii_ws(next) {
return end;
}
return end - 1;
}
let matched = std::str::from_utf8(&bytes[start..end]).unwrap();
if !matched.chars().all(|c| c.is_whitespace()) {
return end;
}
let tail = std::str::from_utf8(&bytes[end..]).unwrap();
let next_char = match tail.chars().next() {
Some(c) => c,
None => return end,
};
if next_char.is_whitespace() {
return end;
}
let last_len = matched.chars().next_back().unwrap().len_utf8();
if end - last_len <= start {
return end;
}
end - last_len
}
#[inline(always)]
const fn is_ascii_ws(b: u8) -> bool {
matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0B | 0x0C)
}
#[cfg(test)]
mod tests {
use super::*;
fn collect_matches(pt: &dyn PreTokenizer, text: &str) -> Vec<(usize, usize)> {
let mut result = vec![];
let mut pos = 0;
while let Some((start, end)) = pt.next_match(text, pos) {
result.push((start, end));
pos = end;
}
result
}
use crate::encoding::{
CL100K_PATTERN, DEEPSEEK_V3_PATTERN, O200K_PATTERN, P50K_PATTERN, QWEN2_PATTERN,
};
fn v2_collect_matches(pattern: &str, text: &str) -> Vec<(usize, usize)> {
let regex = Regex::new(pattern).unwrap();
let bytes = text.as_bytes();
let mut result = vec![];
let mut pos = 0;
while pos < text.len() {
let mat = match regex.find_at(text, pos) {
Some(m) => m,
None => break,
};
let start = mat.start();
let end = adjust_whitespace_end(bytes, start, mat.end());
result.push((start, end));
pos = end;
}
result
}
#[test]
fn test_cl100k_english() {
let pt = RegexPreTokenizer::new(CL100K_PATTERN, FastPath::Cl100k);
let v2 = v2_collect_matches(CL100K_PATTERN, "Hello, world!");
let v3 = collect_matches(&pt, "Hello, world!");
assert_eq!(v2, v3);
}
#[test]
fn test_cl100k_cjk() {
let pt = RegexPreTokenizer::new(CL100K_PATTERN, FastPath::Cl100k);
let text = "你好世界";
let v2 = v2_collect_matches(CL100K_PATTERN, text);
let v3 = collect_matches(&pt, text);
assert_eq!(v2, v3);
}
#[test]
fn test_cl100k_contractions() {
let pt = RegexPreTokenizer::new(CL100K_PATTERN, FastPath::Cl100k);
let text = "I'm don't they're we've she'll it'd";
let v2 = v2_collect_matches(CL100K_PATTERN, text);
let v3 = collect_matches(&pt, text);
assert_eq!(v2, v3);
}
#[test]
fn test_o200k_english() {
let pt = RegexPreTokenizer::new(O200K_PATTERN, FastPath::O200k);
let text = "Hello, world! CamelCase mixedScript123";
let v2 = v2_collect_matches(O200K_PATTERN, text);
let v3 = collect_matches(&pt, text);
assert_eq!(v2, v3);
}
#[test]
fn test_p50k_english() {
let pt = RegexPreTokenizer::new(P50K_PATTERN, FastPath::None);
let text = "Hello world, I'm testing!";
let v2 = v2_collect_matches(P50K_PATTERN, text);
let v3 = collect_matches(&pt, text);
assert_eq!(v2, v3);
}
#[test]
fn test_empty_input() {
let pt = RegexPreTokenizer::new(CL100K_PATTERN, FastPath::Cl100k);
assert_eq!(collect_matches(&pt, ""), vec![]);
}
#[test]
fn test_only_whitespace() {
let pt = RegexPreTokenizer::new(CL100K_PATTERN, FastPath::Cl100k);
let text = " \n \t ";
let v2 = v2_collect_matches(CL100K_PATTERN, text);
let v3 = collect_matches(&pt, text);
assert_eq!(v2, v3);
}
#[test]
fn test_emoji() {
let pt = RegexPreTokenizer::new(CL100K_PATTERN, FastPath::Cl100k);
let text = "🎉🚀💡";
let v2 = v2_collect_matches(CL100K_PATTERN, text);
let v3 = collect_matches(&pt, text);
assert_eq!(v2, v3);
}
#[test]
fn test_mixed_script() {
let pt = RegexPreTokenizer::new(CL100K_PATTERN, FastPath::Cl100k);
let text = "Hello 你好 World 🌍";
let v2 = v2_collect_matches(CL100K_PATTERN, text);
let v3 = collect_matches(&pt, text);
assert_eq!(v2, v3);
}
#[test]
fn test_adjust_whitespace_single_byte() {
assert_eq!(adjust_whitespace_end(b"a b", 0, 1), 1);
}
#[test]
fn test_adjust_whitespace_at_end_of_input() {
assert_eq!(adjust_whitespace_end(b" ", 0, 2), 2);
}
#[test]
fn test_adjust_whitespace_non_ws_piece() {
assert_eq!(adjust_whitespace_end(b"hello world", 0, 5), 5);
}
#[test]
fn test_adjust_whitespace_trim_before_nonws() {
let bytes = b" x";
assert_eq!(adjust_whitespace_end(bytes, 0, 2), 1);
}
#[test]
fn test_adjust_whitespace_no_trim_before_ws() {
let bytes = b" ";
assert_eq!(adjust_whitespace_end(bytes, 0, 2), 2);
}
#[test]
fn test_adjust_whitespace_unicode_slow_path() {
let input = "\u{3000}\u{3000}x";
let bytes = input.as_bytes();
assert_eq!(adjust_whitespace_end(bytes, 0, 6), 3);
}
#[test]
fn test_adjust_whitespace_unicode_followed_by_unicode_ws() {
let input = "\u{3000}\u{3000}\u{3000}";
let bytes = input.as_bytes();
assert_eq!(adjust_whitespace_end(bytes, 0, 6), 6);
}
#[test]
fn test_adjust_whitespace_single_multibyte_ws_before_nonws() {
let input = "\u{3000}x";
let bytes = input.as_bytes();
assert_eq!(adjust_whitespace_end(bytes, 0, 3), 3);
}
#[test]
fn test_all_patterns_match_v2() {
let texts = vec![
"Hello, world!",
"你好世界",
"fn main() { }",
" hello ",
"line1\nline2\n",
"café résumé",
"100% of $1,000",
"a@b.com",
" \t\n ",
"",
"a",
"hello world! 你好 🚀 test 123",
];
for &(pattern, fast) in &[
(CL100K_PATTERN, FastPath::Cl100k),
(O200K_PATTERN, FastPath::O200k),
(P50K_PATTERN, FastPath::None),
] {
let pt = RegexPreTokenizer::new(pattern, fast);
for text in &texts {
let v2 = v2_collect_matches(pattern, text);
let v3 = collect_matches(&pt, text);
assert_eq!(v2, v3, "mismatch for pattern / text: {text:?}");
}
}
}
proptest::proptest! {
#![proptest_config(proptest::prelude::ProptestConfig::with_cases(20000))]
#[test]
fn prop_cl100k_fast_matches_regex(text in ".*") {
let pt = RegexPreTokenizer::new(CL100K_PATTERN, FastPath::Cl100k);
let fast = collect_matches(&pt, &text);
let reference = v2_collect_matches(CL100K_PATTERN, &text);
proptest::prop_assert_eq!(fast, reference, "fast/regex mismatch for {:?}", text);
}
#[test]
fn prop_cl100k_fast_matches_regex_ascii(text in "[ -~ \t\r\n]*") {
let pt = RegexPreTokenizer::new(CL100K_PATTERN, FastPath::Cl100k);
let fast = collect_matches(&pt, &text);
let reference = v2_collect_matches(CL100K_PATTERN, &text);
proptest::prop_assert_eq!(fast, reference, "fast/regex mismatch for {:?}", text);
}
#[test]
fn prop_o200k_fast_matches_regex(text in ".*") {
let pt = RegexPreTokenizer::new(O200K_PATTERN, FastPath::O200k);
let fast = collect_matches(&pt, &text);
let reference = v2_collect_matches(O200K_PATTERN, &text);
proptest::prop_assert_eq!(fast, reference, "fast/regex mismatch for {:?}", text);
}
#[test]
fn prop_o200k_fast_matches_regex_ascii(text in "[ -~ \t\r\n]*") {
let pt = RegexPreTokenizer::new(O200K_PATTERN, FastPath::O200k);
let fast = collect_matches(&pt, &text);
let reference = v2_collect_matches(O200K_PATTERN, &text);
proptest::prop_assert_eq!(fast, reference, "fast/regex mismatch for {:?}", text);
}
#[test]
fn prop_qwen2_fast_matches_regex(text in ".*") {
let pt = RegexPreTokenizer::new(QWEN2_PATTERN, FastPath::Qwen2);
let fast = collect_matches(&pt, &text);
let reference = v2_collect_matches(QWEN2_PATTERN, &text);
proptest::prop_assert_eq!(fast, reference, "fast/regex mismatch for {:?}", text);
}
#[test]
fn prop_qwen2_fast_matches_regex_ascii(text in "[ -~ \t\r\n]*") {
let pt = RegexPreTokenizer::new(QWEN2_PATTERN, FastPath::Qwen2);
let fast = collect_matches(&pt, &text);
let reference = v2_collect_matches(QWEN2_PATTERN, &text);
proptest::prop_assert_eq!(fast, reference, "fast/regex mismatch for {:?}", text);
}
#[test]
fn prop_deepseek_fast_matches_regex(text in ".*") {
let pt = RegexPreTokenizer::new(DEEPSEEK_V3_PATTERN, FastPath::Deepseek);
let fast = collect_matches(&pt, &text);
let reference = v2_collect_matches(DEEPSEEK_V3_PATTERN, &text);
proptest::prop_assert_eq!(fast, reference, "fast/regex mismatch for {:?}", text);
}
#[test]
fn prop_deepseek_fast_matches_regex_ascii(text in "[ -~ \t\r\n]*") {
let pt = RegexPreTokenizer::new(DEEPSEEK_V3_PATTERN, FastPath::Deepseek);
let fast = collect_matches(&pt, &text);
let reference = v2_collect_matches(DEEPSEEK_V3_PATTERN, &text);
proptest::prop_assert_eq!(fast, reference, "fast/regex mismatch for {:?}", text);
}
}
}