oar_ocr/core/batch/mod.rs
1//! Batch processing utilities for the OCR pipeline.
2//!
3//! This module provides structures and functions for handling batched data
4//! in the OCR pipeline, including batching of input data, sampling, and
5//! tensor operations for batched processing.
6
7pub mod dynamic;
8
9use crate::core::traits::Sampler;
10use std::sync::Arc;
11
12/// A 2-dimensional tensor represented as a 2D array of f32 values.
13pub type Tensor2D = ndarray::Array2<f32>;
14
15/// A 3-dimensional tensor represented as a 3D array of f32 values.
16pub type Tensor3D = ndarray::Array3<f32>;
17
18/// A 4-dimensional tensor represented as a 4D array of f32 values.
19pub type Tensor4D = ndarray::Array4<f32>;
20
21/// A 1-dimensional tensor represented as a dynamic-dimensional array of f32 values.
22pub type Tensor1D = ndarray::ArrayD<f32>;
23
24/// Data structure for holding batched input data.
25///
26/// This struct contains the instances, input paths, and indexes for a batch of data.
27/// It's used in the OCR pipeline to process multiple inputs together for efficiency.
28pub struct BatchData {
29 /// The instances in the batch, stored as `Arc<str>` for efficient sharing.
30 pub instances: Vec<Arc<str>>,
31 /// The input paths for the instances in the batch, stored as `Arc<str>` for efficient sharing.
32 pub input_paths: Vec<Arc<str>>,
33 /// The indexes of the instances in the original data set.
34 pub indexes: Vec<usize>,
35}
36
37impl BatchData {
38 /// Creates a new BatchData instance from shared `Arc<str>` paths and indexes.
39 ///
40 /// # Arguments
41 ///
42 /// * `paths` - A vector of `Arc<str>` representing the paths to the instances.
43 /// * `indexes` - A vector of usize representing the indexes of the instances in the original data set.
44 ///
45 /// # Returns
46 ///
47 /// A new BatchData instance.
48 pub fn from_shared_arc_paths(paths: Vec<Arc<str>>, indexes: Vec<usize>) -> Self {
49 let input_paths = paths.clone();
50 Self {
51 instances: paths,
52 input_paths,
53 indexes,
54 }
55 }
56
57 /// Returns the number of instances in the batch.
58 ///
59 /// # Returns
60 ///
61 /// The number of instances in the batch.
62 pub fn len(&self) -> usize {
63 self.instances.len()
64 }
65
66 /// Checks if the batch is empty.
67 ///
68 /// # Returns
69 ///
70 /// True if the batch is empty, false otherwise.
71 pub fn is_empty(&self) -> bool {
72 self.instances.is_empty()
73 }
74
75 /// Returns an iterator over the instances as string slices.
76 ///
77 /// # Returns
78 ///
79 /// An iterator over the instances as string slices.
80 pub fn instances_as_str(&self) -> impl Iterator<Item = &str> + '_ {
81 self.instances.iter().map(|arc| arc.as_ref())
82 }
83
84 /// Returns an iterator over the input paths as string slices.
85 ///
86 /// # Returns
87 ///
88 /// An iterator over the input paths as string slices.
89 pub fn input_paths_as_str(&self) -> impl Iterator<Item = &str> + '_ {
90 self.input_paths.iter().map(|arc| arc.as_ref())
91 }
92}
93
94/// A sampler that creates batches of data with a specified batch size.
95///
96/// This struct is used to divide data into batches for processing in the OCR pipeline.
97/// It implements the Sampler trait for String data.
98#[derive(Debug)]
99pub struct BatchSampler {
100 /// The size of each batch.
101 batch_size: usize,
102}
103
104impl BatchSampler {
105 /// Creates a new BatchSampler with the specified batch size.
106 ///
107 /// # Arguments
108 ///
109 /// * `batch_size` - The size of each batch.
110 ///
111 /// # Returns
112 ///
113 /// A new BatchSampler instance.
114 pub fn new(batch_size: usize) -> Self {
115 Self { batch_size }
116 }
117
118 /// Returns the batch size.
119 ///
120 /// # Returns
121 ///
122 /// The batch size.
123 pub fn batch_size(&self) -> usize {
124 self.batch_size
125 }
126
127 /// Creates an iterator over batches of data.
128 ///
129 /// # Arguments
130 ///
131 /// * `data` - A slice of data to be batched.
132 ///
133 /// # Returns
134 ///
135 /// An iterator over batches of data.
136 pub fn batches<'a, T>(&self, data: &'a [T]) -> impl Iterator<Item = &'a [T]> {
137 if self.batch_size == 0 {
138 data.chunks(1).take(0)
139 } else {
140 data.chunks(self.batch_size).take(usize::MAX)
141 }
142 }
143
144 /// Creates an iterator over batches of data with their indexes.
145 ///
146 /// # Arguments
147 ///
148 /// * `data` - A slice of data to be batched.
149 ///
150 /// # Returns
151 ///
152 /// An iterator over tuples containing batches of data and their indexes.
153 pub fn batches_with_indexes<'a, T>(
154 &self,
155 data: &'a [T],
156 ) -> impl Iterator<Item = (&'a [T], Vec<usize>)> {
157 let batch_size = if self.batch_size == 0 {
158 1
159 } else {
160 self.batch_size
161 };
162 let take_count = if self.batch_size == 0 { 0 } else { usize::MAX };
163
164 data.chunks(batch_size)
165 .take(take_count)
166 .enumerate()
167 .map(move |(batch_idx, chunk)| {
168 let start_idx = batch_idx * self.batch_size;
169 let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
170 (chunk, indexes)
171 })
172 }
173
174 /// Samples batches of data from a vector of strings.
175 ///
176 /// # Arguments
177 ///
178 /// * `data` - A vector of strings to be batched.
179 ///
180 /// # Returns
181 ///
182 /// A vector of BatchData instances.
183 pub fn sample_batch(&self, data: Vec<String>) -> Vec<BatchData> {
184 if self.batch_size == 0 {
185 return Vec::new();
186 }
187
188 data.chunks(self.batch_size)
189 .enumerate()
190 .map(|(batch_idx, chunk)| {
191 let start_idx = batch_idx * self.batch_size;
192 let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
193
194 BatchData::from_shared_arc_paths(
195 chunk.iter().map(|s| Arc::from(s.as_str())).collect(),
196 indexes,
197 )
198 })
199 .collect()
200 }
201}
202
203impl Sampler<String> for BatchSampler {
204 type BatchData = BatchData;
205
206 /// Samples batches of data from a vector of strings.
207 ///
208 /// This method implements the Sampler trait for String data.
209 ///
210 /// # Arguments
211 ///
212 /// * `data` - A vector of strings to be batched.
213 ///
214 /// # Returns
215 ///
216 /// A vector of BatchData instances.
217 fn sample(&self, data: Vec<String>) -> Vec<Self::BatchData> {
218 self.sample_batch(data)
219 }
220}
221
222/// A struct for converting image data into batched tensor format.
223///
224/// This struct provides methods for validating input data and converting
225/// images into a batched tensor format suitable for processing in the OCR pipeline.
226#[derive(Debug, Default)]
227pub struct ToBatch;
228
229impl ToBatch {
230 /// Creates a new ToBatch instance.
231 ///
232 /// # Returns
233 ///
234 /// A new ToBatch instance.
235 pub fn new() -> Self {
236 ToBatch
237 }
238
239 /// Validates the input images and their shapes.
240 ///
241 /// This method checks that the images and shapes arrays have the same length,
242 /// that all images have the correct number of elements for their shapes,
243 /// and that all dimensions are greater than zero.
244 ///
245 /// # Arguments
246 ///
247 /// * `imgs` - A slice of vectors of f32 values representing the images.
248 /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
249 ///
250 /// # Returns
251 ///
252 /// A Result indicating success or an OCRError if validation fails.
253 pub fn validate_inputs(
254 &self,
255 imgs: &[Vec<f32>],
256 shapes: &[(usize, usize, usize)],
257 ) -> Result<(), crate::core::OCRError> {
258 if imgs.is_empty() && shapes.is_empty() {
259 return Ok(());
260 }
261
262 if imgs.is_empty() {
263 return Err(crate::core::OCRError::InvalidInput {
264 message: "Images array is empty but shapes array is not".to_string(),
265 });
266 }
267
268 if shapes.is_empty() {
269 return Err(crate::core::OCRError::InvalidInput {
270 message: "Shapes array is empty but images array is not".to_string(),
271 });
272 }
273
274 if imgs.len() != shapes.len() {
275 return Err(crate::core::OCRError::InvalidInput {
276 message: format!(
277 "Images and shapes must have the same length: got {} images and {} shapes",
278 imgs.len(),
279 shapes.len()
280 ),
281 });
282 }
283
284 for (i, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
285 let expected_len = c * h * w;
286 if img.len() != expected_len {
287 return Err(crate::core::OCRError::InvalidInput {
288 message: format!(
289 "Image {} has {} elements but shape ({}, {}, {}) requires {}",
290 i,
291 img.len(),
292 c,
293 h,
294 w,
295 expected_len
296 ),
297 });
298 }
299
300 if c == 0 || h == 0 || w == 0 {
301 return Err(crate::core::OCRError::InvalidInput {
302 message: format!(
303 "Image {i} has invalid shape dimensions ({c}, {h}, {w}): all must be greater than 0"
304 ),
305 });
306 }
307
308 if expected_len > crate::core::constants::MAX_TENSOR_SIZE {
309 return Err(crate::core::OCRError::InvalidInput {
310 message: format!(
311 "Image {} tensor size {} exceeds maximum allowed size {}",
312 i,
313 expected_len,
314 crate::core::constants::MAX_TENSOR_SIZE
315 ),
316 });
317 }
318 }
319
320 Ok(())
321 }
322
323 /// Applies the batch conversion to the input images and shapes.
324 ///
325 /// This method validates the inputs, then converts the images into a batched tensor format.
326 /// If all images have the same dimensions, it uses a more efficient contiguous copying method.
327 /// Otherwise, it uses a method that handles mixed dimensions.
328 ///
329 /// # Arguments
330 ///
331 /// * `imgs` - A slice of vectors of f32 values representing the images.
332 /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
333 ///
334 /// # Returns
335 ///
336 /// A Result containing a vector of f32 values representing the batched tensor,
337 /// or an OCRError if the operation fails.
338 pub fn apply(
339 &self,
340 imgs: &[Vec<f32>],
341 shapes: &[(usize, usize, usize)],
342 ) -> Result<Vec<f32>, crate::core::OCRError> {
343 self.validate_inputs(imgs, shapes)?;
344
345 if imgs.is_empty() {
346 return Ok(Vec::new());
347 }
348
349 let batch_size = imgs.len();
350 let first_shape = shapes.first().copied().unwrap_or((0, 0, 0));
351 let channels = first_shape.0;
352 let mut max_height = first_shape.1;
353 let mut max_width = first_shape.2;
354 let mut all_same_dimensions = true;
355
356 for (i, &(c, h, w)) in shapes.iter().enumerate() {
357 if c != channels {
358 return Err(crate::core::OCRError::InvalidInput {
359 message: format!(
360 "All images must have the same channel count: image 0 has {channels} channels, image {i} has {c} channels"
361 ),
362 });
363 }
364
365 if h > max_height {
366 max_height = h;
367 }
368 if w > max_width {
369 max_width = w;
370 }
371 if all_same_dimensions && (h != first_shape.1 || w != first_shape.2) {
372 all_same_dimensions = false;
373 }
374 }
375
376 let tensor_size = batch_size * channels * max_height * max_width;
377 let mut batch_tensor = vec![0.0; tensor_size];
378
379 if all_same_dimensions {
380 self.apply_contiguous(imgs, &mut batch_tensor, channels, max_height, max_width);
381 } else {
382 self.apply_mixed_dimensions(
383 imgs,
384 shapes,
385 &mut batch_tensor,
386 channels,
387 max_height,
388 max_width,
389 );
390 }
391
392 Ok(batch_tensor)
393 }
394
395 /// Applies contiguous copying for images with the same dimensions.
396 ///
397 /// This method is used when all images in the batch have the same dimensions,
398 /// allowing for more efficient copying.
399 ///
400 /// # Arguments
401 ///
402 /// * `imgs` - A slice of vectors of f32 values representing the images.
403 /// * `batch_tensor` - A mutable slice of f32 values representing the batched tensor.
404 /// * `channels` - The number of channels in the images.
405 /// * `height` - The height of the images.
406 /// * `width` - The width of the images.
407 fn apply_contiguous(
408 &self,
409 imgs: &[Vec<f32>],
410 batch_tensor: &mut [f32],
411 channels: usize,
412 height: usize,
413 width: usize,
414 ) {
415 let img_size = channels * height * width;
416
417 for (batch_idx, img) in imgs.iter().enumerate() {
418 let batch_offset = batch_idx * img_size;
419 let dst_slice = &mut batch_tensor[batch_offset..batch_offset + img.len()];
420
421 dst_slice.copy_from_slice(img);
422 }
423 }
424
425 /// Applies copying for images with mixed dimensions.
426 ///
427 /// This method is used when images in the batch have different dimensions,
428 /// requiring padding to the maximum dimensions.
429 ///
430 /// # Arguments
431 ///
432 /// * `imgs` - A slice of vectors of f32 values representing the images.
433 /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
434 /// * `batch_tensor` - A mutable slice of f32 values representing the batched tensor.
435 /// * `channels` - The number of channels in the images.
436 /// * `max_height` - The maximum height among all images in the batch.
437 /// * `max_width` - The maximum width among all images in the batch.
438 fn apply_mixed_dimensions(
439 &self,
440 imgs: &[Vec<f32>],
441 shapes: &[(usize, usize, usize)],
442 batch_tensor: &mut [f32],
443 channels: usize,
444 max_height: usize,
445 max_width: usize,
446 ) {
447 for (batch_idx, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
448 let batch_base = batch_idx * channels * max_height * max_width;
449
450 for ch in 0..c {
451 let src_channel_start = ch * h * w;
452 let dst_channel_start = batch_base + ch * max_height * max_width;
453
454 for y in 0..h {
455 let src_row_start = src_channel_start + y * w;
456 let dst_row_start = dst_channel_start + y * max_width;
457
458 let src_row = &img[src_row_start..src_row_start + w];
459 let dst_row = &mut batch_tensor[dst_row_start..dst_row_start + w];
460 dst_row.copy_from_slice(src_row);
461 }
462 }
463 }
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470
471 #[test]
472 fn test_to_batch_apply_contiguous() {
473 let to_batch = ToBatch::new();
474
475 // Create test images with same dimensions
476 let img1 = vec![1.0, 2.0, 3.0, 4.0]; // 1x2x2 image
477 let img2 = vec![5.0, 6.0, 7.0, 8.0]; // 1x2x2 image
478 let imgs = vec![img1, img2];
479 let shapes = vec![(1, 2, 2), (1, 2, 2)]; // Same shapes
480
481 let result = to_batch.apply(&imgs, &shapes).unwrap();
482
483 // Expected: batch_size=2, channels=1, height=2, width=2
484 // Total size: 2 * 1 * 2 * 2 = 8
485 assert_eq!(result.len(), 8);
486
487 // First image should be at positions 0-3
488 assert_eq!(result[0], 1.0);
489 assert_eq!(result[1], 2.0);
490 assert_eq!(result[2], 3.0);
491 assert_eq!(result[3], 4.0);
492
493 // Second image should be at positions 4-7
494 assert_eq!(result[4], 5.0);
495 assert_eq!(result[5], 6.0);
496 assert_eq!(result[6], 7.0);
497 assert_eq!(result[7], 8.0);
498 }
499
500 #[test]
501 fn test_to_batch_apply_mixed_dimensions() {
502 let to_batch = ToBatch::new();
503
504 // Create test images with different dimensions
505 let img1 = vec![1.0, 2.0]; // 1x1x2 image
506 let img2 = vec![3.0, 4.0, 5.0, 6.0]; // 1x2x2 image
507 let imgs = vec![img1, img2];
508 let shapes = vec![(1, 1, 2), (1, 2, 2)]; // Different shapes
509
510 let result = to_batch.apply(&imgs, &shapes).unwrap();
511
512 // Expected: batch_size=2, channels=1, max_height=2, max_width=2
513 // Total size: 2 * 1 * 2 * 2 = 8
514 assert_eq!(result.len(), 8);
515
516 // First image (1x2) should be padded to (2x2)
517 // Second image (2x2) should fit exactly
518 // The exact layout depends on the mixed dimensions implementation
519 assert!(result.contains(&1.0));
520 assert!(result.contains(&2.0));
521 assert!(result.contains(&3.0));
522 assert!(result.contains(&4.0));
523 assert!(result.contains(&5.0));
524 assert!(result.contains(&6.0));
525 }
526}