Skip to main content

grapsus_proxy/inference/
tiktoken.rs

1//! Tiktoken integration for accurate token counting.
2//!
3//! Provides a cached tokenizer manager that:
4//! - Caches BPE instances for different encodings (avoid recreation overhead)
5//! - Maps model names to the correct tokenizer encoding
6//! - Extracts and tokenizes just the message content from chat completions
7//!
8//! # Encodings
9//!
10//! | Encoding | Models |
11//! |----------|--------|
12//! | `o200k_base` | GPT-4o, GPT-4o-mini |
13//! | `cl100k_base` | GPT-4, GPT-4-turbo, GPT-3.5-turbo, text-embedding-* |
14//! | `p50k_base` | Codex, text-davinci-003 |
15//!
16//! # Usage
17//!
18//! ```ignore
19//! let manager = TiktokenManager::new();
20//! let tokens = manager.count_tokens("gpt-4", "Hello, world!");
21//! let request_tokens = manager.count_chat_request(body, Some("gpt-4o"));
22//! ```
23
24use once_cell::sync::Lazy;
25use parking_lot::RwLock;
26use serde_json::Value;
27use std::collections::HashMap;
28use std::sync::Arc;
29use tracing::{debug, trace, warn};
30
31#[cfg(feature = "tiktoken")]
32use tiktoken_rs::{cl100k_base, o200k_base, p50k_base, CoreBPE};
33
34/// Tiktoken encoding types
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum TiktokenEncoding {
37    /// GPT-4o, GPT-4o-mini (newest)
38    O200kBase,
39    /// GPT-4, GPT-4-turbo, GPT-3.5-turbo
40    Cl100kBase,
41    /// Codex, text-davinci-003
42    P50kBase,
43}
44
45impl TiktokenEncoding {
46    /// Get the encoding name for logging
47    pub fn name(&self) -> &'static str {
48        match self {
49            Self::O200kBase => "o200k_base",
50            Self::Cl100kBase => "cl100k_base",
51            Self::P50kBase => "p50k_base",
52        }
53    }
54}
55
56/// Global tiktoken manager instance
57static TIKTOKEN_MANAGER: Lazy<TiktokenManager> = Lazy::new(TiktokenManager::new);
58
59/// Get the global tiktoken manager
60pub fn tiktoken_manager() -> &'static TiktokenManager {
61    &TIKTOKEN_MANAGER
62}
63
64/// Manages cached tiktoken BPE instances for different encodings.
65///
66/// Thread-safe and lazily initialized - encodings are only loaded when first used.
67pub struct TiktokenManager {
68    #[cfg(feature = "tiktoken")]
69    encodings: RwLock<HashMap<TiktokenEncoding, Arc<CoreBPE>>>,
70    #[cfg(not(feature = "tiktoken"))]
71    _marker: std::marker::PhantomData<()>,
72}
73
74impl TiktokenManager {
75    /// Create a new tiktoken manager
76    pub fn new() -> Self {
77        #[cfg(feature = "tiktoken")]
78        {
79            Self {
80                encodings: RwLock::new(HashMap::new()),
81            }
82        }
83        #[cfg(not(feature = "tiktoken"))]
84        {
85            Self {
86                _marker: std::marker::PhantomData,
87            }
88        }
89    }
90
91    /// Get the appropriate encoding for a model name
92    pub fn encoding_for_model(&self, model: &str) -> TiktokenEncoding {
93        let model_lower = model.to_lowercase();
94
95        // GPT-4o family uses o200k_base
96        if model_lower.contains("gpt-4o") || model_lower.contains("gpt4o") {
97            return TiktokenEncoding::O200kBase;
98        }
99
100        // GPT-4 and GPT-3.5-turbo use cl100k_base
101        if model_lower.contains("gpt-4")
102            || model_lower.contains("gpt-3.5")
103            || model_lower.contains("gpt-35")
104            || model_lower.contains("text-embedding")
105            || model_lower.contains("claude")
106        // Claude approximation
107        {
108            return TiktokenEncoding::Cl100kBase;
109        }
110
111        // Codex and older models use p50k_base
112        if model_lower.contains("code-")
113            || model_lower.contains("codex")
114            || model_lower.contains("text-davinci-003")
115            || model_lower.contains("text-davinci-002")
116        {
117            return TiktokenEncoding::P50kBase;
118        }
119
120        // Default to cl100k_base (most common)
121        TiktokenEncoding::Cl100kBase
122    }
123
124    /// Count tokens in text using the appropriate encoding for the model
125    #[cfg(feature = "tiktoken")]
126    pub fn count_tokens(&self, model: Option<&str>, text: &str) -> u64 {
127        let encoding = model
128            .map(|m| self.encoding_for_model(m))
129            .unwrap_or(TiktokenEncoding::Cl100kBase);
130
131        self.count_tokens_with_encoding(encoding, text)
132    }
133
134    #[cfg(not(feature = "tiktoken"))]
135    pub fn count_tokens(&self, _model: Option<&str>, text: &str) -> u64 {
136        // Fallback to character-based estimation
137        (text.chars().count() / 4).max(1) as u64
138    }
139
140    /// Count tokens using a specific encoding
141    #[cfg(feature = "tiktoken")]
142    pub fn count_tokens_with_encoding(&self, encoding: TiktokenEncoding, text: &str) -> u64 {
143        match self.get_or_create_bpe(encoding) {
144            Some(bpe) => {
145                let tokens = bpe.encode_with_special_tokens(text);
146                tokens.len() as u64
147            }
148            None => {
149                // Fallback to character estimation
150                (text.chars().count() / 4).max(1) as u64
151            }
152        }
153    }
154
155    #[cfg(not(feature = "tiktoken"))]
156    pub fn count_tokens_with_encoding(&self, _encoding: TiktokenEncoding, text: &str) -> u64 {
157        (text.chars().count() / 4).max(1) as u64
158    }
159
160    /// Count tokens in a chat completion request body
161    ///
162    /// Parses the JSON and extracts message content for accurate token counting.
163    /// Returns estimated tokens including overhead for message formatting.
164    pub fn count_chat_request(&self, body: &[u8], model: Option<&str>) -> u64 {
165        // Try to parse as JSON
166        let json: Value = match serde_json::from_slice(body) {
167            Ok(v) => v,
168            Err(_) => {
169                // If not valid JSON, count the whole body as text
170                let text = String::from_utf8_lossy(body);
171                return self.count_tokens(model, &text);
172            }
173        };
174
175        // Extract model from body if not provided
176        let model_name = model.or_else(|| json.get("model").and_then(|m| m.as_str()));
177
178        // Extract messages array
179        let messages = match json.get("messages").and_then(|m| m.as_array()) {
180            Some(msgs) => msgs,
181            None => {
182                // Not a chat completion request, try other formats
183                return self.count_non_chat_request(&json, model_name);
184            }
185        };
186
187        // Count tokens in messages
188        let mut total_tokens: u64 = 0;
189
190        // Per-message overhead (role, separators, etc.)
191        // OpenAI uses ~4 tokens overhead per message
192        const MESSAGE_OVERHEAD: u64 = 4;
193
194        for message in messages {
195            // Add message overhead
196            total_tokens += MESSAGE_OVERHEAD;
197
198            // Count role tokens (typically 1 token)
199            if let Some(role) = message.get("role").and_then(|r| r.as_str()) {
200                total_tokens += self.count_tokens(model_name, role);
201            }
202
203            // Count content tokens
204            if let Some(content) = message.get("content") {
205                match content {
206                    Value::String(text) => {
207                        total_tokens += self.count_tokens(model_name, text);
208                    }
209                    Value::Array(parts) => {
210                        // Multi-modal content (text + images)
211                        for part in parts {
212                            if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
213                                total_tokens += self.count_tokens(model_name, text);
214                            }
215                            // Image tokens are estimated separately (not text)
216                            if part.get("image_url").is_some() {
217                                // Rough estimate: 85 tokens for low detail, 765 for high detail
218                                total_tokens += 170; // Medium estimate
219                            }
220                        }
221                    }
222                    _ => {}
223                }
224            }
225
226            // Count name tokens if present
227            if let Some(name) = message.get("name").and_then(|n| n.as_str()) {
228                total_tokens += self.count_tokens(model_name, name);
229            }
230
231            // Count tool calls if present
232            if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) {
233                for tool_call in tool_calls {
234                    if let Some(function) = tool_call.get("function") {
235                        if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
236                            total_tokens += self.count_tokens(model_name, name);
237                        }
238                        if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
239                            total_tokens += self.count_tokens(model_name, args);
240                        }
241                    }
242                }
243            }
244        }
245
246        // Add conversation overhead (typically 3 tokens)
247        total_tokens += 3;
248
249        // Account for max_tokens in response (estimate output)
250        if let Some(max_tokens) = json.get("max_tokens").and_then(|m| m.as_u64()) {
251            // Add estimated output tokens (assume ~50% utilization)
252            total_tokens += max_tokens / 2;
253        }
254
255        trace!(
256            message_count = messages.len(),
257            total_tokens = total_tokens,
258            model = ?model_name,
259            "Counted tokens in chat request"
260        );
261
262        total_tokens
263    }
264
265    /// Count tokens for non-chat requests (completions, embeddings)
266    fn count_non_chat_request(&self, json: &Value, model: Option<&str>) -> u64 {
267        let mut total_tokens: u64 = 0;
268
269        // Legacy completions API: { "prompt": "..." }
270        if let Some(prompt) = json.get("prompt") {
271            match prompt {
272                Value::String(text) => {
273                    total_tokens += self.count_tokens(model, text);
274                }
275                Value::Array(prompts) => {
276                    for p in prompts {
277                        if let Some(text) = p.as_str() {
278                            total_tokens += self.count_tokens(model, text);
279                        }
280                    }
281                }
282                _ => {}
283            }
284        }
285
286        // Embeddings API: { "input": "..." }
287        if let Some(input) = json.get("input") {
288            match input {
289                Value::String(text) => {
290                    total_tokens += self.count_tokens(model, text);
291                }
292                Value::Array(inputs) => {
293                    for i in inputs {
294                        if let Some(text) = i.as_str() {
295                            total_tokens += self.count_tokens(model, text);
296                        }
297                    }
298                }
299                _ => {}
300            }
301        }
302
303        // If still zero, count the whole body
304        if total_tokens == 0 {
305            let body_text = json.to_string();
306            total_tokens = self.count_tokens(model, &body_text);
307        }
308
309        total_tokens
310    }
311
312    /// Get or create a BPE instance for the given encoding
313    #[cfg(feature = "tiktoken")]
314    fn get_or_create_bpe(&self, encoding: TiktokenEncoding) -> Option<Arc<CoreBPE>> {
315        // Try read lock first
316        {
317            let cache = self.encodings.read();
318            if let Some(bpe) = cache.get(&encoding) {
319                return Some(Arc::clone(bpe));
320            }
321        }
322
323        // Need to create - acquire write lock
324        let mut cache = self.encodings.write();
325
326        // Double-check after acquiring write lock
327        if let Some(bpe) = cache.get(&encoding) {
328            return Some(Arc::clone(bpe));
329        }
330
331        // Create the encoding
332        let bpe = match encoding {
333            TiktokenEncoding::O200kBase => {
334                debug!(encoding = "o200k_base", "Initializing tiktoken encoding");
335                o200k_base().ok()
336            }
337            TiktokenEncoding::Cl100kBase => {
338                debug!(encoding = "cl100k_base", "Initializing tiktoken encoding");
339                cl100k_base().ok()
340            }
341            TiktokenEncoding::P50kBase => {
342                debug!(encoding = "p50k_base", "Initializing tiktoken encoding");
343                p50k_base().ok()
344            }
345        };
346
347        match bpe {
348            Some(bpe) => {
349                let arc_bpe = Arc::new(bpe);
350                cache.insert(encoding, Arc::clone(&arc_bpe));
351                Some(arc_bpe)
352            }
353            None => {
354                warn!(
355                    encoding = encoding.name(),
356                    "Failed to initialize tiktoken encoding"
357                );
358                None
359            }
360        }
361    }
362
363    /// Check if tiktoken is available (feature enabled)
364    pub fn is_available(&self) -> bool {
365        cfg!(feature = "tiktoken")
366    }
367}
368
369impl Default for TiktokenManager {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375// ============================================================================
376// Tests
377// ============================================================================
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn test_encoding_for_model() {
385        let manager = TiktokenManager::new();
386
387        // GPT-4o uses o200k_base
388        assert_eq!(
389            manager.encoding_for_model("gpt-4o"),
390            TiktokenEncoding::O200kBase
391        );
392        assert_eq!(
393            manager.encoding_for_model("gpt-4o-mini"),
394            TiktokenEncoding::O200kBase
395        );
396
397        // GPT-4 uses cl100k_base
398        assert_eq!(
399            manager.encoding_for_model("gpt-4"),
400            TiktokenEncoding::Cl100kBase
401        );
402        assert_eq!(
403            manager.encoding_for_model("gpt-4-turbo"),
404            TiktokenEncoding::Cl100kBase
405        );
406        assert_eq!(
407            manager.encoding_for_model("gpt-3.5-turbo"),
408            TiktokenEncoding::Cl100kBase
409        );
410
411        // Claude uses cl100k approximation
412        assert_eq!(
413            manager.encoding_for_model("claude-3-opus"),
414            TiktokenEncoding::Cl100kBase
415        );
416
417        // Codex uses p50k_base
418        assert_eq!(
419            manager.encoding_for_model("code-davinci-002"),
420            TiktokenEncoding::P50kBase
421        );
422
423        // Unknown defaults to cl100k_base
424        assert_eq!(
425            manager.encoding_for_model("unknown-model"),
426            TiktokenEncoding::Cl100kBase
427        );
428    }
429
430    #[test]
431    fn test_count_tokens_basic() {
432        let manager = TiktokenManager::new();
433
434        // Basic text counting
435        let tokens = manager.count_tokens(Some("gpt-4"), "Hello, world!");
436        assert!(tokens > 0);
437
438        // Without model (uses default)
439        let tokens = manager.count_tokens(None, "Hello, world!");
440        assert!(tokens > 0);
441    }
442
443    #[test]
444    fn test_count_chat_request() {
445        let manager = TiktokenManager::new();
446
447        let body = br#"{
448            "model": "gpt-4",
449            "messages": [
450                {"role": "system", "content": "You are a helpful assistant."},
451                {"role": "user", "content": "Hello!"}
452            ]
453        }"#;
454
455        let tokens = manager.count_chat_request(body, None);
456        assert!(tokens > 0);
457        // Should be roughly: system message (~10) + user message (~5) + overhead (~10)
458        assert!(tokens >= 10);
459    }
460
461    #[test]
462    fn test_count_chat_request_with_tools() {
463        let manager = TiktokenManager::new();
464
465        let body = br#"{
466            "model": "gpt-4",
467            "messages": [
468                {"role": "user", "content": "What's the weather?"},
469                {"role": "assistant", "tool_calls": [
470                    {"function": {"name": "get_weather", "arguments": "{\"city\": \"NYC\"}"}}
471                ]}
472            ]
473        }"#;
474
475        let tokens = manager.count_chat_request(body, None);
476        assert!(tokens > 0);
477    }
478
479    #[test]
480    fn test_count_embeddings_request() {
481        let manager = TiktokenManager::new();
482
483        let body = br#"{
484            "model": "text-embedding-ada-002",
485            "input": "Hello, world!"
486        }"#;
487
488        let tokens = manager.count_chat_request(body, None);
489        assert!(tokens > 0);
490    }
491
492    #[test]
493    fn test_count_invalid_json() {
494        let manager = TiktokenManager::new();
495
496        let body = b"not valid json at all";
497        let tokens = manager.count_chat_request(body, Some("gpt-4"));
498        // Should fall back to text counting
499        assert!(tokens > 0);
500    }
501
502    #[test]
503    #[cfg(feature = "tiktoken")]
504    fn test_tiktoken_accurate_hello_world() {
505        let manager = TiktokenManager::new();
506
507        // "Hello world" is typically 2 tokens with cl100k_base
508        let tokens =
509            manager.count_tokens_with_encoding(TiktokenEncoding::Cl100kBase, "Hello world");
510        assert_eq!(tokens, 2);
511    }
512
513    #[test]
514    #[cfg(feature = "tiktoken")]
515    fn test_tiktoken_caching() {
516        let manager = TiktokenManager::new();
517
518        // First call creates the encoding
519        let tokens1 = manager.count_tokens(Some("gpt-4"), "Test message");
520        // Second call should use cached encoding
521        let tokens2 = manager.count_tokens(Some("gpt-4"), "Test message");
522
523        assert_eq!(tokens1, tokens2);
524    }
525}