lc/utils/
token.rs

1use anyhow::Result;
2use lru::LruCache;
3use parking_lot::Mutex;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6use tiktoken_rs::{get_bpe_from_model, CoreBPE};
7
8/// Token counter for various models with caching
9pub struct TokenCounter {
10    encoder: CoreBPE,
11    // LRU cache for token counts to avoid repeated tokenization
12    token_cache: Arc<Mutex<LruCache<String, usize>>>,
13    // Cache for truncated text to avoid repeated truncation
14    truncation_cache: Arc<Mutex<LruCache<(String, usize), String>>>,
15}
16
17// Global cache for encoder instances to avoid repeated creation
18lazy_static::lazy_static! {
19    static ref ENCODER_CACHE: Arc<Mutex<LruCache<String, CoreBPE>>> =
20        Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(10).unwrap())));
21}
22
23impl TokenCounter {
24    /// Create a new token counter for the given model with caching
25    pub fn new(model_name: &str) -> Result<Self> {
26        // Map model names to tiktoken model names
27        let tiktoken_model = map_model_to_tiktoken(model_name);
28
29        // Try to get encoder from cache first
30        let encoder = {
31            let mut cache = ENCODER_CACHE.lock();
32            if let Some(cached_encoder) = cache.get(&tiktoken_model) {
33                cached_encoder.clone()
34            } else {
35                let new_encoder = get_bpe_from_model(&tiktoken_model).map_err(|e| {
36                    anyhow::anyhow!(
37                        "Failed to create token encoder for model '{}': {}",
38                        model_name,
39                        e
40                    )
41                })?;
42                cache.put(tiktoken_model, new_encoder.clone());
43                new_encoder
44            }
45        };
46
47        Ok(Self {
48            encoder,
49            token_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(1000).unwrap()))),
50            truncation_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(100).unwrap()))),
51        })
52    }
53
54    /// Count tokens in the given text with caching
55    pub fn count_tokens(&self, text: &str) -> usize {
56        // Check cache first
57        {
58            let mut cache = self.token_cache.lock();
59            if let Some(&cached_count) = cache.get(text) {
60                return cached_count;
61            }
62        }
63
64        // Calculate token count
65        let count = self.encoder.encode_with_special_tokens(text).len();
66
67        // Store in cache
68        {
69            let mut cache = self.token_cache.lock();
70            cache.put(text.to_string(), count);
71        }
72
73        count
74    }
75
76    /// Estimate tokens for a chat request including system prompt and history
77    pub fn estimate_chat_tokens(
78        &self,
79        prompt: &str,
80        system_prompt: Option<&str>,
81        history: &[crate::database::ChatEntry],
82    ) -> usize {
83        let mut total_tokens = 0;
84
85        // Count system prompt tokens
86        if let Some(sys_prompt) = system_prompt {
87            total_tokens += self.count_tokens(sys_prompt);
88            total_tokens += 4; // Overhead for system message formatting
89        }
90
91        // Count history tokens
92        for entry in history {
93            total_tokens += self.count_tokens(&entry.question);
94            total_tokens += self.count_tokens(&entry.response);
95            total_tokens += 8; // Overhead for message formatting (role, etc.)
96        }
97
98        // Count current prompt tokens
99        total_tokens += self.count_tokens(prompt);
100        total_tokens += 4; // Overhead for user message formatting
101
102        // Add some buffer for response generation
103        total_tokens += 100; // Reserve space for response start
104
105        total_tokens
106    }
107
108    /// Check if the estimated tokens exceed the context limit
109    pub fn exceeds_context_limit(
110        &self,
111        prompt: &str,
112        system_prompt: Option<&str>,
113        history: &[crate::database::ChatEntry],
114        context_limit: u32,
115    ) -> bool {
116        let estimated_tokens = self.estimate_chat_tokens(prompt, system_prompt, history);
117        estimated_tokens > context_limit as usize
118    }
119
120    /// Truncate input to fit within context limit
121    pub fn truncate_to_fit(
122        &self,
123        prompt: &str,
124        system_prompt: Option<&str>,
125        history: &[crate::database::ChatEntry],
126        context_limit: u32,
127        max_output_tokens: Option<u32>,
128    ) -> (String, Vec<crate::database::ChatEntry>) {
129        let max_output = max_output_tokens.unwrap_or(4096) as usize;
130        let available_tokens = (context_limit as usize).saturating_sub(max_output);
131
132        // Always preserve the current prompt and system prompt
133        let mut used_tokens = self.count_tokens(prompt) + 4; // User message overhead
134        if let Some(sys_prompt) = system_prompt {
135            used_tokens += self.count_tokens(sys_prompt) + 4; // System message overhead
136        }
137
138        if used_tokens >= available_tokens {
139            // Even the prompt alone is too large, truncate it
140            let max_prompt_tokens = available_tokens.saturating_sub(100); // Leave some buffer
141            let truncated_prompt = self.truncate_text(prompt, max_prompt_tokens);
142            return (truncated_prompt, Vec::new());
143        }
144
145        // Include as much history as possible
146        let mut truncated_history = Vec::new();
147        let remaining_tokens = available_tokens - used_tokens;
148        let mut history_tokens = 0;
149
150        // Include history from most recent to oldest
151        for entry in history.iter().rev() {
152            let entry_tokens =
153                self.count_tokens(&entry.question) + self.count_tokens(&entry.response) + 8;
154            if history_tokens + entry_tokens <= remaining_tokens {
155                history_tokens += entry_tokens;
156                truncated_history.insert(0, entry.clone());
157            } else {
158                break;
159            }
160        }
161
162        (prompt.to_string(), truncated_history)
163    }
164
165    /// Truncate text to fit within token limit with caching
166    fn truncate_text(&self, text: &str, max_tokens: usize) -> String {
167        let cache_key = (text.to_string(), max_tokens);
168
169        // Check cache first
170        {
171            let mut cache = self.truncation_cache.lock();
172            if let Some(cached_result) = cache.get(&cache_key) {
173                return cached_result.clone();
174            }
175        }
176
177        let tokens = self.encoder.encode_with_special_tokens(text);
178        if tokens.len() <= max_tokens {
179            return text.to_string();
180        }
181
182        let result = {
183            let truncated_tokens = &tokens[..max_tokens];
184            match self.encoder.decode(truncated_tokens.to_vec()) {
185                Ok(decoded) => decoded,
186                Err(_) => {
187                    // Fallback: truncate by characters (rough approximation)
188                    let chars: Vec<char> = text.chars().collect();
189                    let estimated_chars = max_tokens * 3; // Rough estimate: 1 token ≈ 3-4 chars
190                    if chars.len() > estimated_chars {
191                        chars[..estimated_chars].iter().collect()
192                    } else {
193                        text.to_string()
194                    }
195                }
196            }
197        };
198
199        // Store in cache
200        {
201            let mut cache = self.truncation_cache.lock();
202            cache.put(cache_key, result.clone());
203        }
204
205        result
206    }
207}
208
209/// Map model names to tiktoken-compatible model names
210/// This is a simplified fallback approach - ideally tokenizer mappings should be
211/// configured per provider in configuration files for accuracy
212fn map_model_to_tiktoken(model_name: &str) -> String {
213    let lower_name = model_name.to_lowercase();
214
215    // Only handle actual OpenAI models with their correct tokenizers
216    if lower_name.contains("gpt-4") {
217        "gpt-4".to_string()
218    } else if lower_name.contains("gpt-3.5") {
219        "gpt-3.5-turbo".to_string()
220    } else {
221        // For all non-OpenAI models, use GPT-4 as a rough approximation
222        // NOTE: This is inaccurate but necessary since tiktoken only supports OpenAI models
223        // TODO: Move to provider-specific tokenizer configuration or disable token counting
224        // for non-OpenAI models to avoid misleading estimates
225        "gpt-4".to_string()
226    }
227}