1use crate::errors::{Result, VisionError};
7use crate::providers::VisionProvider;
8use crate::types::OcrResult;
9use futures::stream::{self, StreamExt};
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14#[derive(Debug, Clone)]
16pub struct BatchConfig {
17 pub max_concurrency: usize,
19 pub continue_on_error: bool,
21 pub report_progress: bool,
23}
24
25impl Default for BatchConfig {
26 fn default() -> Self {
27 Self {
28 max_concurrency: num_cpus::get(),
29 continue_on_error: true,
30 report_progress: false,
31 }
32 }
33}
34
35impl BatchConfig {
36 pub fn fast() -> Self {
40 Self {
41 max_concurrency: num_cpus::get() * 2,
42 continue_on_error: true,
43 report_progress: false,
44 }
45 }
46
47 pub fn gpu() -> Self {
51 Self {
52 max_concurrency: 4,
53 continue_on_error: true,
54 report_progress: true,
55 }
56 }
57
58 pub fn with_progress() -> Self {
60 Self {
61 max_concurrency: num_cpus::get(),
62 continue_on_error: true,
63 report_progress: true,
64 }
65 }
66
67 pub fn with_max_concurrency(mut self, max: usize) -> Self {
69 self.max_concurrency = max.max(1);
70 self
71 }
72
73 pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
75 self.continue_on_error = continue_on_error;
76 self
77 }
78
79 pub fn with_report_progress(mut self, report: bool) -> Self {
81 self.report_progress = report;
82 self
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct BatchProgress {
89 pub total: usize,
91 pub completed: usize,
93 pub succeeded: usize,
95 pub failed: usize,
97}
98
99impl BatchProgress {
100 pub fn new(total: usize) -> Self {
102 Self {
103 total,
104 completed: 0,
105 succeeded: 0,
106 failed: 0,
107 }
108 }
109
110 pub fn percentage(&self) -> f32 {
112 if self.total == 0 {
113 1.0
114 } else {
115 self.completed as f32 / self.total as f32
116 }
117 }
118
119 pub fn is_complete(&self) -> bool {
121 self.completed >= self.total
122 }
123
124 pub fn format(&self) -> String {
126 format!(
127 "{}/{} ({:.1}%) - {} succeeded, {} failed",
128 self.completed,
129 self.total,
130 self.percentage() * 100.0,
131 self.succeeded,
132 self.failed
133 )
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct BatchItemResult {
140 pub index: usize,
142 pub result: Option<OcrResult>,
144 pub error: Option<String>,
146}
147
148impl BatchItemResult {
149 pub fn is_success(&self) -> bool {
151 self.result.is_some()
152 }
153
154 pub fn is_error(&self) -> bool {
156 self.error.is_some()
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct BatchResult {
163 pub items: Vec<BatchItemResult>,
165 pub progress: BatchProgress,
167 pub total_time_ms: u64,
169}
170
171impl BatchResult {
172 pub fn successful_results(&self) -> Vec<&OcrResult> {
174 self.items
175 .iter()
176 .filter_map(|item| item.result.as_ref())
177 .collect()
178 }
179
180 pub fn errors(&self) -> Vec<(usize, &str)> {
182 self.items
183 .iter()
184 .filter_map(|item| item.error.as_ref().map(|e| (item.index, e.as_str())))
185 .collect()
186 }
187
188 pub fn success_rate(&self) -> f32 {
190 if self.items.is_empty() {
191 0.0
192 } else {
193 self.progress.succeeded as f32 / self.items.len() as f32
194 }
195 }
196
197 pub fn summary(&self) -> String {
199 format!(
200 "Batch processing complete: {} items in {}ms\n\
201 Success: {} ({:.1}%), Failed: {}",
202 self.items.len(),
203 self.total_time_ms,
204 self.progress.succeeded,
205 self.success_rate() * 100.0,
206 self.progress.failed
207 )
208 }
209}
210
211pub type ProgressCallback = Arc<dyn Fn(&BatchProgress) + Send + Sync>;
213
214pub struct BatchProcessor {
216 config: BatchConfig,
217 progress_callback: Option<ProgressCallback>,
218}
219
220impl BatchProcessor {
221 pub fn new(config: BatchConfig) -> Self {
223 Self {
224 config,
225 progress_callback: None,
226 }
227 }
228
229 pub fn default_config() -> Self {
231 Self::new(BatchConfig::default())
232 }
233
234 pub fn with_progress_callback<F>(mut self, callback: F) -> Self
238 where
239 F: Fn(&BatchProgress) + Send + Sync + 'static,
240 {
241 self.progress_callback = Some(Arc::new(callback));
242 self
243 }
244
245 pub async fn process_batch(
256 &self,
257 provider: Arc<dyn VisionProvider>,
258 images: Vec<Vec<u8>>,
259 ) -> Result<BatchResult> {
260 let start_time = std::time::Instant::now();
261 let total = images.len();
262
263 if total == 0 {
264 return Ok(BatchResult {
265 items: vec![],
266 progress: BatchProgress::new(0),
267 total_time_ms: 0,
268 });
269 }
270
271 let progress = Arc::new(RwLock::new(BatchProgress::new(total)));
273 let completed_count = Arc::new(AtomicUsize::new(0));
274 let succeeded_count = Arc::new(AtomicUsize::new(0));
275 let failed_count = Arc::new(AtomicUsize::new(0));
276
277 let items = stream::iter(images.into_iter().enumerate())
279 .map(|(index, image_data)| {
280 let provider = Arc::clone(&provider);
281 let progress = Arc::clone(&progress);
282 let completed = Arc::clone(&completed_count);
283 let succeeded = Arc::clone(&succeeded_count);
284 let failed = Arc::clone(&failed_count);
285 let progress_callback = self.progress_callback.clone();
286 let continue_on_error = self.config.continue_on_error;
287
288 async move {
289 let result = provider.process_image(&image_data).await;
291
292 let is_success = result.is_ok();
294 completed.fetch_add(1, Ordering::SeqCst);
295
296 if is_success {
297 succeeded.fetch_add(1, Ordering::SeqCst);
298 } else {
299 failed.fetch_add(1, Ordering::SeqCst);
300 }
301
302 {
304 let mut p = progress.write().await;
305 p.completed = completed.load(Ordering::SeqCst);
306 p.succeeded = succeeded.load(Ordering::SeqCst);
307 p.failed = failed.load(Ordering::SeqCst);
308
309 if let Some(ref callback) = progress_callback {
311 callback(&p);
312 }
313 }
314
315 match result {
317 Ok(ocr_result) => Ok(BatchItemResult {
318 index,
319 result: Some(ocr_result),
320 error: None,
321 }),
322 Err(e) => {
323 if !continue_on_error {
324 return Err(e);
326 }
327 Ok(BatchItemResult {
328 index,
329 result: None,
330 error: Some(e.to_string()),
331 })
332 }
333 }
334 }
335 })
336 .buffer_unordered(self.config.max_concurrency);
337
338 let results: Vec<Result<BatchItemResult>> = items.collect().await;
340
341 let mut item_results = Vec::new();
343 for result in results {
344 match result {
345 Ok(item) => item_results.push(item),
346 Err(e) => {
347 return Err(e);
349 }
350 }
351 }
352
353 item_results.sort_by_key(|item| item.index);
355
356 let elapsed = start_time.elapsed();
357 let final_progress = progress.read().await.clone();
358
359 Ok(BatchResult {
360 items: item_results,
361 progress: final_progress,
362 total_time_ms: elapsed.as_millis() as u64,
363 })
364 }
365
366 pub async fn process_batch_with_callback<F>(
370 &self,
371 provider: Arc<dyn VisionProvider>,
372 images: Vec<Vec<u8>>,
373 callback: F,
374 ) -> Result<()>
375 where
376 F: Fn(usize, Result<OcrResult>) + Send + Sync,
377 {
378 let total = images.len();
379
380 if total == 0 {
381 return Ok(());
382 }
383
384 let mut stream = stream::iter(images.into_iter().enumerate())
386 .map(|(index, image_data)| {
387 let provider = Arc::clone(&provider);
388 async move {
389 let result = provider.process_image(&image_data).await;
390 (index, result)
391 }
392 })
393 .buffer_unordered(self.config.max_concurrency);
394
395 while let Some((index, result)) = stream.next().await {
397 if !self.config.continue_on_error && result.is_err() {
398 return result.map(|_| ());
399 }
400 callback(index, result);
401 }
402
403 Ok(())
404 }
405}
406
407pub async fn process_batch_simple(
411 provider: Arc<dyn VisionProvider>,
412 images: Vec<Vec<u8>>,
413) -> Result<Vec<Result<OcrResult>>> {
414 let processor = BatchProcessor::default_config();
415 let batch_result = processor.process_batch(provider, images).await?;
416
417 Ok(batch_result
418 .items
419 .into_iter()
420 .map(|item| {
421 if let Some(result) = item.result {
422 Ok(result)
423 } else {
424 Err(VisionError::image_processing(
425 item.error.unwrap_or_else(|| "Unknown error".to_string()),
426 ))
427 }
428 })
429 .collect())
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use crate::providers::MockVisionProvider;
436
437 #[test]
438 fn test_batch_config_default() {
439 let config = BatchConfig::default();
440 assert!(config.max_concurrency > 0);
441 assert!(config.continue_on_error);
442 }
443
444 #[test]
445 fn test_batch_config_fast() {
446 let config = BatchConfig::fast();
447 assert!(config.max_concurrency >= num_cpus::get());
448 }
449
450 #[test]
451 fn test_batch_config_gpu() {
452 let config = BatchConfig::gpu();
453 assert_eq!(config.max_concurrency, 4);
454 assert!(config.report_progress);
455 }
456
457 #[test]
458 fn test_batch_config_builders() {
459 let config = BatchConfig::default()
460 .with_max_concurrency(8)
461 .with_continue_on_error(false)
462 .with_report_progress(true);
463
464 assert_eq!(config.max_concurrency, 8);
465 assert!(!config.continue_on_error);
466 assert!(config.report_progress);
467 }
468
469 #[test]
470 fn test_batch_progress() {
471 let mut progress = BatchProgress::new(100);
472 assert_eq!(progress.percentage(), 0.0);
473 assert!(!progress.is_complete());
474
475 progress.completed = 50;
476 assert_eq!(progress.percentage(), 0.5);
477 assert!(!progress.is_complete());
478
479 progress.completed = 100;
480 assert_eq!(progress.percentage(), 1.0);
481 assert!(progress.is_complete());
482 }
483
484 #[test]
485 fn test_batch_progress_format() {
486 let progress = BatchProgress {
487 total: 100,
488 completed: 50,
489 succeeded: 45,
490 failed: 5,
491 };
492
493 let formatted = progress.format();
494 assert!(formatted.contains("50/100"));
495 assert!(formatted.contains("45 succeeded"));
496 assert!(formatted.contains("5 failed"));
497 }
498
499 #[test]
500 fn test_batch_item_result() {
501 let success_item = BatchItemResult {
502 index: 0,
503 result: Some(OcrResult::from_text("test")),
504 error: None,
505 };
506 assert!(success_item.is_success());
507 assert!(!success_item.is_error());
508
509 let error_item = BatchItemResult {
510 index: 1,
511 result: None,
512 error: Some("error".to_string()),
513 };
514 assert!(!error_item.is_success());
515 assert!(error_item.is_error());
516 }
517
518 #[tokio::test]
519 async fn test_batch_processor_creation() {
520 let processor = BatchProcessor::default_config();
521 assert!(processor.config.max_concurrency > 0);
522 }
523
524 #[tokio::test]
525 async fn test_process_batch_empty() {
526 let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
527 let processor = BatchProcessor::default_config();
528 let result = processor.process_batch(provider, vec![]).await.unwrap();
529
530 assert_eq!(result.items.len(), 0);
531 assert_eq!(result.progress.total, 0);
532 }
533
534 #[tokio::test]
535 async fn test_process_batch_single() {
536 let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
537 provider.load_model().await.unwrap();
538
539 let processor = BatchProcessor::default_config();
540 let images = vec![b"test image".to_vec()];
541 let result = processor.process_batch(provider, images).await.unwrap();
542
543 assert_eq!(result.items.len(), 1);
544 assert_eq!(result.progress.total, 1);
545 assert_eq!(result.progress.completed, 1);
546 assert_eq!(result.progress.succeeded, 1);
547 }
548
549 #[tokio::test]
550 async fn test_process_batch_multiple() {
551 let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
552 provider.load_model().await.unwrap();
553
554 let processor = BatchProcessor::default_config();
555 let images = vec![b"image1".to_vec(), b"image2".to_vec(), b"image3".to_vec()];
556 let result = processor.process_batch(provider, images).await.unwrap();
557
558 assert_eq!(result.items.len(), 3);
559 assert_eq!(result.progress.succeeded, 3);
560 assert_eq!(result.success_rate(), 1.0);
561 }
562
563 #[tokio::test]
564 async fn test_batch_result_methods() {
565 let batch_result = BatchResult {
566 items: vec![
567 BatchItemResult {
568 index: 0,
569 result: Some(OcrResult::from_text("success")),
570 error: None,
571 },
572 BatchItemResult {
573 index: 1,
574 result: None,
575 error: Some("failed".to_string()),
576 },
577 ],
578 progress: BatchProgress {
579 total: 2,
580 completed: 2,
581 succeeded: 1,
582 failed: 1,
583 },
584 total_time_ms: 100,
585 };
586
587 assert_eq!(batch_result.successful_results().len(), 1);
588 assert_eq!(batch_result.errors().len(), 1);
589 assert_eq!(batch_result.success_rate(), 0.5);
590
591 let summary = batch_result.summary();
592 assert!(summary.contains("2 items"));
593 assert!(summary.contains("100ms"));
594 }
595
596 #[tokio::test]
597 async fn test_progress_callback() {
598 let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
599 provider.load_model().await.unwrap();
600
601 let progress_updates = Arc::new(RwLock::new(Vec::new()));
602 let updates_clone = Arc::clone(&progress_updates);
603
604 let processor = BatchProcessor::new(BatchConfig::with_progress()).with_progress_callback(
605 move |progress| {
606 let updates = updates_clone.clone();
607 let completed = progress.completed; tokio::spawn(async move {
609 updates.write().await.push(completed);
610 });
611 },
612 );
613
614 let images = vec![b"img1".to_vec(), b"img2".to_vec()];
615 let _result = processor.process_batch(provider, images).await.unwrap();
616
617 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
619
620 let updates = progress_updates.read().await;
621 assert!(!updates.is_empty());
622 }
623
624 #[tokio::test]
625 async fn test_process_batch_simple() {
626 let provider = Arc::new(MockVisionProvider::new()) as Arc<dyn VisionProvider>;
627 provider.load_model().await.unwrap();
628
629 let images = vec![b"test".to_vec()];
630 let results = process_batch_simple(provider, images).await.unwrap();
631
632 assert_eq!(results.len(), 1);
633 assert!(results[0].is_ok());
634 }
635}