use super::validation::ensure_non_empty_images;
use crate::ConfigValidator;
use crate::core::OCRError;
use crate::core::traits::TaskDefinition;
use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
use crate::processors::BoundingBox;
use crate::utils::{ScoreValidator, validate_max_value};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum MergeBboxMode {
#[default]
Large,
Union,
Small,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UnclipRatio {
Uniform(f32),
Separate(f32, f32),
PerClass(HashMap<usize, (f32, f32)>),
}
impl Default for UnclipRatio {
fn default() -> Self {
UnclipRatio::Separate(1.0, 1.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
pub struct LayoutDetectionConfig {
#[validate(range(min = 0.0, max = 1.0))]
pub score_threshold: f32,
#[validate(min = 1)]
pub max_elements: usize,
#[serde(default)]
pub class_thresholds: Option<HashMap<String, f32>>,
#[serde(default)]
pub class_merge_modes: Option<HashMap<String, MergeBboxMode>>,
#[serde(default = "default_layout_nms")]
pub layout_nms: bool,
#[serde(default = "default_nms_threshold")]
pub nms_threshold: f32,
#[serde(default)]
pub layout_unclip_ratio: Option<UnclipRatio>,
}
fn default_layout_nms() -> bool {
true
}
fn default_nms_threshold() -> f32 {
0.5
}
impl Default for LayoutDetectionConfig {
fn default() -> Self {
Self {
score_threshold: 0.5,
max_elements: 100,
class_thresholds: None,
class_merge_modes: None,
layout_nms: true,
nms_threshold: 0.5,
layout_unclip_ratio: None,
}
}
}
impl LayoutDetectionConfig {
pub fn with_pp_structurev3_thresholds() -> Self {
let mut class_thresholds = HashMap::new();
class_thresholds.insert("paragraph_title".to_string(), 0.3);
class_thresholds.insert("formula".to_string(), 0.3);
class_thresholds.insert("text".to_string(), 0.4);
class_thresholds.insert("seal".to_string(), 0.45);
Self {
score_threshold: 0.3,
max_elements: 100,
class_thresholds: Some(class_thresholds),
class_merge_modes: None,
layout_nms: true,
nms_threshold: 0.5,
layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
}
}
pub fn with_pp_doclayoutv2_defaults() -> Self {
let mut class_thresholds = HashMap::new();
class_thresholds.insert("abstract".to_string(), 0.5);
class_thresholds.insert("algorithm".to_string(), 0.5);
class_thresholds.insert("aside_text".to_string(), 0.5);
class_thresholds.insert("chart".to_string(), 0.5);
class_thresholds.insert("content".to_string(), 0.5);
class_thresholds.insert("display_formula".to_string(), 0.4);
class_thresholds.insert("doc_title".to_string(), 0.4);
class_thresholds.insert("figure_title".to_string(), 0.5);
class_thresholds.insert("footer".to_string(), 0.5);
class_thresholds.insert("footer_image".to_string(), 0.5);
class_thresholds.insert("footnote".to_string(), 0.5);
class_thresholds.insert("formula_number".to_string(), 0.5);
class_thresholds.insert("header".to_string(), 0.5);
class_thresholds.insert("header_image".to_string(), 0.5);
class_thresholds.insert("image".to_string(), 0.5);
class_thresholds.insert("inline_formula".to_string(), 0.4);
class_thresholds.insert("number".to_string(), 0.5);
class_thresholds.insert("paragraph_title".to_string(), 0.4);
class_thresholds.insert("reference".to_string(), 0.5);
class_thresholds.insert("reference_content".to_string(), 0.5);
class_thresholds.insert("seal".to_string(), 0.45);
class_thresholds.insert("table".to_string(), 0.5);
class_thresholds.insert("text".to_string(), 0.4);
class_thresholds.insert("vertical_text".to_string(), 0.4);
class_thresholds.insert("vision_footnote".to_string(), 0.5);
let mut merge_modes = HashMap::new();
merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
merge_modes.insert("content".to_string(), MergeBboxMode::Union);
merge_modes.insert("display_formula".to_string(), MergeBboxMode::Large);
merge_modes.insert("doc_title".to_string(), MergeBboxMode::Large);
merge_modes.insert("figure_title".to_string(), MergeBboxMode::Union);
merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
merge_modes.insert("footer_image".to_string(), MergeBboxMode::Union);
merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
merge_modes.insert("header".to_string(), MergeBboxMode::Union);
merge_modes.insert("header_image".to_string(), MergeBboxMode::Union);
merge_modes.insert("image".to_string(), MergeBboxMode::Union);
merge_modes.insert("inline_formula".to_string(), MergeBboxMode::Large);
merge_modes.insert("number".to_string(), MergeBboxMode::Union);
merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
merge_modes.insert("table".to_string(), MergeBboxMode::Union);
merge_modes.insert("text".to_string(), MergeBboxMode::Union);
merge_modes.insert("vertical_text".to_string(), MergeBboxMode::Union);
merge_modes.insert("vision_footnote".to_string(), MergeBboxMode::Union);
Self {
score_threshold: 0.4,
max_elements: 100,
class_thresholds: Some(class_thresholds),
class_merge_modes: Some(merge_modes),
layout_nms: true,
nms_threshold: 0.5,
layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
}
}
pub fn with_pp_doclayoutv3_defaults() -> Self {
let mut merge_modes = HashMap::new();
merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
merge_modes.insert("content".to_string(), MergeBboxMode::Union);
merge_modes.insert("display_formula".to_string(), MergeBboxMode::Large);
merge_modes.insert("doc_title".to_string(), MergeBboxMode::Large);
merge_modes.insert("figure_title".to_string(), MergeBboxMode::Union);
merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
merge_modes.insert("footer_image".to_string(), MergeBboxMode::Union);
merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
merge_modes.insert("header".to_string(), MergeBboxMode::Union);
merge_modes.insert("header_image".to_string(), MergeBboxMode::Union);
merge_modes.insert("image".to_string(), MergeBboxMode::Union);
merge_modes.insert("inline_formula".to_string(), MergeBboxMode::Large);
merge_modes.insert("number".to_string(), MergeBboxMode::Union);
merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
merge_modes.insert("table".to_string(), MergeBboxMode::Union);
merge_modes.insert("text".to_string(), MergeBboxMode::Union);
merge_modes.insert("vertical_text".to_string(), MergeBboxMode::Union);
merge_modes.insert("vision_footnote".to_string(), MergeBboxMode::Union);
Self {
score_threshold: 0.3,
max_elements: 100,
class_thresholds: None,
class_merge_modes: Some(merge_modes),
layout_nms: true,
nms_threshold: 0.5,
layout_unclip_ratio: Some(UnclipRatio::Separate(1.0, 1.0)),
}
}
pub fn with_pp_structurev3_defaults() -> Self {
let mut cfg = Self::with_pp_structurev3_thresholds();
let mut merge_modes = HashMap::new();
merge_modes.insert("paragraph_title".to_string(), MergeBboxMode::Large);
merge_modes.insert("image".to_string(), MergeBboxMode::Large);
merge_modes.insert("text".to_string(), MergeBboxMode::Union);
merge_modes.insert("number".to_string(), MergeBboxMode::Union);
merge_modes.insert("abstract".to_string(), MergeBboxMode::Union);
merge_modes.insert("content".to_string(), MergeBboxMode::Union);
merge_modes.insert("figure_table_chart_title".to_string(), MergeBboxMode::Union);
merge_modes.insert("formula".to_string(), MergeBboxMode::Large);
merge_modes.insert("table".to_string(), MergeBboxMode::Union);
merge_modes.insert("reference".to_string(), MergeBboxMode::Union);
merge_modes.insert("doc_title".to_string(), MergeBboxMode::Union);
merge_modes.insert("footnote".to_string(), MergeBboxMode::Union);
merge_modes.insert("header".to_string(), MergeBboxMode::Union);
merge_modes.insert("algorithm".to_string(), MergeBboxMode::Union);
merge_modes.insert("footer".to_string(), MergeBboxMode::Union);
merge_modes.insert("seal".to_string(), MergeBboxMode::Union);
merge_modes.insert("chart".to_string(), MergeBboxMode::Large);
merge_modes.insert("formula_number".to_string(), MergeBboxMode::Union);
merge_modes.insert("aside_text".to_string(), MergeBboxMode::Union);
merge_modes.insert("reference_content".to_string(), MergeBboxMode::Union);
cfg.class_merge_modes = Some(merge_modes);
cfg.layout_unclip_ratio = Some(UnclipRatio::Separate(1.0, 1.0));
cfg
}
pub fn get_class_threshold(&self, class_name: &str) -> f32 {
self.class_thresholds
.as_ref()
.and_then(|thresholds| thresholds.get(class_name).copied())
.unwrap_or(self.score_threshold)
}
pub fn get_class_merge_mode(&self, class_name: &str) -> MergeBboxMode {
self.class_merge_modes
.as_ref()
.and_then(|modes| modes.get(class_name).copied())
.unwrap_or(MergeBboxMode::Large)
}
}
#[derive(Debug, Clone)]
pub struct LayoutDetectionElement {
pub bbox: BoundingBox,
pub element_type: String,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct LayoutDetectionOutput {
pub elements: Vec<Vec<LayoutDetectionElement>>,
pub is_reading_order_sorted: bool,
}
impl LayoutDetectionOutput {
pub fn empty() -> Self {
Self {
elements: Vec::new(),
is_reading_order_sorted: false,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
elements: Vec::with_capacity(capacity),
is_reading_order_sorted: false,
}
}
pub fn with_reading_order_sorted(mut self, sorted: bool) -> Self {
self.is_reading_order_sorted = sorted;
self
}
}
impl TaskDefinition for LayoutDetectionOutput {
const TASK_NAME: &'static str = "layout_detection";
const TASK_DOC: &'static str = "Layout detection/analysis";
fn empty() -> Self {
LayoutDetectionOutput::empty()
}
}
#[derive(Debug, Default)]
pub struct LayoutDetectionTask {
config: LayoutDetectionConfig,
}
impl LayoutDetectionTask {
pub fn new(config: LayoutDetectionConfig) -> Self {
Self { config }
}
}
impl Task for LayoutDetectionTask {
type Config = LayoutDetectionConfig;
type Input = ImageTaskInput;
type Output = LayoutDetectionOutput;
fn task_type(&self) -> TaskType {
TaskType::LayoutDetection
}
fn schema(&self) -> TaskSchema {
TaskSchema::new(
TaskType::LayoutDetection,
vec!["image".to_string()],
vec!["layout_elements".to_string()],
)
}
fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
ensure_non_empty_images(&input.images, "No images provided for layout detection")?;
Ok(())
}
fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
let validator = ScoreValidator::new_unit_range("score");
for (idx, elements) in output.elements.iter().enumerate() {
validate_max_value(
elements.len(),
self.config.max_elements,
"element count",
&format!("Image {}", idx),
)?;
let scores: Vec<f32> = elements.iter().map(|e| e.score).collect();
validator.validate_scores_with(&scores, |elem_idx| {
format!("Image {}, element {}", idx, elem_idx)
})?;
}
Ok(())
}
fn empty_output(&self) -> Self::Output {
LayoutDetectionOutput::empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::processors::Point;
use image::RgbImage;
#[test]
fn test_layout_detection_task_creation() {
let task = LayoutDetectionTask::default();
assert_eq!(task.task_type(), TaskType::LayoutDetection);
}
#[test]
fn test_input_validation() {
let task = LayoutDetectionTask::default();
let empty_input = ImageTaskInput::new(vec![]);
assert!(task.validate_input(&empty_input).is_err());
let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
assert!(task.validate_input(&valid_input).is_ok());
}
#[test]
fn test_output_validation() {
let task = LayoutDetectionTask::default();
let box1 = BoundingBox::new(vec![
Point::new(0.0, 0.0),
Point::new(10.0, 0.0),
Point::new(10.0, 10.0),
Point::new(0.0, 10.0),
]);
let element = LayoutDetectionElement {
bbox: box1,
element_type: "text".to_string(),
score: 0.95,
};
let output = LayoutDetectionOutput {
elements: vec![vec![element]],
is_reading_order_sorted: false,
};
assert!(task.validate_output(&output).is_ok());
}
#[test]
fn test_schema() {
let task = LayoutDetectionTask::default();
let schema = task.schema();
assert_eq!(schema.task_type, TaskType::LayoutDetection);
assert!(schema.input_types.contains(&"image".to_string()));
assert!(schema.output_types.contains(&"layout_elements".to_string()));
}
}