sentinel_proxy/inference/
tiktoken.rs

1//! Tiktoken integration for accurate token counting.
2//!
3//! Provides a cached tokenizer manager that:
4//! - Caches BPE instances for different encodings (avoid recreation overhead)
5//! - Maps model names to the correct tokenizer encoding
6//! - Extracts and tokenizes just the message content from chat completions
7//!
8//! # Encodings
9//!
10//! | Encoding | Models |
11//! |----------|--------|
12//! | `o200k_base` | GPT-4o, GPT-4o-mini |
13//! | `cl100k_base` | GPT-4, GPT-4-turbo, GPT-3.5-turbo, text-embedding-* |
14//! | `p50k_base` | Codex, text-davinci-003 |
15//!
16//! # Usage
17//!
18//! ```ignore
19//! let manager = TiktokenManager::new();
20//! let tokens = manager.count_tokens("gpt-4", "Hello, world!");
21//! let request_tokens = manager.count_chat_request(body, Some("gpt-4o"));
22//! ```
23
24use once_cell::sync::Lazy;
25use parking_lot::RwLock;
26use serde_json::Value;
27use std::collections::HashMap;
28use std::sync::Arc;
29use tracing::{debug, trace, warn};
30
31#[cfg(feature = "tiktoken")]
32use tiktoken_rs::{cl100k_base, o200k_base, p50k_base, CoreBPE};
33
34/// Tiktoken encoding types
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum TiktokenEncoding {
37    /// GPT-4o, GPT-4o-mini (newest)
38    O200kBase,
39    /// GPT-4, GPT-4-turbo, GPT-3.5-turbo
40    Cl100kBase,
41    /// Codex, text-davinci-003
42    P50kBase,
43}
44
45impl TiktokenEncoding {
46    /// Get the encoding name for logging
47    pub fn name(&self) -> &'static str {
48        match self {
49            Self::O200kBase => "o200k_base",
50            Self::Cl100kBase => "cl100k_base",
51            Self::P50kBase => "p50k_base",
52        }
53    }
54}
55
56/// Global tiktoken manager instance
57static TIKTOKEN_MANAGER: Lazy<TiktokenManager> = Lazy::new(TiktokenManager::new);
58
59/// Get the global tiktoken manager
60pub fn tiktoken_manager() -> &'static TiktokenManager {
61    &TIKTOKEN_MANAGER
62}
63
64/// Manages cached tiktoken BPE instances for different encodings.
65///
66/// Thread-safe and lazily initialized - encodings are only loaded when first used.
67pub struct TiktokenManager {
68    #[cfg(feature = "tiktoken")]
69    encodings: RwLock<HashMap<TiktokenEncoding, Arc<CoreBPE>>>,
70    #[cfg(not(feature = "tiktoken"))]
71    _marker: std::marker::PhantomData<()>,
72}
73
74impl TiktokenManager {
75    /// Create a new tiktoken manager
76    pub fn new() -> Self {
77        #[cfg(feature = "tiktoken")]
78        {
79            Self {
80                encodings: RwLock::new(HashMap::new()),
81            }
82        }
83        #[cfg(not(feature = "tiktoken"))]
84        {
85            Self {
86                _marker: std::marker::PhantomData,
87            }
88        }
89    }
90
91    /// Get the appropriate encoding for a model name
92    pub fn encoding_for_model(&self, model: &str) -> TiktokenEncoding {
93        let model_lower = model.to_lowercase();
94
95        // GPT-4o family uses o200k_base
96        if model_lower.contains("gpt-4o") || model_lower.contains("gpt4o") {
97            return TiktokenEncoding::O200kBase;
98        }
99
100        // GPT-4 and GPT-3.5-turbo use cl100k_base
101        if model_lower.contains("gpt-4")
102            || model_lower.contains("gpt-3.5")
103            || model_lower.contains("gpt-35")
104            || model_lower.contains("text-embedding")
105            || model_lower.contains("claude") // Claude approximation
106        {
107            return TiktokenEncoding::Cl100kBase;
108        }
109
110        // Codex and older models use p50k_base
111        if model_lower.contains("code-")
112            || model_lower.contains("codex")
113            || model_lower.contains("text-davinci-003")
114            || model_lower.contains("text-davinci-002")
115        {
116            return TiktokenEncoding::P50kBase;
117        }
118
119        // Default to cl100k_base (most common)
120        TiktokenEncoding::Cl100kBase
121    }
122
123    /// Count tokens in text using the appropriate encoding for the model
124    #[cfg(feature = "tiktoken")]
125    pub fn count_tokens(&self, model: Option<&str>, text: &str) -> u64 {
126        let encoding = model
127            .map(|m| self.encoding_for_model(m))
128            .unwrap_or(TiktokenEncoding::Cl100kBase);
129
130        self.count_tokens_with_encoding(encoding, text)
131    }
132
133    #[cfg(not(feature = "tiktoken"))]
134    pub fn count_tokens(&self, _model: Option<&str>, text: &str) -> u64 {
135        // Fallback to character-based estimation
136        (text.chars().count() / 4).max(1) as u64
137    }
138
139    /// Count tokens using a specific encoding
140    #[cfg(feature = "tiktoken")]
141    pub fn count_tokens_with_encoding(&self, encoding: TiktokenEncoding, text: &str) -> u64 {
142        match self.get_or_create_bpe(encoding) {
143            Some(bpe) => {
144                let tokens = bpe.encode_with_special_tokens(text);
145                tokens.len() as u64
146            }
147            None => {
148                // Fallback to character estimation
149                (text.chars().count() / 4).max(1) as u64
150            }
151        }
152    }
153
154    #[cfg(not(feature = "tiktoken"))]
155    pub fn count_tokens_with_encoding(&self, _encoding: TiktokenEncoding, text: &str) -> u64 {
156        (text.chars().count() / 4).max(1) as u64
157    }
158
159    /// Count tokens in a chat completion request body
160    ///
161    /// Parses the JSON and extracts message content for accurate token counting.
162    /// Returns estimated tokens including overhead for message formatting.
163    pub fn count_chat_request(&self, body: &[u8], model: Option<&str>) -> u64 {
164        // Try to parse as JSON
165        let json: Value = match serde_json::from_slice(body) {
166            Ok(v) => v,
167            Err(_) => {
168                // If not valid JSON, count the whole body as text
169                let text = String::from_utf8_lossy(body);
170                return self.count_tokens(model, &text);
171            }
172        };
173
174        // Extract model from body if not provided
175        let model_name = model.or_else(|| json.get("model").and_then(|m| m.as_str()));
176
177        // Extract messages array
178        let messages = match json.get("messages").and_then(|m| m.as_array()) {
179            Some(msgs) => msgs,
180            None => {
181                // Not a chat completion request, try other formats
182                return self.count_non_chat_request(&json, model_name);
183            }
184        };
185
186        // Count tokens in messages
187        let mut total_tokens: u64 = 0;
188
189        // Per-message overhead (role, separators, etc.)
190        // OpenAI uses ~4 tokens overhead per message
191        const MESSAGE_OVERHEAD: u64 = 4;
192
193        for message in messages {
194            // Add message overhead
195            total_tokens += MESSAGE_OVERHEAD;
196
197            // Count role tokens (typically 1 token)
198            if let Some(role) = message.get("role").and_then(|r| r.as_str()) {
199                total_tokens += self.count_tokens(model_name, role);
200            }
201
202            // Count content tokens
203            if let Some(content) = message.get("content") {
204                match content {
205                    Value::String(text) => {
206                        total_tokens += self.count_tokens(model_name, text);
207                    }
208                    Value::Array(parts) => {
209                        // Multi-modal content (text + images)
210                        for part in parts {
211                            if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
212                                total_tokens += self.count_tokens(model_name, text);
213                            }
214                            // Image tokens are estimated separately (not text)
215                            if part.get("image_url").is_some() {
216                                // Rough estimate: 85 tokens for low detail, 765 for high detail
217                                total_tokens += 170; // Medium estimate
218                            }
219                        }
220                    }
221                    _ => {}
222                }
223            }
224
225            // Count name tokens if present
226            if let Some(name) = message.get("name").and_then(|n| n.as_str()) {
227                total_tokens += self.count_tokens(model_name, name);
228            }
229
230            // Count tool calls if present
231            if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) {
232                for tool_call in tool_calls {
233                    if let Some(function) = tool_call.get("function") {
234                        if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
235                            total_tokens += self.count_tokens(model_name, name);
236                        }
237                        if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
238                            total_tokens += self.count_tokens(model_name, args);
239                        }
240                    }
241                }
242            }
243        }
244
245        // Add conversation overhead (typically 3 tokens)
246        total_tokens += 3;
247
248        // Account for max_tokens in response (estimate output)
249        if let Some(max_tokens) = json.get("max_tokens").and_then(|m| m.as_u64()) {
250            // Add estimated output tokens (assume ~50% utilization)
251            total_tokens += max_tokens / 2;
252        }
253
254        trace!(
255            message_count = messages.len(),
256            total_tokens = total_tokens,
257            model = ?model_name,
258            "Counted tokens in chat request"
259        );
260
261        total_tokens
262    }
263
264    /// Count tokens for non-chat requests (completions, embeddings)
265    fn count_non_chat_request(&self, json: &Value, model: Option<&str>) -> u64 {
266        let mut total_tokens: u64 = 0;
267
268        // Legacy completions API: { "prompt": "..." }
269        if let Some(prompt) = json.get("prompt") {
270            match prompt {
271                Value::String(text) => {
272                    total_tokens += self.count_tokens(model, text);
273                }
274                Value::Array(prompts) => {
275                    for p in prompts {
276                        if let Some(text) = p.as_str() {
277                            total_tokens += self.count_tokens(model, text);
278                        }
279                    }
280                }
281                _ => {}
282            }
283        }
284
285        // Embeddings API: { "input": "..." }
286        if let Some(input) = json.get("input") {
287            match input {
288                Value::String(text) => {
289                    total_tokens += self.count_tokens(model, text);
290                }
291                Value::Array(inputs) => {
292                    for i in inputs {
293                        if let Some(text) = i.as_str() {
294                            total_tokens += self.count_tokens(model, text);
295                        }
296                    }
297                }
298                _ => {}
299            }
300        }
301
302        // If still zero, count the whole body
303        if total_tokens == 0 {
304            let body_text = json.to_string();
305            total_tokens = self.count_tokens(model, &body_text);
306        }
307
308        total_tokens
309    }
310
311    /// Get or create a BPE instance for the given encoding
312    #[cfg(feature = "tiktoken")]
313    fn get_or_create_bpe(&self, encoding: TiktokenEncoding) -> Option<Arc<CoreBPE>> {
314        // Try read lock first
315        {
316            let cache = self.encodings.read();
317            if let Some(bpe) = cache.get(&encoding) {
318                return Some(Arc::clone(bpe));
319            }
320        }
321
322        // Need to create - acquire write lock
323        let mut cache = self.encodings.write();
324
325        // Double-check after acquiring write lock
326        if let Some(bpe) = cache.get(&encoding) {
327            return Some(Arc::clone(bpe));
328        }
329
330        // Create the encoding
331        let bpe = match encoding {
332            TiktokenEncoding::O200kBase => {
333                debug!(encoding = "o200k_base", "Initializing tiktoken encoding");
334                o200k_base().ok()
335            }
336            TiktokenEncoding::Cl100kBase => {
337                debug!(encoding = "cl100k_base", "Initializing tiktoken encoding");
338                cl100k_base().ok()
339            }
340            TiktokenEncoding::P50kBase => {
341                debug!(encoding = "p50k_base", "Initializing tiktoken encoding");
342                p50k_base().ok()
343            }
344        };
345
346        match bpe {
347            Some(bpe) => {
348                let arc_bpe = Arc::new(bpe);
349                cache.insert(encoding, Arc::clone(&arc_bpe));
350                Some(arc_bpe)
351            }
352            None => {
353                warn!(
354                    encoding = encoding.name(),
355                    "Failed to initialize tiktoken encoding"
356                );
357                None
358            }
359        }
360    }
361
362    /// Check if tiktoken is available (feature enabled)
363    pub fn is_available(&self) -> bool {
364        cfg!(feature = "tiktoken")
365    }
366}
367
368impl Default for TiktokenManager {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374// ============================================================================
375// Tests
376// ============================================================================
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_encoding_for_model() {
384        let manager = TiktokenManager::new();
385
386        // GPT-4o uses o200k_base
387        assert_eq!(
388            manager.encoding_for_model("gpt-4o"),
389            TiktokenEncoding::O200kBase
390        );
391        assert_eq!(
392            manager.encoding_for_model("gpt-4o-mini"),
393            TiktokenEncoding::O200kBase
394        );
395
396        // GPT-4 uses cl100k_base
397        assert_eq!(
398            manager.encoding_for_model("gpt-4"),
399            TiktokenEncoding::Cl100kBase
400        );
401        assert_eq!(
402            manager.encoding_for_model("gpt-4-turbo"),
403            TiktokenEncoding::Cl100kBase
404        );
405        assert_eq!(
406            manager.encoding_for_model("gpt-3.5-turbo"),
407            TiktokenEncoding::Cl100kBase
408        );
409
410        // Claude uses cl100k approximation
411        assert_eq!(
412            manager.encoding_for_model("claude-3-opus"),
413            TiktokenEncoding::Cl100kBase
414        );
415
416        // Codex uses p50k_base
417        assert_eq!(
418            manager.encoding_for_model("code-davinci-002"),
419            TiktokenEncoding::P50kBase
420        );
421
422        // Unknown defaults to cl100k_base
423        assert_eq!(
424            manager.encoding_for_model("unknown-model"),
425            TiktokenEncoding::Cl100kBase
426        );
427    }
428
429    #[test]
430    fn test_count_tokens_basic() {
431        let manager = TiktokenManager::new();
432
433        // Basic text counting
434        let tokens = manager.count_tokens(Some("gpt-4"), "Hello, world!");
435        assert!(tokens > 0);
436
437        // Without model (uses default)
438        let tokens = manager.count_tokens(None, "Hello, world!");
439        assert!(tokens > 0);
440    }
441
442    #[test]
443    fn test_count_chat_request() {
444        let manager = TiktokenManager::new();
445
446        let body = br#"{
447            "model": "gpt-4",
448            "messages": [
449                {"role": "system", "content": "You are a helpful assistant."},
450                {"role": "user", "content": "Hello!"}
451            ]
452        }"#;
453
454        let tokens = manager.count_chat_request(body, None);
455        assert!(tokens > 0);
456        // Should be roughly: system message (~10) + user message (~5) + overhead (~10)
457        assert!(tokens >= 10);
458    }
459
460    #[test]
461    fn test_count_chat_request_with_tools() {
462        let manager = TiktokenManager::new();
463
464        let body = br#"{
465            "model": "gpt-4",
466            "messages": [
467                {"role": "user", "content": "What's the weather?"},
468                {"role": "assistant", "tool_calls": [
469                    {"function": {"name": "get_weather", "arguments": "{\"city\": \"NYC\"}"}}
470                ]}
471            ]
472        }"#;
473
474        let tokens = manager.count_chat_request(body, None);
475        assert!(tokens > 0);
476    }
477
478    #[test]
479    fn test_count_embeddings_request() {
480        let manager = TiktokenManager::new();
481
482        let body = br#"{
483            "model": "text-embedding-ada-002",
484            "input": "Hello, world!"
485        }"#;
486
487        let tokens = manager.count_chat_request(body, None);
488        assert!(tokens > 0);
489    }
490
491    #[test]
492    fn test_count_invalid_json() {
493        let manager = TiktokenManager::new();
494
495        let body = b"not valid json at all";
496        let tokens = manager.count_chat_request(body, Some("gpt-4"));
497        // Should fall back to text counting
498        assert!(tokens > 0);
499    }
500
501    #[test]
502    #[cfg(feature = "tiktoken")]
503    fn test_tiktoken_accurate_hello_world() {
504        let manager = TiktokenManager::new();
505
506        // "Hello world" is typically 2 tokens with cl100k_base
507        let tokens = manager.count_tokens_with_encoding(TiktokenEncoding::Cl100kBase, "Hello world");
508        assert_eq!(tokens, 2);
509    }
510
511    #[test]
512    #[cfg(feature = "tiktoken")]
513    fn test_tiktoken_caching() {
514        let manager = TiktokenManager::new();
515
516        // First call creates the encoding
517        let tokens1 = manager.count_tokens(Some("gpt-4"), "Test message");
518        // Second call should use cached encoding
519        let tokens2 = manager.count_tokens(Some("gpt-4"), "Test message");
520
521        assert_eq!(tokens1, tokens2);
522    }
523}