1use anyhow::Result;
2use lru::LruCache;
3use parking_lot::Mutex;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6use tiktoken_rs::{get_bpe_from_model, CoreBPE};
7
8pub struct TokenCounter {
10 encoder: CoreBPE,
11 token_cache: Arc<Mutex<LruCache<String, usize>>>,
13 truncation_cache: Arc<Mutex<LruCache<(String, usize), String>>>,
15}
16
17lazy_static::lazy_static! {
19 static ref ENCODER_CACHE: Arc<Mutex<LruCache<String, CoreBPE>>> =
20 Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(10).unwrap())));
21}
22
23impl TokenCounter {
24 pub fn new(model_name: &str) -> Result<Self> {
26 let tiktoken_model = map_model_to_tiktoken(model_name);
28
29 let encoder = {
31 let mut cache = ENCODER_CACHE.lock();
32 if let Some(cached_encoder) = cache.get(&tiktoken_model) {
33 cached_encoder.clone()
34 } else {
35 let new_encoder = get_bpe_from_model(&tiktoken_model).map_err(|e| {
36 anyhow::anyhow!(
37 "Failed to create token encoder for model '{}': {}",
38 model_name,
39 e
40 )
41 })?;
42 cache.put(tiktoken_model, new_encoder.clone());
43 new_encoder
44 }
45 };
46
47 Ok(Self {
48 encoder,
49 token_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(1000).unwrap()))),
50 truncation_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(100).unwrap()))),
51 })
52 }
53
54 pub fn count_tokens(&self, text: &str) -> usize {
56 {
58 let mut cache = self.token_cache.lock();
59 if let Some(&cached_count) = cache.get(text) {
60 return cached_count;
61 }
62 }
63
64 let count = self.encoder.encode_with_special_tokens(text).len();
66
67 {
69 let mut cache = self.token_cache.lock();
70 cache.put(text.to_string(), count);
71 }
72
73 count
74 }
75
76 pub fn estimate_chat_tokens(
78 &self,
79 prompt: &str,
80 system_prompt: Option<&str>,
81 history: &[crate::database::ChatEntry],
82 ) -> usize {
83 let mut total_tokens = 0;
84
85 if let Some(sys_prompt) = system_prompt {
87 total_tokens += self.count_tokens(sys_prompt);
88 total_tokens += 4; }
90
91 for entry in history {
93 total_tokens += self.count_tokens(&entry.question);
94 total_tokens += self.count_tokens(&entry.response);
95 total_tokens += 8; }
97
98 total_tokens += self.count_tokens(prompt);
100 total_tokens += 4; total_tokens += 100; total_tokens
106 }
107
108 pub fn exceeds_context_limit(
110 &self,
111 prompt: &str,
112 system_prompt: Option<&str>,
113 history: &[crate::database::ChatEntry],
114 context_limit: u32,
115 ) -> bool {
116 let estimated_tokens = self.estimate_chat_tokens(prompt, system_prompt, history);
117 estimated_tokens > context_limit as usize
118 }
119
120 pub fn truncate_to_fit(
122 &self,
123 prompt: &str,
124 system_prompt: Option<&str>,
125 history: &[crate::database::ChatEntry],
126 context_limit: u32,
127 max_output_tokens: Option<u32>,
128 ) -> (String, Vec<crate::database::ChatEntry>) {
129 let max_output = max_output_tokens.unwrap_or(4096) as usize;
130 let available_tokens = (context_limit as usize).saturating_sub(max_output);
131
132 let mut used_tokens = self.count_tokens(prompt) + 4; if let Some(sys_prompt) = system_prompt {
135 used_tokens += self.count_tokens(sys_prompt) + 4; }
137
138 if used_tokens >= available_tokens {
139 let max_prompt_tokens = available_tokens.saturating_sub(100); let truncated_prompt = self.truncate_text(prompt, max_prompt_tokens);
142 return (truncated_prompt, Vec::new());
143 }
144
145 let mut truncated_history = Vec::new();
147 let remaining_tokens = available_tokens - used_tokens;
148 let mut history_tokens = 0;
149
150 for entry in history.iter().rev() {
152 let entry_tokens =
153 self.count_tokens(&entry.question) + self.count_tokens(&entry.response) + 8;
154 if history_tokens + entry_tokens <= remaining_tokens {
155 history_tokens += entry_tokens;
156 truncated_history.insert(0, entry.clone());
157 } else {
158 break;
159 }
160 }
161
162 (prompt.to_string(), truncated_history)
163 }
164
165 fn truncate_text(&self, text: &str, max_tokens: usize) -> String {
167 let cache_key = (text.to_string(), max_tokens);
168
169 {
171 let mut cache = self.truncation_cache.lock();
172 if let Some(cached_result) = cache.get(&cache_key) {
173 return cached_result.clone();
174 }
175 }
176
177 let tokens = self.encoder.encode_with_special_tokens(text);
178 if tokens.len() <= max_tokens {
179 return text.to_string();
180 }
181
182 let result = {
183 let truncated_tokens = &tokens[..max_tokens];
184 match self.encoder.decode(truncated_tokens.to_vec()) {
185 Ok(decoded) => decoded,
186 Err(_) => {
187 let chars: Vec<char> = text.chars().collect();
189 let estimated_chars = max_tokens * 3; if chars.len() > estimated_chars {
191 chars[..estimated_chars].iter().collect()
192 } else {
193 text.to_string()
194 }
195 }
196 }
197 };
198
199 {
201 let mut cache = self.truncation_cache.lock();
202 cache.put(cache_key, result.clone());
203 }
204
205 result
206 }
207}
208
209fn map_model_to_tiktoken(model_name: &str) -> String {
213 let lower_name = model_name.to_lowercase();
214
215 if lower_name.contains("gpt-4") {
217 "gpt-4".to_string()
218 } else if lower_name.contains("gpt-3.5") {
219 "gpt-3.5-turbo".to_string()
220 } else {
221 "gpt-4".to_string()
226 }
227}