use crate::core::OCRError;
use crate::utils::{PaddingStrategy, ResizePadConfig, resize_and_pad};
use image::RgbImage;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AspectRatioBucket {
pub min_ratio: f32,
pub max_ratio: f32,
pub target_dims: (u32, u32),
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AspectRatioBucketingConfig {
pub buckets: Vec<AspectRatioBucket>,
pub padding_color: [u8; 3],
pub fallback_to_exact: bool,
pub max_images_per_bucket: usize,
}
impl Default for AspectRatioBucketingConfig {
fn default() -> Self {
Self {
buckets: vec![
AspectRatioBucket {
min_ratio: 0.0,
max_ratio: 0.8,
target_dims: (64, 32),
name: "tall".to_string(),
},
AspectRatioBucket {
min_ratio: 0.8,
max_ratio: 1.2,
target_dims: (32, 32),
name: "square".to_string(),
},
AspectRatioBucket {
min_ratio: 1.2,
max_ratio: 2.5,
target_dims: (32, 80),
name: "normal".to_string(),
},
AspectRatioBucket {
min_ratio: 2.5,
max_ratio: 4.5,
target_dims: (32, 160),
name: "wide".to_string(),
},
AspectRatioBucket {
min_ratio: 4.5,
max_ratio: f32::MAX, target_dims: (32, 320),
name: "ultra_wide".to_string(),
},
],
padding_color: [0, 0, 0], fallback_to_exact: false,
max_images_per_bucket: 0, }
}
}
#[derive(Debug, Clone)]
pub struct AspectRatioBucketing {
config: AspectRatioBucketingConfig,
resize_configs: HashMap<String, ResizePadConfig>,
}
impl Default for AspectRatioBucketing {
fn default() -> Self {
Self::new(AspectRatioBucketingConfig::default())
}
}
impl AspectRatioBucketing {
pub fn new(config: AspectRatioBucketingConfig) -> Self {
let mut resize_configs = HashMap::new();
for bucket in &config.buckets {
let (target_height, target_width) = bucket.target_dims;
let resize_config = ResizePadConfig::new((target_width, target_height))
.with_padding_strategy(PaddingStrategy::SolidColor(config.padding_color));
resize_configs.insert(bucket.name.clone(), resize_config);
}
Self {
config,
resize_configs,
}
}
pub fn calculate_aspect_ratio(&self, image: &RgbImage) -> f32 {
let (width, height) = image.dimensions();
width as f32 / height as f32
}
pub fn find_bucket(&self, aspect_ratio: f32) -> Option<&AspectRatioBucket> {
self.config
.buckets
.iter()
.find(|bucket| aspect_ratio >= bucket.min_ratio && aspect_ratio < bucket.max_ratio)
}
pub fn resize_and_pad_to_bucket(
&self,
image: &RgbImage,
bucket: &AspectRatioBucket,
) -> Result<RgbImage, OCRError> {
let config =
self.resize_configs
.get(&bucket.name)
.ok_or_else(|| OCRError::ConfigError {
message: format!("No cached resize config found for bucket: {}", bucket.name),
})?;
let padded = resize_and_pad(image, config).map_err(OCRError::from)?;
Ok(padded)
}
pub fn group_images_by_buckets(
&self,
images: Vec<(usize, RgbImage)>,
) -> Result<HashMap<String, Vec<(usize, RgbImage)>>, OCRError> {
let mut bucket_groups: HashMap<String, Vec<(usize, RgbImage)>> = HashMap::new();
let mut exact_groups: HashMap<(u32, u32), Vec<(usize, RgbImage)>> = HashMap::new();
for (index, image) in images {
let aspect_ratio = self.calculate_aspect_ratio(&image);
if let Some(bucket) = self.find_bucket(aspect_ratio) {
let processed_image = self.resize_and_pad_to_bucket(&image, bucket)?;
let bucket_group = bucket_groups.entry(bucket.name.clone()).or_default();
if self.config.max_images_per_bucket == 0
|| bucket_group.len() < self.config.max_images_per_bucket
{
bucket_group.push((index, processed_image));
} else if self.config.fallback_to_exact {
let dims = (image.height(), image.width());
exact_groups.entry(dims).or_default().push((index, image));
} else {
bucket_group.push((index, processed_image));
}
} else if self.config.fallback_to_exact {
let dims = (image.height(), image.width());
exact_groups.entry(dims).or_default().push((index, image));
} else {
return Err(OCRError::ConfigError {
message: format!(
"No bucket found for aspect ratio {:.2} and fallback disabled",
aspect_ratio
),
});
}
}
for ((h, w), group) in exact_groups {
let exact_key = format!("exact_{}x{}", h, w);
bucket_groups.insert(exact_key, group);
}
Ok(bucket_groups)
}
pub fn get_bucket_stats(&self, images: &[(usize, RgbImage)]) -> HashMap<String, usize> {
let mut stats = HashMap::new();
for (_index, image) in images {
let aspect_ratio = self.calculate_aspect_ratio(image);
if let Some(bucket) = self.find_bucket(aspect_ratio) {
*stats.entry(bucket.name.clone()).or_insert(0) += 1;
} else {
*stats.entry("no_bucket".to_string()).or_insert(0) += 1;
}
}
stats
}
}
#[cfg(test)]
mod tests {
use super::*;
use image::{ImageBuffer, Rgb};
fn create_test_image(width: u32, height: u32) -> RgbImage {
ImageBuffer::from_pixel(width, height, Rgb([255, 255, 255]))
}
#[test]
fn test_aspect_ratio_calculation() {
let bucketing = AspectRatioBucketing::default();
let image = create_test_image(100, 50);
let ratio = bucketing.calculate_aspect_ratio(&image);
assert_eq!(ratio, 2.0);
}
#[test]
fn test_bucket_finding() {
let bucketing = AspectRatioBucketing::default();
assert_eq!(
bucketing
.find_bucket(0.5)
.map(|bucket| bucket.name.as_str()),
Some("tall")
);
assert_eq!(
bucketing
.find_bucket(1.0)
.map(|bucket| bucket.name.as_str()),
Some("square")
);
assert_eq!(
bucketing
.find_bucket(2.0)
.map(|bucket| bucket.name.as_str()),
Some("normal")
);
assert_eq!(
bucketing
.find_bucket(3.0)
.map(|bucket| bucket.name.as_str()),
Some("wide")
);
assert_eq!(
bucketing
.find_bucket(5.0)
.map(|bucket| bucket.name.as_str()),
Some("ultra_wide")
);
assert_eq!(
bucketing
.find_bucket(100.0)
.map(|bucket| bucket.name.as_str()),
Some("ultra_wide")
);
assert_eq!(
bucketing
.find_bucket(5000.0)
.map(|bucket| bucket.name.as_str()),
Some("ultra_wide")
);
}
#[test]
fn test_resize_and_pad() -> Result<(), OCRError> {
let bucketing = AspectRatioBucketing::default();
let image = create_test_image(100, 50); let Some(bucket) = bucketing.find_bucket(2.0) else {
panic!("expected bucket for aspect ratio 2.0");
};
let result = bucketing.resize_and_pad_to_bucket(&image, bucket)?;
let (width, height) = result.dimensions();
assert_eq!((height, width), bucket.target_dims);
Ok(())
}
#[test]
fn test_group_images_by_buckets() -> Result<(), OCRError> {
let bucketing = AspectRatioBucketing::default();
let images = vec![
(0, create_test_image(100, 50)), (1, create_test_image(200, 100)), (2, create_test_image(50, 100)), (3, create_test_image(100, 100)), (4, create_test_image(300, 60)), ];
let groups = bucketing.group_images_by_buckets(images)?;
assert!(groups.len() >= 4);
assert!(groups.contains_key("normal"));
assert!(groups.contains_key("tall"));
assert!(groups.contains_key("square"));
assert!(groups.contains_key("ultra_wide"));
assert_eq!(groups.get("normal").map(|v| v.len()), Some(2));
assert_eq!(groups.get("tall").map(|v| v.len()), Some(1));
assert_eq!(groups.get("square").map(|v| v.len()), Some(1));
assert_eq!(groups.get("ultra_wide").map(|v| v.len()), Some(1));
Ok(())
}
#[test]
fn test_bucket_efficiency_comparison() -> Result<(), OCRError> {
let bucketing = AspectRatioBucketing::default();
let mut images = Vec::new();
for i in 0..20 {
let width = 100 + i * 2; let height = 50 + i; images.push((i as usize, create_test_image(width, height)));
}
let bucket_groups = bucketing.group_images_by_buckets(images.clone())?;
let mut exact_groups = HashMap::new();
for (i, image) in images {
let dims = (image.height(), image.width());
exact_groups
.entry(dims)
.or_insert_with(Vec::new)
.push((i, image));
}
assert!(bucket_groups.len() < exact_groups.len());
let largest_bucket_size = bucket_groups.values().map(|v| v.len()).max().unwrap_or(0);
assert!(largest_bucket_size > 10);
let exact_single_groups = exact_groups.values().filter(|v| v.len() == 1).count();
assert!(exact_single_groups > 15); Ok(())
}
#[test]
fn test_json_serialization_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
let config = AspectRatioBucketingConfig::default();
let json_str = serde_json::to_string(&config)?;
let deserialized: AspectRatioBucketingConfig = serde_json::from_str(&json_str)?;
assert_eq!(config.buckets.len(), deserialized.buckets.len());
assert_eq!(config.padding_color, deserialized.padding_color);
assert_eq!(config.fallback_to_exact, deserialized.fallback_to_exact);
assert_eq!(
config.max_images_per_bucket,
deserialized.max_images_per_bucket
);
let bucketing = AspectRatioBucketing::new(deserialized);
let bucket_name = bucketing
.find_bucket(5000.0)
.map(|bucket| bucket.name.as_str());
assert_eq!(bucket_name, Some("ultra_wide"));
Ok(())
}
#[test]
fn test_resize_config_caching() {
let bucketing = AspectRatioBucketing::default();
assert_eq!(
bucketing.resize_configs.len(),
bucketing.config.buckets.len()
);
for bucket in &bucketing.config.buckets {
assert!(bucketing.resize_configs.contains_key(&bucket.name));
let cached_config = &bucketing.resize_configs[&bucket.name];
let (target_height, target_width) = bucket.target_dims;
assert_eq!(cached_config.target_dims, (target_width, target_height));
let PaddingStrategy::SolidColor(color) = cached_config.padding_strategy else {
assert!(
matches!(
cached_config.padding_strategy,
PaddingStrategy::SolidColor(_)
),
"Expected SolidColor padding strategy"
);
continue;
};
assert_eq!(color, bucketing.config.padding_color);
}
let test_image = create_test_image(100, 50); let maybe_bucket = bucketing.find_bucket(2.0);
let Some(bucket) = maybe_bucket else {
assert!(
maybe_bucket.is_some(),
"aspect ratio 2.0 should map to an existing bucket"
);
return;
};
let result = match bucketing.resize_and_pad_to_bucket(&test_image, bucket) {
Ok(image) => image,
Err(err) => {
panic!("expected resize_and_pad_to_bucket to succeed, got {err:?}");
}
};
let (target_height, target_width) = bucket.target_dims;
assert_eq!(result.dimensions(), (target_width, target_height));
}
}