oar-ocr-core 0.6.3

Core types and predictors for oar-ocr
Documentation
//! Document Orientation Classification Adapter
//!
//! This adapter uses the PP-LCNet model to classify document image orientation.

use crate::apply_ort_config;
use crate::core::OCRError;
use crate::core::traits::{
    adapter::{AdapterInfo, ModelAdapter},
    task::Task,
};
use crate::domain::tasks::{
    Classification, DocumentOrientationConfig, DocumentOrientationOutput, DocumentOrientationTask,
};
use crate::impl_adapter_builder;
use crate::models::classification::{PPLCNetModel, PPLCNetModelBuilder, PPLCNetPostprocessConfig};

/// Document orientation classification adapter that uses the PP-LCNet model.
#[derive(Debug)]
pub struct DocumentOrientationAdapter {
    /// The underlying PP-LCNet model
    model: PPLCNetModel,
    /// Adapter information
    info: AdapterInfo,
    /// Task configuration
    config: DocumentOrientationConfig,
    /// Postprocessing configuration
    postprocess_config: PPLCNetPostprocessConfig,
}

impl DocumentOrientationAdapter {
    /// Creates a new document orientation adapter.
    pub fn new(
        model: PPLCNetModel,
        info: AdapterInfo,
        config: DocumentOrientationConfig,
        postprocess_config: PPLCNetPostprocessConfig,
    ) -> Self {
        Self {
            model,
            info,
            config,
            postprocess_config,
        }
    }

    /// Default input shape for document orientation classification.
    /// PP-LCNet document orientation models are trained with 224x224 inputs.
    pub const DEFAULT_INPUT_SHAPE: (u32, u32) = (224, 224);

    /// Class labels for document orientation.
    pub fn labels() -> Vec<String> {
        vec![
            "0".to_string(),
            "90".to_string(),
            "180".to_string(),
            "270".to_string(),
        ]
    }
}

impl ModelAdapter for DocumentOrientationAdapter {
    type Task = DocumentOrientationTask;

    fn info(&self) -> AdapterInfo {
        self.info.clone()
    }

    fn execute(
        &self,
        input: <Self::Task as Task>::Input,
        config: Option<&<Self::Task as Task>::Config>,
    ) -> Result<<Self::Task as Task>::Output, OCRError> {
        let effective_config = config.unwrap_or(&self.config);

        // Update postprocess config with task-specific topk
        let mut postprocess_config = self.postprocess_config.clone();
        postprocess_config.topk = effective_config.topk;

        // Use model to get predictions with error context
        let model_output = self
            .model
            .forward(input.images, &postprocess_config)
            .map_err(|e| {
                OCRError::adapter_execution_error(
                    "DocumentOrientationAdapter",
                    format!(
                        "failed to classify document orientation (topk={})",
                        effective_config.topk
                    ),
                    e,
                )
            })?;

        // Convert model output to task-specific output with structured classifications
        let label_names = model_output.label_names.unwrap_or_else(|| {
            model_output
                .class_ids
                .iter()
                .map(|ids| ids.iter().map(|&id| format!("{}", id * 90)).collect())
                .collect()
        });

        // Create structured classifications
        let classifications = model_output
            .class_ids
            .into_iter()
            .zip(model_output.scores)
            .zip(label_names)
            .map(|((class_ids, scores), labels)| {
                class_ids
                    .into_iter()
                    .zip(scores)
                    .zip(labels)
                    .map(|((class_id, score), label)| Classification::new(class_id, label, score))
                    .collect()
            })
            .collect();

        Ok(DocumentOrientationOutput { classifications })
    }

    fn supports_batching(&self) -> bool {
        true
    }

    fn recommended_batch_size(&self) -> usize {
        32
    }
}

impl_adapter_builder! {
    builder_name: DocumentOrientationAdapterBuilder,
    adapter_name: DocumentOrientationAdapter,
    config_type: DocumentOrientationConfig,
    adapter_type: "document_orientation",
    adapter_desc: "Classifies document image orientation (0°, 90°, 180°, 270°)",
    task_type: DocumentOrientation,

    fields: {
        input_shape: (u32, u32) = DocumentOrientationAdapter::DEFAULT_INPUT_SHAPE,
        model_name_override: Option<String> = None,
    },

    methods: {
        pub fn input_shape(mut self, input_shape: (u32, u32)) -> Self {
            self.input_shape = input_shape;
            self
        }

        pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
            self.model_name_override = Some(model_name.into());
            self
        }
    }

    build: |builder: DocumentOrientationAdapterBuilder, model_path: &std::path::Path| -> Result<DocumentOrientationAdapter, OCRError> {
        let (task_config, ort_config) = builder.config
            .into_validated_parts()
            .map_err(|err| OCRError::ConfigError {
                message: err.to_string(),
            })?;

        // Build the PP-LCNet model
        let preprocess_config = super::preprocessing::pp_lcnet_preprocess(builder.input_shape);

        let model = apply_ort_config!(
            PPLCNetModelBuilder::new().preprocess_config(preprocess_config),
            ort_config
        )
        .build(model_path)?;

        // Create postprocessing configuration
        let postprocess_config = PPLCNetPostprocessConfig {
            labels: DocumentOrientationAdapter::labels(),
            topk: 1, // Will be overridden by task config
        };

        // Create adapter info using the helper
        let mut info = DocumentOrientationAdapterBuilder::base_adapter_info();
        if let Some(model_name) = builder.model_name_override {
            info.model_name = model_name;
        }

        Ok(DocumentOrientationAdapter::new(
            model,
            info,
            task_config,
            postprocess_config,
        ))
    },
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::core::traits::adapter::AdapterBuilder;

    #[test]
    fn test_builder_creation() {
        let builder = DocumentOrientationAdapterBuilder::new();
        assert_eq!(builder.adapter_type(), "document_orientation");
    }

    #[test]
    fn test_builder_with_config() {
        let config = DocumentOrientationConfig {
            score_threshold: 0.7,
            topk: 2,
        };

        let builder = DocumentOrientationAdapterBuilder::new().with_config(config.clone());
        assert_eq!(builder.config.task_config().topk, 2);
        assert_eq!(builder.config.task_config().score_threshold, 0.7);
    }

    #[test]
    fn test_builder_fluent_api() {
        let builder = DocumentOrientationAdapterBuilder::new().input_shape((224, 224));

        assert_eq!(builder.input_shape, (224, 224));
    }

    #[test]
    fn test_default_builder() {
        let builder = DocumentOrientationAdapterBuilder::default();
        assert_eq!(builder.adapter_type(), "document_orientation");
        assert_eq!(
            builder.input_shape,
            DocumentOrientationAdapter::DEFAULT_INPUT_SHAPE
        );
    }

    #[test]
    fn test_labels() {
        let labels = DocumentOrientationAdapter::labels();
        assert_eq!(labels, vec!["0", "90", "180", "270"]);
    }
}