use moka::sync::Cache;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TokenizerModel {
Cl100kBase,
P50kBase,
Claude,
Llama,
Generic,
}
impl TokenizerModel {
pub fn bytes_per_token(&self) -> f32 {
match self {
Self::Cl100kBase => 3.8,
Self::P50kBase => 4.0,
Self::Claude => 4.2,
Self::Llama => 4.0,
Self::Generic => 4.0,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Cl100kBase => "cl100k_base",
Self::P50kBase => "p50k_base",
Self::Claude => "claude",
Self::Llama => "llama",
Self::Generic => "generic",
}
}
}
#[derive(Debug, Clone)]
pub struct ExactTokenConfig {
pub model: TokenizerModel,
pub cache_size: usize,
pub cache_ttl_secs: u64,
pub fallback_on_error: bool,
pub max_cache_text_len: usize,
}
impl Default for ExactTokenConfig {
fn default() -> Self {
Self {
model: TokenizerModel::Cl100kBase,
cache_size: 10_000,
cache_ttl_secs: 3600,
fallback_on_error: true,
max_cache_text_len: 10_000,
}
}
}
impl ExactTokenConfig {
pub fn gpt4() -> Self {
Self {
model: TokenizerModel::Cl100kBase,
..Default::default()
}
}
pub fn claude() -> Self {
Self {
model: TokenizerModel::Claude,
..Default::default()
}
}
}
pub trait TokenCounter: Send + Sync {
fn count(&self, text: &str) -> usize;
fn count_for_model(&self, text: &str, model: TokenizerModel) -> usize {
let _ = model; self.count(text)
}
fn tokenize(&self, text: &str) -> Vec<u32>;
fn decode(&self, tokens: &[u32]) -> String;
fn model(&self) -> TokenizerModel;
fn is_exact(&self) -> bool;
}
pub struct ExactTokenCounter {
config: ExactTokenConfig,
cache: Cache<u64, usize>,
vocab: Arc<BpeVocab>,
stats: Arc<TokenCacheStats>,
}
struct BpeVocab {
token_to_id: HashMap<String, u32>,
id_to_token: HashMap<u32, String>,
#[allow(dead_code)]
merges: HashMap<(String, String), String>,
special_tokens: HashMap<String, u32>,
}
impl BpeVocab {
fn cl100k_base() -> Self {
let mut token_to_id = HashMap::new();
let mut id_to_token = HashMap::new();
for b in 32u8..127 {
let token = String::from(b as char);
let id = b as u32;
token_to_id.insert(token.clone(), id);
id_to_token.insert(id, token);
}
let common_tokens = [
"the", "ing", "tion", "ed", "er", "es", "en", "al", "re", "on", "an", "or", "ar", "is",
"it", "at", "as", "le", "ve", " the", " a", " to", " of", " and", " in", " is", " for",
" ", "\n", "\t", "```", "...", "->", "=>", "==", "!=",
];
let mut id = 200u32;
for token in common_tokens {
token_to_id.insert(token.to_string(), id);
id_to_token.insert(id, token.to_string());
id += 1;
}
let mut special_tokens = HashMap::new();
special_tokens.insert("<|endoftext|>".to_string(), 100257);
special_tokens.insert("<|fim_prefix|>".to_string(), 100258);
special_tokens.insert("<|fim_middle|>".to_string(), 100259);
special_tokens.insert("<|fim_suffix|>".to_string(), 100260);
Self {
token_to_id,
id_to_token,
merges: HashMap::new(),
special_tokens,
}
}
fn tokenize(&self, text: &str) -> Vec<u32> {
let mut tokens = Vec::new();
let mut remaining = text;
while !remaining.is_empty() {
let mut matched = false;
for (special, id) in &self.special_tokens {
if remaining.starts_with(special) {
tokens.push(*id);
remaining = &remaining[special.len()..];
matched = true;
break;
}
}
if matched {
continue;
}
for len in (1..=remaining.len().min(10)).rev() {
if let Some(substr) = remaining.get(..len) {
if let Some(&id) = self.token_to_id.get(substr) {
tokens.push(id);
remaining = &remaining[len..];
matched = true;
break;
}
}
}
if !matched {
if let Some(c) = remaining.chars().next() {
let byte_id = (c as u32).min(255);
tokens.push(byte_id);
remaining = &remaining[c.len_utf8()..];
}
}
}
tokens
}
fn decode(&self, tokens: &[u32]) -> String {
let mut result = String::new();
for &id in tokens {
if let Some(token) = self.id_to_token.get(&id) {
result.push_str(token);
} else {
if id < 256 {
if let Some(c) = char::from_u32(id) {
result.push(c);
}
}
}
}
result
}
}
#[derive(Debug, Default)]
pub struct TokenCacheStats {
pub hits: std::sync::atomic::AtomicUsize,
pub misses: std::sync::atomic::AtomicUsize,
pub tokenizations: std::sync::atomic::AtomicUsize,
pub total_tokens: std::sync::atomic::AtomicUsize,
}
impl TokenCacheStats {
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
}
impl ExactTokenCounter {
pub fn new(config: ExactTokenConfig) -> Self {
let cache = Cache::builder()
.max_capacity(config.cache_size as u64)
.time_to_live(std::time::Duration::from_secs(config.cache_ttl_secs))
.build();
Self {
config,
cache,
vocab: Arc::new(BpeVocab::cl100k_base()),
stats: Arc::new(TokenCacheStats::default()),
}
}
pub fn default_counter() -> Self {
Self::new(ExactTokenConfig::default())
}
pub fn stats(&self) -> &Arc<TokenCacheStats> {
&self.stats
}
fn text_hash(text: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
fn count_cached(&self, text: &str) -> usize {
if text.len() > self.config.max_cache_text_len {
self.stats
.misses
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
return self.tokenize(text).len();
}
let hash = Self::text_hash(text);
if let Some(count) = self.cache.get(&hash) {
self.stats
.hits
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
return count;
}
self.stats
.misses
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tokens = self.tokenize(text);
let count = tokens.len();
self.cache.insert(hash, count);
self.stats
.total_tokens
.fetch_add(count, std::sync::atomic::Ordering::Relaxed);
count
}
#[allow(dead_code)]
fn estimate_tokens(&self, text: &str) -> usize {
let bytes = text.len();
((bytes as f32) / self.config.model.bytes_per_token()).ceil() as usize
}
}
impl TokenCounter for ExactTokenCounter {
fn count(&self, text: &str) -> usize {
self.stats
.tokenizations
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.count_cached(text)
}
fn count_for_model(&self, text: &str, model: TokenizerModel) -> usize {
if model == self.config.model {
self.count(text)
} else {
let bytes = text.len();
((bytes as f32) / model.bytes_per_token()).ceil() as usize
}
}
fn tokenize(&self, text: &str) -> Vec<u32> {
self.vocab.tokenize(text)
}
fn decode(&self, tokens: &[u32]) -> String {
self.vocab.decode(tokens)
}
fn model(&self) -> TokenizerModel {
self.config.model
}
fn is_exact(&self) -> bool {
true
}
}
pub struct HeuristicTokenCounter {
bytes_per_token: f32,
model: TokenizerModel,
}
impl HeuristicTokenCounter {
pub fn new() -> Self {
Self {
bytes_per_token: 4.0,
model: TokenizerModel::Generic,
}
}
pub fn for_model(model: TokenizerModel) -> Self {
Self {
bytes_per_token: model.bytes_per_token(),
model,
}
}
}
impl Default for HeuristicTokenCounter {
fn default() -> Self {
Self::new()
}
}
impl TokenCounter for HeuristicTokenCounter {
fn count(&self, text: &str) -> usize {
let bytes = text.len();
((bytes as f32) / self.bytes_per_token).ceil() as usize
}
fn tokenize(&self, text: &str) -> Vec<u32> {
text.split_whitespace()
.enumerate()
.map(|(i, _)| i as u32)
.collect()
}
fn decode(&self, _tokens: &[u32]) -> String {
"[decode not supported for heuristic counter]".to_string()
}
fn model(&self) -> TokenizerModel {
self.model
}
fn is_exact(&self) -> bool {
false
}
}
pub struct ExactBudgetEnforcer<C: TokenCounter> {
counter: Arc<C>,
budget: usize,
used: std::sync::atomic::AtomicUsize,
}
impl<C: TokenCounter> ExactBudgetEnforcer<C> {
pub fn new(counter: Arc<C>, budget: usize) -> Self {
Self {
counter,
budget,
used: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn remaining(&self) -> usize {
self.budget
.saturating_sub(self.used.load(std::sync::atomic::Ordering::Relaxed))
}
pub fn fits(&self, text: &str) -> bool {
let tokens = self.counter.count(text);
tokens <= self.remaining()
}
pub fn try_consume(&self, text: &str) -> Option<usize> {
let tokens = self.counter.count(text);
let remaining = self.remaining();
if tokens <= remaining {
self.used
.fetch_add(tokens, std::sync::atomic::Ordering::Relaxed);
Some(tokens)
} else {
None
}
}
pub fn force_consume(&self, tokens: usize) {
self.used
.fetch_add(tokens, std::sync::atomic::Ordering::Relaxed);
}
pub fn truncate_to_fit(&self, text: &str) -> (String, usize) {
let remaining = self.remaining();
if remaining == 0 {
return (String::new(), 0);
}
let mut low = 0;
let mut high = text.len();
let mut best_len = 0;
let mut best_tokens = 0;
while low < high {
let mid = (low + high + 1) / 2;
let truncated = if mid >= text.len() {
text.to_string()
} else {
let mut end = mid;
while !text.is_char_boundary(end) && end > 0 {
end -= 1;
}
text[..end].to_string()
};
let tokens = self.counter.count(&truncated);
if tokens <= remaining {
best_len = truncated.len();
best_tokens = tokens;
low = mid;
} else {
high = mid - 1;
}
}
if best_len == 0 {
(String::new(), 0)
} else {
(text[..best_len].to_string(), best_tokens)
}
}
pub fn summary(&self) -> BudgetSummary {
let used = self.used.load(std::sync::atomic::Ordering::Relaxed);
BudgetSummary {
budget: self.budget,
used,
remaining: self.budget.saturating_sub(used),
utilization: (used as f64) / (self.budget as f64),
}
}
}
#[derive(Debug, Clone)]
pub struct BudgetSummary {
pub budget: usize,
pub used: usize,
pub remaining: usize,
pub utilization: f64,
}
pub fn count_tokens_exact(text: &str) -> usize {
let counter = ExactTokenCounter::default_counter();
counter.count(text)
}
pub fn count_tokens_heuristic(text: &str) -> usize {
let counter = HeuristicTokenCounter::new();
counter.count(text)
}
pub fn create_budget_enforcer(budget: usize) -> ExactBudgetEnforcer<ExactTokenCounter> {
let counter = Arc::new(ExactTokenCounter::default_counter());
ExactBudgetEnforcer::new(counter, budget)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_token_count() {
let counter = ExactTokenCounter::default_counter();
let count = counter.count("Hello, world!");
assert!(count > 0);
assert!(count < 20); }
#[test]
fn test_tokenize_and_decode() {
let counter = ExactTokenCounter::default_counter();
let text = "Hello world";
let tokens = counter.tokenize(text);
assert!(!tokens.is_empty());
let decoded = counter.decode(&tokens);
assert!(!decoded.is_empty());
}
#[test]
fn test_cache_hits() {
let counter = ExactTokenCounter::default_counter();
let _ = counter.count("test text for caching");
let _ = counter.count("test text for caching");
let stats = counter.stats();
let hits = stats.hits.load(std::sync::atomic::Ordering::Relaxed);
let misses = stats.misses.load(std::sync::atomic::Ordering::Relaxed);
assert!(hits >= 1);
assert!(misses >= 1);
}
#[test]
fn test_heuristic_counter() {
let counter = HeuristicTokenCounter::new();
let count = counter.count("Hello world");
assert!(count >= 2 && count <= 5);
}
#[test]
fn test_budget_enforcer() {
let counter = Arc::new(ExactTokenCounter::default_counter());
let enforcer = ExactBudgetEnforcer::new(counter, 100);
assert_eq!(enforcer.remaining(), 100);
let consumed = enforcer.try_consume("Hello world").unwrap();
assert!(consumed > 0);
assert!(enforcer.remaining() < 100);
}
#[test]
fn test_budget_truncation() {
let counter = Arc::new(ExactTokenCounter::default_counter());
let enforcer = ExactBudgetEnforcer::new(counter, 5);
let long_text =
"This is a very long text that definitely exceeds five tokens and should be truncated";
let (truncated, tokens) = enforcer.truncate_to_fit(long_text);
assert!(truncated.len() < long_text.len());
assert!(tokens <= 5);
}
#[test]
fn test_budget_summary() {
let counter = Arc::new(HeuristicTokenCounter::new());
let enforcer = ExactBudgetEnforcer::new(counter, 100);
enforcer.force_consume(25);
let summary = enforcer.summary();
assert_eq!(summary.budget, 100);
assert_eq!(summary.used, 25);
assert_eq!(summary.remaining, 75);
assert!((summary.utilization - 0.25).abs() < 0.01);
}
#[test]
fn test_model_specific_counting() {
let counter = ExactTokenCounter::default_counter();
let text = "Hello, world!";
let gpt4_count = counter.count_for_model(text, TokenizerModel::Cl100kBase);
let claude_count = counter.count_for_model(text, TokenizerModel::Claude);
assert!(gpt4_count > 0);
assert!(claude_count > 0);
}
}