openai_mock/utils/
token_counting.rs

1use tiktoken_rs::{cl100k_base, p50k_base, o200k_base};
2use crate::models::completion::Usage;
3
4pub struct ChatMessage {
5    pub role: String,
6    pub content: String,
7}
8
9pub struct TokenCounter {
10    encoding: tiktoken_rs::CoreBPE,
11}
12
13impl TokenCounter {
14    pub fn new(model: &str) -> Result<Self, Box<dyn std::error::Error>> {
15        let encoding = match model {
16            "gpt-4" | "gpt-3.5-turbo" | "text-embedding-ada-002" => {
17                cl100k_base()?
18            },
19            "gpt-4o" | "gpt-4o-mini" => {
20                o200k_base()?
21            },
22            "text-davinci-002" | "text-davinci-003" => {
23                p50k_base()?
24            },
25            _ => cl100k_base()? // default to cl100k_base
26        };
27
28        Ok(Self { encoding })
29    }
30
31    pub fn count_tokens(&self, text: &str) -> u32 {
32        self.encoding.encode_with_special_tokens(text).len() as u32
33    }
34
35    pub fn count_messages_tokens(&self, messages: &[ChatMessage]) -> u32 {
36        // Add 3 tokens for each message for ChatML format
37        let per_message_tokens = 3;
38
39        messages.iter().map(|msg| {
40            self.count_tokens(&msg.content) +
41            self.count_tokens(&msg.role) +
42            per_message_tokens
43        }).sum()
44    }
45
46    /// Creates a Usage struct with token counts for prompt and completion
47    pub fn calculate_usage(&self, prompt: &str, completion: &str) -> Usage {
48        let prompt_tokens = self.count_tokens(prompt);
49        let completion_tokens = self.count_tokens(completion);
50
51        Usage {
52            prompt_tokens,
53            completion_tokens,
54            total_tokens: prompt_tokens + completion_tokens,
55        }
56    }
57
58    /// Truncates text to approximately fit within max_tokens
59    pub fn truncate_to_tokens(&self, text: &str, max_tokens: u32) -> String {
60        let tokens = self.encoding.encode_with_special_tokens(text);
61        if tokens.len() as u32 <= max_tokens {
62            return text.to_string();
63        }
64
65        let truncated_tokens = tokens[..max_tokens as usize].to_vec();
66        self.encoding.decode(truncated_tokens).unwrap()
67    }
68}