oar_ocr_core/core/batch/mod.rs
1//! Batch processing utilities for the OCR pipeline.
2//!
3//! This module provides structures and functions for handling batched data
4//! in the OCR pipeline, including batching of input data, sampling, and
5//! tensor operations for batched processing.
6
7pub mod dynamic;
8
9use crate::core::traits::Sampler;
10use std::sync::Arc;
11
12/// Data structure for holding batched input data.
13///
14/// This struct contains the instances, input paths, and indexes for a batch of data.
15/// It's used in the OCR pipeline to process multiple inputs together for efficiency.
16pub struct BatchData {
17 /// The instances in the batch, stored as `Arc<str>` for efficient sharing.
18 pub instances: Vec<Arc<str>>,
19 /// The input paths for the instances in the batch, stored as `Arc<str>` for efficient sharing.
20 pub input_paths: Vec<Arc<str>>,
21 /// The indexes of the instances in the original data set.
22 pub indexes: Vec<usize>,
23}
24
25impl BatchData {
26 /// Creates a new BatchData instance from shared `Arc<str>` paths and indexes.
27 ///
28 /// # Arguments
29 ///
30 /// * `paths` - A vector of `Arc<str>` representing the paths to the instances.
31 /// * `indexes` - A vector of usize representing the indexes of the instances in the original data set.
32 ///
33 /// # Returns
34 ///
35 /// A new BatchData instance.
36 pub fn from_shared_arc_paths(paths: Vec<Arc<str>>, indexes: Vec<usize>) -> Self {
37 let input_paths = paths.clone();
38 Self {
39 instances: paths,
40 input_paths,
41 indexes,
42 }
43 }
44
45 /// Returns the number of instances in the batch.
46 ///
47 /// # Returns
48 ///
49 /// The number of instances in the batch.
50 pub fn len(&self) -> usize {
51 self.instances.len()
52 }
53
54 /// Checks if the batch is empty.
55 ///
56 /// # Returns
57 ///
58 /// True if the batch is empty, false otherwise.
59 pub fn is_empty(&self) -> bool {
60 self.instances.is_empty()
61 }
62
63 /// Returns an iterator over the instances as string slices.
64 ///
65 /// # Returns
66 ///
67 /// An iterator over the instances as string slices.
68 pub fn instances_as_str(&self) -> impl Iterator<Item = &str> + '_ {
69 self.instances.iter().map(|arc| arc.as_ref())
70 }
71
72 /// Returns an iterator over the input paths as string slices.
73 ///
74 /// # Returns
75 ///
76 /// An iterator over the input paths as string slices.
77 pub fn input_paths_as_str(&self) -> impl Iterator<Item = &str> + '_ {
78 self.input_paths.iter().map(|arc| arc.as_ref())
79 }
80}
81
82/// A sampler that creates batches of data with a specified batch size.
83///
84/// This struct is used to divide data into batches for processing in the OCR pipeline.
85/// It implements the Sampler trait for String data.
86#[derive(Debug)]
87pub struct BatchSampler {
88 /// The size of each batch.
89 batch_size: usize,
90}
91
92impl BatchSampler {
93 /// Creates a new BatchSampler with the specified batch size.
94 ///
95 /// # Arguments
96 ///
97 /// * `batch_size` - The size of each batch.
98 ///
99 /// # Returns
100 ///
101 /// A new BatchSampler instance.
102 pub fn new(batch_size: usize) -> Self {
103 Self { batch_size }
104 }
105
106 /// Returns the batch size.
107 ///
108 /// # Returns
109 ///
110 /// The batch size.
111 pub fn batch_size(&self) -> usize {
112 self.batch_size
113 }
114
115 /// Creates an iterator over batches of data.
116 ///
117 /// # Arguments
118 ///
119 /// * `data` - A slice of data to be batched.
120 ///
121 /// # Returns
122 ///
123 /// An iterator over batches of data.
124 pub fn batches<'a, T>(&self, data: &'a [T]) -> impl Iterator<Item = &'a [T]> {
125 if self.batch_size == 0 {
126 data.chunks(1).take(0)
127 } else {
128 data.chunks(self.batch_size).take(usize::MAX)
129 }
130 }
131
132 /// Creates an iterator over batches of data with their indexes.
133 ///
134 /// # Arguments
135 ///
136 /// * `data` - A slice of data to be batched.
137 ///
138 /// # Returns
139 ///
140 /// An iterator over tuples containing batches of data and their indexes.
141 pub fn batches_with_indexes<'a, T>(
142 &self,
143 data: &'a [T],
144 ) -> impl Iterator<Item = (&'a [T], Vec<usize>)> {
145 let batch_size = if self.batch_size == 0 {
146 1
147 } else {
148 self.batch_size
149 };
150 let take_count = if self.batch_size == 0 { 0 } else { usize::MAX };
151
152 data.chunks(batch_size)
153 .take(take_count)
154 .enumerate()
155 .map(move |(batch_idx, chunk)| {
156 let start_idx = batch_idx * self.batch_size;
157 let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
158 (chunk, indexes)
159 })
160 }
161
162 /// Samples batches of data from a vector of strings.
163 ///
164 /// # Arguments
165 ///
166 /// * `data` - A vector of strings to be batched.
167 ///
168 /// # Returns
169 ///
170 /// A vector of BatchData instances.
171 pub fn sample_batch(&self, data: Vec<String>) -> Vec<BatchData> {
172 if self.batch_size == 0 {
173 return Vec::new();
174 }
175
176 data.chunks(self.batch_size)
177 .enumerate()
178 .map(|(batch_idx, chunk)| {
179 let start_idx = batch_idx * self.batch_size;
180 let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
181
182 BatchData::from_shared_arc_paths(
183 chunk.iter().map(|s| Arc::from(s.as_str())).collect(),
184 indexes,
185 )
186 })
187 .collect()
188 }
189}
190
191impl Sampler<String> for BatchSampler {
192 type BatchData = BatchData;
193
194 /// Samples batches of data from a vector of strings.
195 ///
196 /// This method implements the Sampler trait for String data.
197 ///
198 /// # Arguments
199 ///
200 /// * `data` - A vector of strings to be batched.
201 ///
202 /// # Returns
203 ///
204 /// A vector of BatchData instances.
205 fn sample(&self, data: Vec<String>) -> Vec<Self::BatchData> {
206 self.sample_batch(data)
207 }
208}
209
210/// A struct for converting image data into batched tensor format.
211///
212/// This struct provides methods for validating input data and converting
213/// images into a batched tensor format suitable for processing in the OCR pipeline.
214#[derive(Debug, Default)]
215pub struct ToBatch;
216
217impl ToBatch {
218 /// Creates a new ToBatch instance.
219 ///
220 /// # Returns
221 ///
222 /// A new ToBatch instance.
223 pub fn new() -> Self {
224 ToBatch
225 }
226
227 /// Validates the input images and their shapes.
228 ///
229 /// This method checks that the images and shapes arrays have the same length,
230 /// that all images have the correct number of elements for their shapes,
231 /// and that all dimensions are greater than zero.
232 ///
233 /// # Arguments
234 ///
235 /// * `imgs` - A slice of vectors of f32 values representing the images.
236 /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
237 ///
238 /// # Returns
239 ///
240 /// A Result indicating success or an OCRError if validation fails.
241 pub fn validate_inputs(
242 &self,
243 imgs: &[Vec<f32>],
244 shapes: &[(usize, usize, usize)],
245 ) -> Result<(), crate::core::OCRError> {
246 if imgs.is_empty() && shapes.is_empty() {
247 return Ok(());
248 }
249
250 if imgs.is_empty() {
251 return Err(crate::core::OCRError::InvalidInput {
252 message: "Images array is empty but shapes array is not".to_string(),
253 });
254 }
255
256 if shapes.is_empty() {
257 return Err(crate::core::OCRError::InvalidInput {
258 message: "Shapes array is empty but images array is not".to_string(),
259 });
260 }
261
262 if imgs.len() != shapes.len() {
263 return Err(crate::core::OCRError::InvalidInput {
264 message: format!(
265 "Images and shapes must have the same length: got {} images and {} shapes",
266 imgs.len(),
267 shapes.len()
268 ),
269 });
270 }
271
272 for (i, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
273 let expected_len = c * h * w;
274 if img.len() != expected_len {
275 return Err(crate::core::OCRError::InvalidInput {
276 message: format!(
277 "Image {} has {} elements but shape ({}, {}, {}) requires {}",
278 i,
279 img.len(),
280 c,
281 h,
282 w,
283 expected_len
284 ),
285 });
286 }
287
288 if c == 0 || h == 0 || w == 0 {
289 return Err(crate::core::OCRError::InvalidInput {
290 message: format!(
291 "Image {i} has invalid shape dimensions ({c}, {h}, {w}): all must be greater than 0"
292 ),
293 });
294 }
295
296 if expected_len > crate::core::constants::MAX_TENSOR_SIZE {
297 return Err(crate::core::OCRError::InvalidInput {
298 message: format!(
299 "Image {} tensor size {} exceeds maximum allowed size {}",
300 i,
301 expected_len,
302 crate::core::constants::MAX_TENSOR_SIZE
303 ),
304 });
305 }
306 }
307
308 Ok(())
309 }
310
311 /// Applies the batch conversion to the input images and shapes.
312 ///
313 /// This method validates the inputs, then converts the images into a batched tensor format.
314 /// If all images have the same dimensions, it uses a more efficient contiguous copying method.
315 /// Otherwise, it uses a method that handles mixed dimensions.
316 ///
317 /// # Arguments
318 ///
319 /// * `imgs` - A slice of vectors of f32 values representing the images.
320 /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
321 ///
322 /// # Returns
323 ///
324 /// A Result containing a vector of f32 values representing the batched tensor,
325 /// or an OCRError if the operation fails.
326 pub fn apply(
327 &self,
328 imgs: &[Vec<f32>],
329 shapes: &[(usize, usize, usize)],
330 ) -> Result<Vec<f32>, crate::core::OCRError> {
331 self.validate_inputs(imgs, shapes)?;
332
333 if imgs.is_empty() {
334 return Ok(Vec::new());
335 }
336
337 let batch_size = imgs.len();
338 let first_shape = shapes.first().copied().unwrap_or((0, 0, 0));
339 let channels = first_shape.0;
340 let mut max_height = first_shape.1;
341 let mut max_width = first_shape.2;
342 let mut all_same_dimensions = true;
343
344 for (i, &(c, h, w)) in shapes.iter().enumerate() {
345 if c != channels {
346 return Err(crate::core::OCRError::InvalidInput {
347 message: format!(
348 "All images must have the same channel count: image 0 has {channels} channels, image {i} has {c} channels"
349 ),
350 });
351 }
352
353 if h > max_height {
354 max_height = h;
355 }
356 if w > max_width {
357 max_width = w;
358 }
359 if all_same_dimensions && (h != first_shape.1 || w != first_shape.2) {
360 all_same_dimensions = false;
361 }
362 }
363
364 let tensor_size = batch_size * channels * max_height * max_width;
365 let mut batch_tensor = vec![0.0; tensor_size];
366
367 if all_same_dimensions {
368 self.apply_contiguous(imgs, &mut batch_tensor, channels, max_height, max_width);
369 } else {
370 self.apply_mixed_dimensions(
371 imgs,
372 shapes,
373 &mut batch_tensor,
374 channels,
375 max_height,
376 max_width,
377 );
378 }
379
380 Ok(batch_tensor)
381 }
382
383 /// Applies contiguous copying for images with the same dimensions.
384 ///
385 /// This method is used when all images in the batch have the same dimensions,
386 /// allowing for more efficient copying.
387 ///
388 /// # Arguments
389 ///
390 /// * `imgs` - A slice of vectors of f32 values representing the images.
391 /// * `batch_tensor` - A mutable slice of f32 values representing the batched tensor.
392 /// * `channels` - The number of channels in the images.
393 /// * `height` - The height of the images.
394 /// * `width` - The width of the images.
395 fn apply_contiguous(
396 &self,
397 imgs: &[Vec<f32>],
398 batch_tensor: &mut [f32],
399 channels: usize,
400 height: usize,
401 width: usize,
402 ) {
403 let img_size = channels * height * width;
404
405 for (batch_idx, img) in imgs.iter().enumerate() {
406 let batch_offset = batch_idx * img_size;
407 let dst_slice = &mut batch_tensor[batch_offset..batch_offset + img.len()];
408
409 dst_slice.copy_from_slice(img);
410 }
411 }
412
413 /// Applies copying for images with mixed dimensions.
414 ///
415 /// This method is used when images in the batch have different dimensions,
416 /// requiring padding to the maximum dimensions.
417 ///
418 /// # Arguments
419 ///
420 /// * `imgs` - A slice of vectors of f32 values representing the images.
421 /// * `shapes` - A slice of tuples representing the shapes of the images (channels, height, width).
422 /// * `batch_tensor` - A mutable slice of f32 values representing the batched tensor.
423 /// * `channels` - The number of channels in the images.
424 /// * `max_height` - The maximum height among all images in the batch.
425 /// * `max_width` - The maximum width among all images in the batch.
426 fn apply_mixed_dimensions(
427 &self,
428 imgs: &[Vec<f32>],
429 shapes: &[(usize, usize, usize)],
430 batch_tensor: &mut [f32],
431 channels: usize,
432 max_height: usize,
433 max_width: usize,
434 ) {
435 for (batch_idx, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
436 let batch_base = batch_idx * channels * max_height * max_width;
437
438 for ch in 0..c {
439 let src_channel_start = ch * h * w;
440 let dst_channel_start = batch_base + ch * max_height * max_width;
441
442 for y in 0..h {
443 let src_row_start = src_channel_start + y * w;
444 let dst_row_start = dst_channel_start + y * max_width;
445
446 let src_row = &img[src_row_start..src_row_start + w];
447 let dst_row = &mut batch_tensor[dst_row_start..dst_row_start + w];
448 dst_row.copy_from_slice(src_row);
449 }
450 }
451 }
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use crate::core::OCRError;
459
460 #[test]
461 fn test_to_batch_apply_contiguous() -> Result<(), OCRError> {
462 let to_batch = ToBatch::new();
463
464 // Create test images with same dimensions
465 let img1 = vec![1.0, 2.0, 3.0, 4.0]; // 1x2x2 image
466 let img2 = vec![5.0, 6.0, 7.0, 8.0]; // 1x2x2 image
467 let imgs = vec![img1, img2];
468 let shapes = vec![(1, 2, 2), (1, 2, 2)]; // Same shapes
469
470 let result = to_batch.apply(&imgs, &shapes)?;
471
472 // Expected: batch_size=2, channels=1, height=2, width=2
473 // Total size: 2 * 1 * 2 * 2 = 8
474 assert_eq!(result.len(), 8);
475
476 // First image should be at positions 0-3
477 assert_eq!(result[0], 1.0);
478 assert_eq!(result[1], 2.0);
479 assert_eq!(result[2], 3.0);
480 assert_eq!(result[3], 4.0);
481
482 // Second image should be at positions 4-7
483 assert_eq!(result[4], 5.0);
484 assert_eq!(result[5], 6.0);
485 assert_eq!(result[6], 7.0);
486 assert_eq!(result[7], 8.0);
487 Ok(())
488 }
489
490 #[test]
491 fn test_to_batch_apply_mixed_dimensions() -> Result<(), OCRError> {
492 let to_batch = ToBatch::new();
493
494 // Create test images with different dimensions
495 let img1 = vec![1.0, 2.0]; // 1x1x2 image
496 let img2 = vec![3.0, 4.0, 5.0, 6.0]; // 1x2x2 image
497 let imgs = vec![img1, img2];
498 let shapes = vec![(1, 1, 2), (1, 2, 2)]; // Different shapes
499
500 let result = to_batch.apply(&imgs, &shapes)?;
501
502 // Expected: batch_size=2, channels=1, max_height=2, max_width=2
503 // Total size: 2 * 1 * 2 * 2 = 8
504 assert_eq!(result.len(), 8);
505
506 // First image (1x2) should be padded to (2x2)
507 // Second image (2x2) should fit exactly
508 // The exact layout depends on the mixed dimensions implementation
509 assert!(result.contains(&1.0));
510 assert!(result.contains(&2.0));
511 assert!(result.contains(&3.0));
512 assert!(result.contains(&4.0));
513 assert!(result.contains(&5.0));
514 assert!(result.contains(&6.0));
515 Ok(())
516 }
517}