1use super::validation::ensure_non_empty_images;
6use crate::ConfigValidator;
7use crate::core::OCRError;
8use crate::core::traits::TaskDefinition;
9use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
10use crate::processors::BoundingBox;
11use crate::utils::{ScoreValidator, validate_max_value};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17pub enum MergeBboxMode {
18 #[default]
20 Large,
21 Union,
23 Small,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum UnclipRatio {
31 Uniform(f32),
33 Separate(f32, f32),
35 PerClass(HashMap<usize, (f32, f32)>),
37}
38
39impl Default for UnclipRatio {
40 fn default() -> Self {
41 UnclipRatio::Separate(1.0, 1.0)
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
47pub struct LayoutDetectionConfig {
48 #[validate(range(0.0, 1.0))]
50 pub score_threshold: f32,
51 #[validate(min(1))]
53 pub max_elements: usize,
54 #[serde(default)]
62 pub class_thresholds: Option<HashMap<String, f32>>,
63 #[serde(default)]
65 pub class_merge_modes: Option<HashMap<String, MergeBboxMode>>,
66 #[serde(default = "default_layout_nms")]
68 pub layout_nms: bool,
69 #[serde(default = "default_nms_threshold")]
71 pub nms_threshold: f32,
72 #[serde(default)]
75 pub layout_unclip_ratio: Option<UnclipRatio>,
76}
77
78fn default_layout_nms() -> bool {
79 true
80}
81
82fn default_nms_threshold() -> f32 {
83 0.5
84}
85
86impl Default for LayoutDetectionConfig {
87 fn default() -> Self {
88 Self {
89 score_threshold: 0.5,
90 max_elements: 100,
91 class_thresholds: None,
92 class_merge_modes: None,
93 layout_nms: true,
94 nms_threshold: 0.5,
95 layout_unclip_ratio: None,
96 }
97 }
98}
99
100impl LayoutDetectionConfig {
101 pub fn with_pp_structurev3_thresholds() -> Self {
110 let mut class_thresholds = HashMap::new();
111 class_thresholds.insert("paragraph_title".to_string(), 0.3);
112 class_thresholds.insert("formula".to_string(), 0.3);
113 class_thresholds.insert("text".to_string(), 0.4);
114 class_thresholds.insert("seal".to_string(), 0.45);
115
116 Self {
117 score_threshold: 0.3,
120 max_elements: 100,
121 class_thresholds: Some(class_thresholds),
122 class_merge_modes: None,
123 layout_nms: true,
124 nms_threshold: 0.5,
125 layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
126 }
127 }
128
129 pub fn with_pp_doclayoutv2_defaults() -> Self {
139 let mut class_thresholds = HashMap::new();
140 class_thresholds.insert("abstract".to_string(), 0.5);
141 class_thresholds.insert("algorithm".to_string(), 0.5);
142 class_thresholds.insert("aside_text".to_string(), 0.5);
143 class_thresholds.insert("chart".to_string(), 0.5);
144 class_thresholds.insert("content".to_string(), 0.5);
145 class_thresholds.insert("display_formula".to_string(), 0.4);
146 class_thresholds.insert("doc_title".to_string(), 0.4);
147 class_thresholds.insert("figure_title".to_string(), 0.5);
148 class_thresholds.insert("footer".to_string(), 0.5);
149 class_thresholds.insert("footer_image".to_string(), 0.5);
150 class_thresholds.insert("footnote".to_string(), 0.5);
151 class_thresholds.insert("formula_number".to_string(), 0.5);
152 class_thresholds.insert("header".to_string(), 0.5);
153 class_thresholds.insert("header_image".to_string(), 0.5);
154 class_thresholds.insert("image".to_string(), 0.5);
155 class_thresholds.insert("inline_formula".to_string(), 0.4);
156 class_thresholds.insert("number".to_string(), 0.5);
157 class_thresholds.insert("paragraph_title".to_string(), 0.4);
158 class_thresholds.insert("reference".to_string(), 0.5);
159 class_thresholds.insert("reference_content".to_string(), 0.5);
160 class_thresholds.insert("seal".to_string(), 0.45);
161 class_thresholds.insert("table".to_string(), 0.5);
162 class_thresholds.insert("text".to_string(), 0.4);
163 class_thresholds.insert("vertical_text".to_string(), 0.4);
164 class_thresholds.insert("vision_footnote".to_string(), 0.5);
165
166 let mut merge_modes = HashMap::new();
167 merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
168 merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
169 merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
170 merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
171 merge_modes.insert("content".to_string(), MergeBboxMode::Union);
172 merge_modes.insert("display_formula".to_string(), MergeBboxMode::Large);
173 merge_modes.insert("doc_title".to_string(), MergeBboxMode::Large);
174 merge_modes.insert("figure_title".to_string(), MergeBboxMode::Union);
175 merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
176 merge_modes.insert("footer_image".to_string(), MergeBboxMode::Union);
177 merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
178 merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
179 merge_modes.insert("header".to_string(), MergeBboxMode::Union);
180 merge_modes.insert("header_image".to_string(), MergeBboxMode::Union);
181 merge_modes.insert("image".to_string(), MergeBboxMode::Union);
182 merge_modes.insert("inline_formula".to_string(), MergeBboxMode::Large);
183 merge_modes.insert("number".to_string(), MergeBboxMode::Union);
184 merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
185 merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
186 merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
187 merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
188 merge_modes.insert("table".to_string(), MergeBboxMode::Union);
189 merge_modes.insert("text".to_string(), MergeBboxMode::Union);
190 merge_modes.insert("vertical_text".to_string(), MergeBboxMode::Union);
191 merge_modes.insert("vision_footnote".to_string(), MergeBboxMode::Union);
192
193 Self {
194 score_threshold: 0.4,
195 max_elements: 100,
196 class_thresholds: Some(class_thresholds),
197 class_merge_modes: Some(merge_modes),
198 layout_nms: true,
199 nms_threshold: 0.5,
200 layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
201 }
202 }
203
204 pub fn with_pp_structurev3_defaults() -> Self {
210 let mut cfg = Self::with_pp_structurev3_thresholds();
211
212 let mut merge_modes = HashMap::new();
213 merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
214 merge_modes.insert("image".to_string(), MergeBboxMode::Large);
215 merge_modes.insert("text".to_string(), MergeBboxMode::Union);
216 merge_modes.insert("number".to_string(), MergeBboxMode::Union);
217 merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
218 merge_modes.insert("content".to_string(), MergeBboxMode::Union);
219 merge_modes.insert("figure_table_chart_title".to_string(), MergeBboxMode::Union);
220 merge_modes.insert("formula".to_string(), MergeBboxMode::Large);
221 merge_modes.insert("table".to_string(), MergeBboxMode::Union);
222 merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
223 merge_modes.insert("doc_title".to_string(), MergeBboxMode::Union);
224 merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
225 merge_modes.insert("header".to_string(), MergeBboxMode::Union);
226 merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
227 merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
228 merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
229 merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
230 merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
231 merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
232 merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
233
234 cfg.class_merge_modes = Some(merge_modes);
235 cfg.layout_unclip_ratio = Some(UnclipRatio::Separate(1.0, 1.0));
236 cfg
237 }
238
239 pub fn get_class_threshold(&self, class_name: &str) -> f32 {
243 self.class_thresholds
244 .as_ref()
245 .and_then(|thresholds| thresholds.get(class_name).copied())
246 .unwrap_or(self.score_threshold)
247 }
248
249 pub fn get_class_merge_mode(&self, class_name: &str) -> MergeBboxMode {
253 self.class_merge_modes
254 .as_ref()
255 .and_then(|modes| modes.get(class_name).copied())
256 .unwrap_or(MergeBboxMode::Large)
257 }
258}
259
260#[derive(Debug, Clone)]
265pub struct LayoutDetectionElement {
266 pub bbox: BoundingBox,
268 pub element_type: String,
270 pub score: f32,
272}
273
274#[derive(Debug, Clone)]
276pub struct LayoutDetectionOutput {
277 pub elements: Vec<Vec<LayoutDetectionElement>>,
279 pub is_reading_order_sorted: bool,
284}
285
286impl LayoutDetectionOutput {
287 pub fn empty() -> Self {
289 Self {
290 elements: Vec::new(),
291 is_reading_order_sorted: false,
292 }
293 }
294
295 pub fn with_capacity(capacity: usize) -> Self {
297 Self {
298 elements: Vec::with_capacity(capacity),
299 is_reading_order_sorted: false,
300 }
301 }
302
303 pub fn with_reading_order_sorted(mut self, sorted: bool) -> Self {
305 self.is_reading_order_sorted = sorted;
306 self
307 }
308}
309
310impl TaskDefinition for LayoutDetectionOutput {
311 const TASK_NAME: &'static str = "layout_detection";
312 const TASK_DOC: &'static str = "Layout detection/analysis";
313
314 fn empty() -> Self {
315 LayoutDetectionOutput::empty()
316 }
317}
318
319#[derive(Debug, Default)]
321pub struct LayoutDetectionTask {
322 config: LayoutDetectionConfig,
323}
324
325impl LayoutDetectionTask {
326 pub fn new(config: LayoutDetectionConfig) -> Self {
328 Self { config }
329 }
330}
331
332impl Task for LayoutDetectionTask {
333 type Config = LayoutDetectionConfig;
334 type Input = ImageTaskInput;
335 type Output = LayoutDetectionOutput;
336
337 fn task_type(&self) -> TaskType {
338 TaskType::LayoutDetection
339 }
340
341 fn schema(&self) -> TaskSchema {
342 TaskSchema::new(
343 TaskType::LayoutDetection,
344 vec!["image".to_string()],
345 vec!["layout_elements".to_string()],
346 )
347 }
348
349 fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
350 ensure_non_empty_images(&input.images, "No images provided for layout detection")?;
351
352 Ok(())
353 }
354
355 fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
356 let validator = ScoreValidator::new_unit_range("score");
357
358 for (idx, elements) in output.elements.iter().enumerate() {
359 validate_max_value(
361 elements.len(),
362 self.config.max_elements,
363 "element count",
364 &format!("Image {}", idx),
365 )?;
366
367 let scores: Vec<f32> = elements.iter().map(|e| e.score).collect();
369 validator.validate_scores_with(&scores, |elem_idx| {
370 format!("Image {}, element {}", idx, elem_idx)
371 })?;
372 }
373
374 Ok(())
375 }
376
377 fn empty_output(&self) -> Self::Output {
378 LayoutDetectionOutput::empty()
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::processors::Point;
386 use image::RgbImage;
387
388 #[test]
389 fn test_layout_detection_task_creation() {
390 let task = LayoutDetectionTask::default();
391 assert_eq!(task.task_type(), TaskType::LayoutDetection);
392 }
393
394 #[test]
395 fn test_input_validation() {
396 let task = LayoutDetectionTask::default();
397
398 let empty_input = ImageTaskInput::new(vec![]);
400 assert!(task.validate_input(&empty_input).is_err());
401
402 let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
404 assert!(task.validate_input(&valid_input).is_ok());
405 }
406
407 #[test]
408 fn test_output_validation() {
409 let task = LayoutDetectionTask::default();
410
411 let box1 = BoundingBox::new(vec![
413 Point::new(0.0, 0.0),
414 Point::new(10.0, 0.0),
415 Point::new(10.0, 10.0),
416 Point::new(0.0, 10.0),
417 ]);
418 let element = LayoutDetectionElement {
419 bbox: box1,
420 element_type: "text".to_string(),
421 score: 0.95,
422 };
423 let output = LayoutDetectionOutput {
424 elements: vec![vec![element]],
425 is_reading_order_sorted: false,
426 };
427 assert!(task.validate_output(&output).is_ok());
428 }
429
430 #[test]
431 fn test_schema() {
432 let task = LayoutDetectionTask::default();
433 let schema = task.schema();
434 assert_eq!(schema.task_type, TaskType::LayoutDetection);
435 assert!(schema.input_types.contains(&"image".to_string()));
436 assert!(schema.output_types.contains(&"layout_elements".to_string()));
437 }
438}