1use crate::error::{ImageError, ImageResult};
8use crate::models::ImageMetadata;
9use ricecoder_providers::token_counter::TokenCounter;
10use std::sync::Arc;
11use tracing::{debug, info};
12
13pub struct ImageTokenCounter {
21 token_counter: Arc<TokenCounter>,
22}
23
24impl ImageTokenCounter {
25 pub fn new() -> Self {
27 Self {
28 token_counter: Arc::new(TokenCounter::new()),
29 }
30 }
31
32 pub fn with_counter(token_counter: Arc<TokenCounter>) -> Self {
34 Self { token_counter }
35 }
36
37 pub fn count_image_tokens(
49 &self,
50 metadata: &ImageMetadata,
51 provider_name: &str,
52 model: &str,
53 ) -> ImageResult<usize> {
54 let tokens = match provider_name.to_lowercase().as_str() {
55 "openai" => self.count_openai_tokens(metadata, model),
56 "anthropic" => self.count_anthropic_tokens(metadata, model),
57 "google" => self.count_google_tokens(metadata, model),
58 "ollama" => self.count_ollama_tokens(metadata, model),
59 _ => self.count_generic_tokens(metadata, model),
60 };
61
62 debug!(
63 provider = provider_name,
64 model = model,
65 image_tokens = tokens,
66 image_dimensions = format!("{}x{}", metadata.width, metadata.height),
67 "Counted image tokens"
68 );
69
70 Ok(tokens)
71 }
72
73 pub fn count_multiple_image_tokens(
85 &self,
86 images: &[ImageMetadata],
87 provider_name: &str,
88 model: &str,
89 ) -> ImageResult<usize> {
90 let mut total_tokens = 0;
91
92 for image in images {
93 let tokens = self.count_image_tokens(image, provider_name, model)?;
94 total_tokens += tokens;
95 }
96
97 info!(
98 provider = provider_name,
99 model = model,
100 image_count = images.len(),
101 total_tokens = total_tokens,
102 "Counted tokens for multiple images"
103 );
104
105 Ok(total_tokens)
106 }
107
108 fn count_openai_tokens(&self, metadata: &ImageMetadata, _model: &str) -> usize {
115 let is_high_res = metadata.width > 768 || metadata.height > 768;
117
118 let base_tokens = if is_high_res { 170 } else { 85 };
120
121 let resolution_factor = ((metadata.width as usize * metadata.height as usize) / 512)
124 .min(1000); base_tokens + resolution_factor
127 }
128
129 fn count_anthropic_tokens(&self, _metadata: &ImageMetadata, _model: &str) -> usize {
135 let base_tokens = 1600;
137
138 let metadata_tokens = 10;
140
141 base_tokens + metadata_tokens
142 }
143
144 fn count_google_tokens(&self, metadata: &ImageMetadata, _model: &str) -> usize {
150 let base_tokens = 258;
152
153 let resolution_factor = if metadata.width > 1024 || metadata.height > 1024 {
156 50
157 } else {
158 0
159 };
160
161 base_tokens + resolution_factor
162 }
163
164 fn count_ollama_tokens(&self, metadata: &ImageMetadata, _model: &str) -> usize {
170 let base_tokens = 100;
172
173 let size_factor = (metadata.size_bytes / 10240) as usize;
176
177 base_tokens + size_factor
178 }
179
180 fn count_generic_tokens(&self, metadata: &ImageMetadata, _model: &str) -> usize {
186 let base_tokens = 100;
188
189 let resolution_factor = (metadata.width as usize * metadata.height as usize) / 1000;
191
192 base_tokens + resolution_factor
193 }
194
195 pub fn count_prompt_tokens(
199 &self,
200 prompt: &str,
201 model: &str,
202 ) -> ImageResult<usize> {
203 self.token_counter
204 .count(prompt, model)
205 .map_err(|e| ImageError::AnalysisFailed(format!("Token counting failed: {}", e)))
206 }
207
208 pub fn count_total_tokens(
221 &self,
222 metadata: &ImageMetadata,
223 prompt: &str,
224 provider_name: &str,
225 model: &str,
226 ) -> ImageResult<usize> {
227 let image_tokens = self.count_image_tokens(metadata, provider_name, model)?;
228 let prompt_tokens = self.count_prompt_tokens(prompt, model)?;
229
230 let total = image_tokens + prompt_tokens;
231
232 debug!(
233 provider = provider_name,
234 model = model,
235 image_tokens = image_tokens,
236 prompt_tokens = prompt_tokens,
237 total_tokens = total,
238 "Counted total tokens for image analysis"
239 );
240
241 Ok(total)
242 }
243
244 pub fn token_counter(&self) -> &TokenCounter {
246 &self.token_counter
247 }
248}
249
250impl Default for ImageTokenCounter {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::formats::ImageFormat;
260 use std::path::PathBuf;
261
262 fn create_test_metadata(width: u32, height: u32, size_bytes: u64) -> ImageMetadata {
263 ImageMetadata::new(
264 PathBuf::from("/test.png"),
265 ImageFormat::Png,
266 size_bytes,
267 width,
268 height,
269 "hash123".to_string(),
270 )
271 }
272
273 #[test]
274 fn test_image_token_counter_creation() {
275 let counter = ImageTokenCounter::new();
276 let _ = counter.token_counter().cache_size();
278 }
279
280 #[test]
281 fn test_count_openai_tokens_standard() {
282 let counter = ImageTokenCounter::new();
283 let metadata = create_test_metadata(800, 600, 1024 * 1024);
284
285 let tokens = counter
286 .count_image_tokens(&metadata, "openai", "gpt-4-vision")
287 .unwrap();
288
289 assert!(tokens >= 85);
291 }
292
293 #[test]
294 fn test_count_openai_tokens_high_res() {
295 let counter = ImageTokenCounter::new();
296 let metadata = create_test_metadata(1920, 1080, 2 * 1024 * 1024);
297
298 let tokens = counter
299 .count_image_tokens(&metadata, "openai", "gpt-4-vision")
300 .unwrap();
301
302 assert!(tokens >= 170);
304 }
305
306 #[test]
307 fn test_count_anthropic_tokens() {
308 let counter = ImageTokenCounter::new();
309 let metadata = create_test_metadata(800, 600, 1024 * 1024);
310
311 let tokens = counter
312 .count_image_tokens(&metadata, "anthropic", "claude-3-vision")
313 .unwrap();
314
315 assert!(tokens >= 1600);
317 }
318
319 #[test]
320 fn test_count_google_tokens() {
321 let counter = ImageTokenCounter::new();
322 let metadata = create_test_metadata(800, 600, 1024 * 1024);
323
324 let tokens = counter
325 .count_image_tokens(&metadata, "google", "gemini-pro-vision")
326 .unwrap();
327
328 assert!(tokens >= 258);
330 }
331
332 #[test]
333 fn test_count_ollama_tokens() {
334 let counter = ImageTokenCounter::new();
335 let metadata = create_test_metadata(800, 600, 1024 * 1024);
336
337 let tokens = counter
338 .count_image_tokens(&metadata, "ollama", "llava")
339 .unwrap();
340
341 assert!(tokens >= 100);
343 }
344
345 #[test]
346 fn test_count_generic_tokens() {
347 let counter = ImageTokenCounter::new();
348 let metadata = create_test_metadata(800, 600, 1024 * 1024);
349
350 let tokens = counter
351 .count_image_tokens(&metadata, "unknown", "unknown-model")
352 .unwrap();
353
354 assert!(tokens >= 100);
356 }
357
358 #[test]
359 fn test_count_multiple_image_tokens() {
360 let counter = ImageTokenCounter::new();
361 let images = vec![
362 create_test_metadata(800, 600, 1024 * 1024),
363 create_test_metadata(1024, 768, 2 * 1024 * 1024),
364 ];
365
366 let total = counter
367 .count_multiple_image_tokens(&images, "openai", "gpt-4-vision")
368 .unwrap();
369
370 assert!(total > 0);
372 }
373
374 #[test]
375 fn test_count_prompt_tokens() {
376 let counter = ImageTokenCounter::new();
377 let prompt = "Please analyze this image";
378
379 let tokens = counter.count_prompt_tokens(prompt, "gpt-4").unwrap();
380
381 assert!(tokens > 0);
383 }
384
385 #[test]
386 fn test_count_total_tokens() {
387 let counter = ImageTokenCounter::new();
388 let metadata = create_test_metadata(800, 600, 1024 * 1024);
389 let prompt = "Please analyze this image";
390
391 let total = counter
392 .count_total_tokens(&metadata, prompt, "openai", "gpt-4-vision")
393 .unwrap();
394
395 assert!(total > 0);
397 }
398
399 #[test]
400 fn test_count_tokens_different_providers() {
401 let counter = ImageTokenCounter::new();
402 let metadata = create_test_metadata(800, 600, 1024 * 1024);
403
404 let openai = counter
405 .count_image_tokens(&metadata, "openai", "gpt-4-vision")
406 .unwrap();
407 let anthropic = counter
408 .count_image_tokens(&metadata, "anthropic", "claude-3-vision")
409 .unwrap();
410 let google = counter
411 .count_image_tokens(&metadata, "google", "gemini-pro-vision")
412 .unwrap();
413
414 assert_ne!(openai, anthropic);
416 assert_ne!(anthropic, google);
417 }
418
419 #[test]
420 fn test_count_tokens_empty_prompt() {
421 let counter = ImageTokenCounter::new();
422 let tokens = counter.count_prompt_tokens("", "gpt-4").unwrap();
423
424 assert_eq!(tokens, 0);
426 }
427}