1use super::validation::ensure_non_empty_images;
6use crate::ConfigValidator;
7use crate::core::OCRError;
8use crate::core::traits::TaskDefinition;
9use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
10use crate::processors::BoundingBox;
11use crate::utils::{ScoreValidator, validate_max_value};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17pub enum MergeBboxMode {
18 #[default]
20 Large,
21 Union,
23 Small,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum UnclipRatio {
31 Uniform(f32),
33 Separate(f32, f32),
35 PerClass(HashMap<usize, (f32, f32)>),
37}
38
39impl Default for UnclipRatio {
40 fn default() -> Self {
41 UnclipRatio::Separate(1.0, 1.0)
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
47pub struct LayoutDetectionConfig {
48 #[validate(range(min = 0.0, max = 1.0))]
50 pub score_threshold: f32,
51 #[validate(min = 1)]
53 pub max_elements: usize,
54 #[serde(default)]
62 pub class_thresholds: Option<HashMap<String, f32>>,
63 #[serde(default)]
65 pub class_merge_modes: Option<HashMap<String, MergeBboxMode>>,
66 #[serde(default = "default_layout_nms")]
68 pub layout_nms: bool,
69 #[serde(default = "default_nms_threshold")]
71 pub nms_threshold: f32,
72 #[serde(default)]
75 pub layout_unclip_ratio: Option<UnclipRatio>,
76}
77
78fn default_layout_nms() -> bool {
79 true
80}
81
82fn default_nms_threshold() -> f32 {
83 0.5
84}
85
86impl Default for LayoutDetectionConfig {
87 fn default() -> Self {
88 Self {
89 score_threshold: 0.5,
90 max_elements: 100,
91 class_thresholds: None,
92 class_merge_modes: None,
93 layout_nms: true,
94 nms_threshold: 0.5,
95 layout_unclip_ratio: None,
96 }
97 }
98}
99
100impl LayoutDetectionConfig {
101 pub fn with_pp_structurev3_thresholds() -> Self {
110 let mut class_thresholds = HashMap::new();
111 class_thresholds.insert("paragraph_title".to_string(), 0.3);
112 class_thresholds.insert("formula".to_string(), 0.3);
113 class_thresholds.insert("text".to_string(), 0.4);
114 class_thresholds.insert("seal".to_string(), 0.45);
115
116 Self {
117 score_threshold: 0.3,
120 max_elements: 100,
121 class_thresholds: Some(class_thresholds),
122 class_merge_modes: None,
123 layout_nms: true,
124 nms_threshold: 0.5,
125 layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
126 }
127 }
128
129 pub fn with_pp_doclayoutv2_defaults() -> Self {
139 let mut class_thresholds = HashMap::new();
140 class_thresholds.insert("abstract".to_string(), 0.5);
141 class_thresholds.insert("algorithm".to_string(), 0.5);
142 class_thresholds.insert("aside_text".to_string(), 0.5);
143 class_thresholds.insert("chart".to_string(), 0.5);
144 class_thresholds.insert("content".to_string(), 0.5);
145 class_thresholds.insert("display_formula".to_string(), 0.4);
146 class_thresholds.insert("doc_title".to_string(), 0.4);
147 class_thresholds.insert("figure_title".to_string(), 0.5);
148 class_thresholds.insert("footer".to_string(), 0.5);
149 class_thresholds.insert("footer_image".to_string(), 0.5);
150 class_thresholds.insert("footnote".to_string(), 0.5);
151 class_thresholds.insert("formula_number".to_string(), 0.5);
152 class_thresholds.insert("header".to_string(), 0.5);
153 class_thresholds.insert("header_image".to_string(), 0.5);
154 class_thresholds.insert("image".to_string(), 0.5);
155 class_thresholds.insert("inline_formula".to_string(), 0.4);
156 class_thresholds.insert("number".to_string(), 0.5);
157 class_thresholds.insert("paragraph_title".to_string(), 0.4);
158 class_thresholds.insert("reference".to_string(), 0.5);
159 class_thresholds.insert("reference_content".to_string(), 0.5);
160 class_thresholds.insert("seal".to_string(), 0.45);
161 class_thresholds.insert("table".to_string(), 0.5);
162 class_thresholds.insert("text".to_string(), 0.4);
163 class_thresholds.insert("vertical_text".to_string(), 0.4);
164 class_thresholds.insert("vision_footnote".to_string(), 0.5);
165
166 let mut merge_modes = HashMap::new();
167 merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
168 merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
169 merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
170 merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
171 merge_modes.insert("content".to_string(), MergeBboxMode::Union);
172 merge_modes.insert("display_formula".to_string(), MergeBboxMode::Large);
173 merge_modes.insert("doc_title".to_string(), MergeBboxMode::Large);
174 merge_modes.insert("figure_title".to_string(), MergeBboxMode::Union);
175 merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
176 merge_modes.insert("footer_image".to_string(), MergeBboxMode::Union);
177 merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
178 merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
179 merge_modes.insert("header".to_string(), MergeBboxMode::Union);
180 merge_modes.insert("header_image".to_string(), MergeBboxMode::Union);
181 merge_modes.insert("image".to_string(), MergeBboxMode::Union);
182 merge_modes.insert("inline_formula".to_string(), MergeBboxMode::Large);
183 merge_modes.insert("number".to_string(), MergeBboxMode::Union);
184 merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
185 merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
186 merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
187 merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
188 merge_modes.insert("table".to_string(), MergeBboxMode::Union);
189 merge_modes.insert("text".to_string(), MergeBboxMode::Union);
190 merge_modes.insert("vertical_text".to_string(), MergeBboxMode::Union);
191 merge_modes.insert("vision_footnote".to_string(), MergeBboxMode::Union);
192
193 Self {
194 score_threshold: 0.4,
195 max_elements: 100,
196 class_thresholds: Some(class_thresholds),
197 class_merge_modes: Some(merge_modes),
198 layout_nms: true,
199 nms_threshold: 0.5,
200 layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
201 }
202 }
203
204 pub fn with_pp_doclayoutv3_defaults() -> Self {
209 let mut merge_modes = HashMap::new();
210 merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
211 merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
212 merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
213 merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
214 merge_modes.insert("content".to_string(), MergeBboxMode::Union);
215 merge_modes.insert("display_formula".to_string(), MergeBboxMode::Large);
216 merge_modes.insert("doc_title".to_string(), MergeBboxMode::Large);
217 merge_modes.insert("figure_title".to_string(), MergeBboxMode::Union);
218 merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
219 merge_modes.insert("footer_image".to_string(), MergeBboxMode::Union);
220 merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
221 merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
222 merge_modes.insert("header".to_string(), MergeBboxMode::Union);
223 merge_modes.insert("header_image".to_string(), MergeBboxMode::Union);
224 merge_modes.insert("image".to_string(), MergeBboxMode::Union);
225 merge_modes.insert("inline_formula".to_string(), MergeBboxMode::Large);
226 merge_modes.insert("number".to_string(), MergeBboxMode::Union);
227 merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
228 merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
229 merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
230 merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
231 merge_modes.insert("table".to_string(), MergeBboxMode::Union);
232 merge_modes.insert("text".to_string(), MergeBboxMode::Union);
233 merge_modes.insert("vertical_text".to_string(), MergeBboxMode::Union);
234 merge_modes.insert("vision_footnote".to_string(), MergeBboxMode::Union);
235
236 Self {
237 score_threshold: 0.3,
238 max_elements: 100,
239 class_thresholds: None,
240 class_merge_modes: Some(merge_modes),
241 layout_nms: true,
242 nms_threshold: 0.5,
243 layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
244 }
245 }
246
247 pub fn with_pp_structurev3_defaults() -> Self {
253 let mut cfg = Self::with_pp_structurev3_thresholds();
254
255 let mut merge_modes = HashMap::new();
256 merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
257 merge_modes.insert("image".to_string(), MergeBboxMode::Large);
258 merge_modes.insert("text".to_string(), MergeBboxMode::Union);
259 merge_modes.insert("number".to_string(), MergeBboxMode::Union);
260 merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
261 merge_modes.insert("content".to_string(), MergeBboxMode::Union);
262 merge_modes.insert("figure_table_chart_title".to_string(), MergeBboxMode::Union);
263 merge_modes.insert("formula".to_string(), MergeBboxMode::Large);
264 merge_modes.insert("table".to_string(), MergeBboxMode::Union);
265 merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
266 merge_modes.insert("doc_title".to_string(), MergeBboxMode::Union);
267 merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
268 merge_modes.insert("header".to_string(), MergeBboxMode::Union);
269 merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
270 merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
271 merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
272 merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
273 merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
274 merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
275 merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
276
277 cfg.class_merge_modes = Some(merge_modes);
278 cfg.layout_unclip_ratio = Some(UnclipRatio::Separate(1.0, 1.0));
279 cfg
280 }
281
282 pub fn get_class_threshold(&self, class_name: &str) -> f32 {
286 self.class_thresholds
287 .as_ref()
288 .and_then(|thresholds| thresholds.get(class_name).copied())
289 .unwrap_or(self.score_threshold)
290 }
291
292 pub fn get_class_merge_mode(&self, class_name: &str) -> MergeBboxMode {
296 self.class_merge_modes
297 .as_ref()
298 .and_then(|modes| modes.get(class_name).copied())
299 .unwrap_or(MergeBboxMode::Large)
300 }
301}
302
303#[derive(Debug, Clone)]
308pub struct LayoutDetectionElement {
309 pub bbox: BoundingBox,
311 pub element_type: String,
313 pub score: f32,
315}
316
317#[derive(Debug, Clone)]
319pub struct LayoutDetectionOutput {
320 pub elements: Vec<Vec<LayoutDetectionElement>>,
322 pub is_reading_order_sorted: bool,
327}
328
329impl LayoutDetectionOutput {
330 pub fn empty() -> Self {
332 Self {
333 elements: Vec::new(),
334 is_reading_order_sorted: false,
335 }
336 }
337
338 pub fn with_capacity(capacity: usize) -> Self {
340 Self {
341 elements: Vec::with_capacity(capacity),
342 is_reading_order_sorted: false,
343 }
344 }
345
346 pub fn with_reading_order_sorted(mut self, sorted: bool) -> Self {
348 self.is_reading_order_sorted = sorted;
349 self
350 }
351}
352
353impl TaskDefinition for LayoutDetectionOutput {
354 const TASK_NAME: &'static str = "layout_detection";
355 const TASK_DOC: &'static str = "Layout detection/analysis";
356
357 fn empty() -> Self {
358 LayoutDetectionOutput::empty()
359 }
360}
361
362#[derive(Debug, Default)]
364pub struct LayoutDetectionTask {
365 config: LayoutDetectionConfig,
366}
367
368impl LayoutDetectionTask {
369 pub fn new(config: LayoutDetectionConfig) -> Self {
371 Self { config }
372 }
373}
374
375impl Task for LayoutDetectionTask {
376 type Config = LayoutDetectionConfig;
377 type Input = ImageTaskInput;
378 type Output = LayoutDetectionOutput;
379
380 fn task_type(&self) -> TaskType {
381 TaskType::LayoutDetection
382 }
383
384 fn schema(&self) -> TaskSchema {
385 TaskSchema::new(
386 TaskType::LayoutDetection,
387 vec!["image".to_string()],
388 vec!["layout_elements".to_string()],
389 )
390 }
391
392 fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
393 ensure_non_empty_images(&input.images, "No images provided for layout detection")?;
394
395 Ok(())
396 }
397
398 fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
399 let validator = ScoreValidator::new_unit_range("score");
400
401 for (idx, elements) in output.elements.iter().enumerate() {
402 validate_max_value(
404 elements.len(),
405 self.config.max_elements,
406 "element count",
407 &format!("Image {}", idx),
408 )?;
409
410 let scores: Vec<f32> = elements.iter().map(|e| e.score).collect();
412 validator.validate_scores_with(&scores, |elem_idx| {
413 format!("Image {}, element {}", idx, elem_idx)
414 })?;
415 }
416
417 Ok(())
418 }
419
420 fn empty_output(&self) -> Self::Output {
421 LayoutDetectionOutput::empty()
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use crate::processors::Point;
429 use image::RgbImage;
430
431 #[test]
432 fn test_layout_detection_task_creation() {
433 let task = LayoutDetectionTask::default();
434 assert_eq!(task.task_type(), TaskType::LayoutDetection);
435 }
436
437 #[test]
438 fn test_input_validation() {
439 let task = LayoutDetectionTask::default();
440
441 let empty_input = ImageTaskInput::new(vec![]);
443 assert!(task.validate_input(&empty_input).is_err());
444
445 let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
447 assert!(task.validate_input(&valid_input).is_ok());
448 }
449
450 #[test]
451 fn test_output_validation() {
452 let task = LayoutDetectionTask::default();
453
454 let box1 = BoundingBox::new(vec![
456 Point::new(0.0, 0.0),
457 Point::new(10.0, 0.0),
458 Point::new(10.0, 10.0),
459 Point::new(0.0, 10.0),
460 ]);
461 let element = LayoutDetectionElement {
462 bbox: box1,
463 element_type: "text".to_string(),
464 score: 0.95,
465 };
466 let output = LayoutDetectionOutput {
467 elements: vec![vec![element]],
468 is_reading_order_sorted: false,
469 };
470 assert!(task.validate_output(&output).is_ok());
471 }
472
473 #[test]
474 fn test_schema() {
475 let task = LayoutDetectionTask::default();
476 let schema = task.schema();
477 assert_eq!(schema.task_type, TaskType::LayoutDetection);
478 assert!(schema.input_types.contains(&"image".to_string()));
479 assert!(schema.output_types.contains(&"layout_elements".to_string()));
480 }
481}