ricecoder_images/
token_counting.rs

1//! Token counting for images across different AI providers.
2//!
3//! This module handles token counting for images, which varies significantly
4//! by provider. It integrates with ricecoder-providers TokenCounter for
5//! consistent token usage tracking.
6
7use crate::error::{ImageError, ImageResult};
8use crate::models::ImageMetadata;
9use ricecoder_providers::token_counter::TokenCounter;
10use std::sync::Arc;
11use tracing::{debug, info};
12
13/// Image token counting for different providers.
14///
15/// Different AI providers count image tokens differently:
16/// - OpenAI: ~85 tokens base + variable based on resolution
17/// - Anthropic: ~1600 tokens per image
18/// - Google: ~258 tokens per image
19/// - Ollama: ~100 tokens per image (estimate)
20pub struct ImageTokenCounter {
21    token_counter: Arc<TokenCounter>,
22}
23
24impl ImageTokenCounter {
25    /// Create a new image token counter.
26    pub fn new() -> Self {
27        Self {
28            token_counter: Arc::new(TokenCounter::new()),
29        }
30    }
31
32    /// Create with an existing token counter.
33    pub fn with_counter(token_counter: Arc<TokenCounter>) -> Self {
34        Self { token_counter }
35    }
36
37    /// Count tokens for a single image based on provider and model.
38    ///
39    /// # Arguments
40    ///
41    /// * `metadata` - Image metadata including dimensions
42    /// * `provider_name` - Name of the provider (openai, anthropic, google, ollama)
43    /// * `model` - Model identifier
44    ///
45    /// # Returns
46    ///
47    /// Number of tokens used for the image
48    pub fn count_image_tokens(
49        &self,
50        metadata: &ImageMetadata,
51        provider_name: &str,
52        model: &str,
53    ) -> ImageResult<usize> {
54        let tokens = match provider_name.to_lowercase().as_str() {
55            "openai" => self.count_openai_tokens(metadata, model),
56            "anthropic" => self.count_anthropic_tokens(metadata, model),
57            "google" => self.count_google_tokens(metadata, model),
58            "ollama" => self.count_ollama_tokens(metadata, model),
59            _ => self.count_generic_tokens(metadata, model),
60        };
61
62        debug!(
63            provider = provider_name,
64            model = model,
65            image_tokens = tokens,
66            image_dimensions = format!("{}x{}", metadata.width, metadata.height),
67            "Counted image tokens"
68        );
69
70        Ok(tokens)
71    }
72
73    /// Count tokens for multiple images.
74    ///
75    /// # Arguments
76    ///
77    /// * `images` - Vector of image metadata
78    /// * `provider_name` - Name of the provider
79    /// * `model` - Model identifier
80    ///
81    /// # Returns
82    ///
83    /// Total tokens for all images
84    pub fn count_multiple_image_tokens(
85        &self,
86        images: &[ImageMetadata],
87        provider_name: &str,
88        model: &str,
89    ) -> ImageResult<usize> {
90        let mut total_tokens = 0;
91
92        for image in images {
93            let tokens = self.count_image_tokens(image, provider_name, model)?;
94            total_tokens += tokens;
95        }
96
97        info!(
98            provider = provider_name,
99            model = model,
100            image_count = images.len(),
101            total_tokens = total_tokens,
102            "Counted tokens for multiple images"
103        );
104
105        Ok(total_tokens)
106    }
107
108    /// Count tokens for OpenAI models.
109    ///
110    /// OpenAI's vision models use the following token counting:
111    /// - Base: ~85 tokens per image
112    /// - Resolution factor: additional tokens based on image dimensions
113    /// - For high-resolution images: ~170 tokens base + resolution factor
114    fn count_openai_tokens(&self, metadata: &ImageMetadata, _model: &str) -> usize {
115        // Determine if this is a high-resolution image
116        let is_high_res = metadata.width > 768 || metadata.height > 768;
117
118        // Base tokens
119        let base_tokens = if is_high_res { 170 } else { 85 };
120
121        // Resolution factor: additional tokens based on image size
122        // Roughly 1 token per 512 pixels
123        let resolution_factor = ((metadata.width as usize * metadata.height as usize) / 512)
124            .min(1000); // Cap at 1000 to avoid excessive token counts
125
126        base_tokens + resolution_factor
127    }
128
129    /// Count tokens for Anthropic models.
130    ///
131    /// Anthropic's vision models use:
132    /// - ~1600 tokens per image (fixed)
133    /// - Additional tokens for image metadata
134    fn count_anthropic_tokens(&self, _metadata: &ImageMetadata, _model: &str) -> usize {
135        // Base tokens for image
136        let base_tokens = 1600;
137
138        // Additional tokens for metadata (format, dimensions)
139        let metadata_tokens = 10;
140
141        base_tokens + metadata_tokens
142    }
143
144    /// Count tokens for Google models.
145    ///
146    /// Google's vision models use:
147    /// - ~258 tokens per image (fixed)
148    /// - Additional tokens based on image complexity
149    fn count_google_tokens(&self, metadata: &ImageMetadata, _model: &str) -> usize {
150        // Base tokens for image
151        let base_tokens = 258;
152
153        // Additional tokens based on resolution
154        // Higher resolution images may require more tokens
155        let resolution_factor = if metadata.width > 1024 || metadata.height > 1024 {
156            50
157        } else {
158            0
159        };
160
161        base_tokens + resolution_factor
162    }
163
164    /// Count tokens for Ollama models.
165    ///
166    /// Ollama's vision models use:
167    /// - ~100 tokens per image (estimate)
168    /// - Additional tokens based on image size
169    fn count_ollama_tokens(&self, metadata: &ImageMetadata, _model: &str) -> usize {
170        // Base tokens for image
171        let base_tokens = 100;
172
173        // Additional tokens based on image size
174        // Roughly 1 token per 10KB
175        let size_factor = (metadata.size_bytes / 10240) as usize;
176
177        base_tokens + size_factor
178    }
179
180    /// Count tokens for generic/unknown providers.
181    ///
182    /// Uses a conservative estimate:
183    /// - ~100 tokens base per image
184    /// - Additional tokens based on resolution
185    fn count_generic_tokens(&self, metadata: &ImageMetadata, _model: &str) -> usize {
186        // Base tokens for image
187        let base_tokens = 100;
188
189        // Resolution factor: 1 token per 1000 pixels
190        let resolution_factor = (metadata.width as usize * metadata.height as usize) / 1000;
191
192        base_tokens + resolution_factor
193    }
194
195    /// Count tokens for image analysis prompt.
196    ///
197    /// This counts tokens for the analysis prompt text that accompanies the image.
198    pub fn count_prompt_tokens(
199        &self,
200        prompt: &str,
201        model: &str,
202    ) -> ImageResult<usize> {
203        self.token_counter
204            .count(prompt, model)
205            .map_err(|e| ImageError::AnalysisFailed(format!("Token counting failed: {}", e)))
206    }
207
208    /// Count total tokens for image analysis (image + prompt).
209    ///
210    /// # Arguments
211    ///
212    /// * `metadata` - Image metadata
213    /// * `prompt` - Analysis prompt text
214    /// * `provider_name` - Provider name
215    /// * `model` - Model identifier
216    ///
217    /// # Returns
218    ///
219    /// Total tokens for image + prompt
220    pub fn count_total_tokens(
221        &self,
222        metadata: &ImageMetadata,
223        prompt: &str,
224        provider_name: &str,
225        model: &str,
226    ) -> ImageResult<usize> {
227        let image_tokens = self.count_image_tokens(metadata, provider_name, model)?;
228        let prompt_tokens = self.count_prompt_tokens(prompt, model)?;
229
230        let total = image_tokens + prompt_tokens;
231
232        debug!(
233            provider = provider_name,
234            model = model,
235            image_tokens = image_tokens,
236            prompt_tokens = prompt_tokens,
237            total_tokens = total,
238            "Counted total tokens for image analysis"
239        );
240
241        Ok(total)
242    }
243
244    /// Get the underlying token counter.
245    pub fn token_counter(&self) -> &TokenCounter {
246        &self.token_counter
247    }
248}
249
250impl Default for ImageTokenCounter {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::formats::ImageFormat;
260    use std::path::PathBuf;
261
262    fn create_test_metadata(width: u32, height: u32, size_bytes: u64) -> ImageMetadata {
263        ImageMetadata::new(
264            PathBuf::from("/test.png"),
265            ImageFormat::Png,
266            size_bytes,
267            width,
268            height,
269            "hash123".to_string(),
270        )
271    }
272
273    #[test]
274    fn test_image_token_counter_creation() {
275        let counter = ImageTokenCounter::new();
276        // Verify cache is initialized (size is always >= 0 for usize)
277        let _ = counter.token_counter().cache_size();
278    }
279
280    #[test]
281    fn test_count_openai_tokens_standard() {
282        let counter = ImageTokenCounter::new();
283        let metadata = create_test_metadata(800, 600, 1024 * 1024);
284
285        let tokens = counter
286            .count_image_tokens(&metadata, "openai", "gpt-4-vision")
287            .unwrap();
288
289        // Should be at least base tokens (85)
290        assert!(tokens >= 85);
291    }
292
293    #[test]
294    fn test_count_openai_tokens_high_res() {
295        let counter = ImageTokenCounter::new();
296        let metadata = create_test_metadata(1920, 1080, 2 * 1024 * 1024);
297
298        let tokens = counter
299            .count_image_tokens(&metadata, "openai", "gpt-4-vision")
300            .unwrap();
301
302        // High-res should have more tokens (base 170 + resolution factor)
303        assert!(tokens >= 170);
304    }
305
306    #[test]
307    fn test_count_anthropic_tokens() {
308        let counter = ImageTokenCounter::new();
309        let metadata = create_test_metadata(800, 600, 1024 * 1024);
310
311        let tokens = counter
312            .count_image_tokens(&metadata, "anthropic", "claude-3-vision")
313            .unwrap();
314
315        // Anthropic should be around 1600 + metadata
316        assert!(tokens >= 1600);
317    }
318
319    #[test]
320    fn test_count_google_tokens() {
321        let counter = ImageTokenCounter::new();
322        let metadata = create_test_metadata(800, 600, 1024 * 1024);
323
324        let tokens = counter
325            .count_image_tokens(&metadata, "google", "gemini-pro-vision")
326            .unwrap();
327
328        // Google should be around 258
329        assert!(tokens >= 258);
330    }
331
332    #[test]
333    fn test_count_ollama_tokens() {
334        let counter = ImageTokenCounter::new();
335        let metadata = create_test_metadata(800, 600, 1024 * 1024);
336
337        let tokens = counter
338            .count_image_tokens(&metadata, "ollama", "llava")
339            .unwrap();
340
341        // Ollama should be around 100 + size factor
342        assert!(tokens >= 100);
343    }
344
345    #[test]
346    fn test_count_generic_tokens() {
347        let counter = ImageTokenCounter::new();
348        let metadata = create_test_metadata(800, 600, 1024 * 1024);
349
350        let tokens = counter
351            .count_image_tokens(&metadata, "unknown", "unknown-model")
352            .unwrap();
353
354        // Generic should be at least 100
355        assert!(tokens >= 100);
356    }
357
358    #[test]
359    fn test_count_multiple_image_tokens() {
360        let counter = ImageTokenCounter::new();
361        let images = vec![
362            create_test_metadata(800, 600, 1024 * 1024),
363            create_test_metadata(1024, 768, 2 * 1024 * 1024),
364        ];
365
366        let total = counter
367            .count_multiple_image_tokens(&images, "openai", "gpt-4-vision")
368            .unwrap();
369
370        // Should be sum of individual tokens
371        assert!(total > 0);
372    }
373
374    #[test]
375    fn test_count_prompt_tokens() {
376        let counter = ImageTokenCounter::new();
377        let prompt = "Please analyze this image";
378
379        let tokens = counter.count_prompt_tokens(prompt, "gpt-4").unwrap();
380
381        // Should have some tokens
382        assert!(tokens > 0);
383    }
384
385    #[test]
386    fn test_count_total_tokens() {
387        let counter = ImageTokenCounter::new();
388        let metadata = create_test_metadata(800, 600, 1024 * 1024);
389        let prompt = "Please analyze this image";
390
391        let total = counter
392            .count_total_tokens(&metadata, prompt, "openai", "gpt-4-vision")
393            .unwrap();
394
395        // Should be image tokens + prompt tokens
396        assert!(total > 0);
397    }
398
399    #[test]
400    fn test_count_tokens_different_providers() {
401        let counter = ImageTokenCounter::new();
402        let metadata = create_test_metadata(800, 600, 1024 * 1024);
403
404        let openai = counter
405            .count_image_tokens(&metadata, "openai", "gpt-4-vision")
406            .unwrap();
407        let anthropic = counter
408            .count_image_tokens(&metadata, "anthropic", "claude-3-vision")
409            .unwrap();
410        let google = counter
411            .count_image_tokens(&metadata, "google", "gemini-pro-vision")
412            .unwrap();
413
414        // Different providers should have different token counts
415        assert_ne!(openai, anthropic);
416        assert_ne!(anthropic, google);
417    }
418
419    #[test]
420    fn test_count_tokens_empty_prompt() {
421        let counter = ImageTokenCounter::new();
422        let tokens = counter.count_prompt_tokens("", "gpt-4").unwrap();
423
424        // Empty prompt should have 0 tokens
425        assert_eq!(tokens, 0);
426    }
427}