1use super::config::{DynamicBatchConfig, PaddingStrategy, ShapeCompatibilityStrategy};
4use super::types::{CompatibleBatch, CrossImageBatch, CrossImageItem};
5use crate::core::OCRError;
6use crate::core::traits::StandardPredictor;
7use image::{ImageBuffer, Rgb, RgbImage};
8use std::collections::HashMap;
9use std::time::Instant;
10
11pub trait DynamicBatcher {
13 fn group_images_by_compatibility(
15 &self,
16 images: Vec<(usize, RgbImage)>,
17 config: &DynamicBatchConfig,
18 ) -> Result<Vec<CompatibleBatch>, OCRError>;
19
20 fn group_cross_image_items(
22 &self,
23 items: Vec<(usize, usize, RgbImage)>, config: &DynamicBatchConfig,
25 ) -> Result<Vec<CrossImageBatch>, OCRError>;
26
27 fn batch_predict<P>(
29 &self,
30 predictor: &P,
31 images: Vec<RgbImage>,
32 config: Option<P::Config>,
33 ) -> Result<Vec<P::Result>, OCRError>
34 where
35 P: StandardPredictor;
36}
37
38#[derive(Debug)]
40pub struct DefaultDynamicBatcher;
41
42impl DefaultDynamicBatcher {
43 pub fn new() -> Self {
45 Self
46 }
47
48 fn calculate_aspect_ratio(image: &RgbImage) -> f32 {
50 let (width, height) = image.dimensions();
51 width as f32 / height as f32
52 }
53
54 fn are_images_compatible(
56 img1: &RgbImage,
57 img2: &RgbImage,
58 strategy: &ShapeCompatibilityStrategy,
59 ) -> bool {
60 match strategy {
61 ShapeCompatibilityStrategy::Exact => img1.dimensions() == img2.dimensions(),
62 ShapeCompatibilityStrategy::AspectRatio { tolerance } => {
63 let ratio1 = Self::calculate_aspect_ratio(img1);
64 let ratio2 = Self::calculate_aspect_ratio(img2);
65 (ratio1 - ratio2).abs() <= *tolerance
66 }
67 ShapeCompatibilityStrategy::MaxDimension { bucket_size } => {
68 let (w1, h1) = img1.dimensions();
69 let (w2, h2) = img2.dimensions();
70 let max1 = w1.max(h1);
71 let max2 = w2.max(h2);
72 max1 / bucket_size == max2 / bucket_size
73 }
74 ShapeCompatibilityStrategy::Custom { targets, tolerance } => {
75 let target1 = Self::find_best_target(img1, targets, *tolerance);
77 let target2 = Self::find_best_target(img2, targets, *tolerance);
78 target1 == target2
79 }
80 }
81 }
82
83 fn find_best_target(
85 image: &RgbImage,
86 targets: &[(u32, u32)],
87 tolerance: f32,
88 ) -> Option<(u32, u32)> {
89 let (width, height) = image.dimensions();
90 let aspect_ratio = width as f32 / height as f32;
91
92 targets
93 .iter()
94 .find(|(target_w, target_h)| {
95 let target_ratio = *target_w as f32 / *target_h as f32;
96 (aspect_ratio - target_ratio).abs() <= tolerance
97 })
98 .copied()
99 }
100
101 fn calculate_target_dimensions(
103 images: &[RgbImage],
104 strategy: &ShapeCompatibilityStrategy,
105 ) -> (u32, u32) {
106 match strategy {
107 ShapeCompatibilityStrategy::Exact => {
108 images.first().map(|img| img.dimensions()).unwrap_or((0, 0))
110 }
111 _ => {
112 let max_width = images.iter().map(|img| img.width()).max().unwrap_or(0);
114 let max_height = images.iter().map(|img| img.height()).max().unwrap_or(0);
115 (max_width, max_height)
116 }
117 }
118 }
119
120 fn pad_image(
122 image: &RgbImage,
123 target_dims: (u32, u32),
124 strategy: &PaddingStrategy,
125 ) -> Result<RgbImage, OCRError> {
126 let (current_width, current_height) = image.dimensions();
127 let (target_width, target_height) = target_dims;
128
129 if current_width == target_width && current_height == target_height {
130 return Ok(image.clone());
131 }
132
133 if current_width > target_width || current_height > target_height {
134 return Err(OCRError::Processing {
135 kind: crate::core::ProcessingStage::ImageProcessing,
136 context: format!(
137 "Image dimensions ({}, {}) exceed target dimensions ({}, {})",
138 current_width, current_height, target_width, target_height
139 ),
140 source: Box::new(crate::core::errors::SimpleError::new("Image too large")),
141 });
142 }
143
144 let mut padded = ImageBuffer::new(target_width, target_height);
145
146 let x_offset = (target_width - current_width) / 2;
148 let y_offset = (target_height - current_height) / 2;
149
150 match strategy {
151 PaddingStrategy::Zero => {
152 for pixel in padded.pixels_mut() {
154 *pixel = Rgb([0, 0, 0]);
155 }
156 Self::copy_centered_image(&mut padded, image, x_offset, y_offset);
158 }
159 PaddingStrategy::Center { fill_color } => {
160 for pixel in padded.pixels_mut() {
162 *pixel = Rgb(*fill_color);
163 }
164 Self::copy_centered_image(&mut padded, image, x_offset, y_offset);
166 }
167 PaddingStrategy::Edge => {
168 Self::apply_optimized_edge_padding(&mut padded, image, x_offset, y_offset);
170 }
171 PaddingStrategy::Smart => {
172 let smart_color = Self::calculate_smart_padding_color(image);
174 for pixel in padded.pixels_mut() {
175 *pixel = smart_color;
176 }
177 Self::copy_centered_image(&mut padded, image, x_offset, y_offset);
179 }
180 }
181
182 Ok(padded)
183 }
184
185 fn copy_centered_image(
187 padded: &mut RgbImage,
188 original: &RgbImage,
189 x_offset: u32,
190 y_offset: u32,
191 ) {
192 let (orig_width, orig_height) = original.dimensions();
193 for y in 0..orig_height {
194 for x in 0..orig_width {
195 let pixel = original.get_pixel(x, y);
196 padded.put_pixel(x + x_offset, y + y_offset, *pixel);
197 }
198 }
199 }
200
201 fn apply_optimized_edge_padding(
203 padded: &mut RgbImage,
204 original: &RgbImage,
205 x_offset: u32,
206 y_offset: u32,
207 ) {
208 let (padded_width, padded_height) = padded.dimensions();
209 let (orig_width, orig_height) = original.dimensions();
210
211 for y in 0..padded_height {
213 for x in 0..padded_width {
214 let source_x = if x < x_offset {
216 0
218 } else if x >= x_offset + orig_width {
219 orig_width - 1
221 } else {
222 x - x_offset
224 };
225
226 let source_y = if y < y_offset {
227 0
229 } else if y >= y_offset + orig_height {
230 orig_height - 1
232 } else {
233 y - y_offset
235 };
236
237 let pixel = original.get_pixel(source_x, source_y);
238 padded.put_pixel(x, y, *pixel);
239 }
240 }
241 }
242
243 fn calculate_smart_padding_color(image: &RgbImage) -> Rgb<u8> {
245 let (width, height) = image.dimensions();
246
247 if width == 0 || height == 0 {
248 return Rgb([0, 0, 0]); }
250
251 let mut edge_pixels = Vec::new();
253
254 for x in 0..width {
256 edge_pixels.push(*image.get_pixel(x, 0)); if height > 1 {
258 edge_pixels.push(*image.get_pixel(x, height - 1)); }
260 }
261
262 for y in 1..height.saturating_sub(1) {
264 edge_pixels.push(*image.get_pixel(0, y)); if width > 1 {
266 edge_pixels.push(*image.get_pixel(width - 1, y)); }
268 }
269
270 if edge_pixels.is_empty() {
271 return Rgb([0, 0, 0]);
272 }
273
274 let mut r_values: Vec<u8> = edge_pixels.iter().map(|p| p.0[0]).collect();
276 let mut g_values: Vec<u8> = edge_pixels.iter().map(|p| p.0[1]).collect();
277 let mut b_values: Vec<u8> = edge_pixels.iter().map(|p| p.0[2]).collect();
278
279 r_values.sort_unstable();
280 g_values.sort_unstable();
281 b_values.sort_unstable();
282
283 let len = r_values.len();
284 let median_r = r_values[len / 2];
285 let median_g = g_values[len / 2];
286 let median_b = b_values[len / 2];
287
288 let adjusted_r = Self::adjust_padding_component(median_r);
292 let adjusted_g = Self::adjust_padding_component(median_g);
293 let adjusted_b = Self::adjust_padding_component(median_b);
294
295 Rgb([adjusted_r, adjusted_g, adjusted_b])
296 }
297
298 fn adjust_padding_component(component: u8) -> u8 {
300 match component {
301 0..=63 => (component as u16 + 16).min(255) as u8,
303 192..=255 => (component as i16 - 16).max(0) as u8,
305 _ => component,
307 }
308 }
309
310 fn generate_batch_id(target_dims: (u32, u32), batch_index: usize) -> String {
312 format!("{}x{}_{}", target_dims.0, target_dims.1, batch_index)
313 }
314}
315
316impl Default for DefaultDynamicBatcher {
317 fn default() -> Self {
318 Self::new()
319 }
320}
321
322impl DynamicBatcher for DefaultDynamicBatcher {
323 fn group_images_by_compatibility(
324 &self,
325 images: Vec<(usize, RgbImage)>,
326 config: &DynamicBatchConfig,
327 ) -> Result<Vec<CompatibleBatch>, OCRError> {
328 let _start_time = Instant::now();
329 let mut batches = Vec::new();
330 let mut batch_counter = 0;
331
332 let mut compatibility_groups: HashMap<String, Vec<(usize, RgbImage)>> = HashMap::new();
334
335 for (index, image) in images {
336 let mut target_group_key = None;
337
338 for (group_key, group_images) in compatibility_groups.iter() {
340 if let Some((_, first_image)) = group_images.first()
341 && Self::are_images_compatible(&image, first_image, &config.shape_compatibility)
342 {
343 target_group_key = Some(group_key.clone());
344 break;
345 }
346 }
347
348 if let Some(group_key) = target_group_key {
350 if let Some(group) = compatibility_groups.get_mut(&group_key) {
351 group.push((index, image));
352 } else {
353 let group_key = format!("group_{}", compatibility_groups.len());
355 compatibility_groups.insert(group_key, vec![(index, image)]);
356 }
357 } else {
358 let group_key = format!("group_{}", compatibility_groups.len());
359 compatibility_groups.insert(group_key, vec![(index, image)]);
360 }
361 }
362
363 for (_, group_images) in compatibility_groups {
365 if group_images.len() < config.min_batch_size {
366 for (index, image) in group_images {
368 let target_dims = image.dimensions();
369 let batch_id = Self::generate_batch_id(target_dims, batch_counter);
370 let mut batch = CompatibleBatch::new(batch_id, target_dims);
371 batch.add_image(image, index);
372 batches.push(batch);
373 batch_counter += 1;
374 }
375 } else {
376 let max_batch_size = config.max_detection_batch_size;
378 let images_vec: Vec<RgbImage> =
379 group_images.iter().map(|(_, img)| img.clone()).collect();
380 let target_dims =
381 Self::calculate_target_dimensions(&images_vec, &config.shape_compatibility);
382
383 for chunk in group_images.chunks(max_batch_size) {
384 let batch_id = Self::generate_batch_id(target_dims, batch_counter);
385 let mut batch = CompatibleBatch::new(batch_id, target_dims);
386
387 for (index, image) in chunk {
388 let padded_image =
390 Self::pad_image(image, target_dims, &config.padding_strategy)?;
391 batch.add_image(padded_image, *index);
392 }
393
394 batches.push(batch);
395 batch_counter += 1;
396 }
397 }
398 }
399
400 Ok(batches)
401 }
402
403 fn group_cross_image_items(
404 &self,
405 items: Vec<(usize, usize, RgbImage)>,
406 config: &DynamicBatchConfig,
407 ) -> Result<Vec<CrossImageBatch>, OCRError> {
408 let mut batches = Vec::new();
409 let mut batch_counter = 0;
410
411 let cross_items: Vec<CrossImageItem> = items
413 .into_iter()
414 .map(|(source_idx, item_idx, image)| CrossImageItem::new(source_idx, item_idx, image))
415 .collect();
416
417 let mut compatibility_groups: HashMap<String, Vec<CrossImageItem>> = HashMap::new();
419
420 for item in cross_items {
421 let mut target_group_key = None;
422
423 for (group_key, group_items) in compatibility_groups.iter() {
425 if let Some(first_item) = group_items.first()
426 && Self::are_images_compatible(
427 &item.image,
428 &first_item.image,
429 &config.shape_compatibility,
430 )
431 {
432 target_group_key = Some(group_key.clone());
433 break;
434 }
435 }
436
437 if let Some(group_key) = target_group_key {
439 if let Some(group) = compatibility_groups.get_mut(&group_key) {
440 group.push(item);
441 } else {
442 let group_key = format!("cross_group_{}", compatibility_groups.len());
444 compatibility_groups.insert(group_key, vec![item]);
445 }
446 } else {
447 let group_key = format!("cross_group_{}", compatibility_groups.len());
448 compatibility_groups.insert(group_key, vec![item]);
449 }
450 }
451
452 for (_, group_items) in compatibility_groups {
454 if group_items.len() < config.min_batch_size {
455 for item in group_items {
457 let target_dims = item.dimensions();
458 let batch_id = Self::generate_batch_id(target_dims, batch_counter);
459 let mut batch = CrossImageBatch::new(batch_id, target_dims);
460 batch.add_item(item);
461 batches.push(batch);
462 batch_counter += 1;
463 }
464 } else {
465 let max_batch_size = config.max_recognition_batch_size;
467 let images_vec: Vec<RgbImage> =
468 group_items.iter().map(|item| item.image.clone()).collect();
469 let target_dims =
470 Self::calculate_target_dimensions(&images_vec, &config.shape_compatibility);
471
472 for chunk in group_items.chunks(max_batch_size) {
473 let batch_id = Self::generate_batch_id(target_dims, batch_counter);
474 let mut batch = CrossImageBatch::new(batch_id, target_dims);
475
476 for item in chunk {
477 let padded_image =
479 Self::pad_image(&item.image, target_dims, &config.padding_strategy)?;
480 let mut padded_item = item.clone();
481 padded_item.image = padded_image;
482 batch.add_item(padded_item);
483 }
484
485 batches.push(batch);
486 batch_counter += 1;
487 }
488 }
489 }
490
491 Ok(batches)
492 }
493
494 fn batch_predict<P>(
495 &self,
496 predictor: &P,
497 images: Vec<RgbImage>,
498 config: Option<P::Config>,
499 ) -> Result<Vec<P::Result>, OCRError>
500 where
501 P: StandardPredictor,
502 {
503 let result = predictor.predict(images, config)?;
507 Ok(vec![result])
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use image::{ImageBuffer, Rgb};
515
516 fn create_test_image(width: u32, height: u32, pattern: &str) -> RgbImage {
518 let mut image = ImageBuffer::new(width, height);
519
520 match pattern {
521 "solid_red" => {
522 for pixel in image.pixels_mut() {
523 *pixel = Rgb([255, 0, 0]);
524 }
525 }
526 "gradient" => {
527 for (x, y, pixel) in image.enumerate_pixels_mut() {
528 let r = (x * 255 / width.max(1)) as u8;
529 let g = (y * 255 / height.max(1)) as u8;
530 *pixel = Rgb([r, g, 128]);
531 }
532 }
533 "border" => {
534 for (x, y, pixel) in image.enumerate_pixels_mut() {
536 if x == 0 {
537 *pixel = Rgb([255, 0, 0]); } else if x == width - 1 {
539 *pixel = Rgb([0, 255, 0]); } else if y == 0 {
541 *pixel = Rgb([0, 0, 255]); } else if y == height - 1 {
543 *pixel = Rgb([255, 255, 0]); } else {
545 *pixel = Rgb([128, 128, 128]); }
547 }
548 }
549 _ => {
550 for pixel in image.pixels_mut() {
552 *pixel = Rgb([0, 0, 0]);
553 }
554 }
555 }
556
557 image
558 }
559
560 #[test]
561 fn test_pad_image_zero_strategy() {
562 let image = create_test_image(10, 10, "solid_red");
563 let strategy = PaddingStrategy::Zero;
564 let result = DefaultDynamicBatcher::pad_image(&image, (20, 20), &strategy).unwrap();
565
566 assert_eq!(result.dimensions(), (20, 20));
567
568 assert_eq!(*result.get_pixel(0, 0), Rgb([0, 0, 0])); assert_eq!(*result.get_pixel(19, 19), Rgb([0, 0, 0])); assert_eq!(*result.get_pixel(10, 10), Rgb([255, 0, 0])); }
575
576 #[test]
577 fn test_pad_image_center_strategy() {
578 let image = create_test_image(10, 10, "solid_red");
579 let strategy = PaddingStrategy::Center {
580 fill_color: [0, 255, 0],
581 }; let result = DefaultDynamicBatcher::pad_image(&image, (20, 20), &strategy).unwrap();
583
584 assert_eq!(result.dimensions(), (20, 20));
585
586 assert_eq!(*result.get_pixel(0, 0), Rgb([0, 255, 0])); assert_eq!(*result.get_pixel(19, 19), Rgb([0, 255, 0])); assert_eq!(*result.get_pixel(10, 10), Rgb([255, 0, 0])); }
593
594 #[test]
595 fn test_pad_image_edge_strategy() {
596 let image = create_test_image(6, 6, "border");
597 let strategy = PaddingStrategy::Edge;
598 let result = DefaultDynamicBatcher::pad_image(&image, (12, 12), &strategy).unwrap();
599
600 assert_eq!(result.dimensions(), (12, 12));
601
602 assert_eq!(*result.get_pixel(0, 6), Rgb([255, 0, 0])); assert_eq!(*result.get_pixel(11, 6), Rgb([0, 255, 0])); assert_eq!(*result.get_pixel(6, 0), Rgb([0, 0, 255])); assert_eq!(*result.get_pixel(6, 11), Rgb([255, 255, 0])); assert_eq!(*result.get_pixel(6, 6), Rgb([128, 128, 128])); }
618
619 #[test]
620 fn test_pad_image_smart_strategy() {
621 let image = create_test_image(10, 10, "border");
622 let strategy = PaddingStrategy::Smart;
623 let result = DefaultDynamicBatcher::pad_image(&image, (20, 20), &strategy).unwrap();
624
625 assert_eq!(result.dimensions(), (20, 20));
626
627 let padding_pixel = *result.get_pixel(0, 0);
630 assert_ne!(padding_pixel, Rgb([64, 64, 64])); assert_eq!(*result.get_pixel(10, 10), Rgb([128, 128, 128])); }
636
637 #[test]
638 fn test_pad_image_no_padding_needed() {
639 let image = create_test_image(10, 10, "solid_red");
640 let strategy = PaddingStrategy::Zero;
641 let result = DefaultDynamicBatcher::pad_image(&image, (10, 10), &strategy).unwrap();
642
643 assert_eq!(result.dimensions(), (10, 10));
645 assert_eq!(*result.get_pixel(5, 5), Rgb([255, 0, 0]));
646 }
647
648 #[test]
649 fn test_pad_image_error_on_oversized_image() {
650 let image = create_test_image(20, 20, "solid_red");
651 let strategy = PaddingStrategy::Zero;
652 let result = DefaultDynamicBatcher::pad_image(&image, (10, 10), &strategy);
653
654 assert!(result.is_err());
656 }
657
658 #[test]
659 fn test_calculate_smart_padding_color() {
660 let uniform_image = create_test_image(10, 10, "solid_red");
662 let smart_color = DefaultDynamicBatcher::calculate_smart_padding_color(&uniform_image);
663
664 assert!(smart_color.0[0] > 200); assert!(smart_color.0[1] < 50); assert!(smart_color.0[2] < 50); let gradient_image = create_test_image(10, 10, "gradient");
671 let gradient_smart_color =
672 DefaultDynamicBatcher::calculate_smart_padding_color(&gradient_image);
673
674 assert!(gradient_smart_color.0[0] < 255);
676 assert!(gradient_smart_color.0[1] < 255);
677 assert!(gradient_smart_color.0[2] < 255);
678 }
679
680 #[test]
681 fn test_adjust_padding_component() {
682 assert!(DefaultDynamicBatcher::adjust_padding_component(30) > 30);
684
685 assert!(DefaultDynamicBatcher::adjust_padding_component(220) < 220);
687
688 assert_eq!(DefaultDynamicBatcher::adjust_padding_component(128), 128);
690
691 assert_eq!(DefaultDynamicBatcher::adjust_padding_component(0), 16);
693 assert_eq!(DefaultDynamicBatcher::adjust_padding_component(255), 239);
694 }
695}