1use crate::types::Message;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TokenEstimate {
11 pub tokens: usize,
12 pub characters: usize,
13 pub words: usize,
14 pub method: EstimationMethod,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
19pub enum EstimationMethod {
20 CharacterRatio,
22 WordBased,
24 TikToken,
26}
27
28pub fn rough_token_count_estimation(content: &str, bytes_per_token: f64) -> usize {
35 (content.len() as f64 / bytes_per_token).round() as usize
36}
37
38pub fn bytes_per_token_for_file_type(file_extension: &str) -> f64 {
43 match file_extension {
44 "json" | "jsonl" | "jsonc" => 2.0,
45 _ => 4.0,
46 }
47}
48
49pub fn rough_token_count_estimation_for_file_type(content: &str, file_extension: &str) -> usize {
53 rough_token_count_estimation(content, bytes_per_token_for_file_type(file_extension))
54}
55
56pub fn rough_token_count_estimation_for_message(message: &Message) -> usize {
59 rough_token_count_estimation_for_content(&message.content)
60}
61
62pub fn rough_token_count_estimation_for_content(content: &str) -> usize {
65 if content.is_empty() {
66 return 0;
67 }
68 rough_token_count_estimation(content, 4.0)
69}
70
71pub fn rough_token_count_estimation_for_messages(messages: &[Message]) -> usize {
74 messages
75 .iter()
76 .map(|msg| rough_token_count_estimation_for_message(msg))
77 .sum()
78}
79
80pub fn estimate_tokens_characters(text: &str) -> TokenEstimate {
87 let characters = text.len();
88 let words = text.split_whitespace().count();
89
90 let ratio = if text.contains("```") {
93 5.5
95 } else if words > 0 {
96 let avg_word_len = characters as f64 / words as f64;
97 if avg_word_len > 8.0 {
98 5.0
100 } else if avg_word_len < 3.0 {
101 3.5
103 } else {
104 4.0
105 }
106 } else {
107 4.0
108 };
109
110 let tokens = (characters as f64 / ratio).ceil() as usize;
111
112 TokenEstimate {
113 tokens,
114 characters,
115 words,
116 method: EstimationMethod::CharacterRatio,
117 }
118}
119
120pub fn estimate_tokens_words(text: &str) -> TokenEstimate {
122 let words = text.split_whitespace().count();
123 let characters = text.len();
124
125 let tokens = (words as f64 / 1.3).ceil() as usize;
127
128 TokenEstimate {
129 tokens,
130 characters,
131 words,
132 method: EstimationMethod::WordBased,
133 }
134}
135
136pub fn estimate_tokens(text: &str) -> TokenEstimate {
138 let char_estimate = estimate_tokens_characters(text);
139 let word_estimate = estimate_tokens_words(text);
140
141 let tokens = (char_estimate.tokens + word_estimate.tokens) / 2;
143
144 TokenEstimate {
145 tokens,
146 characters: char_estimate.characters,
147 words: char_estimate.words,
148 method: EstimationMethod::CharacterRatio,
149 }
150}
151
152pub fn estimate_message_tokens<T: MessageContent>(messages: &[T]) -> usize {
154 messages
155 .iter()
156 .map(|m| {
157 let content = m.content();
158 let role_overhead = 4;
160 estimate_tokens(content).tokens + role_overhead
161 })
162 .sum()
163}
164
165pub fn estimate_conversation(conversation: &str) -> TokenEstimate {
167 let turns = conversation
169 .matches("User:")
170 .count()
171 .max(conversation.matches("Assistant:").count());
172
173 let turn_overhead = turns * 10;
175
176 let base = estimate_tokens(conversation);
177 TokenEstimate {
178 tokens: base.tokens + turn_overhead,
179 characters: base.characters,
180 words: base.words,
181 method: base.method,
182 }
183}
184
185pub fn estimate_tool_definitions(tools: &[ToolDefinition]) -> usize {
187 tools
188 .iter()
189 .map(|t| {
190 let name_tokens = estimate_tokens(&t.name).tokens;
191 let desc_tokens = t
192 .description
193 .as_ref()
194 .map(|d| estimate_tokens(d).tokens)
195 .unwrap_or(0);
196 let params_tokens = estimate_tokens(&t.input_schema).tokens;
197 name_tokens + desc_tokens + params_tokens + 20 })
199 .sum()
200}
201
202pub trait MessageContent {
204 fn content(&self) -> &str;
205}
206
207impl MessageContent for String {
208 fn content(&self) -> &str {
209 self.as_str()
210 }
211}
212
213impl MessageContent for &str {
214 fn content(&self) -> &str {
215 self
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct ChatMessage {
222 pub role: String,
223 pub content: String,
224}
225
226impl MessageContent for ChatMessage {
227 fn content(&self) -> &str {
228 &self.content
229 }
230}
231
232#[derive(Debug, Clone)]
234pub struct ToolDefinition {
235 pub name: String,
236 pub description: Option<String>,
237 pub input_schema: String,
238}
239
240pub fn calculate_padding(input_tokens: usize, max_tokens: usize, context_limit: usize) -> usize {
243 let available_for_input = context_limit.saturating_sub(max_tokens);
245 if input_tokens < available_for_input {
246 available_for_input.saturating_sub(input_tokens)
247 } else {
248 0
249 }
250}
251
252pub fn fits_in_context(content_tokens: usize, max_tokens: usize, context_limit: usize) -> bool {
254 content_tokens + max_tokens <= context_limit
255}
256
257pub mod encoding {
259 pub const CHARS_PER_TOKEN_EN: f64 = 4.0;
261 pub const CHARS_PER_TOKEN_CODE: f64 = 5.5;
262 pub const CHARS_PER_TOKEN_CJK: f64 = 2.0; pub fn is_code(text: &str) -> bool {
266 let code_indicators = [
267 "```", "function", "class ", "def ", "const ", "let ", "var ", "import ",
268 ];
269 code_indicators.iter().any(|i| text.contains(i))
270 }
271
272 pub fn is_cjk(text: &str) -> bool {
274 text.chars().any(|c| {
275 (c >= '\u{4E00}' && c <= '\u{9FFF}') || (c >= '\u{3040}' && c <= '\u{309F}') || (c >= '\u{30A0}' && c <= '\u{30FF}') || (c >= '\u{AC00}' && c <= '\u{D7AF}') })
280 }
281
282 pub fn chars_per_token(text: &str) -> f64 {
284 if is_code(text) {
285 super::encoding::CHARS_PER_TOKEN_CODE
286 } else if is_cjk(text) {
287 super::encoding::CHARS_PER_TOKEN_CJK
288 } else {
289 super::encoding::CHARS_PER_TOKEN_EN
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::types::MessageRole;
298
299 #[test]
304 fn test_rough_token_count_estimation() {
305 assert_eq!(rough_token_count_estimation("Hello world", 4.0), 3);
307 assert_eq!(rough_token_count_estimation(&"a".repeat(100), 4.0), 25);
309 }
310
311 #[test]
312 fn test_bytes_per_token_for_file_type() {
313 assert_eq!(bytes_per_token_for_file_type("json"), 2.0);
314 assert_eq!(bytes_per_token_for_file_type("jsonl"), 2.0);
315 assert_eq!(bytes_per_token_for_file_type("rs"), 4.0);
316 assert_eq!(bytes_per_token_for_file_type("txt"), 4.0);
317 }
318
319 #[test]
320 fn test_rough_token_count_estimation_for_file_type() {
321 assert_eq!(
323 rough_token_count_estimation_for_file_type(&"a".repeat(100), "json"),
324 50
325 );
326 assert_eq!(
328 rough_token_count_estimation_for_file_type(&"a".repeat(100), "rs"),
329 25
330 );
331 }
332
333 #[test]
334 fn test_rough_token_count_estimation_for_content() {
335 assert_eq!(rough_token_count_estimation_for_content(""), 0);
336 assert_eq!(rough_token_count_estimation_for_content("Hello"), 1);
338 }
339
340 #[test]
341 fn test_rough_token_count_estimation_for_message() {
342 let msg = crate::types::Message {
343 role: MessageRole::User,
344 content: "Hello world".to_string(),
345 ..Default::default()
346 };
347 assert_eq!(rough_token_count_estimation_for_message(&msg), 3);
349 }
350
351 #[test]
352 fn test_rough_token_count_estimation_for_messages() {
353 let messages = vec![
354 crate::types::Message {
355 role: MessageRole::User,
356 content: "Hello".to_string(),
357 ..Default::default()
358 },
359 crate::types::Message {
360 role: MessageRole::Assistant,
361 content: "Hi there".to_string(),
362 ..Default::default()
363 },
364 ];
365 assert_eq!(rough_token_count_estimation_for_messages(&messages), 3);
369 }
370
371 #[test]
376 fn test_estimate_tokens_characters() {
377 let result = estimate_tokens_characters("Hello, world!");
378 assert!(result.tokens >= 3);
379 assert_eq!(result.characters, 13);
380 }
381
382 #[test]
383 fn test_estimate_tokens_words() {
384 let result = estimate_tokens_words("Hello world this is a test");
385 assert!(result.tokens > 0);
386 assert_eq!(result.words, 6);
387 }
388
389 #[test]
390 fn test_estimate_tokens() {
391 let result = estimate_tokens("The quick brown fox jumps over the lazy dog");
392 assert!(result.tokens > 0);
393 }
394
395 #[test]
396 fn test_estimate_conversation() {
397 let conv = "User: Hello\nAssistant: Hi there!\nUser: How are you?";
398 let result = estimate_conversation(conv);
399 assert!(result.tokens > 0);
400 }
401
402 #[test]
403 fn test_estimate_tool_definitions() {
404 let tools = vec![ToolDefinition {
405 name: "Read".to_string(),
406 description: Some("Read a file".to_string()),
407 input_schema: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#
408 .to_string(),
409 }];
410 let tokens = estimate_tool_definitions(&tools);
411 assert!(tokens > 0);
412 }
413
414 #[test]
415 fn test_calculate_padding() {
416 assert_eq!(calculate_padding(1000, 500, 2000), 500);
417 assert_eq!(calculate_padding(1500, 500, 2000), 0);
418 }
419
420 #[test]
421 fn test_fits_in_context() {
422 assert!(fits_in_context(1000, 500, 2000));
423 assert!(!fits_in_context(1600, 500, 2000));
424 }
425
426 #[test]
427 fn test_encoding_chars_per_token() {
428 assert_eq!(
429 encoding::chars_per_token("Hello world"),
430 encoding::CHARS_PER_TOKEN_EN
431 );
432 assert_eq!(
433 encoding::chars_per_token("function test() {}"),
434 encoding::CHARS_PER_TOKEN_CODE
435 );
436 }
437
438 #[test]
439 fn test_is_code() {
440 assert!(encoding::is_code("function foo() { return 1; }"));
441 assert!(!encoding::is_code("Hello world"));
442 }
443
444 #[test]
445 fn test_is_cjk() {
446 assert!(encoding::is_cjk("你好世界"));
447 assert!(!encoding::is_cjk("Hello world"));
448 }
449
450 #[test]
451 fn test_message_content_trait() {
452 let msg = ChatMessage {
453 role: "user".to_string(),
454 content: "Hello".to_string(),
455 };
456 assert_eq!(msg.content(), "Hello");
457 }
458}