1use crate::config::ImageConfig;
13use crate::error::{ImageError, ImageResult};
14use crate::models::{ImageAnalysisResult, ImageMetadata};
15use ricecoder_providers::models::{ChatRequest, Message};
16use ricecoder_providers::provider::Provider;
17use ricecoder_providers::token_counter::TokenCounter;
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::time::sleep;
21use tracing::{debug, error, info, warn};
22
23const MAX_RETRIES: u32 = 3;
25
26const INITIAL_RETRY_DELAY_MS: u64 = 100;
28
29#[derive(Debug, Clone)]
46pub struct AnalysisRetryContext {
47 pub metadata: ImageMetadata,
49 pub image_data: Vec<u8>,
51 pub last_error: Option<String>,
53 pub retry_attempts: u32,
55}
56
57impl AnalysisRetryContext {
58 pub fn new(metadata: ImageMetadata, image_data: Vec<u8>) -> Self {
60 Self {
61 metadata,
62 image_data,
63 last_error: None,
64 retry_attempts: 0,
65 }
66 }
67
68 pub fn record_failure(&mut self, error: String) {
70 self.last_error = Some(error);
71 self.retry_attempts += 1;
72 }
73
74 pub fn can_retry(&self) -> bool {
76 self.retry_attempts < 5 }
78
79 pub fn get_error_message(&self) -> String {
81 match &self.last_error {
82 Some(err) => {
83 if self.can_retry() {
84 format!(
85 "Analysis failed: {}. You can retry by clicking the retry button.",
86 err
87 )
88 } else {
89 format!(
90 "Analysis failed after {} attempts: {}. Please try again later.",
91 self.retry_attempts, err
92 )
93 }
94 }
95 None => "No error recorded".to_string(),
96 }
97 }
98}
99
100pub struct ImageAnalyzer {
102 config: ImageConfig,
103 token_counter: Arc<TokenCounter>,
104}
105
106impl ImageAnalyzer {
107 pub fn new() -> ImageResult<Self> {
109 let config = ImageConfig::load_with_hierarchy()?;
110 Ok(Self {
111 config,
112 token_counter: Arc::new(TokenCounter::new()),
113 })
114 }
115
116 pub fn with_config(config: ImageConfig) -> Self {
118 Self {
119 config,
120 token_counter: Arc::new(TokenCounter::new()),
121 }
122 }
123
124 pub async fn analyze(
140 &self,
141 metadata: &ImageMetadata,
142 provider: &dyn Provider,
143 image_data: &[u8],
144 ) -> ImageResult<ImageAnalysisResult> {
145 let optimized_data = if metadata.size_mb() > self.config.analysis.max_image_size_mb as f64 {
147 info!(
148 size_mb = metadata.size_mb(),
149 max_mb = self.config.analysis.max_image_size_mb,
150 "Optimizing large image before analysis"
151 );
152 self.optimize_image(image_data).await?
153 } else {
154 image_data.to_vec()
155 };
156
157 self.analyze_with_retry(metadata, provider, &optimized_data)
159 .await
160 }
161
162 pub async fn analyze_multiple(
173 &self,
174 images: Vec<(ImageMetadata, Vec<u8>)>,
175 provider: &dyn Provider,
176 ) -> Vec<ImageResult<ImageAnalysisResult>> {
177 let mut results = Vec::new();
178
179 for (metadata, image_data) in images {
180 let result = self.analyze(&metadata, provider, &image_data).await;
181 results.push(result);
182 }
183
184 results
185 }
186
187 async fn analyze_with_retry(
189 &self,
190 metadata: &ImageMetadata,
191 provider: &dyn Provider,
192 image_data: &[u8],
193 ) -> ImageResult<ImageAnalysisResult> {
194 let mut retry_count = 0;
195 let mut delay_ms = INITIAL_RETRY_DELAY_MS;
196
197 loop {
198 match self.perform_analysis(metadata, provider, image_data).await {
199 Ok(result) => {
200 info!(
201 image_hash = %metadata.hash,
202 provider = provider.name(),
203 tokens_used = result.tokens_used,
204 "Image analysis completed successfully"
205 );
206 return Ok(result);
207 }
208 Err(err) => {
209 retry_count += 1;
210
211 if retry_count >= MAX_RETRIES {
212 error!(
213 image_hash = %metadata.hash,
214 provider = provider.name(),
215 error = %err,
216 attempts = retry_count,
217 "Image analysis failed after retries"
218 );
219 return Err(ImageError::AnalysisFailed(format!(
220 "Analysis failed after {} attempts: {}. Please try again.",
221 retry_count, err
222 )));
223 }
224
225 warn!(
226 image_hash = %metadata.hash,
227 provider = provider.name(),
228 error = %err,
229 attempt = retry_count,
230 retry_delay_ms = delay_ms,
231 "Image analysis failed, retrying..."
232 );
233
234 sleep(Duration::from_millis(delay_ms)).await;
236 delay_ms *= 2; }
238 }
239 }
240 }
241
242 async fn perform_analysis(
244 &self,
245 metadata: &ImageMetadata,
246 provider: &dyn Provider,
247 image_data: &[u8],
248 ) -> ImageResult<ImageAnalysisResult> {
249 let prompt = format!(
251 "Please analyze this image. Provide a detailed description of what you see, \
252 including any text, objects, people, and overall context.\n\n\
253 Image format: {}\n\
254 Image dimensions: {}x{} pixels\n\
255 Image size: {:.1} MB",
256 metadata.format_str(),
257 metadata.width,
258 metadata.height,
259 metadata.size_mb()
260 );
261
262 let image_base64 = base64_encode(image_data);
264
265 let request = ChatRequest {
269 model: provider.models().first().ok_or_else(|| {
270 ImageError::AnalysisFailed("Provider has no available models".to_string())
271 })?.id.clone(),
272 messages: vec![Message {
273 role: "user".to_string(),
274 content: format!(
275 "{}\n\n[Image: format={}, size={} bytes, base64={}...]",
276 prompt,
277 metadata.format_str(),
278 image_data.len(),
279 &image_base64[..std::cmp::min(50, image_base64.len())]
280 ),
281 }],
282 temperature: Some(0.7),
283 max_tokens: Some(1000),
284 stream: false,
285 };
286
287 debug!(
289 provider = provider.name(),
290 model = &request.model,
291 image_size = image_data.len(),
292 image_format = metadata.format_str(),
293 "Sending image to provider for analysis"
294 );
295
296 let response = tokio::time::timeout(
297 Duration::from_secs(self.config.analysis.timeout_seconds),
298 provider.chat(request),
299 )
300 .await
301 .map_err(|_| {
302 ImageError::AnalysisFailed(format!(
303 "Analysis timeout after {} seconds",
304 self.config.analysis.timeout_seconds
305 ))
306 })?
307 .map_err(|e| ImageError::AnalysisFailed(e.to_string()))?;
308
309 let tokens_used = self
311 .count_image_tokens(metadata, &response.model)
312 .unwrap_or(0) as u32;
313
314 let result = ImageAnalysisResult::new(
316 metadata.hash.clone(),
317 response.content,
318 provider.name().to_string(),
319 tokens_used,
320 );
321
322 Ok(result)
323 }
324
325 pub fn count_image_tokens(&self, metadata: &ImageMetadata, model: &str) -> ImageResult<usize> {
333 let provider_name = if model.contains("gpt") {
335 "openai"
336 } else if model.contains("claude") {
337 "anthropic"
338 } else if model.contains("gemini") {
339 "google"
340 } else {
341 "ollama"
342 };
343
344 let base_tokens = match provider_name {
346 "openai" => {
347 let resolution_factor = (metadata.width as usize * metadata.height as usize) / 10000;
349 85 + resolution_factor
350 }
351 "anthropic" => {
352 1600
354 }
355 "google" => {
356 258
358 }
359 _ => {
360 100
362 }
363 };
364
365 debug!(
366 provider = provider_name,
367 model = model,
368 image_tokens = base_tokens,
369 "Counted image tokens"
370 );
371
372 Ok(base_tokens)
373 }
374
375 async fn optimize_image(&self, image_data: &[u8]) -> ImageResult<Vec<u8>> {
380 if image_data.is_empty() {
383 return Err(ImageError::InvalidFile(
384 "Image data is empty".to_string(),
385 ));
386 }
387
388 debug!(
389 original_size = image_data.len(),
390 "Image optimization placeholder (would resize/compress in production)"
391 );
392
393 Ok(image_data.to_vec())
396 }
397
398 pub fn config(&self) -> &ImageConfig {
400 &self.config
401 }
402
403 pub fn token_counter(&self) -> &TokenCounter {
405 &self.token_counter
406 }
407
408 pub async fn retry_analysis(
422 &self,
423 mut context: AnalysisRetryContext,
424 provider: &dyn Provider,
425 ) -> ImageResult<ImageAnalysisResult> {
426 if !context.can_retry() {
427 return Err(ImageError::AnalysisFailed(
428 "Maximum retry attempts exceeded. Please try again later.".to_string(),
429 ));
430 }
431
432 info!(
433 image_hash = %context.metadata.hash,
434 provider = provider.name(),
435 attempt = context.retry_attempts + 1,
436 "Retrying image analysis"
437 );
438
439 match self.analyze(&context.metadata, provider, &context.image_data).await {
440 Ok(result) => {
441 info!(
442 image_hash = %context.metadata.hash,
443 provider = provider.name(),
444 "Image analysis succeeded on retry"
445 );
446 Ok(result)
447 }
448 Err(err) => {
449 context.record_failure(err.to_string());
450 Err(ImageError::AnalysisFailed(context.get_error_message()))
451 }
452 }
453 }
454}
455
456impl Default for ImageAnalyzer {
457 fn default() -> Self {
458 Self::new().unwrap_or_else(|_| {
459 Self::with_config(ImageConfig::default())
460 })
461 }
462}
463
464fn base64_encode(data: &[u8]) -> String {
466 const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
467 let mut result = String::new();
468
469 for chunk in data.chunks(3) {
470 let b1 = chunk[0];
471 let b2 = chunk.get(1).copied().unwrap_or(0);
472 let b3 = chunk.get(2).copied().unwrap_or(0);
473
474 let n = ((b1 as u32) << 16) | ((b2 as u32) << 8) | (b3 as u32);
475
476 result.push(BASE64_CHARS[((n >> 18) & 63) as usize] as char);
477 result.push(BASE64_CHARS[((n >> 12) & 63) as usize] as char);
478
479 if chunk.len() > 1 {
480 result.push(BASE64_CHARS[((n >> 6) & 63) as usize] as char);
481 } else {
482 result.push('=');
483 }
484
485 if chunk.len() > 2 {
486 result.push(BASE64_CHARS[(n & 63) as usize] as char);
487 } else {
488 result.push('=');
489 }
490 }
491
492 result
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_analyzer_creation() {
501 let analyzer = ImageAnalyzer::new().unwrap_or_else(|_| {
502 ImageAnalyzer::with_config(ImageConfig::default())
503 });
504 assert_eq!(analyzer.config().analysis.timeout_seconds, 10);
505 }
506
507 #[test]
508 fn test_analyzer_with_config() {
509 let config = ImageConfig::default();
510 let analyzer = ImageAnalyzer::with_config(config.clone());
511 assert_eq!(analyzer.config().analysis.max_image_size_mb, 10);
512 }
513
514 #[test]
515 fn test_analyzer_default() {
516 let _analyzer = ImageAnalyzer::default();
517 }
518
519 #[tokio::test]
520 async fn test_optimize_image_empty() {
521 let analyzer = ImageAnalyzer::default();
522 let result = analyzer.optimize_image(&[]).await;
523 assert!(result.is_err());
524 }
525
526 #[tokio::test]
527 async fn test_optimize_image_valid() {
528 let analyzer = ImageAnalyzer::default();
529 let data = vec![1, 2, 3, 4, 5];
530 let result = analyzer.optimize_image(&data).await;
531 assert!(result.is_ok());
532 }
533
534 #[test]
535 fn test_base64_encode() {
536 let data = b"Hello";
538 let encoded = base64_encode(data);
539 assert!(!encoded.is_empty());
540 assert!(encoded.len() > 0);
541
542 let empty = base64_encode(&[]);
544 assert_eq!(empty, "");
545
546 let single = base64_encode(&[65]); assert!(!single.is_empty());
549 }
550
551 #[test]
552 fn test_count_image_tokens_openai() {
553 let analyzer = ImageAnalyzer::default();
554 let metadata = ImageMetadata::new(
555 std::path::PathBuf::from("/test.png"),
556 crate::formats::ImageFormat::Png,
557 1024,
558 800,
559 600,
560 "hash123".to_string(),
561 );
562
563 let tokens = analyzer.count_image_tokens(&metadata, "gpt-4-vision").unwrap();
564 assert!(tokens > 0);
565 assert!(tokens >= 85); }
567
568 #[test]
569 fn test_count_image_tokens_anthropic() {
570 let analyzer = ImageAnalyzer::default();
571 let metadata = ImageMetadata::new(
572 std::path::PathBuf::from("/test.png"),
573 crate::formats::ImageFormat::Png,
574 1024,
575 800,
576 600,
577 "hash123".to_string(),
578 );
579
580 let tokens = analyzer.count_image_tokens(&metadata, "claude-3-vision").unwrap();
581 assert_eq!(tokens, 1600);
582 }
583
584 #[test]
585 fn test_count_image_tokens_google() {
586 let analyzer = ImageAnalyzer::default();
587 let metadata = ImageMetadata::new(
588 std::path::PathBuf::from("/test.png"),
589 crate::formats::ImageFormat::Png,
590 1024,
591 800,
592 600,
593 "hash123".to_string(),
594 );
595
596 let tokens = analyzer.count_image_tokens(&metadata, "gemini-pro-vision").unwrap();
597 assert_eq!(tokens, 258);
598 }
599
600 #[test]
601 fn test_count_image_tokens_ollama() {
602 let analyzer = ImageAnalyzer::default();
603 let metadata = ImageMetadata::new(
604 std::path::PathBuf::from("/test.png"),
605 crate::formats::ImageFormat::Png,
606 1024,
607 800,
608 600,
609 "hash123".to_string(),
610 );
611
612 let tokens = analyzer.count_image_tokens(&metadata, "llava").unwrap();
613 assert_eq!(tokens, 100);
614 }
615
616 #[test]
617 fn test_retry_context_creation() {
618 let metadata = ImageMetadata::new(
619 std::path::PathBuf::from("/test.png"),
620 crate::formats::ImageFormat::Png,
621 1024,
622 800,
623 600,
624 "hash123".to_string(),
625 );
626 let image_data = vec![1, 2, 3, 4, 5];
627
628 let context = AnalysisRetryContext::new(metadata, image_data);
629 assert_eq!(context.retry_attempts, 0);
630 assert!(context.can_retry());
631 assert!(context.last_error.is_none());
632 }
633
634 #[test]
635 fn test_retry_context_record_failure() {
636 let metadata = ImageMetadata::new(
637 std::path::PathBuf::from("/test.png"),
638 crate::formats::ImageFormat::Png,
639 1024,
640 800,
641 600,
642 "hash123".to_string(),
643 );
644 let image_data = vec![1, 2, 3, 4, 5];
645
646 let mut context = AnalysisRetryContext::new(metadata, image_data);
647 context.record_failure("Test error".to_string());
648
649 assert_eq!(context.retry_attempts, 1);
650 assert!(context.can_retry());
651 assert!(context.last_error.is_some());
652 }
653
654 #[test]
655 fn test_retry_context_max_retries() {
656 let metadata = ImageMetadata::new(
657 std::path::PathBuf::from("/test.png"),
658 crate::formats::ImageFormat::Png,
659 1024,
660 800,
661 600,
662 "hash123".to_string(),
663 );
664 let image_data = vec![1, 2, 3, 4, 5];
665
666 let mut context = AnalysisRetryContext::new(metadata, image_data);
667
668 for i in 0..5 {
670 context.record_failure(format!("Error {}", i));
671 }
672
673 assert_eq!(context.retry_attempts, 5);
674 assert!(!context.can_retry());
675 }
676
677 #[test]
678 fn test_retry_context_error_message() {
679 let metadata = ImageMetadata::new(
680 std::path::PathBuf::from("/test.png"),
681 crate::formats::ImageFormat::Png,
682 1024,
683 800,
684 600,
685 "hash123".to_string(),
686 );
687 let image_data = vec![1, 2, 3, 4, 5];
688
689 let mut context = AnalysisRetryContext::new(metadata, image_data);
690 context.record_failure("Provider timeout".to_string());
691
692 let msg = context.get_error_message();
693 assert!(msg.contains("Provider timeout"));
694 assert!(msg.contains("retry"));
695 }
696
697 #[test]
698 fn test_retry_context_error_message_max_retries() {
699 let metadata = ImageMetadata::new(
700 std::path::PathBuf::from("/test.png"),
701 crate::formats::ImageFormat::Png,
702 1024,
703 800,
704 600,
705 "hash123".to_string(),
706 );
707 let image_data = vec![1, 2, 3, 4, 5];
708
709 let mut context = AnalysisRetryContext::new(metadata, image_data);
710
711 for i in 0..5 {
713 context.record_failure(format!("Error {}", i));
714 }
715
716 let msg = context.get_error_message();
717 assert!(msg.contains("5 attempts"));
718 assert!(msg.contains("try again later"));
719 }
720}