Skip to main content

oar_ocr_core/core/batch/
mod.rs

1//! Batch processing utilities for the OCR pipeline.
2//!
3//! This module provides structures and functions for handling batched data
4//! in the OCR pipeline, including batching of input data, sampling, and
5//! tensor operations for batched processing.
6
7pub mod dynamic;
8
9use crate::core::traits::Sampler;
10use std::sync::Arc;
11
12/// Data structure for holding batched input data.
13///
14/// This struct contains the instances, input paths, and indexes for a batch of data.
15/// It's used in the OCR pipeline to process multiple inputs together for efficiency.
16pub struct BatchData {
17    /// The instances in the batch, stored as `Arc<str>` for efficient sharing.
18    pub instances: Vec<Arc<str>>,
19    /// The input paths for the instances in the batch, stored as `Arc<str>` for efficient sharing.
20    pub input_paths: Vec<Arc<str>>,
21    /// The indexes of the instances in the original data set.
22    pub indexes: Vec<usize>,
23}
24
25impl BatchData {
26    /// Creates a new BatchData instance from shared `Arc<str>` paths and indexes.
27    ///
28    /// # Arguments
29    ///
30    /// * `paths` - A vector of `Arc<str>` representing the paths to the instances.
31    /// * `indexes` - A vector of usize representing the indexes of the instances in the original data set.
32    ///
33    /// # Returns
34    ///
35    /// A new BatchData instance.
36    pub fn from_shared_arc_paths(paths: Vec<Arc<str>>, indexes: Vec<usize>) -> Self {
37        let input_paths = paths.clone();
38        Self {
39            instances: paths,
40            input_paths,
41            indexes,
42        }
43    }
44
45    /// Returns the number of instances in the batch.
46    ///
47    /// # Returns
48    ///
49    /// The number of instances in the batch.
50    pub fn len(&self) -> usize {
51        self.instances.len()
52    }
53
54    /// Checks if the batch is empty.
55    ///
56    /// # Returns
57    ///
58    /// True if the batch is empty, false otherwise.
59    pub fn is_empty(&self) -> bool {
60        self.instances.is_empty()
61    }
62
63    /// Returns an iterator over the instances as string slices.
64    ///
65    /// # Returns
66    ///
67    /// An iterator over the instances as string slices.
68    pub fn instances_as_str(&self) -> impl Iterator<Item = &str> + '_ {
69        self.instances.iter().map(|arc| arc.as_ref())
70    }
71
72    /// Returns an iterator over the input paths as string slices.
73    ///
74    /// # Returns
75    ///
76    /// An iterator over the input paths as string slices.
77    pub fn input_paths_as_str(&self) -> impl Iterator<Item = &str> + '_ {
78        self.input_paths.iter().map(|arc| arc.as_ref())
79    }
80}
81
82/// A sampler that creates batches of data with a specified batch size.
83///
84/// This struct is used to divide data into batches for processing in the OCR pipeline.
85/// It implements the Sampler trait for String data.
86#[derive(Debug)]
87pub struct BatchSampler {
88    /// The size of each batch.
89    batch_size: usize,
90}
91
92impl BatchSampler {
93    /// Creates a new BatchSampler with the specified batch size.
94    ///
95    /// # Arguments
96    ///
97    /// * `batch_size` - The size of each batch.
98    ///
99    /// # Returns
100    ///
101    /// A new BatchSampler instance.
102    pub fn new(batch_size: usize) -> Self {
103        Self { batch_size }
104    }
105
106    /// Returns the batch size.
107    ///
108    /// # Returns
109    ///
110    /// The batch size.
111    pub fn batch_size(&self) -> usize {
112        self.batch_size
113    }
114
115    /// Creates an iterator over batches of data.
116    ///
117    /// # Arguments
118    ///
119    /// * `data` - A slice of data to be batched.
120    ///
121    /// # Returns
122    ///
123    /// An iterator over batches of data.
124    pub fn batches<'a, T>(&self, data: &'a [T]) -> impl Iterator<Item = &'a [T]> {
125        if self.batch_size == 0 {
126            data.chunks(1).take(0)
127        } else {
128            data.chunks(self.batch_size).take(usize::MAX)
129        }
130    }
131
132    /// Creates an iterator over batches of data with their indexes.
133    ///
134    /// # Arguments
135    ///
136    /// * `data` - A slice of data to be batched.
137    ///
138    /// # Returns
139    ///
140    /// An iterator over tuples containing batches of data and their indexes.
141    pub fn batches_with_indexes<'a, T>(
142        &self,
143        data: &'a [T],
144    ) -> impl Iterator<Item = (&'a [T], Vec<usize>)> {
145        let batch_size = if self.batch_size == 0 {
146            1
147        } else {
148            self.batch_size
149        };
150        let take_count = if self.batch_size == 0 { 0 } else { usize::MAX };
151
152        data.chunks(batch_size)
153            .take(take_count)
154            .enumerate()
155            .map(move |(batch_idx, chunk)| {
156                let start_idx = batch_idx * self.batch_size;
157                let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
158                (chunk, indexes)
159            })
160    }
161
162    /// Samples batches of data from a vector of strings.
163    ///
164    /// # Arguments
165    ///
166    /// * `data` - A vector of strings to be batched.
167    ///
168    /// # Returns
169    ///
170    /// A vector of BatchData instances.
171    pub fn sample_batch(&self, data: Vec<String>) -> Vec<BatchData> {
172        if self.batch_size == 0 {
173            return Vec::new();
174        }
175
176        data.chunks(self.batch_size)
177            .enumerate()
178            .map(|(batch_idx, chunk)| {
179                let start_idx = batch_idx * self.batch_size;
180                let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
181
182                BatchData::from_shared_arc_paths(
183                    chunk.iter().map(|s| Arc::from(s.as_str())).collect(),
184                    indexes,
185                )
186            })
187            .collect()
188    }
189}
190
191impl Sampler<String> for BatchSampler {
192    type BatchData = BatchData;
193
194    /// Samples batches of data from a vector of strings.
195    ///
196    /// This method implements the Sampler trait for String data.
197    ///
198    /// # Arguments
199    ///
200    /// * `data` - A vector of strings to be batched.
201    ///
202    /// # Returns
203    ///
204    /// A vector of BatchData instances.
205    fn sample(&self, data: Vec<String>) -> Vec<Self::BatchData> {
206        self.sample_batch(data)
207    }
208}
209
210/// A struct for converting image data into batched tensor format.
211///
212/// This struct provides methods for validating input data and converting
213/// images into a batched tensor format suitable for processing in the OCR pipeline.
214#[derive(Debug, Default)]
215pub struct ToBatch;
216
217impl ToBatch {
218    /// Creates a new ToBatch instance.
219    ///
220    /// # Returns
221    ///
222    /// A new ToBatch instance.
223    pub fn new() -> Self {
224        ToBatch
225    }
226
227    /// Validates the input images and their shapes.
228    ///
229    /// This method checks that the images and shapes arrays have the same length,
230    /// that all images have the correct number of elements for their shapes,
231    /// and that all dimensions are greater than zero.
232    ///
233    /// # Arguments
234    ///
235    /// * `imgs` - A slice of vectors of f32 values representing the images.
236    /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
237    ///
238    /// # Returns
239    ///
240    /// A Result indicating success or an OCRError if validation fails.
241    pub fn validate_inputs(
242        &self,
243        imgs: &[Vec<f32>],
244        shapes: &[(usize, usize, usize)],
245    ) -> Result<(), crate::core::OCRError> {
246        if imgs.is_empty() && shapes.is_empty() {
247            return Ok(());
248        }
249
250        if imgs.is_empty() {
251            return Err(crate::core::OCRError::InvalidInput {
252                message: "Images array is empty but shapes array is not".to_string(),
253            });
254        }
255
256        if shapes.is_empty() {
257            return Err(crate::core::OCRError::InvalidInput {
258                message: "Shapes array is empty but images array is not".to_string(),
259            });
260        }
261
262        if imgs.len() != shapes.len() {
263            return Err(crate::core::OCRError::InvalidInput {
264                message: format!(
265                    "Images and shapes must have the same length: got {} images and {} shapes",
266                    imgs.len(),
267                    shapes.len()
268                ),
269            });
270        }
271
272        for (i, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
273            let expected_len = c * h * w;
274            if img.len() != expected_len {
275                return Err(crate::core::OCRError::InvalidInput {
276                    message: format!(
277                        "Image {} has {} elements but shape ({}, {}, {}) requires {}",
278                        i,
279                        img.len(),
280                        c,
281                        h,
282                        w,
283                        expected_len
284                    ),
285                });
286            }
287
288            if c == 0 || h == 0 || w == 0 {
289                return Err(crate::core::OCRError::InvalidInput {
290                    message: format!(
291                        "Image {i} has invalid shape dimensions ({c}, {h}, {w}): all must be greater than 0"
292                    ),
293                });
294            }
295
296            if expected_len > crate::core::constants::MAX_TENSOR_SIZE {
297                return Err(crate::core::OCRError::InvalidInput {
298                    message: format!(
299                        "Image {} tensor size {} exceeds maximum allowed size {}",
300                        i,
301                        expected_len,
302                        crate::core::constants::MAX_TENSOR_SIZE
303                    ),
304                });
305            }
306        }
307
308        Ok(())
309    }
310
311    /// Applies the batch conversion to the input images and shapes.
312    ///
313    /// This method validates the inputs, then converts the images into a batched tensor format.
314    /// If all images have the same dimensions, it uses a more efficient contiguous copying method.
315    /// Otherwise, it uses a method that handles mixed dimensions.
316    ///
317    /// # Arguments
318    ///
319    /// * `imgs` - A slice of vectors of f32 values representing the images.
320    /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
321    ///
322    /// # Returns
323    ///
324    /// A Result containing a vector of f32 values representing the batched tensor,
325    /// or an OCRError if the operation fails.
326    pub fn apply(
327        &self,
328        imgs: &[Vec<f32>],
329        shapes: &[(usize, usize, usize)],
330    ) -> Result<Vec<f32>, crate::core::OCRError> {
331        self.validate_inputs(imgs, shapes)?;
332
333        if imgs.is_empty() {
334            return Ok(Vec::new());
335        }
336
337        let batch_size = imgs.len();
338        let first_shape = shapes.first().copied().unwrap_or((0, 0, 0));
339        let channels = first_shape.0;
340        let mut max_height = first_shape.1;
341        let mut max_width = first_shape.2;
342        let mut all_same_dimensions = true;
343
344        for (i, &(c, h, w)) in shapes.iter().enumerate() {
345            if c != channels {
346                return Err(crate::core::OCRError::InvalidInput {
347                    message: format!(
348                        "All images must have the same channel count: image 0 has {channels} channels, image {i} has {c} channels"
349                    ),
350                });
351            }
352
353            if h > max_height {
354                max_height = h;
355            }
356            if w > max_width {
357                max_width = w;
358            }
359            if all_same_dimensions && (h != first_shape.1 || w != first_shape.2) {
360                all_same_dimensions = false;
361            }
362        }
363
364        let tensor_size = batch_size * channels * max_height * max_width;
365        let mut batch_tensor = vec![0.0; tensor_size];
366
367        if all_same_dimensions {
368            self.apply_contiguous(imgs, &mut batch_tensor, channels, max_height, max_width);
369        } else {
370            self.apply_mixed_dimensions(
371                imgs,
372                shapes,
373                &mut batch_tensor,
374                channels,
375                max_height,
376                max_width,
377            );
378        }
379
380        Ok(batch_tensor)
381    }
382
383    /// Applies contiguous copying for images with the same dimensions.
384    ///
385    /// This method is used when all images in the batch have the same dimensions,
386    /// allowing for more efficient copying.
387    ///
388    /// # Arguments
389    ///
390    /// * `imgs` - A slice of vectors of f32 values representing the images.
391    /// * `batch_tensor` - A mutable slice of f32 values representing the batched tensor.
392    /// * `channels` - The number of channels in the images.
393    /// * `height` - The height of the images.
394    /// * `width` - The width of the images.
395    fn apply_contiguous(
396        &self,
397        imgs: &[Vec<f32>],
398        batch_tensor: &mut [f32],
399        channels: usize,
400        height: usize,
401        width: usize,
402    ) {
403        let img_size = channels * height * width;
404
405        for (batch_idx, img) in imgs.iter().enumerate() {
406            let batch_offset = batch_idx * img_size;
407            let dst_slice = &mut batch_tensor[batch_offset..batch_offset + img.len()];
408
409            dst_slice.copy_from_slice(img);
410        }
411    }
412
413    /// Applies copying for images with mixed dimensions.
414    ///
415    /// This method is used when images in the batch have different dimensions,
416    /// requiring padding to the maximum dimensions.
417    ///
418    /// # Arguments
419    ///
420    /// * `imgs` - A slice of vectors of f32 values representing the images.
421    /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
422    /// * `batch_tensor` - A mutable slice of f32 values representing the batched tensor.
423    /// * `channels` - The number of channels in the images.
424    /// * `max_height` - The maximum height among all images in the batch.
425    /// * `max_width` - The maximum width among all images in the batch.
426    fn apply_mixed_dimensions(
427        &self,
428        imgs: &[Vec<f32>],
429        shapes: &[(usize, usize, usize)],
430        batch_tensor: &mut [f32],
431        channels: usize,
432        max_height: usize,
433        max_width: usize,
434    ) {
435        for (batch_idx, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
436            let batch_base = batch_idx * channels * max_height * max_width;
437
438            for ch in 0..c {
439                let src_channel_start = ch * h * w;
440                let dst_channel_start = batch_base + ch * max_height * max_width;
441
442                for y in 0..h {
443                    let src_row_start = src_channel_start + y * w;
444                    let dst_row_start = dst_channel_start + y * max_width;
445
446                    let src_row = &img[src_row_start..src_row_start + w];
447                    let dst_row = &mut batch_tensor[dst_row_start..dst_row_start + w];
448                    dst_row.copy_from_slice(src_row);
449                }
450            }
451        }
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458    use crate::core::OCRError;
459
460    #[test]
461    fn test_to_batch_apply_contiguous() -> Result<(), OCRError> {
462        let to_batch = ToBatch::new();
463
464        // Create test images with same dimensions
465        let img1 = vec![1.0, 2.0, 3.0, 4.0]; // 1x2x2 image
466        let img2 = vec![5.0, 6.0, 7.0, 8.0]; // 1x2x2 image
467        let imgs = vec![img1, img2];
468        let shapes = vec![(1, 2, 2), (1, 2, 2)]; // Same shapes
469
470        let result = to_batch.apply(&imgs, &shapes)?;
471
472        // Expected: batch_size=2, channels=1, height=2, width=2
473        // Total size: 2 * 1 * 2 * 2 = 8
474        assert_eq!(result.len(), 8);
475
476        // First image should be at positions 0-3
477        assert_eq!(result[0], 1.0);
478        assert_eq!(result[1], 2.0);
479        assert_eq!(result[2], 3.0);
480        assert_eq!(result[3], 4.0);
481
482        // Second image should be at positions 4-7
483        assert_eq!(result[4], 5.0);
484        assert_eq!(result[5], 6.0);
485        assert_eq!(result[6], 7.0);
486        assert_eq!(result[7], 8.0);
487        Ok(())
488    }
489
490    #[test]
491    fn test_to_batch_apply_mixed_dimensions() -> Result<(), OCRError> {
492        let to_batch = ToBatch::new();
493
494        // Create test images with different dimensions
495        let img1 = vec![1.0, 2.0]; // 1x1x2 image
496        let img2 = vec![3.0, 4.0, 5.0, 6.0]; // 1x2x2 image
497        let imgs = vec![img1, img2];
498        let shapes = vec![(1, 1, 2), (1, 2, 2)]; // Different shapes
499
500        let result = to_batch.apply(&imgs, &shapes)?;
501
502        // Expected: batch_size=2, channels=1, max_height=2, max_width=2
503        // Total size: 2 * 1 * 2 * 2 = 8
504        assert_eq!(result.len(), 8);
505
506        // First image (1x2) should be padded to (2x2)
507        // Second image (2x2) should fit exactly
508        // The exact layout depends on the mixed dimensions implementation
509        assert!(result.contains(&1.0));
510        assert!(result.contains(&2.0));
511        assert!(result.contains(&3.0));
512        assert!(result.contains(&4.0));
513        assert!(result.contains(&5.0));
514        assert!(result.contains(&6.0));
515        Ok(())
516    }
517}