use oar_ocr_core::core::OCRError;
use oar_ocr_core::core::traits::adapter::ModelAdapter;
use oar_ocr_core::core::traits::task::ImageTaskInput;
use oar_ocr_core::domain::adapters::{DocumentOrientationAdapter, UVDocRectifierAdapter};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct OrientationCorrection {
pub angle: f32,
pub rotated_width: u32,
pub rotated_height: u32,
}
#[derive(Debug)]
pub(crate) struct PreprocessResult {
pub image: Arc<image::RgbImage>,
pub orientation_angle: Option<f32>,
pub rotation: Option<OrientationCorrection>,
pub rectified_img: Option<Arc<image::RgbImage>>,
}
#[derive(Debug, Clone)]
pub(crate) struct DocumentPreprocessor<'a> {
orientation_adapter: Option<&'a DocumentOrientationAdapter>,
rectification_adapter: Option<&'a UVDocRectifierAdapter>,
}
impl<'a> DocumentPreprocessor<'a> {
pub(crate) fn new(
orientation_adapter: Option<&'a DocumentOrientationAdapter>,
rectification_adapter: Option<&'a UVDocRectifierAdapter>,
) -> Self {
Self {
orientation_adapter,
rectification_adapter,
}
}
pub(crate) fn preprocess(
&self,
image: Arc<image::RgbImage>,
) -> Result<PreprocessResult, OCRError> {
let (mut current_image, orientation_angle, rotation) =
if let Some(orientation_adapter) = self.orientation_adapter {
let (rotated, rotation) =
correct_image_orientation(Arc::clone(&image), orientation_adapter)?;
(rotated, rotation.map(|r| r.angle), rotation)
} else {
(image, None, None)
};
let mut rectified_img: Option<Arc<image::RgbImage>> = None;
if let Some(rectification_adapter) = self.rectification_adapter {
let input = ImageTaskInput::new(vec![(*current_image).clone()]);
let rect_output = rectification_adapter.execute(input, None)?;
if let Some(rectified) = rect_output.rectified_images.first() {
current_image = Arc::new(rectified.clone());
rectified_img = Some(Arc::clone(¤t_image));
}
}
let rotation = if rectified_img.is_none() {
rotation
} else {
None
};
Ok(PreprocessResult {
image: current_image,
orientation_angle,
rotation,
rectified_img,
})
}
}
fn apply_orientation_from_class_id(
image: Arc<image::RgbImage>,
class_id: Option<usize>,
) -> (Arc<image::RgbImage>, Option<OrientationCorrection>) {
let Some(class_id) = class_id else {
return (image, None);
};
let angle = (class_id as f32) * 90.0;
let rotated = match class_id {
1 => Arc::new(image::imageops::rotate270(&*image)),
2 => Arc::new(image::imageops::rotate180(&*image)),
3 => Arc::new(image::imageops::rotate90(&*image)),
_ => image,
};
let correction = OrientationCorrection {
angle,
rotated_width: rotated.width(),
rotated_height: rotated.height(),
};
(rotated, Some(correction))
}
pub(crate) fn correct_image_orientation(
image: Arc<image::RgbImage>,
orientation_adapter: &DocumentOrientationAdapter,
) -> Result<(Arc<image::RgbImage>, Option<OrientationCorrection>), OCRError> {
let input = ImageTaskInput::new(vec![(*image).clone()]);
let output = orientation_adapter.execute(input, None)?;
let class_id = output
.classifications
.first()
.and_then(|c| c.first())
.map(|c| c.class_id);
Ok(apply_orientation_from_class_id(image, class_id))
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_image(width: u32, height: u32) -> image::RgbImage {
image::RgbImage::new(width, height)
}
#[test]
fn test_orientation_correction_creation() {
let correction = OrientationCorrection {
angle: 90.0,
rotated_width: 200,
rotated_height: 100,
};
assert_eq!(correction.angle, 90.0);
assert_eq!(correction.rotated_width, 200);
assert_eq!(correction.rotated_height, 100);
}
#[test]
fn test_orientation_correction_equality() {
let c1 = OrientationCorrection {
angle: 180.0,
rotated_width: 100,
rotated_height: 200,
};
let c2 = OrientationCorrection {
angle: 180.0,
rotated_width: 100,
rotated_height: 200,
};
let c3 = OrientationCorrection {
angle: 90.0,
rotated_width: 100,
rotated_height: 200,
};
assert_eq!(c1, c2);
assert_ne!(c1, c3);
}
#[test]
fn test_orientation_correction_copy() {
let c1 = OrientationCorrection {
angle: 270.0,
rotated_width: 150,
rotated_height: 300,
};
let c2 = c1; assert_eq!(c1, c2);
}
#[test]
fn test_preprocess_result_creation() {
let image = Arc::new(create_test_image(100, 200));
let result = PreprocessResult {
image: Arc::clone(&image),
orientation_angle: Some(90.0),
rotation: Some(OrientationCorrection {
angle: 90.0,
rotated_width: 200,
rotated_height: 100,
}),
rectified_img: None,
};
assert_eq!(result.image.width(), 100);
assert_eq!(result.image.height(), 200);
assert_eq!(result.orientation_angle, Some(90.0));
assert!(result.rotation.is_some());
assert!(result.rectified_img.is_none());
}
#[test]
fn test_preprocess_result_with_rectified_image() {
let rectified = Arc::new(create_test_image(120, 220));
let result = PreprocessResult {
image: Arc::clone(&rectified),
orientation_angle: None,
rotation: None, rectified_img: Some(rectified),
};
assert!(result.orientation_angle.is_none());
assert!(result.rotation.is_none());
assert!(result.rectified_img.is_some());
}
#[test]
fn test_document_preprocessor_no_adapters() -> Result<(), OCRError> {
let preprocessor = DocumentPreprocessor::new(None, None);
let image = Arc::new(create_test_image(100, 200));
let result = preprocessor.preprocess(Arc::clone(&image))?;
assert_eq!(result.image.width(), 100);
assert_eq!(result.image.height(), 200);
assert!(result.orientation_angle.is_none());
assert!(result.rotation.is_none());
assert!(result.rectified_img.is_none());
Ok(())
}
#[test]
fn test_document_preprocessor_clone() -> Result<(), OCRError> {
let preprocessor = DocumentPreprocessor::<'static>::new(None, None);
let cloned = preprocessor.clone();
let image = Arc::new(create_test_image(50, 50));
let r1 = preprocessor.preprocess(Arc::clone(&image))?;
let r2 = cloned.preprocess(Arc::clone(&image))?;
assert_eq!(r1.image.width(), r2.image.width());
assert_eq!(r1.image.height(), r2.image.height());
Ok(())
}
#[test]
fn test_apply_orientation_none_class_id_returns_original_image() {
let image = Arc::new(create_test_image(100, 200));
let (result, correction) = apply_orientation_from_class_id(Arc::clone(&image), None);
assert!(Arc::ptr_eq(&image, &result));
assert!(correction.is_none());
}
#[test]
fn test_apply_orientation_class_id_0_no_rotation() {
let image = Arc::new(create_test_image(100, 200));
let (result, correction) = apply_orientation_from_class_id(Arc::clone(&image), Some(0));
assert!(Arc::ptr_eq(&image, &result));
assert_eq!(result.width(), 100);
assert_eq!(result.height(), 200);
let Some(correction) = correction else {
panic!("correction should be Some");
};
assert_eq!(correction.angle, 0.0);
assert_eq!(correction.rotated_width, 100);
assert_eq!(correction.rotated_height, 200);
}
#[test]
fn test_apply_orientation_class_id_1_rotates_270_ccw() {
let image = Arc::new(create_test_image(100, 200));
let (result, correction) = apply_orientation_from_class_id(Arc::clone(&image), Some(1));
assert_eq!(result.width(), 200);
assert_eq!(result.height(), 100);
assert!(!Arc::ptr_eq(&image, &result));
let Some(correction) = correction else {
panic!("correction should be Some");
};
assert_eq!(correction.angle, 90.0);
assert_eq!(correction.rotated_width, 200);
assert_eq!(correction.rotated_height, 100);
}
#[test]
fn test_apply_orientation_class_id_2_rotates_180() {
let image = Arc::new(create_test_image(100, 200));
let (result, correction) = apply_orientation_from_class_id(Arc::clone(&image), Some(2));
assert_eq!(result.width(), 100);
assert_eq!(result.height(), 200);
assert!(!Arc::ptr_eq(&image, &result));
let Some(correction) = correction else {
panic!("correction should be Some");
};
assert_eq!(correction.angle, 180.0);
assert_eq!(correction.rotated_width, 100);
assert_eq!(correction.rotated_height, 200);
}
#[test]
fn test_apply_orientation_class_id_3_rotates_90_cw() {
let image = Arc::new(create_test_image(100, 200));
let (result, correction) = apply_orientation_from_class_id(Arc::clone(&image), Some(3));
assert_eq!(result.width(), 200);
assert_eq!(result.height(), 100);
assert!(!Arc::ptr_eq(&image, &result));
let Some(correction) = correction else {
panic!("correction should be Some");
};
assert_eq!(correction.angle, 270.0);
assert_eq!(correction.rotated_width, 200);
assert_eq!(correction.rotated_height, 100);
}
#[test]
fn test_apply_orientation_unknown_class_id_preserves_metadata() {
let image = Arc::new(create_test_image(100, 200));
let (result, correction) = apply_orientation_from_class_id(Arc::clone(&image), Some(99));
assert!(Arc::ptr_eq(&image, &result));
assert_eq!(result.width(), 100);
assert_eq!(result.height(), 200);
let Some(correction) = correction else {
panic!("correction should be Some");
};
assert_eq!(correction.angle, 8910.0); assert_eq!(correction.rotated_width, 100);
assert_eq!(correction.rotated_height, 200);
}
#[test]
fn test_apply_orientation_square_image_all_rotations() {
let image = Arc::new(create_test_image(150, 150));
for class_id in 0..4 {
let (result, correction) =
apply_orientation_from_class_id(Arc::clone(&image), Some(class_id));
assert_eq!(result.width(), 150);
assert_eq!(result.height(), 150);
let Some(correction) = correction else {
panic!("correction should be Some");
};
assert_eq!(correction.angle, (class_id as f32) * 90.0);
}
}
#[test]
fn test_angle_calculation_from_class_id() {
assert_eq!(0_f32 * 90.0, 0.0);
assert_eq!(1_f32 * 90.0, 90.0);
assert_eq!(2_f32 * 90.0, 180.0);
assert_eq!(3_f32 * 90.0, 270.0);
}
#[test]
fn test_arc_sharing_without_clone() {
let image = Arc::new(create_test_image(100, 200));
assert_eq!(Arc::strong_count(&image), 1);
let shared = Arc::clone(&image);
assert_eq!(Arc::strong_count(&image), 2);
assert_eq!(Arc::strong_count(&shared), 2);
assert!(Arc::ptr_eq(&image, &shared));
}
#[test]
fn test_preprocess_result_invariant_rotation_none_when_rectified() {
let _image = Arc::new(create_test_image(100, 200));
let rectified = Arc::new(create_test_image(110, 210));
let result = PreprocessResult {
image: Arc::clone(&rectified),
orientation_angle: Some(90.0), rotation: None, rectified_img: Some(rectified),
};
assert!(result.rotation.is_none());
assert!(result.rectified_img.is_some());
}
#[test]
fn test_preprocess_result_rotation_preserved_without_rectification() {
let (rotated_image, correction) =
apply_orientation_from_class_id(Arc::new(create_test_image(100, 200)), Some(1));
let result = PreprocessResult {
image: rotated_image,
orientation_angle: Some(90.0),
rotation: correction,
rectified_img: None,
};
assert!(result.rotation.is_some());
assert!(result.rectified_img.is_none());
let Some(rotation) = result.rotation.as_ref() else {
panic!("expected rotation metadata to be Some");
};
assert_eq!(rotation.angle, 90.0);
}
fn extract_class_id_from_classifications(
classifications: &[Vec<(usize, String, f32)>],
) -> Option<usize> {
classifications
.first()
.and_then(|c| c.first())
.map(|(class_id, _, _)| *class_id)
}
#[test]
fn test_classification_extraction_with_valid_result() {
let classifications = vec![vec![(1_usize, "90".to_string(), 0.95_f32)]];
let class_id = extract_class_id_from_classifications(&classifications);
assert_eq!(class_id, Some(1));
}
#[test]
fn test_classification_extraction_with_multiple_topk() {
let classifications = vec![vec![
(2_usize, "180".to_string(), 0.85_f32),
(0_usize, "0".to_string(), 0.10_f32),
(1_usize, "90".to_string(), 0.05_f32),
]];
let class_id = extract_class_id_from_classifications(&classifications);
assert_eq!(class_id, Some(2)); }
#[test]
fn test_classification_extraction_with_empty_inner_vec() {
let classifications: Vec<Vec<(usize, String, f32)>> = vec![vec![]];
let class_id = extract_class_id_from_classifications(&classifications);
assert_eq!(class_id, None);
}
#[test]
fn test_classification_extraction_with_empty_outer_vec() {
let classifications: Vec<Vec<(usize, String, f32)>> = vec![];
let class_id = extract_class_id_from_classifications(&classifications);
assert_eq!(class_id, None);
}
#[test]
fn test_classification_extraction_for_batch_input() {
let classifications = vec![
vec![(1_usize, "90".to_string(), 0.95_f32)],
vec![(2_usize, "180".to_string(), 0.90_f32)],
];
let class_id = extract_class_id_from_classifications(&classifications);
assert_eq!(class_id, Some(1)); }
#[test]
fn test_orientation_correction_flow_with_90_degree_detection() {
let image = Arc::new(create_test_image(100, 200));
let class_id = Some(1_usize); let (result, correction) = apply_orientation_from_class_id(Arc::clone(&image), class_id);
assert_eq!(result.width(), 200); assert_eq!(result.height(), 100);
let Some(correction) = correction else {
panic!("expected correction metadata to be Some");
};
assert_eq!(correction.angle, 90.0);
assert_eq!(correction.rotated_width, 200);
assert_eq!(correction.rotated_height, 100);
}
#[test]
fn test_orientation_correction_flow_no_classification_available() {
let image = Arc::new(create_test_image(100, 200));
let class_id = None; let (result, correction) = apply_orientation_from_class_id(Arc::clone(&image), class_id);
assert!(Arc::ptr_eq(&image, &result));
assert!(correction.is_none());
}
#[test]
fn test_preprocessor_builds_correct_result_structure() -> Result<(), OCRError> {
let preprocessor = DocumentPreprocessor::new(None, None);
let image = Arc::new(create_test_image(100, 200));
let result = preprocessor.preprocess(Arc::clone(&image))?;
assert_eq!(result.image.width(), 100);
assert_eq!(result.image.height(), 200);
assert!(result.orientation_angle.is_none());
assert!(result.rotation.is_none());
assert!(result.rectified_img.is_none());
Ok(())
}
}