1use crate::config::EstimationOptions;
7
8#[inline(always)]
11fn is_cjk(c: char) -> bool {
12 matches!(c,
13 '\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{3000}'..='\u{303F}' | '\u{FF00}'..='\u{FFEF}' | '\u{30A0}'..='\u{30FF}' | '\u{2E80}'..='\u{2EFF}' | '\u{31C0}'..='\u{31EF}' | '\u{3200}'..='\u{32FF}' | '\u{3300}'..='\u{33FF}' | '\u{AC00}'..='\u{D7AF}' | '\u{1100}'..='\u{11FF}' | '\u{3130}'..='\u{318F}' | '\u{A960}'..='\u{A97F}' | '\u{D7B0}'..='\u{D7FF}' )
28}
29
30#[inline(always)]
31fn is_punctuation(c: char) -> bool {
32 matches!(
33 c,
34 '.' | ','
35 | '!'
36 | '?'
37 | ';'
38 | '\''
39 | '"'
40 | '\u{201E}' | '\u{201C}' | '\u{201D}' | '\u{2018}' | '\u{2019}' | '-'
46 | '('
47 | ')'
48 | '{'
49 | '}'
50 | '['
51 | ']'
52 | '<'
53 | '>'
54 | ':'
55 | '/'
56 | '\\'
57 | '|'
58 | '@'
59 | '#'
60 | '$'
61 | '%'
62 | '^'
63 | '&'
64 | '*'
65 | '+'
66 | '='
67 | '`'
68 | '~'
69 )
70}
71
72#[inline(always)]
73fn is_alphanumeric_latin(c: char) -> bool {
74 c.is_ascii_alphanumeric()
75 || matches!(c, '\u{00C0}'..='\u{00D6}' | '\u{00D8}'..='\u{00F6}' | '\u{00F8}'..='\u{00FF}')
76}
77
78#[derive(Clone, Copy, PartialEq, Eq)]
81enum SplitKind {
82 Whitespace,
83 Punctuation,
84 Word,
85}
86
87#[inline(always)]
88fn split_classify(c: char) -> SplitKind {
89 if c.is_whitespace() {
90 SplitKind::Whitespace
91 } else if c.is_ascii() {
92 if is_punctuation(c) {
93 SplitKind::Punctuation
94 } else {
95 SplitKind::Word
96 }
97 } else if is_punctuation(c) {
98 SplitKind::Punctuation
99 } else {
100 SplitKind::Word
101 }
102}
103
104#[inline(always)]
107fn score_word(
108 byte_len: usize,
109 char_count: usize,
110 has_cjk: bool,
111 all_alphanum: bool,
112 all_digits: bool,
113 lang_cpt: Option<f64>,
114 default_cpt: f64,
115) -> usize {
116 if has_cjk {
117 return char_count;
118 }
119 if all_digits {
120 return 1;
121 }
122 if byte_len <= 3 {
123 return 1;
124 }
125 if all_alphanum || lang_cpt.is_some() {
126 let cpt = lang_cpt.unwrap_or(default_cpt);
127 return (byte_len as f64 / cpt).ceil() as usize;
128 }
129 char_count
130}
131
132#[inline(always)]
133fn score_punctuation(byte_len: usize) -> usize {
134 if byte_len <= 3 {
135 1
136 } else {
137 (byte_len + 1) / 2
138 }
139}
140
141pub fn estimate_token_count(text: &str) -> usize {
154 estimate_token_count_with_options(text, &EstimationOptions::default())
155}
156
157pub fn estimate_token_count_with_options(text: &str, options: &EstimationOptions) -> usize {
169 if text.is_empty() {
170 return 0;
171 }
172
173 let mut total_tokens: usize = 0;
174
175 let mut seg_split_kind = SplitKind::Word;
176 let mut seg_byte_len: usize = 0;
177 let mut seg_char_count: usize = 0;
178 let mut seg_has_cjk = false;
179 let mut seg_all_alphanum = true;
180 let mut seg_all_digits = true;
181 let mut seg_lang_cpt: Option<f64> = None;
182 let mut in_segment = false;
183
184 let default_cpt = options.default_chars_per_token;
185
186 macro_rules! flush {
187 () => {
188 total_tokens += match seg_split_kind {
189 SplitKind::Whitespace => 0,
190 SplitKind::Punctuation => score_punctuation(seg_byte_len),
191 SplitKind::Word => score_word(
192 seg_byte_len,
193 seg_char_count,
194 seg_has_cjk,
195 seg_all_alphanum,
196 seg_all_digits,
197 seg_lang_cpt,
198 default_cpt,
199 ),
200 };
201 };
202 }
203
204 for c in text.chars() {
205 let kind = split_classify(c);
206
207 if in_segment && kind == seg_split_kind {
208 seg_byte_len += c.len_utf8();
209 seg_char_count += 1;
210 if kind == SplitKind::Word {
211 if is_cjk(c) {
212 seg_has_cjk = true;
213 }
214 if !is_alphanumeric_latin(c) {
215 seg_all_alphanum = false;
216 seg_all_digits = false;
217 } else if !c.is_ascii_digit() {
218 seg_all_digits = false;
219 }
220 if seg_lang_cpt.is_none() {
221 seg_lang_cpt = detect_language_cpt(c, options);
222 }
223 }
224 } else {
225 if in_segment {
226 flush!();
227 }
228 seg_split_kind = kind;
229 seg_byte_len = c.len_utf8();
230 seg_char_count = 1;
231 seg_has_cjk = kind == SplitKind::Word && is_cjk(c);
232 seg_all_alphanum = kind != SplitKind::Word || is_alphanumeric_latin(c);
233 seg_all_digits = kind == SplitKind::Word && c.is_ascii_digit();
234 seg_lang_cpt = if kind == SplitKind::Word {
235 detect_language_cpt(c, options)
236 } else {
237 None
238 };
239 in_segment = true;
240 }
241 }
242
243 if in_segment {
244 flush!();
245 }
246
247 total_tokens
248}
249
250#[inline(always)]
251fn detect_language_cpt(c: char, options: &EstimationOptions) -> Option<f64> {
252 for lc in &options.language_configs {
253 if (lc.matcher)(c) {
254 return Some(lc.chars_per_token);
255 }
256 }
257 None
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::config::DEFAULT_CHARS_PER_TOKEN;
264
265 #[test]
266 fn empty_string() {
267 assert_eq!(estimate_token_count(""), 0);
268 }
269
270 #[test]
271 fn pure_whitespace() {
272 assert_eq!(estimate_token_count(" "), 0);
273 assert_eq!(estimate_token_count("\t\n"), 0);
274 }
275
276 #[test]
277 fn pure_cjk() {
278 assert_eq!(estimate_token_count("你好世界"), 4);
279 }
280
281 #[test]
282 fn pure_punctuation() {
283 assert_eq!(estimate_token_count("..."), 1);
284 assert_eq!(estimate_token_count(","), 1);
285 }
286
287 #[test]
288 fn numeric_string() {
289 assert_eq!(estimate_token_count("12345"), 1);
290 assert_eq!(estimate_token_count("3.14"), 3);
291 }
292
293 #[test]
294 fn short_words() {
295 assert_eq!(estimate_token_count("Hi Bob"), 2);
296 }
297
298 #[test]
299 fn mixed_content() {
300 let count = estimate_token_count("Hello, world!");
301 assert!(count >= 2, "Expected at least 2 tokens, got {count}");
302 }
303
304 #[test]
305 fn german_text() {
306 let count = estimate_token_count("Ärgerlich");
307 assert!(count > 0);
308 }
309
310 #[test]
311 fn french_text() {
312 let count = estimate_token_count("résumé");
313 assert!(count > 0);
314 }
315
316 #[test]
317 fn english_sentence() {
318 let count = estimate_token_count("The quick brown fox jumps over the lazy dog");
319 assert!(count >= 9, "Expected at least 9 tokens, got {count}");
320 }
321
322 #[test]
323 fn default_chars_per_token_constant() {
324 assert_eq!(DEFAULT_CHARS_PER_TOKEN, 6.0);
325 }
326
327 #[test]
328 fn underscore_identifiers() {
329 let count = estimate_token_count("process_items");
330 assert_eq!(count, 13); }
332
333 #[test]
334 fn custom_options() {
335 let opts = EstimationOptions {
336 default_chars_per_token: 4.0,
337 language_configs: vec![],
338 };
339 let count = estimate_token_count_with_options("abcdefgh", &opts);
340 assert_eq!(count, 2);
342 }
343}