ricecoder_providers/
token_counter.rs1use crate::error::ProviderError;
7use std::collections::HashMap;
8use std::sync::Mutex;
9
10pub trait TokenCounterTrait: Send + Sync {
12 fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError>;
14
15 fn clear_cache(&self);
17
18 fn cache_size(&self) -> usize;
20}
21
22pub struct TokenCounter {
24 cache: Mutex<HashMap<String, usize>>,
25}
26
27impl TokenCounter {
28 pub fn new() -> Self {
30 Self {
31 cache: Mutex::new(HashMap::new()),
32 }
33 }
34
35 pub fn count_tokens_openai(&self, content: &str, model: &str) -> usize {
42 let cache_key = format!("{}:{}", model, content);
44 if let Ok(cache) = self.cache.lock() {
45 if let Some(&count) = cache.get(&cache_key) {
46 return count;
47 }
48 }
49
50 let estimated = self.estimate_tokens(content, model);
52
53 if let Ok(mut cache) = self.cache.lock() {
55 cache.insert(cache_key, estimated);
56 }
57
58 estimated
59 }
60
61 pub fn count(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
66 Ok(self.count_tokens_openai(content, model))
67 }
68
69 fn estimate_tokens(&self, content: &str, _model: &str) -> usize {
71 if content.is_empty() {
72 return 0;
73 }
74
75 let estimated = (content.len() as f64 / 4.0).ceil() as usize;
79
80 std::cmp::max(1, estimated)
82 }
83
84 pub fn clear_cache(&self) {
86 if let Ok(mut cache) = self.cache.lock() {
87 cache.clear();
88 }
89 }
90
91 pub fn cache_size(&self) -> usize {
93 self.cache.lock().map(|c| c.len()).unwrap_or(0)
94 }
95}
96
97impl Default for TokenCounter {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl TokenCounterTrait for TokenCounter {
104 fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
105 self.count(content, model)
106 }
107
108 fn clear_cache(&self) {
109 if let Ok(mut cache) = self.cache.lock() {
110 cache.clear();
111 }
112 }
113
114 fn cache_size(&self) -> usize {
115 self.cache.lock().map(|c| c.len()).unwrap_or(0)
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn test_token_counter_empty_string() {
125 let counter = TokenCounter::new();
126 assert_eq!(counter.count_tokens_openai("", "gpt-4"), 0);
127 }
128
129 #[test]
130 fn test_token_counter_simple_text() {
131 let counter = TokenCounter::new();
132 let tokens = counter.count_tokens_openai("Hello world", "gpt-4");
133 assert!(tokens > 0);
134 }
135
136 #[test]
137 fn test_token_counter_caching() {
138 let counter = TokenCounter::new();
139 let content = "This is a test message";
140 let tokens1 = counter.count_tokens_openai(content, "gpt-4");
141 let tokens2 = counter.count_tokens_openai(content, "gpt-4");
142 assert_eq!(tokens1, tokens2);
143 assert_eq!(counter.cache_size(), 1);
144 }
145
146 #[test]
147 fn test_token_counter_different_models() {
148 let counter = TokenCounter::new();
149 let content = "Test content";
150 let _tokens_gpt4 = counter.count_tokens_openai(content, "gpt-4");
151 let _tokens_gpt35 = counter.count_tokens_openai(content, "gpt-3.5-turbo");
152 assert_eq!(counter.cache_size(), 2);
154 }
155
156 #[test]
157 fn test_token_counter_special_characters() {
158 let counter = TokenCounter::new();
159 let simple = counter.count_tokens_openai("hello", "gpt-4");
160 let with_special = counter.count_tokens_openai("hello!!!???", "gpt-4");
161 assert!(with_special >= simple);
163 }
164
165 #[test]
166 fn test_token_counter_clear_cache() {
167 let counter = TokenCounter::new();
168 counter.count_tokens_openai("test", "gpt-4");
169 assert_eq!(counter.cache_size(), 1);
170 counter.clear_cache();
171 assert_eq!(counter.cache_size(), 0);
172 }
173}