oar_ocr_core/domain/tasks/
layout_detection.rs

1//! Concrete task implementations for layout detection.
2//!
3//! This module provides the layout detection task that identifies document layout elements.
4
5use super::validation::ensure_non_empty_images;
6use crate::ConfigValidator;
7use crate::core::OCRError;
8use crate::core::traits::TaskDefinition;
9use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
10use crate::processors::BoundingBox;
11use crate::utils::{ScoreValidator, validate_max_value};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// Bounding box merge mode for layout elements.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17pub enum MergeBboxMode {
18    /// Keep the larger bounding box
19    #[default]
20    Large,
21    /// Merge to union of bounding boxes
22    Union,
23    /// Keep the smaller bounding box
24    Small,
25}
26
27/// Unclip ratio configuration for layout detection.
28/// Controls how bounding boxes are scaled while keeping the center fixed.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum UnclipRatio {
31    /// Single ratio applied to both width and height
32    Uniform(f32),
33    /// Separate ratios for (width, height)
34    Separate(f32, f32),
35    /// Per-class ratios: class_id -> (width_ratio, height_ratio)
36    PerClass(HashMap<usize, (f32, f32)>),
37}
38
39impl Default for UnclipRatio {
40    fn default() -> Self {
41        UnclipRatio::Separate(1.0, 1.0)
42    }
43}
44
45/// Configuration for layout detection task.
46#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
47pub struct LayoutDetectionConfig {
48    /// Default score threshold for detection (default: 0.5)
49    #[validate(range(0.0, 1.0))]
50    pub score_threshold: f32,
51    /// Maximum number of layout elements (default: 100)
52    #[validate(min(1))]
53    pub max_elements: usize,
54    /// Per-class score thresholds (overrides score_threshold for specific classes)
55    /// PP-StructureV3 defaults:
56    /// - paragraph_title: 0.3
57    /// - formula: 0.3
58    /// - text: 0.4
59    /// - seal: 0.45
60    /// - others: 0.5
61    #[serde(default)]
62    pub class_thresholds: Option<HashMap<String, f32>>,
63    /// Per-class bounding box merge modes
64    #[serde(default)]
65    pub class_merge_modes: Option<HashMap<String, MergeBboxMode>>,
66    /// Enable NMS for layout detection (default: true)
67    #[serde(default = "default_layout_nms")]
68    pub layout_nms: bool,
69    /// NMS threshold (default: 0.5)
70    #[serde(default = "default_nms_threshold")]
71    pub nms_threshold: f32,
72    /// Unclip ratio for expanding/shrinking bounding boxes (PP-StructureV3)
73    /// Default: [1.0, 1.0] (no change)
74    #[serde(default)]
75    pub layout_unclip_ratio: Option<UnclipRatio>,
76}
77
78fn default_layout_nms() -> bool {
79    true
80}
81
82fn default_nms_threshold() -> f32 {
83    0.5
84}
85
86impl Default for LayoutDetectionConfig {
87    fn default() -> Self {
88        Self {
89            score_threshold: 0.5,
90            max_elements: 100,
91            class_thresholds: None,
92            class_merge_modes: None,
93            layout_nms: true,
94            nms_threshold: 0.5,
95            layout_unclip_ratio: None,
96        }
97    }
98}
99
100impl LayoutDetectionConfig {
101    /// Creates a config with PP-StructureV3 default class thresholds.
102    ///
103    /// PP-StructureV3 uses different thresholds for different element types:
104    /// - paragraph_title: 0.3
105    /// - formula: 0.3
106    /// - text: 0.4
107    /// - seal: 0.45
108    /// - others: 0.5 (default)
109    pub fn with_pp_structurev3_thresholds() -> Self {
110        let mut class_thresholds = HashMap::new();
111        class_thresholds.insert("paragraph_title".to_string(), 0.3);
112        class_thresholds.insert("formula".to_string(), 0.3);
113        class_thresholds.insert("text".to_string(), 0.4);
114        class_thresholds.insert("seal".to_string(), 0.45);
115
116        Self {
117            // Use the minimum of the per-class thresholds so the postprocessor doesn't drop
118            // candidates before the per-class filter runs.
119            score_threshold: 0.3,
120            max_elements: 100,
121            class_thresholds: Some(class_thresholds),
122            class_merge_modes: None,
123            layout_nms: true,
124            nms_threshold: 0.5,
125            layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
126        }
127    }
128
129    /// Creates a config with PP-DocLayoutV2 default thresholds and merge modes.
130    ///
131    /// These defaults are aligned with OpenOCR/OpenDoc's pipeline config:
132    /// `OpenOCR/configs/rec/unirec/opendoc_pipeline.yml`.
133    ///
134    /// Notes:
135    /// - The postprocessor applies `score_threshold` before per-class thresholds, so we set it
136    ///   to the minimum per-class threshold (0.4) to avoid dropping valid candidates early.
137    /// - `class_merge_modes` is populated for all labels so merge behavior is deterministic.
138    pub fn with_pp_doclayoutv2_defaults() -> Self {
139        let mut class_thresholds = HashMap::new();
140        class_thresholds.insert("abstract".to_string(), 0.5);
141        class_thresholds.insert("algorithm".to_string(), 0.5);
142        class_thresholds.insert("aside_text".to_string(), 0.5);
143        class_thresholds.insert("chart".to_string(), 0.5);
144        class_thresholds.insert("content".to_string(), 0.5);
145        class_thresholds.insert("display_formula".to_string(), 0.4);
146        class_thresholds.insert("doc_title".to_string(), 0.4);
147        class_thresholds.insert("figure_title".to_string(), 0.5);
148        class_thresholds.insert("footer".to_string(), 0.5);
149        class_thresholds.insert("footer_image".to_string(), 0.5);
150        class_thresholds.insert("footnote".to_string(), 0.5);
151        class_thresholds.insert("formula_number".to_string(), 0.5);
152        class_thresholds.insert("header".to_string(), 0.5);
153        class_thresholds.insert("header_image".to_string(), 0.5);
154        class_thresholds.insert("image".to_string(), 0.5);
155        class_thresholds.insert("inline_formula".to_string(), 0.4);
156        class_thresholds.insert("number".to_string(), 0.5);
157        class_thresholds.insert("paragraph_title".to_string(), 0.4);
158        class_thresholds.insert("reference".to_string(), 0.5);
159        class_thresholds.insert("reference_content".to_string(), 0.5);
160        class_thresholds.insert("seal".to_string(), 0.45);
161        class_thresholds.insert("table".to_string(), 0.5);
162        class_thresholds.insert("text".to_string(), 0.4);
163        class_thresholds.insert("vertical_text".to_string(), 0.4);
164        class_thresholds.insert("vision_footnote".to_string(), 0.5);
165
166        let mut merge_modes = HashMap::new();
167        merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
168        merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
169        merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
170        merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
171        merge_modes.insert("content".to_string(), MergeBboxMode::Union);
172        merge_modes.insert("display_formula".to_string(), MergeBboxMode::Large);
173        merge_modes.insert("doc_title".to_string(), MergeBboxMode::Large);
174        merge_modes.insert("figure_title".to_string(), MergeBboxMode::Union);
175        merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
176        merge_modes.insert("footer_image".to_string(), MergeBboxMode::Union);
177        merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
178        merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
179        merge_modes.insert("header".to_string(), MergeBboxMode::Union);
180        merge_modes.insert("header_image".to_string(), MergeBboxMode::Union);
181        merge_modes.insert("image".to_string(), MergeBboxMode::Union);
182        merge_modes.insert("inline_formula".to_string(), MergeBboxMode::Large);
183        merge_modes.insert("number".to_string(), MergeBboxMode::Union);
184        merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
185        merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
186        merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
187        merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
188        merge_modes.insert("table".to_string(), MergeBboxMode::Union);
189        merge_modes.insert("text".to_string(), MergeBboxMode::Union);
190        merge_modes.insert("vertical_text".to_string(), MergeBboxMode::Union);
191        merge_modes.insert("vision_footnote".to_string(), MergeBboxMode::Union);
192
193        Self {
194            score_threshold: 0.4,
195            max_elements: 100,
196            class_thresholds: Some(class_thresholds),
197            class_merge_modes: Some(merge_modes),
198            layout_nms: true,
199            nms_threshold: 0.5,
200            layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
201        }
202    }
203
204    /// Creates a config with PP-StructureV3 default thresholds, merge modes, and unclip ratio.
205    ///
206    /// Merge modes follow standard configuration:
207    /// - "large": paragraph_title, image, formula, chart
208    /// - "union": all other PP-DocLayout_plus-L classes
209    pub fn with_pp_structurev3_defaults() -> Self {
210        let mut cfg = Self::with_pp_structurev3_thresholds();
211
212        let mut merge_modes = HashMap::new();
213        merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
214        merge_modes.insert("image".to_string(), MergeBboxMode::Large);
215        merge_modes.insert("text".to_string(), MergeBboxMode::Union);
216        merge_modes.insert("number".to_string(), MergeBboxMode::Union);
217        merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
218        merge_modes.insert("content".to_string(), MergeBboxMode::Union);
219        merge_modes.insert("figure_table_chart_title".to_string(), MergeBboxMode::Union);
220        merge_modes.insert("formula".to_string(), MergeBboxMode::Large);
221        merge_modes.insert("table".to_string(), MergeBboxMode::Union);
222        merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
223        merge_modes.insert("doc_title".to_string(), MergeBboxMode::Union);
224        merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
225        merge_modes.insert("header".to_string(), MergeBboxMode::Union);
226        merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
227        merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
228        merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
229        merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
230        merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
231        merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
232        merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
233
234        cfg.class_merge_modes = Some(merge_modes);
235        cfg.layout_unclip_ratio = Some(UnclipRatio::Separate(1.0, 1.0));
236        cfg
237    }
238
239    /// Gets the threshold for a specific class.
240    ///
241    /// Returns the class-specific threshold if configured, otherwise the default threshold.
242    pub fn get_class_threshold(&self, class_name: &str) -> f32 {
243        self.class_thresholds
244            .as_ref()
245            .and_then(|thresholds| thresholds.get(class_name).copied())
246            .unwrap_or(self.score_threshold)
247    }
248
249    /// Gets the merge mode for a specific class.
250    ///
251    /// Returns the class-specific merge mode if configured, otherwise Large (default).
252    pub fn get_class_merge_mode(&self, class_name: &str) -> MergeBboxMode {
253        self.class_merge_modes
254            .as_ref()
255            .and_then(|modes| modes.get(class_name).copied())
256            .unwrap_or(MergeBboxMode::Large)
257    }
258}
259
260/// A detected layout element from the layout detection model.
261///
262/// This represents the raw output from layout detection before conversion
263/// to the final `LayoutElement` in `domain::structure`.
264#[derive(Debug, Clone)]
265pub struct LayoutDetectionElement {
266    /// Bounding box of the element
267    pub bbox: BoundingBox,
268    /// Type of layout element (raw string label from model)
269    pub element_type: String,
270    /// Confidence score (0.0 to 1.0)
271    pub score: f32,
272}
273
274/// Output from layout detection task.
275#[derive(Debug, Clone)]
276pub struct LayoutDetectionOutput {
277    /// Detected layout elements per image
278    pub elements: Vec<Vec<LayoutDetectionElement>>,
279    /// Whether elements are already sorted by reading order (e.g., from PP-DocLayoutV2)
280    ///
281    /// When `true`, downstream consumers can skip reading order sorting algorithms
282    /// as the elements are already in the correct reading order based on model output.
283    pub is_reading_order_sorted: bool,
284}
285
286impl LayoutDetectionOutput {
287    /// Creates an empty layout detection output.
288    pub fn empty() -> Self {
289        Self {
290            elements: Vec::new(),
291            is_reading_order_sorted: false,
292        }
293    }
294
295    /// Creates a layout detection output with the given capacity.
296    pub fn with_capacity(capacity: usize) -> Self {
297        Self {
298            elements: Vec::with_capacity(capacity),
299            is_reading_order_sorted: false,
300        }
301    }
302
303    /// Sets the reading order sorted flag.
304    pub fn with_reading_order_sorted(mut self, sorted: bool) -> Self {
305        self.is_reading_order_sorted = sorted;
306        self
307    }
308}
309
310impl TaskDefinition for LayoutDetectionOutput {
311    const TASK_NAME: &'static str = "layout_detection";
312    const TASK_DOC: &'static str = "Layout detection/analysis";
313
314    fn empty() -> Self {
315        LayoutDetectionOutput::empty()
316    }
317}
318
319/// Layout detection task implementation.
320#[derive(Debug, Default)]
321pub struct LayoutDetectionTask {
322    config: LayoutDetectionConfig,
323}
324
325impl LayoutDetectionTask {
326    /// Creates a new layout detection task.
327    pub fn new(config: LayoutDetectionConfig) -> Self {
328        Self { config }
329    }
330}
331
332impl Task for LayoutDetectionTask {
333    type Config = LayoutDetectionConfig;
334    type Input = ImageTaskInput;
335    type Output = LayoutDetectionOutput;
336
337    fn task_type(&self) -> TaskType {
338        TaskType::LayoutDetection
339    }
340
341    fn schema(&self) -> TaskSchema {
342        TaskSchema::new(
343            TaskType::LayoutDetection,
344            vec!["image".to_string()],
345            vec!["layout_elements".to_string()],
346        )
347    }
348
349    fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
350        ensure_non_empty_images(&input.images, "No images provided for layout detection")?;
351
352        Ok(())
353    }
354
355    fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
356        let validator = ScoreValidator::new_unit_range("score");
357
358        for (idx, elements) in output.elements.iter().enumerate() {
359            // Validate element count
360            validate_max_value(
361                elements.len(),
362                self.config.max_elements,
363                "element count",
364                &format!("Image {}", idx),
365            )?;
366
367            // Validate scores
368            let scores: Vec<f32> = elements.iter().map(|e| e.score).collect();
369            validator.validate_scores_with(&scores, |elem_idx| {
370                format!("Image {}, element {}", idx, elem_idx)
371            })?;
372        }
373
374        Ok(())
375    }
376
377    fn empty_output(&self) -> Self::Output {
378        LayoutDetectionOutput::empty()
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::processors::Point;
386    use image::RgbImage;
387
388    #[test]
389    fn test_layout_detection_task_creation() {
390        let task = LayoutDetectionTask::default();
391        assert_eq!(task.task_type(), TaskType::LayoutDetection);
392    }
393
394    #[test]
395    fn test_input_validation() {
396        let task = LayoutDetectionTask::default();
397
398        // Empty images should fail
399        let empty_input = ImageTaskInput::new(vec![]);
400        assert!(task.validate_input(&empty_input).is_err());
401
402        // Valid images should pass
403        let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
404        assert!(task.validate_input(&valid_input).is_ok());
405    }
406
407    #[test]
408    fn test_output_validation() {
409        let task = LayoutDetectionTask::default();
410
411        // Valid output should pass
412        let box1 = BoundingBox::new(vec![
413            Point::new(0.0, 0.0),
414            Point::new(10.0, 0.0),
415            Point::new(10.0, 10.0),
416            Point::new(0.0, 10.0),
417        ]);
418        let element = LayoutDetectionElement {
419            bbox: box1,
420            element_type: "text".to_string(),
421            score: 0.95,
422        };
423        let output = LayoutDetectionOutput {
424            elements: vec![vec![element]],
425            is_reading_order_sorted: false,
426        };
427        assert!(task.validate_output(&output).is_ok());
428    }
429
430    #[test]
431    fn test_schema() {
432        let task = LayoutDetectionTask::default();
433        let schema = task.schema();
434        assert_eq!(schema.task_type, TaskType::LayoutDetection);
435        assert!(schema.input_types.contains(&"image".to_string()));
436        assert!(schema.output_types.contains(&"layout_elements".to_string()));
437    }
438}