Skip to main content

aft_tokenizer/
claude.rs

1use fancy_regex::Regex;
2use std::collections::HashMap;
3use std::sync::OnceLock;
4
5// Generated by build.rs from ai-tokenizer Claude encoding; vendored in-tree so
6// the crate builds from a clean clone without fetching `tmp/ai-tokenizer-pkg/`.
7// To refresh: see crates/aft-tokenizer/build.rs.
8include!("claude_data.rs");
9
10const PAT_STR: &str = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
11const NO_RANK: u32 = u32::MAX;
12
13static PATTERN_REGEX: OnceLock<Regex> = OnceLock::new();
14static STRING_ENCODER: OnceLock<HashMap<&'static str, u32>> = OnceLock::new();
15static BINARY_ENCODER: OnceLock<HashMap<Vec<u8>, u32>> = OnceLock::new();
16
17/// Count Claude tokens in `text`.
18pub fn count_tokens(text: &str) -> usize {
19    if text.is_empty() {
20        return 0;
21    }
22    encode_internal(text).len()
23}
24
25/// Encode `text` into Claude token IDs.
26pub fn encode(text: &str) -> Vec<u32> {
27    if text.is_empty() {
28        return Vec::new();
29    }
30    encode_internal(text)
31}
32
33fn encode_internal(text: &str) -> Vec<u32> {
34    // ai-tokenizer checks direct rank before pre-tokenizing for short strings.
35    if text.encode_utf16().take(10).count() < 10 {
36        if let Some(&rank) = string_encoder().get(text) {
37            return vec![rank];
38        }
39    }
40
41    let mut result = Vec::new();
42    for piece in pattern_regex().find_iter(text) {
43        let piece = piece
44            .expect("Claude pre-tokenizer regex should not fail")
45            .as_str();
46
47        if let Some(&rank) = string_encoder().get(piece) {
48            result.push(rank);
49            continue;
50        }
51
52        result.extend(byte_pair_merge(piece.as_bytes()));
53    }
54    result
55}
56
57fn pattern_regex() -> &'static Regex {
58    PATTERN_REGEX.get_or_init(|| Regex::new(PAT_STR).expect("Claude pre-tokenizer regex"))
59}
60
61fn string_encoder() -> &'static HashMap<&'static str, u32> {
62    STRING_ENCODER.get_or_init(|| STRING_ENCODER_ENTRIES.iter().copied().collect())
63}
64
65fn binary_encoder() -> &'static HashMap<Vec<u8>, u32> {
66    BINARY_ENCODER.get_or_init(|| {
67        BINARY_ENCODER_ENTRIES
68            .iter()
69            .map(|(bytes, rank)| (bytes.to_vec(), *rank))
70            .collect()
71    })
72}
73
74fn byte_pair_merge(piece: &[u8]) -> Vec<u32> {
75    let len = piece.len();
76    let mut starts = Vec::with_capacity(len + 1);
77    let mut ranks = Vec::with_capacity(len + 1);
78
79    for i in 0..=len {
80        starts.push(i);
81        if i < len.saturating_sub(1) {
82            ranks.push(get_rank_for_slice(&piece[i..i + 2]));
83        } else {
84            ranks.push(NO_RANK);
85        }
86    }
87
88    while starts.len() > 1 {
89        let mut min_rank = NO_RANK;
90        let mut min_idx = None;
91        for (idx, &rank) in ranks.iter().take(ranks.len().saturating_sub(1)).enumerate() {
92            if rank < min_rank {
93                min_rank = rank;
94                min_idx = Some(idx);
95            }
96        }
97
98        let Some(min_idx) = min_idx else { break };
99        if min_rank == NO_RANK {
100            break;
101        }
102
103        starts.remove(min_idx + 1);
104        ranks.remove(min_idx);
105        if min_idx < ranks.len() {
106            ranks[min_idx] = get_rank(piece, &starts, min_idx);
107        }
108        if min_idx > 0 {
109            ranks[min_idx - 1] = get_rank(piece, &starts, min_idx - 1);
110        }
111    }
112
113    let mut output = Vec::with_capacity(starts.len().saturating_sub(1));
114    for window in starts.windows(2) {
115        let rank = get_rank_for_slice(&piece[window[0]..window[1]]);
116        if rank != NO_RANK {
117            output.push(rank);
118        }
119    }
120    output
121}
122
123fn get_rank(piece: &[u8], starts: &[usize], start_index: usize) -> u32 {
124    let Some(&pair_start) = starts.get(start_index) else {
125        return NO_RANK;
126    };
127    let Some(&pair_end) = starts.get(start_index + 2) else {
128        return NO_RANK;
129    };
130    get_rank_for_slice(&piece[pair_start..pair_end])
131}
132
133fn get_rank_for_slice(slice: &[u8]) -> u32 {
134    if let Ok(as_string) = std::str::from_utf8(slice) {
135        if let Some(&rank) = string_encoder().get(as_string) {
136            return rank;
137        }
138    }
139
140    binary_encoder().get(slice).copied().unwrap_or(NO_RANK)
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn empty_input_has_no_tokens() {
149        assert!(encode("").is_empty());
150        assert_eq!(count_tokens(""), 0);
151    }
152
153    #[test]
154    fn direct_ascii_token() {
155        assert_eq!(encode("!"), vec![5]);
156        assert_eq!(count_tokens("!"), 1);
157    }
158}