oar_ocr_core/domain/tasks/
table_classification.rs1use super::document_orientation::Classification;
7use super::validation::ensure_non_empty_images;
8use crate::ConfigValidator;
9use crate::core::OCRError;
10use crate::core::traits::TaskDefinition;
11use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
12use crate::utils::ScoreValidator;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
17pub struct TableClassificationConfig {
18 #[validate(range(min = 0.0, max = 1.0))]
20 pub score_threshold: f32,
21 #[validate(min = 1)]
23 pub topk: usize,
24}
25
26impl Default for TableClassificationConfig {
27 fn default() -> Self {
28 Self {
29 score_threshold: 0.5,
30 topk: 2,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct TableClassificationOutput {
38 pub classifications: Vec<Vec<Classification>>,
40}
41
42impl TableClassificationOutput {
43 pub fn empty() -> Self {
45 Self {
46 classifications: Vec::new(),
47 }
48 }
49
50 pub fn with_capacity(capacity: usize) -> Self {
52 Self {
53 classifications: Vec::with_capacity(capacity),
54 }
55 }
56}
57
58impl TaskDefinition for TableClassificationOutput {
59 const TASK_NAME: &'static str = "table_classification";
60 const TASK_DOC: &'static str =
61 "Table classification - classifying table images as wired or wireless";
62
63 fn empty() -> Self {
64 TableClassificationOutput::empty()
65 }
66}
67
68#[derive(Debug, Default)]
70pub struct TableClassificationTask {
71 _config: TableClassificationConfig,
72}
73
74impl TableClassificationTask {
75 pub fn new(config: TableClassificationConfig) -> Self {
77 Self { _config: config }
78 }
79}
80
81impl Task for TableClassificationTask {
82 type Config = TableClassificationConfig;
83 type Input = ImageTaskInput;
84 type Output = TableClassificationOutput;
85
86 fn task_type(&self) -> TaskType {
87 TaskType::TableClassification
88 }
89
90 fn schema(&self) -> TaskSchema {
91 TaskSchema::new(
92 TaskType::TableClassification,
93 vec!["image".to_string()],
94 vec!["table_type_labels".to_string(), "scores".to_string()],
95 )
96 }
97
98 fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
99 ensure_non_empty_images(&input.images, "No images provided for table classification")?;
100
101 Ok(())
102 }
103
104 fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
105 let validator = ScoreValidator::new_unit_range("score");
106
107 for (idx, classifications) in output.classifications.iter().enumerate() {
108 for classification in classifications.iter() {
109 if classification.class_id > 1 {
111 return Err(OCRError::InvalidInput {
112 message: format!(
113 "Image {}: invalid class_id {}. Expected 0-1 (wired_table, wireless_table)",
114 idx, classification.class_id
115 ),
116 });
117 }
118 }
119
120 let scores: Vec<f32> = classifications.iter().map(|c| c.score).collect();
122 validator.validate_scores_with(&scores, |class_idx| {
123 format!("Image {}, classification {}", idx, class_idx)
124 })?;
125 }
126
127 Ok(())
128 }
129
130 fn empty_output(&self) -> Self::Output {
131 TableClassificationOutput::empty()
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use image::RgbImage;
139
140 #[test]
141 fn test_table_classification_task_creation() {
142 let task = TableClassificationTask::default();
143 assert_eq!(task.task_type(), TaskType::TableClassification);
144 }
145
146 #[test]
147 fn test_input_validation() {
148 let task = TableClassificationTask::default();
149
150 let empty_input = ImageTaskInput::new(vec![]);
152 assert!(task.validate_input(&empty_input).is_err());
153
154 let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
156 assert!(task.validate_input(&valid_input).is_ok());
157 }
158
159 #[test]
160 fn test_output_validation() {
161 let task = TableClassificationTask::default();
162
163 let classification1 = Classification::new(0, "wired_table".to_string(), 0.85);
165 let classification2 = Classification::new(1, "wireless_table".to_string(), 0.15);
166 let output = TableClassificationOutput {
167 classifications: vec![vec![classification1, classification2]],
168 };
169 assert!(task.validate_output(&output).is_ok());
170
171 let bad_classification = Classification::new(2, "invalid".to_string(), 0.95);
173 let bad_output = TableClassificationOutput {
174 classifications: vec![vec![bad_classification]],
175 };
176 assert!(task.validate_output(&bad_output).is_err());
177
178 let bad_score_classification = Classification::new(0, "wired_table".to_string(), 1.5);
180 let bad_score_output = TableClassificationOutput {
181 classifications: vec![vec![bad_score_classification]],
182 };
183 assert!(task.validate_output(&bad_score_output).is_err());
184 }
185
186 #[test]
187 fn test_schema() {
188 let task = TableClassificationTask::default();
189 let schema = task.schema();
190 assert_eq!(schema.task_type, TaskType::TableClassification);
191 assert!(schema.input_types.contains(&"image".to_string()));
192 assert!(
193 schema
194 .output_types
195 .contains(&"table_type_labels".to_string())
196 );
197 }
198}