oar_ocr/core/batch/
mod.rs

1//! Batch processing utilities for the OCR pipeline.
2//!
3//! This module provides structures and functions for handling batched data
4//! in the OCR pipeline, including batching of input data, sampling, and
5//! tensor operations for batched processing.
6
7pub mod dynamic;
8
9use crate::core::traits::Sampler;
10use std::sync::Arc;
11
12/// A 2-dimensional tensor represented as a 2D array of f32 values.
13pub type Tensor2D = ndarray::Array2<f32>;
14
15/// A 3-dimensional tensor represented as a 3D array of f32 values.
16pub type Tensor3D = ndarray::Array3<f32>;
17
18/// A 4-dimensional tensor represented as a 4D array of f32 values.
19pub type Tensor4D = ndarray::Array4<f32>;
20
21/// A 1-dimensional tensor represented as a dynamic-dimensional array of f32 values.
22pub type Tensor1D = ndarray::ArrayD<f32>;
23
24/// Data structure for holding batched input data.
25///
26/// This struct contains the instances, input paths, and indexes for a batch of data.
27/// It's used in the OCR pipeline to process multiple inputs together for efficiency.
28pub struct BatchData {
29    /// The instances in the batch, stored as `Arc<str>` for efficient sharing.
30    pub instances: Vec<Arc<str>>,
31    /// The input paths for the instances in the batch, stored as `Arc<str>` for efficient sharing.
32    pub input_paths: Vec<Arc<str>>,
33    /// The indexes of the instances in the original data set.
34    pub indexes: Vec<usize>,
35}
36
37impl BatchData {
38    /// Creates a new BatchData instance from shared `Arc<str>` paths and indexes.
39    ///
40    /// # Arguments
41    ///
42    /// * `paths` - A vector of `Arc<str>` representing the paths to the instances.
43    /// * `indexes` - A vector of usize representing the indexes of the instances in the original data set.
44    ///
45    /// # Returns
46    ///
47    /// A new BatchData instance.
48    pub fn from_shared_arc_paths(paths: Vec<Arc<str>>, indexes: Vec<usize>) -> Self {
49        let input_paths = paths.clone();
50        Self {
51            instances: paths,
52            input_paths,
53            indexes,
54        }
55    }
56
57    /// Returns the number of instances in the batch.
58    ///
59    /// # Returns
60    ///
61    /// The number of instances in the batch.
62    pub fn len(&self) -> usize {
63        self.instances.len()
64    }
65
66    /// Checks if the batch is empty.
67    ///
68    /// # Returns
69    ///
70    /// True if the batch is empty, false otherwise.
71    pub fn is_empty(&self) -> bool {
72        self.instances.is_empty()
73    }
74
75    /// Returns an iterator over the instances as string slices.
76    ///
77    /// # Returns
78    ///
79    /// An iterator over the instances as string slices.
80    pub fn instances_as_str(&self) -> impl Iterator<Item = &str> + '_ {
81        self.instances.iter().map(|arc| arc.as_ref())
82    }
83
84    /// Returns an iterator over the input paths as string slices.
85    ///
86    /// # Returns
87    ///
88    /// An iterator over the input paths as string slices.
89    pub fn input_paths_as_str(&self) -> impl Iterator<Item = &str> + '_ {
90        self.input_paths.iter().map(|arc| arc.as_ref())
91    }
92}
93
94/// A sampler that creates batches of data with a specified batch size.
95///
96/// This struct is used to divide data into batches for processing in the OCR pipeline.
97/// It implements the Sampler trait for String data.
98#[derive(Debug)]
99pub struct BatchSampler {
100    /// The size of each batch.
101    batch_size: usize,
102}
103
104impl BatchSampler {
105    /// Creates a new BatchSampler with the specified batch size.
106    ///
107    /// # Arguments
108    ///
109    /// * `batch_size` - The size of each batch.
110    ///
111    /// # Returns
112    ///
113    /// A new BatchSampler instance.
114    pub fn new(batch_size: usize) -> Self {
115        Self { batch_size }
116    }
117
118    /// Returns the batch size.
119    ///
120    /// # Returns
121    ///
122    /// The batch size.
123    pub fn batch_size(&self) -> usize {
124        self.batch_size
125    }
126
127    /// Creates an iterator over batches of data.
128    ///
129    /// # Arguments
130    ///
131    /// * `data` - A slice of data to be batched.
132    ///
133    /// # Returns
134    ///
135    /// An iterator over batches of data.
136    pub fn batches<'a, T>(&self, data: &'a [T]) -> impl Iterator<Item = &'a [T]> {
137        if self.batch_size == 0 {
138            data.chunks(1).take(0)
139        } else {
140            data.chunks(self.batch_size).take(usize::MAX)
141        }
142    }
143
144    /// Creates an iterator over batches of data with their indexes.
145    ///
146    /// # Arguments
147    ///
148    /// * `data` - A slice of data to be batched.
149    ///
150    /// # Returns
151    ///
152    /// An iterator over tuples containing batches of data and their indexes.
153    pub fn batches_with_indexes<'a, T>(
154        &self,
155        data: &'a [T],
156    ) -> impl Iterator<Item = (&'a [T], Vec<usize>)> {
157        let batch_size = if self.batch_size == 0 {
158            1
159        } else {
160            self.batch_size
161        };
162        let take_count = if self.batch_size == 0 { 0 } else { usize::MAX };
163
164        data.chunks(batch_size)
165            .take(take_count)
166            .enumerate()
167            .map(move |(batch_idx, chunk)| {
168                let start_idx = batch_idx * self.batch_size;
169                let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
170                (chunk, indexes)
171            })
172    }
173
174    /// Samples batches of data from a vector of strings.
175    ///
176    /// # Arguments
177    ///
178    /// * `data` - A vector of strings to be batched.
179    ///
180    /// # Returns
181    ///
182    /// A vector of BatchData instances.
183    pub fn sample_batch(&self, data: Vec<String>) -> Vec<BatchData> {
184        if self.batch_size == 0 {
185            return Vec::new();
186        }
187
188        data.chunks(self.batch_size)
189            .enumerate()
190            .map(|(batch_idx, chunk)| {
191                let start_idx = batch_idx * self.batch_size;
192                let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
193
194                BatchData::from_shared_arc_paths(
195                    chunk.iter().map(|s| Arc::from(s.as_str())).collect(),
196                    indexes,
197                )
198            })
199            .collect()
200    }
201}
202
203impl Sampler<String> for BatchSampler {
204    type BatchData = BatchData;
205
206    /// Samples batches of data from a vector of strings.
207    ///
208    /// This method implements the Sampler trait for String data.
209    ///
210    /// # Arguments
211    ///
212    /// * `data` - A vector of strings to be batched.
213    ///
214    /// # Returns
215    ///
216    /// A vector of BatchData instances.
217    fn sample(&self, data: Vec<String>) -> Vec<Self::BatchData> {
218        self.sample_batch(data)
219    }
220}
221
222/// A struct for converting image data into batched tensor format.
223///
224/// This struct provides methods for validating input data and converting
225/// images into a batched tensor format suitable for processing in the OCR pipeline.
226#[derive(Debug, Default)]
227pub struct ToBatch;
228
229impl ToBatch {
230    /// Creates a new ToBatch instance.
231    ///
232    /// # Returns
233    ///
234    /// A new ToBatch instance.
235    pub fn new() -> Self {
236        ToBatch
237    }
238
239    /// Validates the input images and their shapes.
240    ///
241    /// This method checks that the images and shapes arrays have the same length,
242    /// that all images have the correct number of elements for their shapes,
243    /// and that all dimensions are greater than zero.
244    ///
245    /// # Arguments
246    ///
247    /// * `imgs` - A slice of vectors of f32 values representing the images.
248    /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
249    ///
250    /// # Returns
251    ///
252    /// A Result indicating success or an OCRError if validation fails.
253    pub fn validate_inputs(
254        &self,
255        imgs: &[Vec<f32>],
256        shapes: &[(usize, usize, usize)],
257    ) -> Result<(), crate::core::OCRError> {
258        if imgs.is_empty() && shapes.is_empty() {
259            return Ok(());
260        }
261
262        if imgs.is_empty() {
263            return Err(crate::core::OCRError::InvalidInput {
264                message: "Images array is empty but shapes array is not".to_string(),
265            });
266        }
267
268        if shapes.is_empty() {
269            return Err(crate::core::OCRError::InvalidInput {
270                message: "Shapes array is empty but images array is not".to_string(),
271            });
272        }
273
274        if imgs.len() != shapes.len() {
275            return Err(crate::core::OCRError::InvalidInput {
276                message: format!(
277                    "Images and shapes must have the same length: got {} images and {} shapes",
278                    imgs.len(),
279                    shapes.len()
280                ),
281            });
282        }
283
284        for (i, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
285            let expected_len = c * h * w;
286            if img.len() != expected_len {
287                return Err(crate::core::OCRError::InvalidInput {
288                    message: format!(
289                        "Image {} has {} elements but shape ({}, {}, {}) requires {}",
290                        i,
291                        img.len(),
292                        c,
293                        h,
294                        w,
295                        expected_len
296                    ),
297                });
298            }
299
300            if c == 0 || h == 0 || w == 0 {
301                return Err(crate::core::OCRError::InvalidInput {
302                    message: format!(
303                        "Image {i} has invalid shape dimensions ({c}, {h}, {w}): all must be greater than 0"
304                    ),
305                });
306            }
307
308            if expected_len > crate::core::constants::MAX_TENSOR_SIZE {
309                return Err(crate::core::OCRError::InvalidInput {
310                    message: format!(
311                        "Image {} tensor size {} exceeds maximum allowed size {}",
312                        i,
313                        expected_len,
314                        crate::core::constants::MAX_TENSOR_SIZE
315                    ),
316                });
317            }
318        }
319
320        Ok(())
321    }
322
323    /// Applies the batch conversion to the input images and shapes.
324    ///
325    /// This method validates the inputs, then converts the images into a batched tensor format.
326    /// If all images have the same dimensions, it uses a more efficient contiguous copying method.
327    /// Otherwise, it uses a method that handles mixed dimensions.
328    ///
329    /// # Arguments
330    ///
331    /// * `imgs` - A slice of vectors of f32 values representing the images.
332    /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
333    ///
334    /// # Returns
335    ///
336    /// A Result containing a vector of f32 values representing the batched tensor,
337    /// or an OCRError if the operation fails.
338    pub fn apply(
339        &self,
340        imgs: &[Vec<f32>],
341        shapes: &[(usize, usize, usize)],
342    ) -> Result<Vec<f32>, crate::core::OCRError> {
343        self.validate_inputs(imgs, shapes)?;
344
345        if imgs.is_empty() {
346            return Ok(Vec::new());
347        }
348
349        let batch_size = imgs.len();
350        let first_shape = shapes.first().copied().unwrap_or((0, 0, 0));
351        let channels = first_shape.0;
352        let mut max_height = first_shape.1;
353        let mut max_width = first_shape.2;
354        let mut all_same_dimensions = true;
355
356        for (i, &(c, h, w)) in shapes.iter().enumerate() {
357            if c != channels {
358                return Err(crate::core::OCRError::InvalidInput {
359                    message: format!(
360                        "All images must have the same channel count: image 0 has {channels} channels, image {i} has {c} channels"
361                    ),
362                });
363            }
364
365            if h > max_height {
366                max_height = h;
367            }
368            if w > max_width {
369                max_width = w;
370            }
371            if all_same_dimensions && (h != first_shape.1 || w != first_shape.2) {
372                all_same_dimensions = false;
373            }
374        }
375
376        let tensor_size = batch_size * channels * max_height * max_width;
377        let mut batch_tensor = vec![0.0; tensor_size];
378
379        if all_same_dimensions {
380            self.apply_contiguous(imgs, &mut batch_tensor, channels, max_height, max_width);
381        } else {
382            self.apply_mixed_dimensions(
383                imgs,
384                shapes,
385                &mut batch_tensor,
386                channels,
387                max_height,
388                max_width,
389            );
390        }
391
392        Ok(batch_tensor)
393    }
394
395    /// Applies contiguous copying for images with the same dimensions.
396    ///
397    /// This method is used when all images in the batch have the same dimensions,
398    /// allowing for more efficient copying.
399    ///
400    /// # Arguments
401    ///
402    /// * `imgs` - A slice of vectors of f32 values representing the images.
403    /// * `batch_tensor` - A mutable slice of f32 values representing the batched tensor.
404    /// * `channels` - The number of channels in the images.
405    /// * `height` - The height of the images.
406    /// * `width` - The width of the images.
407    fn apply_contiguous(
408        &self,
409        imgs: &[Vec<f32>],
410        batch_tensor: &mut [f32],
411        channels: usize,
412        height: usize,
413        width: usize,
414    ) {
415        let img_size = channels * height * width;
416
417        for (batch_idx, img) in imgs.iter().enumerate() {
418            let batch_offset = batch_idx * img_size;
419            let dst_slice = &mut batch_tensor[batch_offset..batch_offset + img.len()];
420
421            dst_slice.copy_from_slice(img);
422        }
423    }
424
425    /// Applies copying for images with mixed dimensions.
426    ///
427    /// This method is used when images in the batch have different dimensions,
428    /// requiring padding to the maximum dimensions.
429    ///
430    /// # Arguments
431    ///
432    /// * `imgs` - A slice of vectors of f32 values representing the images.
433    /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
434    /// * `batch_tensor` - A mutable slice of f32 values representing the batched tensor.
435    /// * `channels` - The number of channels in the images.
436    /// * `max_height` - The maximum height among all images in the batch.
437    /// * `max_width` - The maximum width among all images in the batch.
438    fn apply_mixed_dimensions(
439        &self,
440        imgs: &[Vec<f32>],
441        shapes: &[(usize, usize, usize)],
442        batch_tensor: &mut [f32],
443        channels: usize,
444        max_height: usize,
445        max_width: usize,
446    ) {
447        for (batch_idx, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
448            let batch_base = batch_idx * channels * max_height * max_width;
449
450            for ch in 0..c {
451                let src_channel_start = ch * h * w;
452                let dst_channel_start = batch_base + ch * max_height * max_width;
453
454                for y in 0..h {
455                    let src_row_start = src_channel_start + y * w;
456                    let dst_row_start = dst_channel_start + y * max_width;
457
458                    let src_row = &img[src_row_start..src_row_start + w];
459                    let dst_row = &mut batch_tensor[dst_row_start..dst_row_start + w];
460                    dst_row.copy_from_slice(src_row);
461                }
462            }
463        }
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    #[test]
472    fn test_to_batch_apply_contiguous() {
473        let to_batch = ToBatch::new();
474
475        // Create test images with same dimensions
476        let img1 = vec![1.0, 2.0, 3.0, 4.0]; // 1x2x2 image
477        let img2 = vec![5.0, 6.0, 7.0, 8.0]; // 1x2x2 image
478        let imgs = vec![img1, img2];
479        let shapes = vec![(1, 2, 2), (1, 2, 2)]; // Same shapes
480
481        let result = to_batch.apply(&imgs, &shapes).unwrap();
482
483        // Expected: batch_size=2, channels=1, height=2, width=2
484        // Total size: 2 * 1 * 2 * 2 = 8
485        assert_eq!(result.len(), 8);
486
487        // First image should be at positions 0-3
488        assert_eq!(result[0], 1.0);
489        assert_eq!(result[1], 2.0);
490        assert_eq!(result[2], 3.0);
491        assert_eq!(result[3], 4.0);
492
493        // Second image should be at positions 4-7
494        assert_eq!(result[4], 5.0);
495        assert_eq!(result[5], 6.0);
496        assert_eq!(result[6], 7.0);
497        assert_eq!(result[7], 8.0);
498    }
499
500    #[test]
501    fn test_to_batch_apply_mixed_dimensions() {
502        let to_batch = ToBatch::new();
503
504        // Create test images with different dimensions
505        let img1 = vec![1.0, 2.0]; // 1x1x2 image
506        let img2 = vec![3.0, 4.0, 5.0, 6.0]; // 1x2x2 image
507        let imgs = vec![img1, img2];
508        let shapes = vec![(1, 1, 2), (1, 2, 2)]; // Different shapes
509
510        let result = to_batch.apply(&imgs, &shapes).unwrap();
511
512        // Expected: batch_size=2, channels=1, max_height=2, max_width=2
513        // Total size: 2 * 1 * 2 * 2 = 8
514        assert_eq!(result.len(), 8);
515
516        // First image (1x2) should be padded to (2x2)
517        // Second image (2x2) should fit exactly
518        // The exact layout depends on the mixed dimensions implementation
519        assert!(result.contains(&1.0));
520        assert!(result.contains(&2.0));
521        assert!(result.contains(&3.0));
522        assert!(result.contains(&4.0));
523        assert!(result.contains(&5.0));
524        assert!(result.contains(&6.0));
525    }
526}