use once_cell::sync::Lazy;
use parking_lot::RwLock;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, trace, warn};
#[cfg(feature = "tiktoken")]
use tiktoken_rs::{cl100k_base, o200k_base, p50k_base, CoreBPE};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TiktokenEncoding {
O200kBase,
Cl100kBase,
P50kBase,
}
impl TiktokenEncoding {
pub fn name(&self) -> &'static str {
match self {
Self::O200kBase => "o200k_base",
Self::Cl100kBase => "cl100k_base",
Self::P50kBase => "p50k_base",
}
}
}
static TIKTOKEN_MANAGER: Lazy<TiktokenManager> = Lazy::new(TiktokenManager::new);
pub fn tiktoken_manager() -> &'static TiktokenManager {
&TIKTOKEN_MANAGER
}
pub struct TiktokenManager {
#[cfg(feature = "tiktoken")]
encodings: RwLock<HashMap<TiktokenEncoding, Arc<CoreBPE>>>,
#[cfg(not(feature = "tiktoken"))]
_marker: std::marker::PhantomData<()>,
}
impl TiktokenManager {
pub fn new() -> Self {
#[cfg(feature = "tiktoken")]
{
Self {
encodings: RwLock::new(HashMap::new()),
}
}
#[cfg(not(feature = "tiktoken"))]
{
Self {
_marker: std::marker::PhantomData,
}
}
}
pub fn encoding_for_model(&self, model: &str) -> TiktokenEncoding {
let model_lower = model.to_lowercase();
if model_lower.contains("gpt-4o") || model_lower.contains("gpt4o") {
return TiktokenEncoding::O200kBase;
}
if model_lower.contains("gpt-4")
|| model_lower.contains("gpt-3.5")
|| model_lower.contains("gpt-35")
|| model_lower.contains("text-embedding")
|| model_lower.contains("claude")
{
return TiktokenEncoding::Cl100kBase;
}
if model_lower.contains("code-")
|| model_lower.contains("codex")
|| model_lower.contains("text-davinci-003")
|| model_lower.contains("text-davinci-002")
{
return TiktokenEncoding::P50kBase;
}
TiktokenEncoding::Cl100kBase
}
#[cfg(feature = "tiktoken")]
pub fn count_tokens(&self, model: Option<&str>, text: &str) -> u64 {
let encoding = model
.map(|m| self.encoding_for_model(m))
.unwrap_or(TiktokenEncoding::Cl100kBase);
self.count_tokens_with_encoding(encoding, text)
}
#[cfg(not(feature = "tiktoken"))]
pub fn count_tokens(&self, _model: Option<&str>, text: &str) -> u64 {
(text.chars().count() / 4).max(1) as u64
}
#[cfg(feature = "tiktoken")]
pub fn count_tokens_with_encoding(&self, encoding: TiktokenEncoding, text: &str) -> u64 {
match self.get_or_create_bpe(encoding) {
Some(bpe) => {
let tokens = bpe.encode_with_special_tokens(text);
tokens.len() as u64
}
None => {
(text.chars().count() / 4).max(1) as u64
}
}
}
#[cfg(not(feature = "tiktoken"))]
pub fn count_tokens_with_encoding(&self, _encoding: TiktokenEncoding, text: &str) -> u64 {
(text.chars().count() / 4).max(1) as u64
}
pub fn count_chat_request(&self, body: &[u8], model: Option<&str>) -> u64 {
let json: Value = match serde_json::from_slice(body) {
Ok(v) => v,
Err(_) => {
let text = String::from_utf8_lossy(body);
return self.count_tokens(model, &text);
}
};
let model_name = model.or_else(|| json.get("model").and_then(|m| m.as_str()));
let messages = match json.get("messages").and_then(|m| m.as_array()) {
Some(msgs) => msgs,
None => {
return self.count_non_chat_request(&json, model_name);
}
};
let mut total_tokens: u64 = 0;
const MESSAGE_OVERHEAD: u64 = 4;
for message in messages {
total_tokens += MESSAGE_OVERHEAD;
if let Some(role) = message.get("role").and_then(|r| r.as_str()) {
total_tokens += self.count_tokens(model_name, role);
}
if let Some(content) = message.get("content") {
match content {
Value::String(text) => {
total_tokens += self.count_tokens(model_name, text);
}
Value::Array(parts) => {
for part in parts {
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
total_tokens += self.count_tokens(model_name, text);
}
if part.get("image_url").is_some() {
total_tokens += 170; }
}
}
_ => {}
}
}
if let Some(name) = message.get("name").and_then(|n| n.as_str()) {
total_tokens += self.count_tokens(model_name, name);
}
if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) {
for tool_call in tool_calls {
if let Some(function) = tool_call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
total_tokens += self.count_tokens(model_name, name);
}
if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
total_tokens += self.count_tokens(model_name, args);
}
}
}
}
}
total_tokens += 3;
if let Some(max_tokens) = json.get("max_tokens").and_then(|m| m.as_u64()) {
total_tokens += max_tokens / 2;
}
trace!(
message_count = messages.len(),
total_tokens = total_tokens,
model = ?model_name,
"Counted tokens in chat request"
);
total_tokens
}
fn count_non_chat_request(&self, json: &Value, model: Option<&str>) -> u64 {
let mut total_tokens: u64 = 0;
if let Some(prompt) = json.get("prompt") {
match prompt {
Value::String(text) => {
total_tokens += self.count_tokens(model, text);
}
Value::Array(prompts) => {
for p in prompts {
if let Some(text) = p.as_str() {
total_tokens += self.count_tokens(model, text);
}
}
}
_ => {}
}
}
if let Some(input) = json.get("input") {
match input {
Value::String(text) => {
total_tokens += self.count_tokens(model, text);
}
Value::Array(inputs) => {
for i in inputs {
if let Some(text) = i.as_str() {
total_tokens += self.count_tokens(model, text);
}
}
}
_ => {}
}
}
if total_tokens == 0 {
let body_text = json.to_string();
total_tokens = self.count_tokens(model, &body_text);
}
total_tokens
}
#[cfg(feature = "tiktoken")]
fn get_or_create_bpe(&self, encoding: TiktokenEncoding) -> Option<Arc<CoreBPE>> {
{
let cache = self.encodings.read();
if let Some(bpe) = cache.get(&encoding) {
return Some(Arc::clone(bpe));
}
}
let mut cache = self.encodings.write();
if let Some(bpe) = cache.get(&encoding) {
return Some(Arc::clone(bpe));
}
let bpe = match encoding {
TiktokenEncoding::O200kBase => {
debug!(encoding = "o200k_base", "Initializing tiktoken encoding");
o200k_base().ok()
}
TiktokenEncoding::Cl100kBase => {
debug!(encoding = "cl100k_base", "Initializing tiktoken encoding");
cl100k_base().ok()
}
TiktokenEncoding::P50kBase => {
debug!(encoding = "p50k_base", "Initializing tiktoken encoding");
p50k_base().ok()
}
};
match bpe {
Some(bpe) => {
let arc_bpe = Arc::new(bpe);
cache.insert(encoding, Arc::clone(&arc_bpe));
Some(arc_bpe)
}
None => {
warn!(
encoding = encoding.name(),
"Failed to initialize tiktoken encoding"
);
None
}
}
}
pub fn is_available(&self) -> bool {
cfg!(feature = "tiktoken")
}
}
impl Default for TiktokenManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encoding_for_model() {
let manager = TiktokenManager::new();
assert_eq!(
manager.encoding_for_model("gpt-4o"),
TiktokenEncoding::O200kBase
);
assert_eq!(
manager.encoding_for_model("gpt-4o-mini"),
TiktokenEncoding::O200kBase
);
assert_eq!(
manager.encoding_for_model("gpt-4"),
TiktokenEncoding::Cl100kBase
);
assert_eq!(
manager.encoding_for_model("gpt-4-turbo"),
TiktokenEncoding::Cl100kBase
);
assert_eq!(
manager.encoding_for_model("gpt-3.5-turbo"),
TiktokenEncoding::Cl100kBase
);
assert_eq!(
manager.encoding_for_model("claude-3-opus"),
TiktokenEncoding::Cl100kBase
);
assert_eq!(
manager.encoding_for_model("code-davinci-002"),
TiktokenEncoding::P50kBase
);
assert_eq!(
manager.encoding_for_model("unknown-model"),
TiktokenEncoding::Cl100kBase
);
}
#[test]
fn test_count_tokens_basic() {
let manager = TiktokenManager::new();
let tokens = manager.count_tokens(Some("gpt-4"), "Hello, world!");
assert!(tokens > 0);
let tokens = manager.count_tokens(None, "Hello, world!");
assert!(tokens > 0);
}
#[test]
fn test_count_chat_request() {
let manager = TiktokenManager::new();
let body = br#"{
"model": "gpt-4",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
]
}"#;
let tokens = manager.count_chat_request(body, None);
assert!(tokens > 0);
assert!(tokens >= 10);
}
#[test]
fn test_count_chat_request_with_tools() {
let manager = TiktokenManager::new();
let body = br#"{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "tool_calls": [
{"function": {"name": "get_weather", "arguments": "{\"city\": \"NYC\"}"}}
]}
]
}"#;
let tokens = manager.count_chat_request(body, None);
assert!(tokens > 0);
}
#[test]
fn test_count_embeddings_request() {
let manager = TiktokenManager::new();
let body = br#"{
"model": "text-embedding-ada-002",
"input": "Hello, world!"
}"#;
let tokens = manager.count_chat_request(body, None);
assert!(tokens > 0);
}
#[test]
fn test_count_invalid_json() {
let manager = TiktokenManager::new();
let body = b"not valid json at all";
let tokens = manager.count_chat_request(body, Some("gpt-4"));
assert!(tokens > 0);
}
#[test]
#[cfg(feature = "tiktoken")]
fn test_tiktoken_accurate_hello_world() {
let manager = TiktokenManager::new();
let tokens =
manager.count_tokens_with_encoding(TiktokenEncoding::Cl100kBase, "Hello world");
assert_eq!(tokens, 2);
}
#[test]
#[cfg(feature = "tiktoken")]
fn test_tiktoken_caching() {
let manager = TiktokenManager::new();
let tokens1 = manager.count_tokens(Some("gpt-4"), "Test message");
let tokens2 = manager.count_tokens(Some("gpt-4"), "Test message");
assert_eq!(tokens1, tokens2);
}
}