oxify_connect_vision/
batch.rs

1//! Batch processing API for OCR operations.
2//!
3//! This module provides efficient batch processing of multiple images
4//! with parallel execution, progress tracking, and result aggregation.
5
6use crate::errors::{Result, VisionError};
7use crate::providers::VisionProvider;
8use crate::types::OcrResult;
9use futures::stream::{self, StreamExt};
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14/// Configuration for batch processing.
15#[derive(Debug, Clone)]
16pub struct BatchConfig {
17    /// Maximum number of concurrent operations
18    pub max_concurrency: usize,
19    /// Continue processing on error (don't stop entire batch)
20    pub continue_on_error: bool,
21    /// Enable progress reporting
22    pub report_progress: bool,
23}
24
25impl Default for BatchConfig {
26    fn default() -> Self {
27        Self {
28            max_concurrency: num_cpus::get(),
29            continue_on_error: true,
30            report_progress: false,
31        }
32    }
33}
34
35impl BatchConfig {
36    /// Create a configuration optimized for fast processing.
37    ///
38    /// Uses maximum available CPUs.
39    pub fn fast() -> Self {
40        Self {
41            max_concurrency: num_cpus::get() * 2,
42            continue_on_error: true,
43            report_progress: false,
44        }
45    }
46
47    /// Create a configuration optimized for GPU processing.
48    ///
49    /// Lower concurrency to avoid GPU memory exhaustion.
50    pub fn gpu() -> Self {
51        Self {
52            max_concurrency: 4,
53            continue_on_error: true,
54            report_progress: true,
55        }
56    }
57
58    /// Create a configuration with progress reporting enabled.
59    pub fn with_progress() -> Self {
60        Self {
61            max_concurrency: num_cpus::get(),
62            continue_on_error: true,
63            report_progress: true,
64        }
65    }
66
67    /// Set maximum concurrency.
68    pub fn with_max_concurrency(mut self, max: usize) -> Self {
69        self.max_concurrency = max.max(1);
70        self
71    }
72
73    /// Set whether to continue on error.
74    pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
75        self.continue_on_error = continue_on_error;
76        self
77    }
78
79    /// Set whether to report progress.
80    pub fn with_report_progress(mut self, report: bool) -> Self {
81        self.report_progress = report;
82        self
83    }
84}
85
86/// Progress information for batch processing.
87#[derive(Debug, Clone)]
88pub struct BatchProgress {
89    /// Total number of items to process
90    pub total: usize,
91    /// Number of items completed
92    pub completed: usize,
93    /// Number of items that succeeded
94    pub succeeded: usize,
95    /// Number of items that failed
96    pub failed: usize,
97}
98
99impl BatchProgress {
100    /// Create a new progress tracker.
101    pub fn new(total: usize) -> Self {
102        Self {
103            total,
104            completed: 0,
105            succeeded: 0,
106            failed: 0,
107        }
108    }
109
110    /// Get completion percentage (0.0 to 1.0).
111    pub fn percentage(&self) -> f32 {
112        if self.total == 0 {
113            1.0
114        } else {
115            self.completed as f32 / self.total as f32
116        }
117    }
118
119    /// Check if processing is complete.
120    pub fn is_complete(&self) -> bool {
121        self.completed >= self.total
122    }
123
124    /// Get a formatted progress string.
125    pub fn format(&self) -> String {
126        format!(
127            "{}/{} ({:.1}%) - {} succeeded, {} failed",
128            self.completed,
129            self.total,
130            self.percentage() * 100.0,
131            self.succeeded,
132            self.failed
133        )
134    }
135}
136
137/// Result of a single batch item.
138#[derive(Debug, Clone)]
139pub struct BatchItemResult {
140    /// Index of the item in the batch
141    pub index: usize,
142    /// OCR result (if successful)
143    pub result: Option<OcrResult>,
144    /// Error (if failed)
145    pub error: Option<String>,
146}
147
148impl BatchItemResult {
149    /// Check if this item succeeded.
150    pub fn is_success(&self) -> bool {
151        self.result.is_some()
152    }
153
154    /// Check if this item failed.
155    pub fn is_error(&self) -> bool {
156        self.error.is_some()
157    }
158}
159
160/// Result of batch processing operation.
161#[derive(Debug, Clone)]
162pub struct BatchResult {
163    /// Individual item results
164    pub items: Vec<BatchItemResult>,
165    /// Final progress statistics
166    pub progress: BatchProgress,
167    /// Total processing time (milliseconds)
168    pub total_time_ms: u64,
169}
170
171impl BatchResult {
172    /// Get all successful results.
173    pub fn successful_results(&self) -> Vec<&OcrResult> {
174        self.items
175            .iter()
176            .filter_map(|item| item.result.as_ref())
177            .collect()
178    }
179
180    /// Get all errors.
181    pub fn errors(&self) -> Vec<(usize, &str)> {
182        self.items
183            .iter()
184            .filter_map(|item| item.error.as_ref().map(|e| (item.index, e.as_str())))
185            .collect()
186    }
187
188    /// Get success rate (0.0 to 1.0).
189    pub fn success_rate(&self) -> f32 {
190        if self.items.is_empty() {
191            0.0
192        } else {
193            self.progress.succeeded as f32 / self.items.len() as f32
194        }
195    }
196
197    /// Get formatted summary.
198    pub fn summary(&self) -> String {
199        format!(
200            "Batch processing complete: {} items in {}ms\n\
201             Success: {} ({:.1}%), Failed: {}",
202            self.items.len(),
203            self.total_time_ms,
204            self.progress.succeeded,
205            self.success_rate() * 100.0,
206            self.progress.failed
207        )
208    }
209}
210
211/// Progress callback for batch processing.
212pub type ProgressCallback = Arc<dyn Fn(&BatchProgress) + Send + Sync>;
213
214/// Batch processor for OCR operations.
215pub struct BatchProcessor {
216    config: BatchConfig,
217    progress_callback: Option<ProgressCallback>,
218}
219
220impl BatchProcessor {
221    /// Create a new batch processor with the given configuration.
222    pub fn new(config: BatchConfig) -> Self {
223        Self {
224            config,
225            progress_callback: None,
226        }
227    }
228
229    /// Create a batch processor with default configuration.
230    pub fn default_config() -> Self {
231        Self::new(BatchConfig::default())
232    }
233
234    /// Set progress callback.
235    ///
236    /// The callback will be invoked periodically with progress updates.
237    pub fn with_progress_callback<F>(mut self, callback: F) -> Self
238    where
239        F: Fn(&BatchProgress) + Send + Sync + 'static,
240    {
241        self.progress_callback = Some(Arc::new(callback));
242        self
243    }
244
245    /// Process multiple images in parallel.
246    ///
247    /// # Arguments
248    ///
249    /// * `provider` - The vision provider to use
250    /// * `images` - Vector of image data (as byte slices)
251    ///
252    /// # Returns
253    ///
254    /// Batch result with individual item results and statistics.
255    pub async fn process_batch(
256        &self,
257        provider: Arc<dyn VisionProvider>,
258        images: Vec<Vec<u8>>,
259    ) -> Result<BatchResult> {
260        let start_time = std::time::Instant::now();
261        let total = images.len();
262
263        if total == 0 {
264            return Ok(BatchResult {
265                items: vec![],
266                progress: BatchProgress::new(0),
267                total_time_ms: 0,
268            });
269        }
270
271        // Create progress tracker
272        let progress = Arc::new(RwLock::new(BatchProgress::new(total)));
273        let completed_count = Arc::new(AtomicUsize::new(0));
274        let succeeded_count = Arc::new(AtomicUsize::new(0));
275        let failed_count = Arc::new(AtomicUsize::new(0));
276
277        // Process items in parallel using stream
278        let items = stream::iter(images.into_iter().enumerate())
279            .map(|(index, image_data)| {
280                let provider = Arc::clone(&provider);
281                let progress = Arc::clone(&progress);
282                let completed = Arc::clone(&completed_count);
283                let succeeded = Arc::clone(&succeeded_count);
284                let failed = Arc::clone(&failed_count);
285                let progress_callback = self.progress_callback.clone();
286                let continue_on_error = self.config.continue_on_error;
287
288                async move {
289                    // Process image
290                    let result = provider.process_image(&image_data).await;
291
292                    // Update progress
293                    let is_success = result.is_ok();
294                    completed.fetch_add(1, Ordering::SeqCst);
295
296                    if is_success {
297                        succeeded.fetch_add(1, Ordering::SeqCst);
298                    } else {
299                        failed.fetch_add(1, Ordering::SeqCst);
300                    }
301
302                    // Update progress tracker
303                    {
304                        let mut p = progress.write().await;
305                        p.completed = completed.load(Ordering::SeqCst);
306                        p.succeeded = succeeded.load(Ordering::SeqCst);
307                        p.failed = failed.load(Ordering::SeqCst);
308
309                        // Invoke callback if set
310                        if let Some(ref callback) = progress_callback {
311                            callback(&p);
312                        }
313                    }
314
315                    // Create item result
316                    match result {
317                        Ok(ocr_result) => Ok(BatchItemResult {
318                            index,
319                            result: Some(ocr_result),
320                            error: None,
321                        }),
322                        Err(e) => {
323                            if !continue_on_error {
324                                // Return error to stop processing
325                                return Err(e);
326                            }
327                            Ok(BatchItemResult {
328                                index,
329                                result: None,
330                                error: Some(e.to_string()),
331                            })
332                        }
333                    }
334                }
335            })
336            .buffer_unordered(self.config.max_concurrency);
337
338        // Collect results
339        let results: Vec<Result<BatchItemResult>> = items.collect().await;
340
341        // Check for fatal errors (if continue_on_error is false)
342        let mut item_results = Vec::new();
343        for result in results {
344            match result {
345                Ok(item) => item_results.push(item),
346                Err(e) => {
347                    // Fatal error occurred
348                    return Err(e);
349                }
350            }
351        }
352
353        // Sort by index to maintain order
354        item_results.sort_by_key(|item| item.index);
355
356        let elapsed = start_time.elapsed();
357        let final_progress = progress.read().await.clone();
358
359        Ok(BatchResult {
360            items: item_results,
361            progress: final_progress,
362            total_time_ms: elapsed.as_millis() as u64,
363        })
364    }
365
366    /// Process multiple images with custom result processing.
367    ///
368    /// This method allows you to process results as they complete.
369    pub async fn process_batch_with_callback<F>(
370        &self,
371        provider: Arc<dyn VisionProvider>,
372        images: Vec<Vec<u8>>,
373        callback: F,
374    ) -> Result<()>
375    where
376        F: Fn(usize, Result<OcrResult>) + Send + Sync,
377    {
378        let total = images.len();
379
380        if total == 0 {
381            return Ok(());
382        }
383
384        // Process items in parallel
385        let mut stream = stream::iter(images.into_iter().enumerate())
386            .map(|(index, image_data)| {
387                let provider = Arc::clone(&provider);
388                async move {
389                    let result = provider.process_image(&image_data).await;
390                    (index, result)
391                }
392            })
393            .buffer_unordered(self.config.max_concurrency);
394
395        // Process results as they complete
396        while let Some((index, result)) = stream.next().await {
397            if !self.config.continue_on_error && result.is_err() {
398                return result.map(|_| ());
399            }
400            callback(index, result);
401        }
402
403        Ok(())
404    }
405}
406
407/// Helper function to process a batch of images with default settings.
408///
409/// This is a convenience function for simple batch processing.
410pub async fn process_batch_simple(
411    provider: Arc<dyn VisionProvider>,
412    images: Vec<Vec<u8>>,
413) -> Result<Vec<Result<OcrResult>>> {
414    let processor = BatchProcessor::default_config();
415    let batch_result = processor.process_batch(provider, images).await?;
416
417    Ok(batch_result
418        .items
419        .into_iter()
420        .map(|item| {
421            if let Some(result) = item.result {
422                Ok(result)
423            } else {
424                Err(VisionError::image_processing(
425                    item.error.unwrap_or_else(|| "Unknown error".to_string()),
426                ))
427            }
428        })
429        .collect())
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::providers::MockVisionProvider;
436
437    #[test]
438    fn test_batch_config_default() {
439        let config = BatchConfig::default();
440        assert!(config.max_concurrency > 0);
441        assert!(config.continue_on_error);
442    }
443
444    #[test]
445    fn test_batch_config_fast() {
446        let config = BatchConfig::fast();
447        assert!(config.max_concurrency >= num_cpus::get());
448    }
449
450    #[test]
451    fn test_batch_config_gpu() {
452        let config = BatchConfig::gpu();
453        assert_eq!(config.max_concurrency, 4);
454        assert!(config.report_progress);
455    }
456
457    #[test]
458    fn test_batch_config_builders() {
459        let config = BatchConfig::default()
460            .with_max_concurrency(8)
461            .with_continue_on_error(false)
462            .with_report_progress(true);
463
464        assert_eq!(config.max_concurrency, 8);
465        assert!(!config.continue_on_error);
466        assert!(config.report_progress);
467    }
468
469    #[test]
470    fn test_batch_progress() {
471        let mut progress = BatchProgress::new(100);
472        assert_eq!(progress.percentage(), 0.0);
473        assert!(!progress.is_complete());
474
475        progress.completed = 50;
476        assert_eq!(progress.percentage(), 0.5);
477        assert!(!progress.is_complete());
478
479        progress.completed = 100;
480        assert_eq!(progress.percentage(), 1.0);
481        assert!(progress.is_complete());
482    }
483
484    #[test]
485    fn test_batch_progress_format() {
486        let progress = BatchProgress {
487            total: 100,
488            completed: 50,
489            succeeded: 45,
490            failed: 5,
491        };
492
493        let formatted = progress.format();
494        assert!(formatted.contains("50/100"));
495        assert!(formatted.contains("45 succeeded"));
496        assert!(formatted.contains("5 failed"));
497    }
498
499    #[test]
500    fn test_batch_item_result() {
501        let success_item = BatchItemResult {
502            index: 0,
503            result: Some(OcrResult::from_text("test")),
504            error: None,
505        };
506        assert!(success_item.is_success());
507        assert!(!success_item.is_error());
508
509        let error_item = BatchItemResult {
510            index: 1,
511            result: None,
512            error: Some("error".to_string()),
513        };
514        assert!(!error_item.is_success());
515        assert!(error_item.is_error());
516    }
517
518    #[tokio::test]
519    async fn test_batch_processor_creation() {
520        let processor = BatchProcessor::default_config();
521        assert!(processor.config.max_concurrency > 0);
522    }
523
524    #[tokio::test]
525    async fn test_process_batch_empty() {
526        let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
527        let processor = BatchProcessor::default_config();
528        let result = processor.process_batch(provider, vec![]).await.unwrap();
529
530        assert_eq!(result.items.len(), 0);
531        assert_eq!(result.progress.total, 0);
532    }
533
534    #[tokio::test]
535    async fn test_process_batch_single() {
536        let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
537        provider.load_model().await.unwrap();
538
539        let processor = BatchProcessor::default_config();
540        let images = vec![b"test image".to_vec()];
541        let result = processor.process_batch(provider, images).await.unwrap();
542
543        assert_eq!(result.items.len(), 1);
544        assert_eq!(result.progress.total, 1);
545        assert_eq!(result.progress.completed, 1);
546        assert_eq!(result.progress.succeeded, 1);
547    }
548
549    #[tokio::test]
550    async fn test_process_batch_multiple() {
551        let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
552        provider.load_model().await.unwrap();
553
554        let processor = BatchProcessor::default_config();
555        let images = vec![b"image1".to_vec(), b"image2".to_vec(), b"image3".to_vec()];
556        let result = processor.process_batch(provider, images).await.unwrap();
557
558        assert_eq!(result.items.len(), 3);
559        assert_eq!(result.progress.succeeded, 3);
560        assert_eq!(result.success_rate(), 1.0);
561    }
562
563    #[tokio::test]
564    async fn test_batch_result_methods() {
565        let batch_result = BatchResult {
566            items: vec![
567                BatchItemResult {
568                    index: 0,
569                    result: Some(OcrResult::from_text("success")),
570                    error: None,
571                },
572                BatchItemResult {
573                    index: 1,
574                    result: None,
575                    error: Some("failed".to_string()),
576                },
577            ],
578            progress: BatchProgress {
579                total: 2,
580                completed: 2,
581                succeeded: 1,
582                failed: 1,
583            },
584            total_time_ms: 100,
585        };
586
587        assert_eq!(batch_result.successful_results().len(), 1);
588        assert_eq!(batch_result.errors().len(), 1);
589        assert_eq!(batch_result.success_rate(), 0.5);
590
591        let summary = batch_result.summary();
592        assert!(summary.contains("2 items"));
593        assert!(summary.contains("100ms"));
594    }
595
596    #[tokio::test]
597    async fn test_progress_callback() {
598        let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
599        provider.load_model().await.unwrap();
600
601        let progress_updates = Arc::new(RwLock::new(Vec::new()));
602        let updates_clone = Arc::clone(&progress_updates);
603
604        let processor = BatchProcessor::new(BatchConfig::with_progress()).with_progress_callback(
605            move |progress| {
606                let updates = updates_clone.clone();
607                let completed = progress.completed; // Clone the data, not the reference
608                tokio::spawn(async move {
609                    updates.write().await.push(completed);
610                });
611            },
612        );
613
614        let images = vec![b"img1".to_vec(), b"img2".to_vec()];
615        let _result = processor.process_batch(provider, images).await.unwrap();
616
617        // Give callbacks time to execute
618        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
619
620        let updates = progress_updates.read().await;
621        assert!(!updates.is_empty());
622    }
623
624    #[tokio::test]
625    async fn test_process_batch_simple() {
626        let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
627        provider.load_model().await.unwrap();
628
629        let images = vec![b"test".to_vec()];
630        let results = process_batch_simple(provider, images).await.unwrap();
631
632        assert_eq!(results.len(), 1);
633        assert!(results[0].is_ok());
634    }
635}