openai_mock/utils/
token_counting.rs1use tiktoken_rs::{cl100k_base, p50k_base, o200k_base};
2use crate::models::completion::Usage;
3
4pub struct ChatMessage {
5 pub role: String,
6 pub content: String,
7}
8
9pub struct TokenCounter {
10 encoding: tiktoken_rs::CoreBPE,
11}
12
13impl TokenCounter {
14 pub fn new(model: &str) -> Result<Self, Box<dyn std::error::Error>> {
15 let encoding = match model {
16 "gpt-4" | "gpt-3.5-turbo" | "text-embedding-ada-002" => {
17 cl100k_base()?
18 },
19 "gpt-4o" | "gpt-4o-mini" => {
20 o200k_base()?
21 },
22 "text-davinci-002" | "text-davinci-003" => {
23 p50k_base()?
24 },
25 _ => cl100k_base()? };
27
28 Ok(Self { encoding })
29 }
30
31 pub fn count_tokens(&self, text: &str) -> u32 {
32 self.encoding.encode_with_special_tokens(text).len() as u32
33 }
34
35 pub fn count_messages_tokens(&self, messages: &[ChatMessage]) -> u32 {
36 let per_message_tokens = 3;
38
39 messages.iter().map(|msg| {
40 self.count_tokens(&msg.content) +
41 self.count_tokens(&msg.role) +
42 per_message_tokens
43 }).sum()
44 }
45
46 pub fn calculate_usage(&self, prompt: &str, completion: &str) -> Usage {
48 let prompt_tokens = self.count_tokens(prompt);
49 let completion_tokens = self.count_tokens(completion);
50
51 Usage {
52 prompt_tokens,
53 completion_tokens,
54 total_tokens: prompt_tokens + completion_tokens,
55 }
56 }
57
58 pub fn truncate_to_tokens(&self, text: &str, max_tokens: u32) -> String {
60 let tokens = self.encoding.encode_with_special_tokens(text);
61 if tokens.len() as u32 <= max_tokens {
62 return text.to_string();
63 }
64
65 let truncated_tokens = tokens[..max_tokens as usize].to_vec();
66 self.encoding.decode(truncated_tokens).unwrap()
67 }
68}