1use fancy_regex::Regex;
2use std::collections::HashMap;
3use std::sync::OnceLock;
4
5include!("claude_data.rs");
9
10const PAT_STR: &str = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
11const NO_RANK: u32 = u32::MAX;
12
13static PATTERN_REGEX: OnceLock<Regex> = OnceLock::new();
14static STRING_ENCODER: OnceLock<HashMap<&'static str, u32>> = OnceLock::new();
15static BINARY_ENCODER: OnceLock<HashMap<Vec<u8>, u32>> = OnceLock::new();
16
17pub fn count_tokens(text: &str) -> usize {
19 if text.is_empty() {
20 return 0;
21 }
22 encode_internal(text).len()
23}
24
25pub fn encode(text: &str) -> Vec<u32> {
27 if text.is_empty() {
28 return Vec::new();
29 }
30 encode_internal(text)
31}
32
33fn encode_internal(text: &str) -> Vec<u32> {
34 if text.encode_utf16().take(10).count() < 10 {
36 if let Some(&rank) = string_encoder().get(text) {
37 return vec![rank];
38 }
39 }
40
41 let mut result = Vec::new();
42 for piece in pattern_regex().find_iter(text) {
43 let piece = piece
44 .expect("Claude pre-tokenizer regex should not fail")
45 .as_str();
46
47 if let Some(&rank) = string_encoder().get(piece) {
48 result.push(rank);
49 continue;
50 }
51
52 result.extend(byte_pair_merge(piece.as_bytes()));
53 }
54 result
55}
56
57fn pattern_regex() -> &'static Regex {
58 PATTERN_REGEX.get_or_init(|| Regex::new(PAT_STR).expect("Claude pre-tokenizer regex"))
59}
60
61fn string_encoder() -> &'static HashMap<&'static str, u32> {
62 STRING_ENCODER.get_or_init(|| STRING_ENCODER_ENTRIES.iter().copied().collect())
63}
64
65fn binary_encoder() -> &'static HashMap<Vec<u8>, u32> {
66 BINARY_ENCODER.get_or_init(|| {
67 BINARY_ENCODER_ENTRIES
68 .iter()
69 .map(|(bytes, rank)| (bytes.to_vec(), *rank))
70 .collect()
71 })
72}
73
74fn byte_pair_merge(piece: &[u8]) -> Vec<u32> {
75 let len = piece.len();
76 let mut starts = Vec::with_capacity(len + 1);
77 let mut ranks = Vec::with_capacity(len + 1);
78
79 for i in 0..=len {
80 starts.push(i);
81 if i < len.saturating_sub(1) {
82 ranks.push(get_rank_for_slice(&piece[i..i + 2]));
83 } else {
84 ranks.push(NO_RANK);
85 }
86 }
87
88 while starts.len() > 1 {
89 let mut min_rank = NO_RANK;
90 let mut min_idx = None;
91 for (idx, &rank) in ranks.iter().take(ranks.len().saturating_sub(1)).enumerate() {
92 if rank < min_rank {
93 min_rank = rank;
94 min_idx = Some(idx);
95 }
96 }
97
98 let Some(min_idx) = min_idx else { break };
99 if min_rank == NO_RANK {
100 break;
101 }
102
103 starts.remove(min_idx + 1);
104 ranks.remove(min_idx);
105 if min_idx < ranks.len() {
106 ranks[min_idx] = get_rank(piece, &starts, min_idx);
107 }
108 if min_idx > 0 {
109 ranks[min_idx - 1] = get_rank(piece, &starts, min_idx - 1);
110 }
111 }
112
113 let mut output = Vec::with_capacity(starts.len().saturating_sub(1));
114 for window in starts.windows(2) {
115 let rank = get_rank_for_slice(&piece[window[0]..window[1]]);
116 if rank != NO_RANK {
117 output.push(rank);
118 }
119 }
120 output
121}
122
123fn get_rank(piece: &[u8], starts: &[usize], start_index: usize) -> u32 {
124 let Some(&pair_start) = starts.get(start_index) else {
125 return NO_RANK;
126 };
127 let Some(&pair_end) = starts.get(start_index + 2) else {
128 return NO_RANK;
129 };
130 get_rank_for_slice(&piece[pair_start..pair_end])
131}
132
133fn get_rank_for_slice(slice: &[u8]) -> u32 {
134 if let Ok(as_string) = std::str::from_utf8(slice) {
135 if let Some(&rank) = string_encoder().get(as_string) {
136 return rank;
137 }
138 }
139
140 binary_encoder().get(slice).copied().unwrap_or(NO_RANK)
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn empty_input_has_no_tokens() {
149 assert!(encode("").is_empty());
150 assert_eq!(count_tokens(""), 0);
151 }
152
153 #[test]
154 fn direct_ascii_token() {
155 assert_eq!(encode("!"), vec![5]);
156 assert_eq!(count_tokens("!"), 1);
157 }
158}