1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
use anyhow::Result;
use lru::LruCache;
use parking_lot::Mutex;
use std::num::NonZeroUsize;
use std::sync::Arc;
use tiktoken_rs::{get_bpe_from_model, CoreBPE};
/// Token counter for various models with caching
pub struct TokenCounter {
encoder: CoreBPE,
// LRU cache for token counts to avoid repeated tokenization
token_cache: Arc<Mutex<LruCache<String, usize>>>,
// Cache for truncated text to avoid repeated truncation
truncation_cache: Arc<Mutex<LruCache<(String, usize), String>>>,
}
// Global cache for encoder instances to avoid repeated creation
lazy_static::lazy_static! {
static ref ENCODER_CACHE: Arc<Mutex<LruCache<String, CoreBPE>>> =
Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(10).unwrap())));
}
impl TokenCounter {
/// Create a new token counter for the given model with caching
pub fn new(model_name: &str) -> Result<Self> {
// Map model names to tiktoken model names
let tiktoken_model = map_model_to_tiktoken(model_name);
// Try to get encoder from cache first
let encoder = {
let mut cache = ENCODER_CACHE.lock();
if let Some(cached_encoder) = cache.get(&tiktoken_model) {
cached_encoder.clone()
} else {
let new_encoder = get_bpe_from_model(&tiktoken_model).map_err(|e| {
anyhow::anyhow!(
"Failed to create token encoder for model '{}': {}",
model_name,
e
)
})?;
cache.put(tiktoken_model, new_encoder.clone());
new_encoder
}
};
Ok(Self {
encoder,
token_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(1000).unwrap()))),
truncation_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(100).unwrap()))),
})
}
/// Count tokens in the given text with caching
pub fn count_tokens(&self, text: &str) -> usize {
// Check cache first
{
let mut cache = self.token_cache.lock();
if let Some(&cached_count) = cache.get(text) {
return cached_count;
}
}
// Calculate token count
let count = self.encoder.encode_with_special_tokens(text).len();
// Store in cache
{
let mut cache = self.token_cache.lock();
cache.put(text.to_string(), count);
}
count
}
/// Estimate tokens for a chat request including system prompt and history
pub fn estimate_chat_tokens(
&self,
prompt: &str,
system_prompt: Option<&str>,
history: &[crate::database::ChatEntry],
) -> usize {
let mut total_tokens = 0;
// Count system prompt tokens
if let Some(sys_prompt) = system_prompt {
total_tokens += self.count_tokens(sys_prompt);
total_tokens += 4; // Overhead for system message formatting
}
// Count history tokens
for entry in history {
total_tokens += self.count_tokens(&entry.question);
total_tokens += self.count_tokens(&entry.response);
total_tokens += 8; // Overhead for message formatting (role, etc.)
}
// Count current prompt tokens
total_tokens += self.count_tokens(prompt);
total_tokens += 4; // Overhead for user message formatting
// Add some buffer for response generation
total_tokens += 100; // Reserve space for response start
total_tokens
}
/// Check if the estimated tokens exceed the context limit
pub fn exceeds_context_limit(
&self,
prompt: &str,
system_prompt: Option<&str>,
history: &[crate::database::ChatEntry],
context_limit: u32,
) -> bool {
let estimated_tokens = self.estimate_chat_tokens(prompt, system_prompt, history);
estimated_tokens > context_limit as usize
}
/// Truncate input to fit within context limit
pub fn truncate_to_fit(
&self,
prompt: &str,
system_prompt: Option<&str>,
history: &[crate::database::ChatEntry],
context_limit: u32,
max_output_tokens: Option<u32>,
) -> (String, Vec<crate::database::ChatEntry>) {
let max_output = max_output_tokens.unwrap_or(4096) as usize;
let available_tokens = (context_limit as usize).saturating_sub(max_output);
// Always preserve the current prompt and system prompt
let mut used_tokens = self.count_tokens(prompt) + 4; // User message overhead
if let Some(sys_prompt) = system_prompt {
used_tokens += self.count_tokens(sys_prompt) + 4; // System message overhead
}
if used_tokens >= available_tokens {
// Even the prompt alone is too large, truncate it
let max_prompt_tokens = available_tokens.saturating_sub(100); // Leave some buffer
let truncated_prompt = self.truncate_text(prompt, max_prompt_tokens);
return (truncated_prompt, Vec::new());
}
// Include as much history as possible
let mut truncated_history = Vec::new();
let remaining_tokens = available_tokens - used_tokens;
let mut history_tokens = 0;
// Include history from most recent to oldest
for entry in history.iter().rev() {
let entry_tokens =
self.count_tokens(&entry.question) + self.count_tokens(&entry.response) + 8;
if history_tokens + entry_tokens <= remaining_tokens {
history_tokens += entry_tokens;
truncated_history.insert(0, entry.clone());
} else {
break;
}
}
(prompt.to_string(), truncated_history)
}
/// Truncate text to fit within token limit with caching
fn truncate_text(&self, text: &str, max_tokens: usize) -> String {
let cache_key = (text.to_string(), max_tokens);
// Check cache first
{
let mut cache = self.truncation_cache.lock();
if let Some(cached_result) = cache.get(&cache_key) {
return cached_result.clone();
}
}
let tokens = self.encoder.encode_with_special_tokens(text);
if tokens.len() <= max_tokens {
return text.to_string();
}
let result = {
let truncated_tokens = &tokens[..max_tokens];
match self.encoder.decode(truncated_tokens.to_vec()) {
Ok(decoded) => decoded,
Err(_) => {
// Fallback: truncate by characters (rough approximation)
let chars: Vec<char> = text.chars().collect();
let estimated_chars = max_tokens * 3; // Rough estimate: 1 token ≈ 3-4 chars
if chars.len() > estimated_chars {
chars[..estimated_chars].iter().collect()
} else {
text.to_string()
}
}
}
};
// Store in cache
{
let mut cache = self.truncation_cache.lock();
cache.put(cache_key, result.clone());
}
result
}
}
/// Map model names to tiktoken-compatible model names
/// This is a simplified fallback approach - ideally tokenizer mappings should be
/// configured per provider in configuration files for accuracy
fn map_model_to_tiktoken(model_name: &str) -> String {
let lower_name = model_name.to_lowercase();
// Only handle actual OpenAI models with their correct tokenizers
if lower_name.contains("gpt-4") {
"gpt-4".to_string()
} else if lower_name.contains("gpt-3.5") {
"gpt-3.5-turbo".to_string()
} else {
// For all non-OpenAI models, use GPT-4 as a rough approximation
// NOTE: This is inaccurate but necessary since tiktoken only supports OpenAI models
// TODO: Move to provider-specific tokenizer configuration or disable token counting
// for non-OpenAI models to avoid misleading estimates
"gpt-4".to_string()
}
}