oar_ocr_core/predictors/
layout_detection.rs1use super::builder::PredictorBuilderState;
6use crate::TaskPredictorBuilder;
7use crate::core::OcrResult;
8use crate::core::errors::OCRError;
9use crate::core::traits::OrtConfigurable;
10use crate::core::traits::adapter::AdapterBuilder;
11use crate::core::traits::task::ImageTaskInput;
12use crate::domain::adapters::LayoutDetectionAdapterBuilder;
13use crate::domain::tasks::layout_detection::{LayoutDetectionConfig, LayoutDetectionTask};
14use crate::predictors::TaskPredictorCore;
15use image::RgbImage;
16use std::path::Path;
17
18#[derive(Debug, Clone)]
20pub struct LayoutDetectionResult {
21 pub elements: Vec<Vec<crate::domain::tasks::layout_detection::LayoutDetectionElement>>,
23 pub is_reading_order_sorted: bool,
28}
29
30pub struct LayoutDetectionPredictor {
32 core: TaskPredictorCore<LayoutDetectionTask>,
33}
34
35impl LayoutDetectionPredictor {
36 pub fn builder() -> LayoutDetectionPredictorBuilder {
37 LayoutDetectionPredictorBuilder::new()
38 }
39
40 pub fn predict(&self, images: Vec<RgbImage>) -> OcrResult<LayoutDetectionResult> {
42 let input = ImageTaskInput::new(images);
43 let output = self.core.predict(input)?;
44 Ok(LayoutDetectionResult {
45 elements: output.elements,
46 is_reading_order_sorted: output.is_reading_order_sorted,
47 })
48 }
49}
50
51#[derive(TaskPredictorBuilder)]
52#[builder(config = LayoutDetectionConfig)]
53pub struct LayoutDetectionPredictorBuilder {
54 state: PredictorBuilderState<LayoutDetectionConfig>,
55 model_name: Option<String>,
56}
57
58impl LayoutDetectionPredictorBuilder {
59 pub fn new() -> Self {
60 Self {
61 state: PredictorBuilderState::new(LayoutDetectionConfig::default()),
62 model_name: None,
63 }
64 }
65
66 pub fn with_pp_structurev3_thresholds() -> Self {
68 Self {
69 state: PredictorBuilderState::new(
70 LayoutDetectionConfig::with_pp_structurev3_thresholds(),
71 ),
72 model_name: None,
73 }
74 }
75
76 pub fn model_name(mut self, name: impl Into<String>) -> Self {
77 self.model_name = Some(name.into());
78 self
79 }
80
81 pub fn score_threshold(mut self, threshold: f32) -> Self {
82 self.state.config_mut().score_threshold = threshold;
83 self
84 }
85
86 pub fn build<P: AsRef<Path>>(self, model_path: P) -> OcrResult<LayoutDetectionPredictor> {
87 let (config, ort_config) = self.state.into_parts();
88 let mut adapter_builder = LayoutDetectionAdapterBuilder::new().task_config(config.clone());
89
90 if let Some(model_name) = self.model_name {
92 let model_config = Self::get_model_config(&model_name)?;
93 adapter_builder = adapter_builder.model_config(model_config);
94 }
95
96 if let Some(ort_cfg) = ort_config {
97 adapter_builder = adapter_builder.with_ort_config(ort_cfg);
98 }
99
100 let adapter = Box::new(adapter_builder.build(model_path.as_ref())?);
101 let task = LayoutDetectionTask::new(config.clone());
102 Ok(LayoutDetectionPredictor {
103 core: TaskPredictorCore::new(adapter, task, config),
104 })
105 }
106
107 const SUPPORTED_MODELS: &'static [&'static str] = &[
109 "picodet_layout_1x",
110 "picodet_layout_1x_table",
111 "picodet_s_layout_3cls",
112 "picodet_l_layout_3cls",
113 "picodet_s_layout_17cls",
114 "picodet_l_layout_17cls",
115 "rtdetr_h_layout_3cls",
116 "rt_detr_h_layout_3cls",
117 "rtdetr_h_layout_17cls",
118 "rt_detr_h_layout_17cls",
119 "pp_docblocklayout",
120 "pp_doclayout_s",
121 "pp_doclayout_m",
122 "pp_doclayout_l",
123 "pp_doclayout_plus_l",
124 "pp_doclayoutv2",
125 "pp_doclayout_v2",
126 "pp_doclayoutv3",
127 "pp_doclayout_v3",
128 ];
129
130 fn get_model_config(model_name: &str) -> OcrResult<crate::domain::adapters::LayoutModelConfig> {
131 use crate::domain::adapters::LayoutModelConfig;
132
133 let normalized = model_name.to_lowercase().replace('-', "_");
134 let config = match normalized.as_str() {
135 "picodet_layout_1x" => LayoutModelConfig::picodet_layout_1x(),
136 "picodet_layout_1x_table" => LayoutModelConfig::picodet_layout_1x_table(),
137 "picodet_s_layout_3cls" => LayoutModelConfig::picodet_s_layout_3cls(),
138 "picodet_l_layout_3cls" => LayoutModelConfig::picodet_l_layout_3cls(),
139 "picodet_s_layout_17cls" => LayoutModelConfig::picodet_s_layout_17cls(),
140 "picodet_l_layout_17cls" => LayoutModelConfig::picodet_l_layout_17cls(),
141 "rtdetr_h_layout_3cls" | "rt_detr_h_layout_3cls" => {
142 LayoutModelConfig::rtdetr_h_layout_3cls()
143 }
144 "rtdetr_h_layout_17cls" | "rt_detr_h_layout_17cls" => {
145 LayoutModelConfig::rtdetr_h_layout_17cls()
146 }
147 "pp_docblocklayout" => LayoutModelConfig::pp_docblocklayout(),
148 "pp_doclayout_s" => LayoutModelConfig::pp_doclayout_s(),
149 "pp_doclayout_m" => LayoutModelConfig::pp_doclayout_m(),
150 "pp_doclayout_l" => LayoutModelConfig::pp_doclayout_l(),
151 "pp_doclayout_plus_l" => LayoutModelConfig::pp_doclayout_plus_l(),
152 "pp_doclayoutv2" | "pp_doclayout_v2" => LayoutModelConfig::pp_doclayoutv2(),
153 "pp_doclayoutv3" | "pp_doclayout_v3" => LayoutModelConfig::pp_doclayoutv3(),
154 _ => {
155 return Err(OCRError::ConfigError {
156 message: format!(
157 "Unknown model name: '{}'. Supported models: {}",
158 model_name,
159 Self::SUPPORTED_MODELS.join(", ")
160 ),
161 });
162 }
163 };
164
165 Ok(config)
166 }
167}
168
169impl Default for LayoutDetectionPredictorBuilder {
170 fn default() -> Self {
171 Self::new()
172 }
173}