use super::validation::ensure_non_empty_images;
use crate::ConfigValidator;
use crate::core::OCRError;
use crate::core::traits::TaskDefinition;
use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
use crate::utils::{ScoreValidator, validate_max_value};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
pub struct TableStructureRecognitionConfig {
#[validate(range(min = 0.0, max = 1.0))]
pub score_threshold: f32,
#[validate(min = 1)]
pub max_structure_length: usize,
}
impl Default for TableStructureRecognitionConfig {
fn default() -> Self {
Self {
score_threshold: 0.5,
max_structure_length: 500,
}
}
}
#[derive(Debug, Clone)]
pub struct TableStructureRecognitionOutput {
pub structures: Vec<Vec<String>>,
pub bboxes: Vec<Vec<Vec<f32>>>,
pub structure_scores: Vec<f32>,
}
impl TableStructureRecognitionOutput {
pub fn empty() -> Self {
Self {
structures: Vec::new(),
bboxes: Vec::new(),
structure_scores: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
structures: Vec::with_capacity(capacity),
bboxes: Vec::with_capacity(capacity),
structure_scores: Vec::with_capacity(capacity),
}
}
}
impl TaskDefinition for TableStructureRecognitionOutput {
const TASK_NAME: &'static str = "table_structure_recognition";
const TASK_DOC: &'static str =
"Table structure recognition - recognizing table structure as HTML with bboxes";
fn empty() -> Self {
TableStructureRecognitionOutput::empty()
}
}
#[derive(Debug, Default)]
pub struct TableStructureRecognitionTask {
config: TableStructureRecognitionConfig,
}
impl TableStructureRecognitionTask {
pub fn new(config: TableStructureRecognitionConfig) -> Self {
Self { config }
}
}
impl Task for TableStructureRecognitionTask {
type Config = TableStructureRecognitionConfig;
type Input = ImageTaskInput;
type Output = TableStructureRecognitionOutput;
fn task_type(&self) -> TaskType {
TaskType::TableStructureRecognition
}
fn schema(&self) -> TaskSchema {
TaskSchema::new(
TaskType::TableStructureRecognition,
vec!["image".to_string()],
vec![
"structure".to_string(),
"bbox".to_string(),
"structure_score".to_string(),
],
)
}
fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
ensure_non_empty_images(
&input.images,
"No images provided for table structure recognition",
)?;
Ok(())
}
fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
let num_images = output.structures.len();
if output.bboxes.len() != num_images || output.structure_scores.len() != num_images {
return Err(OCRError::InvalidInput {
message: format!(
"Output length mismatch: structures={}, bboxes={}, scores={}",
num_images,
output.bboxes.len(),
output.structure_scores.len()
),
});
}
let validator = ScoreValidator::new_unit_range("score");
for (img_idx, (structure, bboxes, score)) in output
.structures
.iter()
.zip(output.bboxes.iter())
.zip(output.structure_scores.iter())
.map(|((s, b), sc)| (s, b, sc))
.enumerate()
{
validate_max_value(
structure.len(),
self.config.max_structure_length,
"Structure length",
&format!("Image {}", img_idx),
)?;
validator.validate_score(*score, &format!("Image {}", img_idx))?;
for (bbox_idx, bbox) in bboxes.iter().enumerate() {
if bbox.len() != 8 {
return Err(OCRError::InvalidInput {
message: format!(
"Image {}, Bbox {}: expected 8 coordinates, got {}",
img_idx,
bbox_idx,
bbox.len()
),
});
}
}
}
Ok(())
}
fn empty_output(&self) -> Self::Output {
TableStructureRecognitionOutput::empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use image::RgbImage;
#[test]
fn test_table_structure_recognition_task_creation() {
let task = TableStructureRecognitionTask::default();
assert_eq!(task.task_type(), TaskType::TableStructureRecognition);
}
#[test]
fn test_input_validation() {
let task = TableStructureRecognitionTask::default();
let empty_input = ImageTaskInput::new(vec![]);
assert!(task.validate_input(&empty_input).is_err());
let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
assert!(task.validate_input(&valid_input).is_ok());
}
#[test]
fn test_output_validation() {
let task = TableStructureRecognitionTask::default();
let output = TableStructureRecognitionOutput {
structures: vec![vec!["<html>".to_string(), "<table>".to_string()]],
bboxes: vec![vec![vec![10.0, 10.0, 50.0, 10.0, 50.0, 30.0, 10.0, 30.0]]],
structure_scores: vec![0.95],
};
assert!(task.validate_output(&output).is_ok());
let bad_bbox_output = TableStructureRecognitionOutput {
structures: vec![vec!["<html>".to_string()]],
bboxes: vec![vec![vec![10.0, 10.0, 50.0]]], structure_scores: vec![0.95],
};
assert!(task.validate_output(&bad_bbox_output).is_err());
let mismatched_output = TableStructureRecognitionOutput {
structures: vec![vec!["<html>".to_string()]],
bboxes: vec![vec![vec![10.0, 10.0, 50.0, 10.0, 50.0, 30.0, 10.0, 30.0]]],
structure_scores: vec![0.95, 0.90], };
assert!(task.validate_output(&mismatched_output).is_err());
}
#[test]
fn test_schema() {
let task = TableStructureRecognitionTask::default();
let schema = task.schema();
assert_eq!(schema.task_type, TaskType::TableStructureRecognition);
assert!(schema.input_types.contains(&"image".to_string()));
assert!(schema.output_types.contains(&"structure".to_string()));
assert!(schema.output_types.contains(&"bbox".to_string()));
}
}