Skip to main content

oar_ocr_core/domain/tasks/
table_classification.rs

1//! Concrete task implementations for table classification.
2//!
3//! This module provides the table classification task that classifies table images
4//! as either "wired_table" (tables with borders) or "wireless_table" (tables without borders).
5
6use super::document_orientation::Classification;
7use super::validation::ensure_non_empty_images;
8use crate::ConfigValidator;
9use crate::core::OCRError;
10use crate::core::traits::TaskDefinition;
11use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
12use crate::utils::ScoreValidator;
13use serde::{Deserialize, Serialize};
14
15/// Configuration for table classification task.
16#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
17pub struct TableClassificationConfig {
18    /// Score threshold for classification (default: 0.5)
19    #[validate(range(min = 0.0, max = 1.0))]
20    pub score_threshold: f32,
21    /// Number of top predictions to return (default: 2)
22    #[validate(min = 1)]
23    pub topk: usize,
24}
25
26impl Default for TableClassificationConfig {
27    fn default() -> Self {
28        Self {
29            score_threshold: 0.5,
30            topk: 2,
31        }
32    }
33}
34
35/// Output from table classification task.
36#[derive(Debug, Clone)]
37pub struct TableClassificationOutput {
38    /// Classification results per image
39    pub classifications: Vec<Vec<Classification>>,
40}
41
42impl TableClassificationOutput {
43    /// Creates an empty table classification output.
44    pub fn empty() -> Self {
45        Self {
46            classifications: Vec::new(),
47        }
48    }
49
50    /// Creates a table classification output with the given capacity.
51    pub fn with_capacity(capacity: usize) -> Self {
52        Self {
53            classifications: Vec::with_capacity(capacity),
54        }
55    }
56}
57
58impl TaskDefinition for TableClassificationOutput {
59    const TASK_NAME: &'static str = "table_classification";
60    const TASK_DOC: &'static str =
61        "Table classification - classifying table images as wired or wireless";
62
63    fn empty() -> Self {
64        TableClassificationOutput::empty()
65    }
66}
67
68/// Table classification task implementation.
69#[derive(Debug, Default)]
70pub struct TableClassificationTask {
71    _config: TableClassificationConfig,
72}
73
74impl TableClassificationTask {
75    /// Creates a new table classification task.
76    pub fn new(config: TableClassificationConfig) -> Self {
77        Self { _config: config }
78    }
79}
80
81impl Task for TableClassificationTask {
82    type Config = TableClassificationConfig;
83    type Input = ImageTaskInput;
84    type Output = TableClassificationOutput;
85
86    fn task_type(&self) -> TaskType {
87        TaskType::TableClassification
88    }
89
90    fn schema(&self) -> TaskSchema {
91        TaskSchema::new(
92            TaskType::TableClassification,
93            vec!["image".to_string()],
94            vec!["table_type_labels".to_string(), "scores".to_string()],
95        )
96    }
97
98    fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
99        ensure_non_empty_images(&input.images, "No images provided for table classification")?;
100
101        Ok(())
102    }
103
104    fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
105        let validator = ScoreValidator::new_unit_range("score");
106
107        for (idx, classifications) in output.classifications.iter().enumerate() {
108            for classification in classifications.iter() {
109                // Validate class IDs (should be 0-1 for 2 table types)
110                if classification.class_id > 1 {
111                    return Err(OCRError::InvalidInput {
112                        message: format!(
113                            "Image {}: invalid class_id {}. Expected 0-1 (wired_table, wireless_table)",
114                            idx, classification.class_id
115                        ),
116                    });
117                }
118            }
119
120            // Validate score ranges
121            let scores: Vec<f32> = classifications.iter().map(|c| c.score).collect();
122            validator.validate_scores_with(&scores, |class_idx| {
123                format!("Image {}, classification {}", idx, class_idx)
124            })?;
125        }
126
127        Ok(())
128    }
129
130    fn empty_output(&self) -> Self::Output {
131        TableClassificationOutput::empty()
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use image::RgbImage;
139
140    #[test]
141    fn test_table_classification_task_creation() {
142        let task = TableClassificationTask::default();
143        assert_eq!(task.task_type(), TaskType::TableClassification);
144    }
145
146    #[test]
147    fn test_input_validation() {
148        let task = TableClassificationTask::default();
149
150        // Empty images should fail
151        let empty_input = ImageTaskInput::new(vec![]);
152        assert!(task.validate_input(&empty_input).is_err());
153
154        // Valid images should pass
155        let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
156        assert!(task.validate_input(&valid_input).is_ok());
157    }
158
159    #[test]
160    fn test_output_validation() {
161        let task = TableClassificationTask::default();
162
163        // Valid output should pass
164        let classification1 = Classification::new(0, "wired_table".to_string(), 0.85);
165        let classification2 = Classification::new(1, "wireless_table".to_string(), 0.15);
166        let output = TableClassificationOutput {
167            classifications: vec![vec![classification1, classification2]],
168        };
169        assert!(task.validate_output(&output).is_ok());
170
171        // Invalid class ID should fail (should be 0-1)
172        let bad_classification = Classification::new(2, "invalid".to_string(), 0.95);
173        let bad_output = TableClassificationOutput {
174            classifications: vec![vec![bad_classification]],
175        };
176        assert!(task.validate_output(&bad_output).is_err());
177
178        // Invalid score should fail
179        let bad_score_classification = Classification::new(0, "wired_table".to_string(), 1.5);
180        let bad_score_output = TableClassificationOutput {
181            classifications: vec![vec![bad_score_classification]],
182        };
183        assert!(task.validate_output(&bad_score_output).is_err());
184    }
185
186    #[test]
187    fn test_schema() {
188        let task = TableClassificationTask::default();
189        let schema = task.schema();
190        assert_eq!(schema.task_type, TaskType::TableClassification);
191        assert!(schema.input_types.contains(&"image".to_string()));
192        assert!(
193            schema
194                .output_types
195                .contains(&"table_type_labels".to_string())
196        );
197    }
198}