1use once_cell::sync::Lazy;
25use parking_lot::RwLock;
26use serde_json::Value;
27use std::collections::HashMap;
28use std::sync::Arc;
29use tracing::{debug, trace, warn};
30
31#[cfg(feature = "tiktoken")]
32use tiktoken_rs::{cl100k_base, o200k_base, p50k_base, CoreBPE};
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum TiktokenEncoding {
37 O200kBase,
39 Cl100kBase,
41 P50kBase,
43}
44
45impl TiktokenEncoding {
46 pub fn name(&self) -> &'static str {
48 match self {
49 Self::O200kBase => "o200k_base",
50 Self::Cl100kBase => "cl100k_base",
51 Self::P50kBase => "p50k_base",
52 }
53 }
54}
55
56static TIKTOKEN_MANAGER: Lazy<TiktokenManager> = Lazy::new(TiktokenManager::new);
58
59pub fn tiktoken_manager() -> &'static TiktokenManager {
61 &TIKTOKEN_MANAGER
62}
63
64pub struct TiktokenManager {
68 #[cfg(feature = "tiktoken")]
69 encodings: RwLock<HashMap<TiktokenEncoding, Arc<CoreBPE>>>,
70 #[cfg(not(feature = "tiktoken"))]
71 _marker: std::marker::PhantomData<()>,
72}
73
74impl TiktokenManager {
75 pub fn new() -> Self {
77 #[cfg(feature = "tiktoken")]
78 {
79 Self {
80 encodings: RwLock::new(HashMap::new()),
81 }
82 }
83 #[cfg(not(feature = "tiktoken"))]
84 {
85 Self {
86 _marker: std::marker::PhantomData,
87 }
88 }
89 }
90
91 pub fn encoding_for_model(&self, model: &str) -> TiktokenEncoding {
93 let model_lower = model.to_lowercase();
94
95 if model_lower.contains("gpt-4o") || model_lower.contains("gpt4o") {
97 return TiktokenEncoding::O200kBase;
98 }
99
100 if model_lower.contains("gpt-4")
102 || model_lower.contains("gpt-3.5")
103 || model_lower.contains("gpt-35")
104 || model_lower.contains("text-embedding")
105 || model_lower.contains("claude")
106 {
108 return TiktokenEncoding::Cl100kBase;
109 }
110
111 if model_lower.contains("code-")
113 || model_lower.contains("codex")
114 || model_lower.contains("text-davinci-003")
115 || model_lower.contains("text-davinci-002")
116 {
117 return TiktokenEncoding::P50kBase;
118 }
119
120 TiktokenEncoding::Cl100kBase
122 }
123
124 #[cfg(feature = "tiktoken")]
126 pub fn count_tokens(&self, model: Option<&str>, text: &str) -> u64 {
127 let encoding = model
128 .map(|m| self.encoding_for_model(m))
129 .unwrap_or(TiktokenEncoding::Cl100kBase);
130
131 self.count_tokens_with_encoding(encoding, text)
132 }
133
134 #[cfg(not(feature = "tiktoken"))]
135 pub fn count_tokens(&self, _model: Option<&str>, text: &str) -> u64 {
136 (text.chars().count() / 4).max(1) as u64
138 }
139
140 #[cfg(feature = "tiktoken")]
142 pub fn count_tokens_with_encoding(&self, encoding: TiktokenEncoding, text: &str) -> u64 {
143 match self.get_or_create_bpe(encoding) {
144 Some(bpe) => {
145 let tokens = bpe.encode_with_special_tokens(text);
146 tokens.len() as u64
147 }
148 None => {
149 (text.chars().count() / 4).max(1) as u64
151 }
152 }
153 }
154
155 #[cfg(not(feature = "tiktoken"))]
156 pub fn count_tokens_with_encoding(&self, _encoding: TiktokenEncoding, text: &str) -> u64 {
157 (text.chars().count() / 4).max(1) as u64
158 }
159
160 pub fn count_chat_request(&self, body: &[u8], model: Option<&str>) -> u64 {
165 let json: Value = match serde_json::from_slice(body) {
167 Ok(v) => v,
168 Err(_) => {
169 let text = String::from_utf8_lossy(body);
171 return self.count_tokens(model, &text);
172 }
173 };
174
175 let model_name = model.or_else(|| json.get("model").and_then(|m| m.as_str()));
177
178 let messages = match json.get("messages").and_then(|m| m.as_array()) {
180 Some(msgs) => msgs,
181 None => {
182 return self.count_non_chat_request(&json, model_name);
184 }
185 };
186
187 let mut total_tokens: u64 = 0;
189
190 const MESSAGE_OVERHEAD: u64 = 4;
193
194 for message in messages {
195 total_tokens += MESSAGE_OVERHEAD;
197
198 if let Some(role) = message.get("role").and_then(|r| r.as_str()) {
200 total_tokens += self.count_tokens(model_name, role);
201 }
202
203 if let Some(content) = message.get("content") {
205 match content {
206 Value::String(text) => {
207 total_tokens += self.count_tokens(model_name, text);
208 }
209 Value::Array(parts) => {
210 for part in parts {
212 if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
213 total_tokens += self.count_tokens(model_name, text);
214 }
215 if part.get("image_url").is_some() {
217 total_tokens += 170; }
220 }
221 }
222 _ => {}
223 }
224 }
225
226 if let Some(name) = message.get("name").and_then(|n| n.as_str()) {
228 total_tokens += self.count_tokens(model_name, name);
229 }
230
231 if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) {
233 for tool_call in tool_calls {
234 if let Some(function) = tool_call.get("function") {
235 if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
236 total_tokens += self.count_tokens(model_name, name);
237 }
238 if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
239 total_tokens += self.count_tokens(model_name, args);
240 }
241 }
242 }
243 }
244 }
245
246 total_tokens += 3;
248
249 if let Some(max_tokens) = json.get("max_tokens").and_then(|m| m.as_u64()) {
251 total_tokens += max_tokens / 2;
253 }
254
255 trace!(
256 message_count = messages.len(),
257 total_tokens = total_tokens,
258 model = ?model_name,
259 "Counted tokens in chat request"
260 );
261
262 total_tokens
263 }
264
265 fn count_non_chat_request(&self, json: &Value, model: Option<&str>) -> u64 {
267 let mut total_tokens: u64 = 0;
268
269 if let Some(prompt) = json.get("prompt") {
271 match prompt {
272 Value::String(text) => {
273 total_tokens += self.count_tokens(model, text);
274 }
275 Value::Array(prompts) => {
276 for p in prompts {
277 if let Some(text) = p.as_str() {
278 total_tokens += self.count_tokens(model, text);
279 }
280 }
281 }
282 _ => {}
283 }
284 }
285
286 if let Some(input) = json.get("input") {
288 match input {
289 Value::String(text) => {
290 total_tokens += self.count_tokens(model, text);
291 }
292 Value::Array(inputs) => {
293 for i in inputs {
294 if let Some(text) = i.as_str() {
295 total_tokens += self.count_tokens(model, text);
296 }
297 }
298 }
299 _ => {}
300 }
301 }
302
303 if total_tokens == 0 {
305 let body_text = json.to_string();
306 total_tokens = self.count_tokens(model, &body_text);
307 }
308
309 total_tokens
310 }
311
312 #[cfg(feature = "tiktoken")]
314 fn get_or_create_bpe(&self, encoding: TiktokenEncoding) -> Option<Arc<CoreBPE>> {
315 {
317 let cache = self.encodings.read();
318 if let Some(bpe) = cache.get(&encoding) {
319 return Some(Arc::clone(bpe));
320 }
321 }
322
323 let mut cache = self.encodings.write();
325
326 if let Some(bpe) = cache.get(&encoding) {
328 return Some(Arc::clone(bpe));
329 }
330
331 let bpe = match encoding {
333 TiktokenEncoding::O200kBase => {
334 debug!(encoding = "o200k_base", "Initializing tiktoken encoding");
335 o200k_base().ok()
336 }
337 TiktokenEncoding::Cl100kBase => {
338 debug!(encoding = "cl100k_base", "Initializing tiktoken encoding");
339 cl100k_base().ok()
340 }
341 TiktokenEncoding::P50kBase => {
342 debug!(encoding = "p50k_base", "Initializing tiktoken encoding");
343 p50k_base().ok()
344 }
345 };
346
347 match bpe {
348 Some(bpe) => {
349 let arc_bpe = Arc::new(bpe);
350 cache.insert(encoding, Arc::clone(&arc_bpe));
351 Some(arc_bpe)
352 }
353 None => {
354 warn!(
355 encoding = encoding.name(),
356 "Failed to initialize tiktoken encoding"
357 );
358 None
359 }
360 }
361 }
362
363 pub fn is_available(&self) -> bool {
365 cfg!(feature = "tiktoken")
366 }
367}
368
369impl Default for TiktokenManager {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_encoding_for_model() {
385 let manager = TiktokenManager::new();
386
387 assert_eq!(
389 manager.encoding_for_model("gpt-4o"),
390 TiktokenEncoding::O200kBase
391 );
392 assert_eq!(
393 manager.encoding_for_model("gpt-4o-mini"),
394 TiktokenEncoding::O200kBase
395 );
396
397 assert_eq!(
399 manager.encoding_for_model("gpt-4"),
400 TiktokenEncoding::Cl100kBase
401 );
402 assert_eq!(
403 manager.encoding_for_model("gpt-4-turbo"),
404 TiktokenEncoding::Cl100kBase
405 );
406 assert_eq!(
407 manager.encoding_for_model("gpt-3.5-turbo"),
408 TiktokenEncoding::Cl100kBase
409 );
410
411 assert_eq!(
413 manager.encoding_for_model("claude-3-opus"),
414 TiktokenEncoding::Cl100kBase
415 );
416
417 assert_eq!(
419 manager.encoding_for_model("code-davinci-002"),
420 TiktokenEncoding::P50kBase
421 );
422
423 assert_eq!(
425 manager.encoding_for_model("unknown-model"),
426 TiktokenEncoding::Cl100kBase
427 );
428 }
429
430 #[test]
431 fn test_count_tokens_basic() {
432 let manager = TiktokenManager::new();
433
434 let tokens = manager.count_tokens(Some("gpt-4"), "Hello, world!");
436 assert!(tokens > 0);
437
438 let tokens = manager.count_tokens(None, "Hello, world!");
440 assert!(tokens > 0);
441 }
442
443 #[test]
444 fn test_count_chat_request() {
445 let manager = TiktokenManager::new();
446
447 let body = br#"{
448 "model": "gpt-4",
449 "messages": [
450 {"role": "system", "content": "You are a helpful assistant."},
451 {"role": "user", "content": "Hello!"}
452 ]
453 }"#;
454
455 let tokens = manager.count_chat_request(body, None);
456 assert!(tokens > 0);
457 assert!(tokens >= 10);
459 }
460
461 #[test]
462 fn test_count_chat_request_with_tools() {
463 let manager = TiktokenManager::new();
464
465 let body = br#"{
466 "model": "gpt-4",
467 "messages": [
468 {"role": "user", "content": "What's the weather?"},
469 {"role": "assistant", "tool_calls": [
470 {"function": {"name": "get_weather", "arguments": "{\"city\": \"NYC\"}"}}
471 ]}
472 ]
473 }"#;
474
475 let tokens = manager.count_chat_request(body, None);
476 assert!(tokens > 0);
477 }
478
479 #[test]
480 fn test_count_embeddings_request() {
481 let manager = TiktokenManager::new();
482
483 let body = br#"{
484 "model": "text-embedding-ada-002",
485 "input": "Hello, world!"
486 }"#;
487
488 let tokens = manager.count_chat_request(body, None);
489 assert!(tokens > 0);
490 }
491
492 #[test]
493 fn test_count_invalid_json() {
494 let manager = TiktokenManager::new();
495
496 let body = b"not valid json at all";
497 let tokens = manager.count_chat_request(body, Some("gpt-4"));
498 assert!(tokens > 0);
500 }
501
502 #[test]
503 #[cfg(feature = "tiktoken")]
504 fn test_tiktoken_accurate_hello_world() {
505 let manager = TiktokenManager::new();
506
507 let tokens =
509 manager.count_tokens_with_encoding(TiktokenEncoding::Cl100kBase, "Hello world");
510 assert_eq!(tokens, 2);
511 }
512
513 #[test]
514 #[cfg(feature = "tiktoken")]
515 fn test_tiktoken_caching() {
516 let manager = TiktokenManager::new();
517
518 let tokens1 = manager.count_tokens(Some("gpt-4"), "Test message");
520 let tokens2 = manager.count_tokens(Some("gpt-4"), "Test message");
522
523 assert_eq!(tokens1, tokens2);
524 }
525}