Skip to main content

oar_ocr_core/predictors/
layout_detection.rs

1//! Layout Detection Predictor
2//!
3//! This module provides a high-level API for document layout detection.
4
5use super::builder::PredictorBuilderState;
6use crate::TaskPredictorBuilder;
7use crate::core::OcrResult;
8use crate::core::errors::OCRError;
9use crate::core::traits::OrtConfigurable;
10use crate::core::traits::adapter::AdapterBuilder;
11use crate::core::traits::task::ImageTaskInput;
12use crate::domain::adapters::LayoutDetectionAdapterBuilder;
13use crate::domain::tasks::layout_detection::{LayoutDetectionConfig, LayoutDetectionTask};
14use crate::predictors::TaskPredictorCore;
15use image::RgbImage;
16use std::path::Path;
17
18/// Layout detection prediction result
19#[derive(Debug, Clone)]
20pub struct LayoutDetectionResult {
21    /// Detected layout elements for each input image
22    pub elements: Vec<Vec<crate::domain::tasks::layout_detection::LayoutDetectionElement>>,
23    /// Whether elements are already sorted by reading order (e.g., from PP-DocLayoutV2)
24    ///
25    /// When `true`, downstream consumers can skip reading order sorting algorithms
26    /// as the elements are already in the correct reading order based on model output.
27    pub is_reading_order_sorted: bool,
28}
29
30/// Layout detection predictor
31pub struct LayoutDetectionPredictor {
32    core: TaskPredictorCore<LayoutDetectionTask>,
33}
34
35impl LayoutDetectionPredictor {
36    pub fn builder() -> LayoutDetectionPredictorBuilder {
37        LayoutDetectionPredictorBuilder::new()
38    }
39
40    /// Predict layout elements in the given images.
41    pub fn predict(&self, images: Vec<RgbImage>) -> OcrResult<LayoutDetectionResult> {
42        let input = ImageTaskInput::new(images);
43        let output = self.core.predict(input)?;
44        Ok(LayoutDetectionResult {
45            elements: output.elements,
46            is_reading_order_sorted: output.is_reading_order_sorted,
47        })
48    }
49}
50
51#[derive(TaskPredictorBuilder)]
52#[builder(config = LayoutDetectionConfig)]
53pub struct LayoutDetectionPredictorBuilder {
54    state: PredictorBuilderState<LayoutDetectionConfig>,
55    model_name: Option<String>,
56}
57
58impl LayoutDetectionPredictorBuilder {
59    pub fn new() -> Self {
60        Self {
61            state: PredictorBuilderState::new(LayoutDetectionConfig::default()),
62            model_name: None,
63        }
64    }
65
66    /// Creates a builder with PP-StructureV3 default class thresholds.
67    pub fn with_pp_structurev3_thresholds() -> Self {
68        Self {
69            state: PredictorBuilderState::new(
70                LayoutDetectionConfig::with_pp_structurev3_thresholds(),
71            ),
72            model_name: None,
73        }
74    }
75
76    pub fn model_name(mut self, name: impl Into<String>) -> Self {
77        self.model_name = Some(name.into());
78        self
79    }
80
81    pub fn score_threshold(mut self, threshold: f32) -> Self {
82        self.state.config_mut().score_threshold = threshold;
83        self
84    }
85
86    pub fn build<P: AsRef<Path>>(self, model_path: P) -> OcrResult<LayoutDetectionPredictor> {
87        let (config, ort_config) = self.state.into_parts();
88        let mut adapter_builder = LayoutDetectionAdapterBuilder::new().task_config(config.clone());
89
90        // Set model configuration if model_name was provided
91        if let Some(model_name) = self.model_name {
92            let model_config = Self::get_model_config(&model_name)?;
93            adapter_builder = adapter_builder.model_config(model_config);
94        }
95
96        if let Some(ort_cfg) = ort_config {
97            adapter_builder = adapter_builder.with_ort_config(ort_cfg);
98        }
99
100        let adapter = Box::new(adapter_builder.build(model_path.as_ref())?);
101        let task = LayoutDetectionTask::new(config.clone());
102        Ok(LayoutDetectionPredictor {
103            core: TaskPredictorCore::new(adapter, task, config),
104        })
105    }
106
107    /// Supported layout model names
108    const SUPPORTED_MODELS: &'static [&'static str] = &[
109        "picodet_layout_1x",
110        "picodet_layout_1x_table",
111        "picodet_s_layout_3cls",
112        "picodet_l_layout_3cls",
113        "picodet_s_layout_17cls",
114        "picodet_l_layout_17cls",
115        "rtdetr_h_layout_3cls",
116        "rt_detr_h_layout_3cls",
117        "rtdetr_h_layout_17cls",
118        "rt_detr_h_layout_17cls",
119        "pp_docblocklayout",
120        "pp_doclayout_s",
121        "pp_doclayout_m",
122        "pp_doclayout_l",
123        "pp_doclayout_plus_l",
124        "pp_doclayoutv2",
125        "pp_doclayout_v2",
126        "pp_doclayoutv3",
127        "pp_doclayout_v3",
128    ];
129
130    fn get_model_config(model_name: &str) -> OcrResult<crate::domain::adapters::LayoutModelConfig> {
131        use crate::domain::adapters::LayoutModelConfig;
132
133        let normalized = model_name.to_lowercase().replace('-', "_");
134        let config = match normalized.as_str() {
135            "picodet_layout_1x" => LayoutModelConfig::picodet_layout_1x(),
136            "picodet_layout_1x_table" => LayoutModelConfig::picodet_layout_1x_table(),
137            "picodet_s_layout_3cls" => LayoutModelConfig::picodet_s_layout_3cls(),
138            "picodet_l_layout_3cls" => LayoutModelConfig::picodet_l_layout_3cls(),
139            "picodet_s_layout_17cls" => LayoutModelConfig::picodet_s_layout_17cls(),
140            "picodet_l_layout_17cls" => LayoutModelConfig::picodet_l_layout_17cls(),
141            "rtdetr_h_layout_3cls" | "rt_detr_h_layout_3cls" => {
142                LayoutModelConfig::rtdetr_h_layout_3cls()
143            }
144            "rtdetr_h_layout_17cls" | "rt_detr_h_layout_17cls" => {
145                LayoutModelConfig::rtdetr_h_layout_17cls()
146            }
147            "pp_docblocklayout" => LayoutModelConfig::pp_docblocklayout(),
148            "pp_doclayout_s" => LayoutModelConfig::pp_doclayout_s(),
149            "pp_doclayout_m" => LayoutModelConfig::pp_doclayout_m(),
150            "pp_doclayout_l" => LayoutModelConfig::pp_doclayout_l(),
151            "pp_doclayout_plus_l" => LayoutModelConfig::pp_doclayout_plus_l(),
152            "pp_doclayoutv2" | "pp_doclayout_v2" => LayoutModelConfig::pp_doclayoutv2(),
153            "pp_doclayoutv3" | "pp_doclayout_v3" => LayoutModelConfig::pp_doclayoutv3(),
154            _ => {
155                return Err(OCRError::ConfigError {
156                    message: format!(
157                        "Unknown model name: '{}'. Supported models: {}",
158                        model_name,
159                        Self::SUPPORTED_MODELS.join(", ")
160                    ),
161                });
162            }
163        };
164
165        Ok(config)
166    }
167}
168
169impl Default for LayoutDetectionPredictorBuilder {
170    fn default() -> Self {
171        Self::new()
172    }
173}