ricecoder_providers/
token_counter.rs

1//! Token counting utilities for different providers
2//!
3//! This module provides token counting functionality for various AI providers.
4//! For production use with OpenAI, consider using tiktoken-rs for accurate token counting.
5
6use crate::error::ProviderError;
7use std::collections::HashMap;
8use std::sync::Mutex;
9
10/// Trait for unified token counting across providers
11pub trait TokenCounterTrait: Send + Sync {
12    /// Count tokens for content in a specific model
13    fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError>;
14
15    /// Clear the token count cache
16    fn clear_cache(&self);
17
18    /// Get cache size
19    fn cache_size(&self) -> usize;
20}
21
22/// Token counter for estimating token usage
23pub struct TokenCounter {
24    cache: Mutex<HashMap<String, usize>>,
25}
26
27impl TokenCounter {
28    /// Create a new token counter
29    pub fn new() -> Self {
30        Self {
31            cache: Mutex::new(HashMap::new()),
32        }
33    }
34
35    /// Count tokens for OpenAI models using estimation
36    ///
37    /// This uses a heuristic-based approach:
38    /// - Average English word is ~4.7 characters
39    /// - Average token is ~4 characters
40    /// - Special tokens and formatting add overhead
41    pub fn count_tokens_openai(&self, content: &str, model: &str) -> usize {
42        // Check cache first
43        let cache_key = format!("{}:{}", model, content);
44        if let Ok(cache) = self.cache.lock() {
45            if let Some(&count) = cache.get(&cache_key) {
46                return count;
47            }
48        }
49
50        // Estimate tokens based on model and content
51        let estimated = self.estimate_tokens(content, model);
52
53        // Cache the result
54        if let Ok(mut cache) = self.cache.lock() {
55            cache.insert(cache_key, estimated);
56        }
57
58        estimated
59    }
60
61    /// Count tokens for content (unified interface)
62    ///
63    /// This method provides a unified interface for token counting across providers.
64    /// It returns a Result type for better error handling.
65    pub fn count(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
66        Ok(self.count_tokens_openai(content, model))
67    }
68
69    /// Estimate token count for content
70    fn estimate_tokens(&self, content: &str, _model: &str) -> usize {
71        if content.is_empty() {
72            return 0;
73        }
74
75        // Estimate based on character count
76        // Heuristic: roughly 1 token per 4 characters
77        // This is a conservative estimate that should not exceed content length
78        let estimated = (content.len() as f64 / 4.0).ceil() as usize;
79
80        // Ensure at least 1 token for non-empty content
81        std::cmp::max(1, estimated)
82    }
83
84    /// Clear the token count cache
85    pub fn clear_cache(&self) {
86        if let Ok(mut cache) = self.cache.lock() {
87            cache.clear();
88        }
89    }
90
91    /// Get cache size
92    pub fn cache_size(&self) -> usize {
93        self.cache.lock().map(|c| c.len()).unwrap_or(0)
94    }
95}
96
97impl Default for TokenCounter {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl TokenCounterTrait for TokenCounter {
104    fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
105        self.count(content, model)
106    }
107
108    fn clear_cache(&self) {
109        if let Ok(mut cache) = self.cache.lock() {
110            cache.clear();
111        }
112    }
113
114    fn cache_size(&self) -> usize {
115        self.cache.lock().map(|c| c.len()).unwrap_or(0)
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn test_token_counter_empty_string() {
125        let counter = TokenCounter::new();
126        assert_eq!(counter.count_tokens_openai("", "gpt-4"), 0);
127    }
128
129    #[test]
130    fn test_token_counter_simple_text() {
131        let counter = TokenCounter::new();
132        let tokens = counter.count_tokens_openai("Hello world", "gpt-4");
133        assert!(tokens > 0);
134    }
135
136    #[test]
137    fn test_token_counter_caching() {
138        let counter = TokenCounter::new();
139        let content = "This is a test message";
140        let tokens1 = counter.count_tokens_openai(content, "gpt-4");
141        let tokens2 = counter.count_tokens_openai(content, "gpt-4");
142        assert_eq!(tokens1, tokens2);
143        assert_eq!(counter.cache_size(), 1);
144    }
145
146    #[test]
147    fn test_token_counter_different_models() {
148        let counter = TokenCounter::new();
149        let content = "Test content";
150        let _tokens_gpt4 = counter.count_tokens_openai(content, "gpt-4");
151        let _tokens_gpt35 = counter.count_tokens_openai(content, "gpt-3.5-turbo");
152        // Both should be cached
153        assert_eq!(counter.cache_size(), 2);
154    }
155
156    #[test]
157    fn test_token_counter_special_characters() {
158        let counter = TokenCounter::new();
159        let simple = counter.count_tokens_openai("hello", "gpt-4");
160        let with_special = counter.count_tokens_openai("hello!!!???", "gpt-4");
161        // Special characters should increase token count
162        assert!(with_special >= simple);
163    }
164
165    #[test]
166    fn test_token_counter_clear_cache() {
167        let counter = TokenCounter::new();
168        counter.count_tokens_openai("test", "gpt-4");
169        assert_eq!(counter.cache_size(), 1);
170        counter.clear_cache();
171        assert_eq!(counter.cache_size(), 0);
172    }
173}