sentinel_proxy/inference/
tiktoken.rs1use once_cell::sync::Lazy;
25use parking_lot::RwLock;
26use serde_json::Value;
27use std::collections::HashMap;
28use std::sync::Arc;
29use tracing::{debug, trace, warn};
30
31#[cfg(feature = "tiktoken")]
32use tiktoken_rs::{cl100k_base, o200k_base, p50k_base, CoreBPE};
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum TiktokenEncoding {
37 O200kBase,
39 Cl100kBase,
41 P50kBase,
43}
44
45impl TiktokenEncoding {
46 pub fn name(&self) -> &'static str {
48 match self {
49 Self::O200kBase => "o200k_base",
50 Self::Cl100kBase => "cl100k_base",
51 Self::P50kBase => "p50k_base",
52 }
53 }
54}
55
56static TIKTOKEN_MANAGER: Lazy<TiktokenManager> = Lazy::new(TiktokenManager::new);
58
59pub fn tiktoken_manager() -> &'static TiktokenManager {
61 &TIKTOKEN_MANAGER
62}
63
64pub struct TiktokenManager {
68 #[cfg(feature = "tiktoken")]
69 encodings: RwLock<HashMap<TiktokenEncoding, Arc<CoreBPE>>>,
70 #[cfg(not(feature = "tiktoken"))]
71 _marker: std::marker::PhantomData<()>,
72}
73
74impl TiktokenManager {
75 pub fn new() -> Self {
77 #[cfg(feature = "tiktoken")]
78 {
79 Self {
80 encodings: RwLock::new(HashMap::new()),
81 }
82 }
83 #[cfg(not(feature = "tiktoken"))]
84 {
85 Self {
86 _marker: std::marker::PhantomData,
87 }
88 }
89 }
90
91 pub fn encoding_for_model(&self, model: &str) -> TiktokenEncoding {
93 let model_lower = model.to_lowercase();
94
95 if model_lower.contains("gpt-4o") || model_lower.contains("gpt4o") {
97 return TiktokenEncoding::O200kBase;
98 }
99
100 if model_lower.contains("gpt-4")
102 || model_lower.contains("gpt-3.5")
103 || model_lower.contains("gpt-35")
104 || model_lower.contains("text-embedding")
105 || model_lower.contains("claude") {
107 return TiktokenEncoding::Cl100kBase;
108 }
109
110 if model_lower.contains("code-")
112 || model_lower.contains("codex")
113 || model_lower.contains("text-davinci-003")
114 || model_lower.contains("text-davinci-002")
115 {
116 return TiktokenEncoding::P50kBase;
117 }
118
119 TiktokenEncoding::Cl100kBase
121 }
122
123 #[cfg(feature = "tiktoken")]
125 pub fn count_tokens(&self, model: Option<&str>, text: &str) -> u64 {
126 let encoding = model
127 .map(|m| self.encoding_for_model(m))
128 .unwrap_or(TiktokenEncoding::Cl100kBase);
129
130 self.count_tokens_with_encoding(encoding, text)
131 }
132
133 #[cfg(not(feature = "tiktoken"))]
134 pub fn count_tokens(&self, _model: Option<&str>, text: &str) -> u64 {
135 (text.chars().count() / 4).max(1) as u64
137 }
138
139 #[cfg(feature = "tiktoken")]
141 pub fn count_tokens_with_encoding(&self, encoding: TiktokenEncoding, text: &str) -> u64 {
142 match self.get_or_create_bpe(encoding) {
143 Some(bpe) => {
144 let tokens = bpe.encode_with_special_tokens(text);
145 tokens.len() as u64
146 }
147 None => {
148 (text.chars().count() / 4).max(1) as u64
150 }
151 }
152 }
153
154 #[cfg(not(feature = "tiktoken"))]
155 pub fn count_tokens_with_encoding(&self, _encoding: TiktokenEncoding, text: &str) -> u64 {
156 (text.chars().count() / 4).max(1) as u64
157 }
158
159 pub fn count_chat_request(&self, body: &[u8], model: Option<&str>) -> u64 {
164 let json: Value = match serde_json::from_slice(body) {
166 Ok(v) => v,
167 Err(_) => {
168 let text = String::from_utf8_lossy(body);
170 return self.count_tokens(model, &text);
171 }
172 };
173
174 let model_name = model.or_else(|| json.get("model").and_then(|m| m.as_str()));
176
177 let messages = match json.get("messages").and_then(|m| m.as_array()) {
179 Some(msgs) => msgs,
180 None => {
181 return self.count_non_chat_request(&json, model_name);
183 }
184 };
185
186 let mut total_tokens: u64 = 0;
188
189 const MESSAGE_OVERHEAD: u64 = 4;
192
193 for message in messages {
194 total_tokens += MESSAGE_OVERHEAD;
196
197 if let Some(role) = message.get("role").and_then(|r| r.as_str()) {
199 total_tokens += self.count_tokens(model_name, role);
200 }
201
202 if let Some(content) = message.get("content") {
204 match content {
205 Value::String(text) => {
206 total_tokens += self.count_tokens(model_name, text);
207 }
208 Value::Array(parts) => {
209 for part in parts {
211 if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
212 total_tokens += self.count_tokens(model_name, text);
213 }
214 if part.get("image_url").is_some() {
216 total_tokens += 170; }
219 }
220 }
221 _ => {}
222 }
223 }
224
225 if let Some(name) = message.get("name").and_then(|n| n.as_str()) {
227 total_tokens += self.count_tokens(model_name, name);
228 }
229
230 if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) {
232 for tool_call in tool_calls {
233 if let Some(function) = tool_call.get("function") {
234 if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
235 total_tokens += self.count_tokens(model_name, name);
236 }
237 if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
238 total_tokens += self.count_tokens(model_name, args);
239 }
240 }
241 }
242 }
243 }
244
245 total_tokens += 3;
247
248 if let Some(max_tokens) = json.get("max_tokens").and_then(|m| m.as_u64()) {
250 total_tokens += max_tokens / 2;
252 }
253
254 trace!(
255 message_count = messages.len(),
256 total_tokens = total_tokens,
257 model = ?model_name,
258 "Counted tokens in chat request"
259 );
260
261 total_tokens
262 }
263
264 fn count_non_chat_request(&self, json: &Value, model: Option<&str>) -> u64 {
266 let mut total_tokens: u64 = 0;
267
268 if let Some(prompt) = json.get("prompt") {
270 match prompt {
271 Value::String(text) => {
272 total_tokens += self.count_tokens(model, text);
273 }
274 Value::Array(prompts) => {
275 for p in prompts {
276 if let Some(text) = p.as_str() {
277 total_tokens += self.count_tokens(model, text);
278 }
279 }
280 }
281 _ => {}
282 }
283 }
284
285 if let Some(input) = json.get("input") {
287 match input {
288 Value::String(text) => {
289 total_tokens += self.count_tokens(model, text);
290 }
291 Value::Array(inputs) => {
292 for i in inputs {
293 if let Some(text) = i.as_str() {
294 total_tokens += self.count_tokens(model, text);
295 }
296 }
297 }
298 _ => {}
299 }
300 }
301
302 if total_tokens == 0 {
304 let body_text = json.to_string();
305 total_tokens = self.count_tokens(model, &body_text);
306 }
307
308 total_tokens
309 }
310
311 #[cfg(feature = "tiktoken")]
313 fn get_or_create_bpe(&self, encoding: TiktokenEncoding) -> Option<Arc<CoreBPE>> {
314 {
316 let cache = self.encodings.read();
317 if let Some(bpe) = cache.get(&encoding) {
318 return Some(Arc::clone(bpe));
319 }
320 }
321
322 let mut cache = self.encodings.write();
324
325 if let Some(bpe) = cache.get(&encoding) {
327 return Some(Arc::clone(bpe));
328 }
329
330 let bpe = match encoding {
332 TiktokenEncoding::O200kBase => {
333 debug!(encoding = "o200k_base", "Initializing tiktoken encoding");
334 o200k_base().ok()
335 }
336 TiktokenEncoding::Cl100kBase => {
337 debug!(encoding = "cl100k_base", "Initializing tiktoken encoding");
338 cl100k_base().ok()
339 }
340 TiktokenEncoding::P50kBase => {
341 debug!(encoding = "p50k_base", "Initializing tiktoken encoding");
342 p50k_base().ok()
343 }
344 };
345
346 match bpe {
347 Some(bpe) => {
348 let arc_bpe = Arc::new(bpe);
349 cache.insert(encoding, Arc::clone(&arc_bpe));
350 Some(arc_bpe)
351 }
352 None => {
353 warn!(
354 encoding = encoding.name(),
355 "Failed to initialize tiktoken encoding"
356 );
357 None
358 }
359 }
360 }
361
362 pub fn is_available(&self) -> bool {
364 cfg!(feature = "tiktoken")
365 }
366}
367
368impl Default for TiktokenManager {
369 fn default() -> Self {
370 Self::new()
371 }
372}
373
374#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_encoding_for_model() {
384 let manager = TiktokenManager::new();
385
386 assert_eq!(
388 manager.encoding_for_model("gpt-4o"),
389 TiktokenEncoding::O200kBase
390 );
391 assert_eq!(
392 manager.encoding_for_model("gpt-4o-mini"),
393 TiktokenEncoding::O200kBase
394 );
395
396 assert_eq!(
398 manager.encoding_for_model("gpt-4"),
399 TiktokenEncoding::Cl100kBase
400 );
401 assert_eq!(
402 manager.encoding_for_model("gpt-4-turbo"),
403 TiktokenEncoding::Cl100kBase
404 );
405 assert_eq!(
406 manager.encoding_for_model("gpt-3.5-turbo"),
407 TiktokenEncoding::Cl100kBase
408 );
409
410 assert_eq!(
412 manager.encoding_for_model("claude-3-opus"),
413 TiktokenEncoding::Cl100kBase
414 );
415
416 assert_eq!(
418 manager.encoding_for_model("code-davinci-002"),
419 TiktokenEncoding::P50kBase
420 );
421
422 assert_eq!(
424 manager.encoding_for_model("unknown-model"),
425 TiktokenEncoding::Cl100kBase
426 );
427 }
428
429 #[test]
430 fn test_count_tokens_basic() {
431 let manager = TiktokenManager::new();
432
433 let tokens = manager.count_tokens(Some("gpt-4"), "Hello, world!");
435 assert!(tokens > 0);
436
437 let tokens = manager.count_tokens(None, "Hello, world!");
439 assert!(tokens > 0);
440 }
441
442 #[test]
443 fn test_count_chat_request() {
444 let manager = TiktokenManager::new();
445
446 let body = br#"{
447 "model": "gpt-4",
448 "messages": [
449 {"role": "system", "content": "You are a helpful assistant."},
450 {"role": "user", "content": "Hello!"}
451 ]
452 }"#;
453
454 let tokens = manager.count_chat_request(body, None);
455 assert!(tokens > 0);
456 assert!(tokens >= 10);
458 }
459
460 #[test]
461 fn test_count_chat_request_with_tools() {
462 let manager = TiktokenManager::new();
463
464 let body = br#"{
465 "model": "gpt-4",
466 "messages": [
467 {"role": "user", "content": "What's the weather?"},
468 {"role": "assistant", "tool_calls": [
469 {"function": {"name": "get_weather", "arguments": "{\"city\": \"NYC\"}"}}
470 ]}
471 ]
472 }"#;
473
474 let tokens = manager.count_chat_request(body, None);
475 assert!(tokens > 0);
476 }
477
478 #[test]
479 fn test_count_embeddings_request() {
480 let manager = TiktokenManager::new();
481
482 let body = br#"{
483 "model": "text-embedding-ada-002",
484 "input": "Hello, world!"
485 }"#;
486
487 let tokens = manager.count_chat_request(body, None);
488 assert!(tokens > 0);
489 }
490
491 #[test]
492 fn test_count_invalid_json() {
493 let manager = TiktokenManager::new();
494
495 let body = b"not valid json at all";
496 let tokens = manager.count_chat_request(body, Some("gpt-4"));
497 assert!(tokens > 0);
499 }
500
501 #[test]
502 #[cfg(feature = "tiktoken")]
503 fn test_tiktoken_accurate_hello_world() {
504 let manager = TiktokenManager::new();
505
506 let tokens = manager.count_tokens_with_encoding(TiktokenEncoding::Cl100kBase, "Hello world");
508 assert_eq!(tokens, 2);
509 }
510
511 #[test]
512 #[cfg(feature = "tiktoken")]
513 fn test_tiktoken_caching() {
514 let manager = TiktokenManager::new();
515
516 let tokens1 = manager.count_tokens(Some("gpt-4"), "Test message");
518 let tokens2 = manager.count_tokens(Some("gpt-4"), "Test message");
520
521 assert_eq!(tokens1, tokens2);
522 }
523}