oar-ocr-core 0.6.3

Core types and predictors for oar-ocr
Documentation
//! Text Detection Task Adapter
//!
//! This adapter uses the DB model and adapts its output to the TextDetection task format.

use crate::apply_ort_config;
use crate::core::OCRError;
use crate::core::traits::{
    adapter::{AdapterInfo, ModelAdapter},
    task::Task,
};
use crate::domain::tasks::{
    Detection, TextDetectionConfig, TextDetectionOutput, TextDetectionTask,
};
use crate::impl_adapter_builder;
use crate::models::detection::db::{DBModel, DBModelBuilder, DBPostprocessConfig};
use crate::processors::{BoxType, ScoreMode};

/// Text detection adapter that uses the DB model.
#[derive(Debug)]
pub struct TextDetectionAdapter {
    /// The underlying DB model
    model: DBModel,
    /// Adapter information
    info: AdapterInfo,
    /// Task configuration
    config: TextDetectionConfig,
}

impl ModelAdapter for TextDetectionAdapter {
    type Task = TextDetectionTask;

    fn info(&self) -> AdapterInfo {
        self.info.clone()
    }

    fn execute(
        &self,
        input: <Self::Task as Task>::Input,
        config: Option<&<Self::Task as Task>::Config>,
    ) -> Result<<Self::Task as Task>::Output, OCRError> {
        let effective_config = config.unwrap_or(&self.config);

        // Use the DB model to detect text with error context
        let model_output = self.model
            .forward(
                input.images,
                effective_config.score_threshold,
                effective_config.box_threshold,
                effective_config.unclip_ratio,
            )
            .map_err(|e| {
                OCRError::adapter_execution_error(
                    "TextDetectionAdapter",
                    format!(
                        "failed to detect text (score_threshold={}, box_threshold={}, unclip_ratio={})",
                        effective_config.score_threshold,
                        effective_config.box_threshold,
                        effective_config.unclip_ratio
                    ),
                    e,
                )
            })?;

        // Convert model output to structured detections
        let detections = model_output
            .boxes
            .into_iter()
            .zip(model_output.scores)
            .map(|(boxes, scores)| {
                boxes
                    .into_iter()
                    .zip(scores)
                    .map(|(bbox, score)| Detection::new(bbox, score))
                    .collect()
            })
            .collect();

        Ok(TextDetectionOutput { detections })
    }

    fn supports_batching(&self) -> bool {
        true
    }

    fn recommended_batch_size(&self) -> usize {
        8
    }
}

impl_adapter_builder! {
    builder_name: TextDetectionAdapterBuilder,
    adapter_name: TextDetectionAdapter,
    config_type: TextDetectionConfig,
    adapter_type: "text_detection",
    adapter_desc: "Detects text regions in images with bounding boxes",
    task_type: TextDetection,

    fields: {
        text_type: Option<String> = None,
    },

    methods: {
        /// Sets the text type for preprocessing and postprocessing configuration.
        ///
        /// This matches the text_type parameter:
        /// - "seal": Uses seal-specific preprocessing (limit_side_len=736, limit_type=Min) and polygon boxes
        /// - Other values or None: Uses general text configuration (limit_side_len=960, limit_type=Max) and quad boxes
        pub fn text_type(mut self, text_type: impl Into<String>) -> Self {
            self.text_type = Some(text_type.into());
            self
        }
    }

    build: |builder: TextDetectionAdapterBuilder, model_path: &std::path::Path| -> Result<TextDetectionAdapter, OCRError> {
        let (task_config, ort_config) = builder.config
            .into_validated_parts()
            .map_err(|err| OCRError::ConfigError {
                message: err.to_string(),
            })?;

        // Determine if this is seal text (uses different preprocessing and box type)
        let is_seal_text = builder
            .text_type
            .as_ref()
            .map(|t| t.to_lowercase() == "seal")
            .unwrap_or(false);

        // Configure DB model preprocessing based on text type
        // Matches standard behavior:
        // - General text: limit_side_len=960, limit_type=Max
        // - Seal text: limit_side_len=736, limit_type=Min
        let mut preprocess_config =
            super::preprocessing::db_preprocess_for_text_type(builder.text_type.as_deref());

        // Override with config values if present
        if let Some(limit) = task_config.limit_side_len {
            preprocess_config.limit_side_len = Some(limit);
        }
        if let Some(limit_type) = task_config.limit_type.clone() {
            preprocess_config.limit_type = Some(limit_type);
        }
        if let Some(max_limit) = task_config.max_side_len {
            preprocess_config.max_side_limit = Some(max_limit);
        }

        // Configure postprocessing based on text type
        // Seal text uses polygon boxes for curved text, general text uses quad boxes
        let box_type = if is_seal_text {
            BoxType::Poly
        } else {
            BoxType::Quad
        };

        let postprocess_config = DBPostprocessConfig {
            score_threshold: task_config.score_threshold,
            box_threshold: task_config.box_threshold,
            unclip_ratio: task_config.unclip_ratio,
            max_candidates: task_config.max_candidates,
            use_dilation: false,
            score_mode: ScoreMode::Fast,
            box_type,
        };

        // Build the DB model
        let model = apply_ort_config!(
            DBModelBuilder::new()
                .preprocess_config(preprocess_config)
                .postprocess_config(postprocess_config),
            ort_config
        )
        .build(model_path)?;

        // Create adapter info using the helper
        let info = TextDetectionAdapterBuilder::base_adapter_info();

        Ok(TextDetectionAdapter {
            model,
            info,
            config: task_config,
        })
    },
}