use std::collections::{HashMap, hash_map::DefaultHasher};
use std::hash::{Hash, Hasher};
use std::sync::{OnceLock, RwLock};
use tracing::debug;
#[derive(Default)]
#[allow(dead_code)]
struct SimpleCache {
cache: HashMap<u64, u32>,
capacity: usize,
}
#[allow(dead_code)]
impl SimpleCache {
fn new(capacity: usize) -> Self {
Self {
cache: HashMap::with_capacity(capacity),
capacity,
}
}
fn get(&self, key: &u64) -> Option<u32> {
self.cache.get(key).copied()
}
fn put(&mut self, key: u64, value: u32) {
if self.cache.len() >= self.capacity {
let keys_to_remove: Vec<_> =
self.cache.keys().take(self.capacity / 2).copied().collect();
for k in keys_to_remove {
self.cache.remove(&k);
}
}
self.cache.insert(key, value);
}
fn clear(&mut self) {
self.cache.clear();
}
fn len(&self) -> usize {
self.cache.len()
}
fn capacity(&self) -> usize {
self.capacity
}
}
#[allow(dead_code)]
static TOKEN_CACHE: OnceLock<RwLock<SimpleCache>> = OnceLock::new();
#[allow(dead_code)]
pub struct TokenCache;
#[allow(dead_code)]
impl TokenCache {
pub fn get(&self, model: &str, text: &str) -> Option<u32> {
let key = self.cache_key(model, text);
let cache = TOKEN_CACHE.get_or_init(|| RwLock::new(SimpleCache::new(10000)));
cache.read().ok()?.get(&key)
}
pub fn set(&self, model: &str, text: &str, token_count: u32) {
let key = self.cache_key(model, text);
let cache = TOKEN_CACHE.get_or_init(|| RwLock::new(SimpleCache::new(10000)));
if let Ok(mut c) = cache.write() {
c.put(key, token_count);
debug!(
"Cached token count: {} tokens for model: {}",
token_count, model
);
}
}
pub fn clear(&self) {
let cache = TOKEN_CACHE.get_or_init(|| RwLock::new(SimpleCache::new(10000)));
if let Ok(mut c) = cache.write() {
c.clear();
debug!("Token cache cleared");
}
}
pub fn stats(&self) -> Option<(usize, usize)> {
let cache = TOKEN_CACHE.get_or_init(|| RwLock::new(SimpleCache::new(10000)));
if let Ok(c) = cache.read() {
Some((c.len(), c.capacity()))
} else {
None
}
}
fn cache_key(&self, model: &str, text: &str) -> u64 {
let mut hasher = DefaultHasher::new();
model.hash(&mut hasher);
text.hash(&mut hasher);
hasher.finish()
}
pub fn estimate_tokens_vectorized(&self, model: &str, text: &str, chars_per_token: f64) -> u32 {
if let Some(cached) = self.get(model, text) {
return cached;
}
if text.is_empty() {
return 0;
}
let byte_len = text.len() as f64;
let char_len = text.chars().count() as f64;
let unicode_factor = if char_len == byte_len { 1.0 } else { 1.2 };
let estimated_tokens = (char_len / chars_per_token * unicode_factor).ceil() as u32;
let final_tokens = (estimated_tokens as f64 * 1.1).ceil() as u32;
self.set(model, text, final_tokens);
final_tokens
}
pub fn batch_estimate(&self, model: &str, texts: &[String], chars_per_token: f64) -> Vec<u32> {
texts
.iter()
.map(|text| self.estimate_tokens_vectorized(model, text, chars_per_token))
.collect()
}
}
#[allow(dead_code)]
pub fn token_cache() -> &'static TokenCache {
static CACHE: TokenCache = TokenCache;
&CACHE
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_cache_basic() {
let cache = token_cache();
cache.clear();
let model = "gpt-4-test";
let text = "Hello world test";
assert!(cache.get(model, text).is_none());
cache.set(model, text, 2);
assert_eq!(cache.get(model, text), Some(2));
}
#[test]
fn test_cache_key_generation() {
let cache = token_cache();
let key1 = cache.cache_key("gpt-4", "hello");
let key2 = cache.cache_key("gpt-4", "world");
let key3 = cache.cache_key("gpt-3.5", "hello");
assert_ne!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_vectorized_estimation() {
let cache = token_cache();
cache.clear();
let model = "gpt-4";
let text = "This is a test text for token estimation";
let chars_per_token = 4.0;
let tokens = cache.estimate_tokens_vectorized(model, text, chars_per_token);
assert!(tokens > 0);
let tokens2 = cache.estimate_tokens_vectorized(model, text, chars_per_token);
assert_eq!(tokens, tokens2);
}
#[test]
fn test_batch_estimation() {
let cache = token_cache();
let model = "gpt-4";
let texts = vec![
"Hello world".to_string(),
"This is a longer text".to_string(),
"Short".to_string(),
];
let chars_per_token = 4.0;
let results = cache.batch_estimate(model, &texts, chars_per_token);
assert_eq!(results.len(), 3);
assert!(results.iter().all(|&count| count > 0));
}
#[test]
fn test_empty_text() {
let cache = token_cache();
let tokens = cache.estimate_tokens_vectorized("gpt-4", "", 4.0);
assert_eq!(tokens, 0);
}
}