oar_ocr_core/predictors/
document_rectification.rs

1//! Document Rectification Predictor
2//!
3//! This module provides a high-level API for document rectification (dewarp).
4
5use super::builder::PredictorBuilderState;
6use crate::TaskPredictorBuilder;
7use crate::core::OcrResult;
8use crate::core::traits::OrtConfigurable;
9use crate::core::traits::adapter::AdapterBuilder;
10use crate::core::traits::task::ImageTaskInput;
11use crate::domain::adapters::UVDocRectifierAdapterBuilder;
12use crate::domain::tasks::document_rectification::{
13    DocumentRectificationConfig, DocumentRectificationTask,
14};
15use crate::predictors::TaskPredictorCore;
16use image::RgbImage;
17use std::path::Path;
18
19/// Document rectification prediction result
20#[derive(Debug, Clone)]
21pub struct DocumentRectificationResult {
22    /// Rectified images
23    pub images: Vec<RgbImage>,
24}
25
26/// Document rectification predictor
27pub struct DocumentRectificationPredictor {
28    core: TaskPredictorCore<DocumentRectificationTask>,
29}
30
31impl DocumentRectificationPredictor {
32    pub fn builder() -> DocumentRectificationPredictorBuilder {
33        DocumentRectificationPredictorBuilder::new()
34    }
35
36    /// Predict document rectification for the given images.
37    pub fn predict(&self, images: Vec<RgbImage>) -> OcrResult<DocumentRectificationResult> {
38        let input = ImageTaskInput::new(images);
39        let output = self.core.predict(input)?;
40        Ok(DocumentRectificationResult {
41            images: output.rectified_images,
42        })
43    }
44}
45
46#[derive(TaskPredictorBuilder)]
47#[builder(config = DocumentRectificationConfig)]
48pub struct DocumentRectificationPredictorBuilder {
49    state: PredictorBuilderState<DocumentRectificationConfig>,
50}
51
52impl DocumentRectificationPredictorBuilder {
53    pub fn new() -> Self {
54        Self {
55            state: PredictorBuilderState::new(DocumentRectificationConfig {
56                rec_image_shape: [3, 0, 0],
57            }),
58        }
59    }
60
61    pub fn build<P: AsRef<Path>>(self, model_path: P) -> OcrResult<DocumentRectificationPredictor> {
62        let (config, ort_config) = self.state.into_parts();
63        let mut adapter_builder = UVDocRectifierAdapterBuilder::new().with_config(config.clone());
64
65        if let Some(ort_cfg) = ort_config {
66            adapter_builder = adapter_builder.with_ort_config(ort_cfg);
67        }
68
69        let adapter = Box::new(adapter_builder.build(model_path.as_ref())?);
70        let task = DocumentRectificationTask::new(config.clone());
71        Ok(DocumentRectificationPredictor {
72            core: TaskPredictorCore::new(adapter, task, config),
73        })
74    }
75}
76
77impl Default for DocumentRectificationPredictorBuilder {
78    fn default() -> Self {
79        Self::new()
80    }
81}