pub mod dynamic;
use crate::core::traits::Sampler;
use std::sync::Arc;
pub struct BatchData {
pub instances: Vec<Arc<str>>,
pub input_paths: Vec<Arc<str>>,
pub indexes: Vec<usize>,
}
impl BatchData {
pub fn from_shared_arc_paths(paths: Vec<Arc<str>>, indexes: Vec<usize>) -> Self {
let input_paths = paths.clone();
Self {
instances: paths,
input_paths,
indexes,
}
}
pub fn len(&self) -> usize {
self.instances.len()
}
pub fn is_empty(&self) -> bool {
self.instances.is_empty()
}
pub fn instances_as_str(&self) -> impl Iterator<Item = &str> + '_ {
self.instances.iter().map(|arc| arc.as_ref())
}
pub fn input_paths_as_str(&self) -> impl Iterator<Item = &str> + '_ {
self.input_paths.iter().map(|arc| arc.as_ref())
}
}
#[derive(Debug)]
pub struct BatchSampler {
batch_size: usize,
}
impl BatchSampler {
pub fn new(batch_size: usize) -> Self {
Self { batch_size }
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn batches<'a, T>(&self, data: &'a [T]) -> impl Iterator<Item = &'a [T]> {
if self.batch_size == 0 {
data.chunks(1).take(0)
} else {
data.chunks(self.batch_size).take(usize::MAX)
}
}
pub fn batches_with_indexes<'a, T>(
&self,
data: &'a [T],
) -> impl Iterator<Item = (&'a [T], Vec<usize>)> {
let batch_size = if self.batch_size == 0 {
1
} else {
self.batch_size
};
let take_count = if self.batch_size == 0 { 0 } else { usize::MAX };
data.chunks(batch_size)
.take(take_count)
.enumerate()
.map(move |(batch_idx, chunk)| {
let start_idx = batch_idx * self.batch_size;
let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
(chunk, indexes)
})
}
pub fn sample_batch(&self, data: Vec<String>) -> Vec<BatchData> {
if self.batch_size == 0 {
return Vec::new();
}
data.chunks(self.batch_size)
.enumerate()
.map(|(batch_idx, chunk)| {
let start_idx = batch_idx * self.batch_size;
let indexes: Vec<usize> = (0..chunk.len()).map(|i| start_idx + i).collect();
BatchData::from_shared_arc_paths(
chunk.iter().map(|s| Arc::from(s.as_str())).collect(),
indexes,
)
})
.collect()
}
}
impl Sampler<String> for BatchSampler {
type BatchData = BatchData;
fn sample(&self, data: Vec<String>) -> Vec<Self::BatchData> {
self.sample_batch(data)
}
}
#[derive(Debug, Default)]
pub struct ToBatch;
impl ToBatch {
pub fn new() -> Self {
ToBatch
}
pub fn validate_inputs(
&self,
imgs: &[Vec<f32>],
shapes: &[(usize, usize, usize)],
) -> Result<(), crate::core::OCRError> {
if imgs.is_empty() && shapes.is_empty() {
return Ok(());
}
if imgs.is_empty() {
return Err(crate::core::OCRError::InvalidInput {
message: "Images array is empty but shapes array is not".to_string(),
});
}
if shapes.is_empty() {
return Err(crate::core::OCRError::InvalidInput {
message: "Shapes array is empty but images array is not".to_string(),
});
}
if imgs.len() != shapes.len() {
return Err(crate::core::OCRError::InvalidInput {
message: format!(
"Images and shapes must have the same length: got {} images and {} shapes",
imgs.len(),
shapes.len()
),
});
}
for (i, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
let expected_len = c * h * w;
if img.len() != expected_len {
return Err(crate::core::OCRError::InvalidInput {
message: format!(
"Image {} has {} elements but shape ({}, {}, {}) requires {}",
i,
img.len(),
c,
h,
w,
expected_len
),
});
}
if c == 0 || h == 0 || w == 0 {
return Err(crate::core::OCRError::InvalidInput {
message: format!(
"Image {i} has invalid shape dimensions ({c}, {h}, {w}): all must be greater than 0"
),
});
}
if expected_len > crate::core::constants::MAX_TENSOR_SIZE {
return Err(crate::core::OCRError::InvalidInput {
message: format!(
"Image {} tensor size {} exceeds maximum allowed size {}",
i,
expected_len,
crate::core::constants::MAX_TENSOR_SIZE
),
});
}
}
Ok(())
}
pub fn apply(
&self,
imgs: &[Vec<f32>],
shapes: &[(usize, usize, usize)],
) -> Result<Vec<f32>, crate::core::OCRError> {
self.validate_inputs(imgs, shapes)?;
if imgs.is_empty() {
return Ok(Vec::new());
}
let batch_size = imgs.len();
let first_shape = shapes.first().copied().unwrap_or((0, 0, 0));
let channels = first_shape.0;
let mut max_height = first_shape.1;
let mut max_width = first_shape.2;
let mut all_same_dimensions = true;
for (i, &(c, h, w)) in shapes.iter().enumerate() {
if c != channels {
return Err(crate::core::OCRError::InvalidInput {
message: format!(
"All images must have the same channel count: image 0 has {channels} channels, image {i} has {c} channels"
),
});
}
if h > max_height {
max_height = h;
}
if w > max_width {
max_width = w;
}
if all_same_dimensions && (h != first_shape.1 || w != first_shape.2) {
all_same_dimensions = false;
}
}
let tensor_size = batch_size * channels * max_height * max_width;
let mut batch_tensor = vec![0.0; tensor_size];
if all_same_dimensions {
self.apply_contiguous(imgs, &mut batch_tensor, channels, max_height, max_width);
} else {
self.apply_mixed_dimensions(
imgs,
shapes,
&mut batch_tensor,
channels,
max_height,
max_width,
);
}
Ok(batch_tensor)
}
fn apply_contiguous(
&self,
imgs: &[Vec<f32>],
batch_tensor: &mut [f32],
channels: usize,
height: usize,
width: usize,
) {
let img_size = channels * height * width;
for (batch_idx, img) in imgs.iter().enumerate() {
let batch_offset = batch_idx * img_size;
let dst_slice = &mut batch_tensor[batch_offset..batch_offset + img.len()];
dst_slice.copy_from_slice(img);
}
}
fn apply_mixed_dimensions(
&self,
imgs: &[Vec<f32>],
shapes: &[(usize, usize, usize)],
batch_tensor: &mut [f32],
channels: usize,
max_height: usize,
max_width: usize,
) {
for (batch_idx, (img, &(c, h, w))) in imgs.iter().zip(shapes.iter()).enumerate() {
let batch_base = batch_idx * channels * max_height * max_width;
for ch in 0..c {
let src_channel_start = ch * h * w;
let dst_channel_start = batch_base + ch * max_height * max_width;
for y in 0..h {
let src_row_start = src_channel_start + y * w;
let dst_row_start = dst_channel_start + y * max_width;
let src_row = &img[src_row_start..src_row_start + w];
let dst_row = &mut batch_tensor[dst_row_start..dst_row_start + w];
dst_row.copy_from_slice(src_row);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::OCRError;
#[test]
fn test_to_batch_apply_contiguous() -> Result<(), OCRError> {
let to_batch = ToBatch::new();
let img1 = vec![1.0, 2.0, 3.0, 4.0]; let img2 = vec![5.0, 6.0, 7.0, 8.0]; let imgs = vec![img1, img2];
let shapes = vec![(1, 2, 2), (1, 2, 2)];
let result = to_batch.apply(&imgs, &shapes)?;
assert_eq!(result.len(), 8);
assert_eq!(result[0], 1.0);
assert_eq!(result[1], 2.0);
assert_eq!(result[2], 3.0);
assert_eq!(result[3], 4.0);
assert_eq!(result[4], 5.0);
assert_eq!(result[5], 6.0);
assert_eq!(result[6], 7.0);
assert_eq!(result[7], 8.0);
Ok(())
}
#[test]
fn test_to_batch_apply_mixed_dimensions() -> Result<(), OCRError> {
let to_batch = ToBatch::new();
let img1 = vec![1.0, 2.0]; let img2 = vec![3.0, 4.0, 5.0, 6.0]; let imgs = vec![img1, img2];
let shapes = vec![(1, 1, 2), (1, 2, 2)];
let result = to_batch.apply(&imgs, &shapes)?;
assert_eq!(result.len(), 8);
assert!(result.contains(&1.0));
assert!(result.contains(&2.0));
assert!(result.contains(&3.0));
assert!(result.contains(&4.0));
assert!(result.contains(&5.0));
assert!(result.contains(&6.0));
Ok(())
}
}