Skip to main content

oar_ocr_core/domain/tasks/
layout_detection.rs

1//! Concrete task implementations for layout detection.
2//!
3//! This module provides the layout detection task that identifies document layout elements.
4
5use super::validation::ensure_non_empty_images;
6use crate::ConfigValidator;
7use crate::core::OCRError;
8use crate::core::traits::TaskDefinition;
9use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
10use crate::processors::BoundingBox;
11use crate::utils::{ScoreValidator, validate_max_value};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// Bounding box merge mode for layout elements.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17pub enum MergeBboxMode {
18    /// Keep the larger bounding box
19    #[default]
20    Large,
21    /// Merge to union of bounding boxes
22    Union,
23    /// Keep the smaller bounding box
24    Small,
25}
26
27/// Unclip ratio configuration for layout detection.
28/// Controls how bounding boxes are scaled while keeping the center fixed.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum UnclipRatio {
31    /// Single ratio applied to both width and height
32    Uniform(f32),
33    /// Separate ratios for (width, height)
34    Separate(f32, f32),
35    /// Per-class ratios: class_id -> (width_ratio, height_ratio)
36    PerClass(HashMap<usize, (f32, f32)>),
37}
38
39impl Default for UnclipRatio {
40    fn default() -> Self {
41        UnclipRatio::Separate(1.0, 1.0)
42    }
43}
44
45/// Configuration for layout detection task.
46#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
47pub struct LayoutDetectionConfig {
48    /// Default score threshold for detection (default: 0.5)
49    #[validate(range(min = 0.0, max = 1.0))]
50    pub score_threshold: f32,
51    /// Maximum number of layout elements (default: 100)
52    #[validate(min = 1)]
53    pub max_elements: usize,
54    /// Per-class score thresholds (overrides score_threshold for specific classes)
55    /// PP-StructureV3 defaults:
56    /// - paragraph_title: 0.3
57    /// - formula: 0.3
58    /// - text: 0.4
59    /// - seal: 0.45
60    /// - others: 0.5
61    #[serde(default)]
62    pub class_thresholds: Option<HashMap<String, f32>>,
63    /// Per-class bounding box merge modes
64    #[serde(default)]
65    pub class_merge_modes: Option<HashMap<String, MergeBboxMode>>,
66    /// Enable NMS for layout detection (default: true)
67    #[serde(default = "default_layout_nms")]
68    pub layout_nms: bool,
69    /// NMS threshold (default: 0.5)
70    #[serde(default = "default_nms_threshold")]
71    pub nms_threshold: f32,
72    /// Unclip ratio for expanding/shrinking bounding boxes (PP-StructureV3)
73    /// Default: [1.0, 1.0] (no change)
74    #[serde(default)]
75    pub layout_unclip_ratio: Option<UnclipRatio>,
76}
77
78fn default_layout_nms() -> bool {
79    true
80}
81
82fn default_nms_threshold() -> f32 {
83    0.5
84}
85
86impl Default for LayoutDetectionConfig {
87    fn default() -> Self {
88        Self {
89            score_threshold: 0.5,
90            max_elements: 100,
91            class_thresholds: None,
92            class_merge_modes: None,
93            layout_nms: true,
94            nms_threshold: 0.5,
95            layout_unclip_ratio: None,
96        }
97    }
98}
99
100impl LayoutDetectionConfig {
101    /// Creates a config with PP-StructureV3 default class thresholds.
102    ///
103    /// PP-StructureV3 uses different thresholds for different element types:
104    /// - paragraph_title: 0.3
105    /// - formula: 0.3
106    /// - text: 0.4
107    /// - seal: 0.45
108    /// - others: 0.5 (default)
109    pub fn with_pp_structurev3_thresholds() -> Self {
110        let mut class_thresholds = HashMap::new();
111        class_thresholds.insert("paragraph_title".to_string(), 0.3);
112        class_thresholds.insert("formula".to_string(), 0.3);
113        class_thresholds.insert("text".to_string(), 0.4);
114        class_thresholds.insert("seal".to_string(), 0.45);
115
116        Self {
117            // Use the minimum of the per-class thresholds so the postprocessor doesn't drop
118            // candidates before the per-class filter runs.
119            score_threshold: 0.3,
120            max_elements: 100,
121            class_thresholds: Some(class_thresholds),
122            class_merge_modes: None,
123            layout_nms: true,
124            nms_threshold: 0.5,
125            layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
126        }
127    }
128
129    /// Creates a config with PP-DocLayoutV2 default thresholds and merge modes.
130    ///
131    /// These defaults are aligned with OpenOCR/OpenDoc's pipeline config:
132    /// `OpenOCR/configs/rec/unirec/opendoc_pipeline.yml`.
133    ///
134    /// Notes:
135    /// - The postprocessor applies `score_threshold` before per-class thresholds, so we set it
136    ///   to the minimum per-class threshold (0.4) to avoid dropping valid candidates early.
137    /// - `class_merge_modes` is populated for all labels so merge behavior is deterministic.
138    pub fn with_pp_doclayoutv2_defaults() -> Self {
139        let mut class_thresholds = HashMap::new();
140        class_thresholds.insert("abstract".to_string(), 0.5);
141        class_thresholds.insert("algorithm".to_string(), 0.5);
142        class_thresholds.insert("aside_text".to_string(), 0.5);
143        class_thresholds.insert("chart".to_string(), 0.5);
144        class_thresholds.insert("content".to_string(), 0.5);
145        class_thresholds.insert("display_formula".to_string(), 0.4);
146        class_thresholds.insert("doc_title".to_string(), 0.4);
147        class_thresholds.insert("figure_title".to_string(), 0.5);
148        class_thresholds.insert("footer".to_string(), 0.5);
149        class_thresholds.insert("footer_image".to_string(), 0.5);
150        class_thresholds.insert("footnote".to_string(), 0.5);
151        class_thresholds.insert("formula_number".to_string(), 0.5);
152        class_thresholds.insert("header".to_string(), 0.5);
153        class_thresholds.insert("header_image".to_string(), 0.5);
154        class_thresholds.insert("image".to_string(), 0.5);
155        class_thresholds.insert("inline_formula".to_string(), 0.4);
156        class_thresholds.insert("number".to_string(), 0.5);
157        class_thresholds.insert("paragraph_title".to_string(), 0.4);
158        class_thresholds.insert("reference".to_string(), 0.5);
159        class_thresholds.insert("reference_content".to_string(), 0.5);
160        class_thresholds.insert("seal".to_string(), 0.45);
161        class_thresholds.insert("table".to_string(), 0.5);
162        class_thresholds.insert("text".to_string(), 0.4);
163        class_thresholds.insert("vertical_text".to_string(), 0.4);
164        class_thresholds.insert("vision_footnote".to_string(), 0.5);
165
166        let mut merge_modes = HashMap::new();
167        merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
168        merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
169        merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
170        merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
171        merge_modes.insert("content".to_string(), MergeBboxMode::Union);
172        merge_modes.insert("display_formula".to_string(), MergeBboxMode::Large);
173        merge_modes.insert("doc_title".to_string(), MergeBboxMode::Large);
174        merge_modes.insert("figure_title".to_string(), MergeBboxMode::Union);
175        merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
176        merge_modes.insert("footer_image".to_string(), MergeBboxMode::Union);
177        merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
178        merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
179        merge_modes.insert("header".to_string(), MergeBboxMode::Union);
180        merge_modes.insert("header_image".to_string(), MergeBboxMode::Union);
181        merge_modes.insert("image".to_string(), MergeBboxMode::Union);
182        merge_modes.insert("inline_formula".to_string(), MergeBboxMode::Large);
183        merge_modes.insert("number".to_string(), MergeBboxMode::Union);
184        merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
185        merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
186        merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
187        merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
188        merge_modes.insert("table".to_string(), MergeBboxMode::Union);
189        merge_modes.insert("text".to_string(), MergeBboxMode::Union);
190        merge_modes.insert("vertical_text".to_string(), MergeBboxMode::Union);
191        merge_modes.insert("vision_footnote".to_string(), MergeBboxMode::Union);
192
193        Self {
194            score_threshold: 0.4,
195            max_elements: 100,
196            class_thresholds: Some(class_thresholds),
197            class_merge_modes: Some(merge_modes),
198            layout_nms: true,
199            nms_threshold: 0.5,
200            layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
201        }
202    }
203
204    /// Creates a config with PP-DocLayoutV3 defaults (0.3 threshold + PP-DocLayout merge modes).
205    ///
206    /// PP-DocLayoutV3 keeps the same label set as PP-DocLayoutV2 but uses a lower
207    /// global threshold in the reference pipeline.
208    pub fn with_pp_doclayoutv3_defaults() -> Self {
209        let mut merge_modes = HashMap::new();
210        merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
211        merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
212        merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
213        merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
214        merge_modes.insert("content".to_string(), MergeBboxMode::Union);
215        merge_modes.insert("display_formula".to_string(), MergeBboxMode::Large);
216        merge_modes.insert("doc_title".to_string(), MergeBboxMode::Large);
217        merge_modes.insert("figure_title".to_string(), MergeBboxMode::Union);
218        merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
219        merge_modes.insert("footer_image".to_string(), MergeBboxMode::Union);
220        merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
221        merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
222        merge_modes.insert("header".to_string(), MergeBboxMode::Union);
223        merge_modes.insert("header_image".to_string(), MergeBboxMode::Union);
224        merge_modes.insert("image".to_string(), MergeBboxMode::Union);
225        merge_modes.insert("inline_formula".to_string(), MergeBboxMode::Large);
226        merge_modes.insert("number".to_string(), MergeBboxMode::Union);
227        merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
228        merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
229        merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
230        merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
231        merge_modes.insert("table".to_string(), MergeBboxMode::Union);
232        merge_modes.insert("text".to_string(), MergeBboxMode::Union);
233        merge_modes.insert("vertical_text".to_string(), MergeBboxMode::Union);
234        merge_modes.insert("vision_footnote".to_string(), MergeBboxMode::Union);
235
236        Self {
237            score_threshold: 0.3,
238            max_elements: 100,
239            class_thresholds: None,
240            class_merge_modes: Some(merge_modes),
241            layout_nms: true,
242            nms_threshold: 0.5,
243            layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
244        }
245    }
246
247    /// Creates a config with PP-StructureV3 default thresholds, merge modes, and unclip ratio.
248    ///
249    /// Merge modes follow standard configuration:
250    /// - "large": paragraph_title, image, formula, chart
251    /// - "union": all other PP-DocLayout_plus-L classes
252    pub fn with_pp_structurev3_defaults() -> Self {
253        let mut cfg = Self::with_pp_structurev3_thresholds();
254
255        let mut merge_modes = HashMap::new();
256        merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
257        merge_modes.insert("image".to_string(), MergeBboxMode::Large);
258        merge_modes.insert("text".to_string(), MergeBboxMode::Union);
259        merge_modes.insert("number".to_string(), MergeBboxMode::Union);
260        merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
261        merge_modes.insert("content".to_string(), MergeBboxMode::Union);
262        merge_modes.insert("figure_table_chart_title".to_string(), MergeBboxMode::Union);
263        merge_modes.insert("formula".to_string(), MergeBboxMode::Large);
264        merge_modes.insert("table".to_string(), MergeBboxMode::Union);
265        merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
266        merge_modes.insert("doc_title".to_string(), MergeBboxMode::Union);
267        merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
268        merge_modes.insert("header".to_string(), MergeBboxMode::Union);
269        merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
270        merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
271        merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
272        merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
273        merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
274        merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
275        merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
276
277        cfg.class_merge_modes = Some(merge_modes);
278        cfg.layout_unclip_ratio = Some(UnclipRatio::Separate(1.0, 1.0));
279        cfg
280    }
281
282    /// Gets the threshold for a specific class.
283    ///
284    /// Returns the class-specific threshold if configured, otherwise the default threshold.
285    pub fn get_class_threshold(&self, class_name: &str) -> f32 {
286        self.class_thresholds
287            .as_ref()
288            .and_then(|thresholds| thresholds.get(class_name).copied())
289            .unwrap_or(self.score_threshold)
290    }
291
292    /// Gets the merge mode for a specific class.
293    ///
294    /// Returns the class-specific merge mode if configured, otherwise Large (default).
295    pub fn get_class_merge_mode(&self, class_name: &str) -> MergeBboxMode {
296        self.class_merge_modes
297            .as_ref()
298            .and_then(|modes| modes.get(class_name).copied())
299            .unwrap_or(MergeBboxMode::Large)
300    }
301}
302
303/// A detected layout element from the layout detection model.
304///
305/// This represents the raw output from layout detection before conversion
306/// to the final `LayoutElement` in `domain::structure`.
307#[derive(Debug, Clone)]
308pub struct LayoutDetectionElement {
309    /// Bounding box of the element
310    pub bbox: BoundingBox,
311    /// Type of layout element (raw string label from model)
312    pub element_type: String,
313    /// Confidence score (0.0 to 1.0)
314    pub score: f32,
315}
316
317/// Output from layout detection task.
318#[derive(Debug, Clone)]
319pub struct LayoutDetectionOutput {
320    /// Detected layout elements per image
321    pub elements: Vec<Vec<LayoutDetectionElement>>,
322    /// Whether elements are already sorted by reading order (e.g., from PP-DocLayoutV2)
323    ///
324    /// When `true`, downstream consumers can skip reading order sorting algorithms
325    /// as the elements are already in the correct reading order based on model output.
326    pub is_reading_order_sorted: bool,
327}
328
329impl LayoutDetectionOutput {
330    /// Creates an empty layout detection output.
331    pub fn empty() -> Self {
332        Self {
333            elements: Vec::new(),
334            is_reading_order_sorted: false,
335        }
336    }
337
338    /// Creates a layout detection output with the given capacity.
339    pub fn with_capacity(capacity: usize) -> Self {
340        Self {
341            elements: Vec::with_capacity(capacity),
342            is_reading_order_sorted: false,
343        }
344    }
345
346    /// Sets the reading order sorted flag.
347    pub fn with_reading_order_sorted(mut self, sorted: bool) -> Self {
348        self.is_reading_order_sorted = sorted;
349        self
350    }
351}
352
353impl TaskDefinition for LayoutDetectionOutput {
354    const TASK_NAME: &'static str = "layout_detection";
355    const TASK_DOC: &'static str = "Layout detection/analysis";
356
357    fn empty() -> Self {
358        LayoutDetectionOutput::empty()
359    }
360}
361
362/// Layout detection task implementation.
363#[derive(Debug, Default)]
364pub struct LayoutDetectionTask {
365    config: LayoutDetectionConfig,
366}
367
368impl LayoutDetectionTask {
369    /// Creates a new layout detection task.
370    pub fn new(config: LayoutDetectionConfig) -> Self {
371        Self { config }
372    }
373}
374
375impl Task for LayoutDetectionTask {
376    type Config = LayoutDetectionConfig;
377    type Input = ImageTaskInput;
378    type Output = LayoutDetectionOutput;
379
380    fn task_type(&self) -> TaskType {
381        TaskType::LayoutDetection
382    }
383
384    fn schema(&self) -> TaskSchema {
385        TaskSchema::new(
386            TaskType::LayoutDetection,
387            vec!["image".to_string()],
388            vec!["layout_elements".to_string()],
389        )
390    }
391
392    fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
393        ensure_non_empty_images(&input.images, "No images provided for layout detection")?;
394
395        Ok(())
396    }
397
398    fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
399        let validator = ScoreValidator::new_unit_range("score");
400
401        for (idx, elements) in output.elements.iter().enumerate() {
402            // Validate element count
403            validate_max_value(
404                elements.len(),
405                self.config.max_elements,
406                "element count",
407                &format!("Image {}", idx),
408            )?;
409
410            // Validate scores
411            let scores: Vec<f32> = elements.iter().map(|e| e.score).collect();
412            validator.validate_scores_with(&scores, |elem_idx| {
413                format!("Image {}, element {}", idx, elem_idx)
414            })?;
415        }
416
417        Ok(())
418    }
419
420    fn empty_output(&self) -> Self::Output {
421        LayoutDetectionOutput::empty()
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use crate::processors::Point;
429    use image::RgbImage;
430
431    #[test]
432    fn test_layout_detection_task_creation() {
433        let task = LayoutDetectionTask::default();
434        assert_eq!(task.task_type(), TaskType::LayoutDetection);
435    }
436
437    #[test]
438    fn test_input_validation() {
439        let task = LayoutDetectionTask::default();
440
441        // Empty images should fail
442        let empty_input = ImageTaskInput::new(vec![]);
443        assert!(task.validate_input(&empty_input).is_err());
444
445        // Valid images should pass
446        let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
447        assert!(task.validate_input(&valid_input).is_ok());
448    }
449
450    #[test]
451    fn test_output_validation() {
452        let task = LayoutDetectionTask::default();
453
454        // Valid output should pass
455        let box1 = BoundingBox::new(vec![
456            Point::new(0.0, 0.0),
457            Point::new(10.0, 0.0),
458            Point::new(10.0, 10.0),
459            Point::new(0.0, 10.0),
460        ]);
461        let element = LayoutDetectionElement {
462            bbox: box1,
463            element_type: "text".to_string(),
464            score: 0.95,
465        };
466        let output = LayoutDetectionOutput {
467            elements: vec![vec![element]],
468            is_reading_order_sorted: false,
469        };
470        assert!(task.validate_output(&output).is_ok());
471    }
472
473    #[test]
474    fn test_schema() {
475        let task = LayoutDetectionTask::default();
476        let schema = task.schema();
477        assert_eq!(schema.task_type, TaskType::LayoutDetection);
478        assert!(schema.input_types.contains(&"image".to_string()));
479        assert!(schema.output_types.contains(&"layout_elements".to_string()));
480    }
481}