use fancy_regex::Regex;
use std::collections::HashMap;
use std::sync::OnceLock;
include!("claude_data.rs");
const PAT_STR: &str = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
const NO_RANK: u32 = u32::MAX;
static PATTERN_REGEX: OnceLock<Regex> = OnceLock::new();
static STRING_ENCODER: OnceLock<HashMap<&'static str, u32>> = OnceLock::new();
static BINARY_ENCODER: OnceLock<HashMap<Vec<u8>, u32>> = OnceLock::new();
pub fn count_tokens(text: &str) -> usize {
if text.is_empty() {
return 0;
}
encode_internal(text).len()
}
pub fn encode(text: &str) -> Vec<u32> {
if text.is_empty() {
return Vec::new();
}
encode_internal(text)
}
fn encode_internal(text: &str) -> Vec<u32> {
if text.encode_utf16().take(10).count() < 10 {
if let Some(&rank) = string_encoder().get(text) {
return vec![rank];
}
}
let mut result = Vec::new();
for piece in pattern_regex().find_iter(text) {
let piece = piece
.expect("Claude pre-tokenizer regex should not fail")
.as_str();
if let Some(&rank) = string_encoder().get(piece) {
result.push(rank);
continue;
}
result.extend(byte_pair_merge(piece.as_bytes()));
}
result
}
fn pattern_regex() -> &'static Regex {
PATTERN_REGEX.get_or_init(|| Regex::new(PAT_STR).expect("Claude pre-tokenizer regex"))
}
fn string_encoder() -> &'static HashMap<&'static str, u32> {
STRING_ENCODER.get_or_init(|| STRING_ENCODER_ENTRIES.iter().copied().collect())
}
fn binary_encoder() -> &'static HashMap<Vec<u8>, u32> {
BINARY_ENCODER.get_or_init(|| {
BINARY_ENCODER_ENTRIES
.iter()
.map(|(bytes, rank)| (bytes.to_vec(), *rank))
.collect()
})
}
fn byte_pair_merge(piece: &[u8]) -> Vec<u32> {
let len = piece.len();
let mut starts = Vec::with_capacity(len + 1);
let mut ranks = Vec::with_capacity(len + 1);
for i in 0..=len {
starts.push(i);
if i < len.saturating_sub(1) {
ranks.push(get_rank_for_slice(&piece[i..i + 2]));
} else {
ranks.push(NO_RANK);
}
}
while starts.len() > 1 {
let mut min_rank = NO_RANK;
let mut min_idx = None;
for (idx, &rank) in ranks.iter().take(ranks.len().saturating_sub(1)).enumerate() {
if rank < min_rank {
min_rank = rank;
min_idx = Some(idx);
}
}
let Some(min_idx) = min_idx else { break };
if min_rank == NO_RANK {
break;
}
starts.remove(min_idx + 1);
ranks.remove(min_idx);
if min_idx < ranks.len() {
ranks[min_idx] = get_rank(piece, &starts, min_idx);
}
if min_idx > 0 {
ranks[min_idx - 1] = get_rank(piece, &starts, min_idx - 1);
}
}
let mut output = Vec::with_capacity(starts.len().saturating_sub(1));
for window in starts.windows(2) {
let rank = get_rank_for_slice(&piece[window[0]..window[1]]);
if rank != NO_RANK {
output.push(rank);
}
}
output
}
fn get_rank(piece: &[u8], starts: &[usize], start_index: usize) -> u32 {
let Some(&pair_start) = starts.get(start_index) else {
return NO_RANK;
};
let Some(&pair_end) = starts.get(start_index + 2) else {
return NO_RANK;
};
get_rank_for_slice(&piece[pair_start..pair_end])
}
fn get_rank_for_slice(slice: &[u8]) -> u32 {
if let Ok(as_string) = std::str::from_utf8(slice) {
if let Some(&rank) = string_encoder().get(as_string) {
return rank;
}
}
binary_encoder().get(slice).copied().unwrap_or(NO_RANK)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_input_has_no_tokens() {
assert!(encode("").is_empty());
assert_eq!(count_tokens(""), 0);
}
#[test]
fn direct_ascii_token() {
assert_eq!(encode("!"), vec![5]);
assert_eq!(count_tokens("!"), 1);
}
}