oar_ocr_core/domain/tasks/
text_recognition.rs1use super::validation::ensure_non_empty_images;
6use crate::ConfigValidator;
7use crate::core::OCRError;
8use crate::core::traits::TaskDefinition;
9use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
10use crate::utils::{ScoreValidator, validate_length_match, validate_max_value};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
17pub struct TextRecognitionConfig {
18 #[validate(range(min = 0.0, max = 1.0))]
20 pub score_threshold: f32,
21 #[validate(min = 1)]
23 pub max_text_length: usize,
24}
25
26impl Default for TextRecognitionConfig {
27 fn default() -> Self {
28 Self {
29 score_threshold: 0.0,
30 max_text_length: 25,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct TextRecognitionOutput {
38 pub texts: Vec<String>,
40 pub scores: Vec<f32>,
42 pub char_positions: Vec<Vec<f32>>,
46 pub char_col_indices: Vec<Vec<usize>>,
49 pub sequence_lengths: Vec<usize>,
51}
52
53impl TextRecognitionOutput {
54 pub fn empty() -> Self {
56 Self {
57 texts: Vec::new(),
58 scores: Vec::new(),
59 char_positions: Vec::new(),
60 char_col_indices: Vec::new(),
61 sequence_lengths: Vec::new(),
62 }
63 }
64
65 pub fn with_capacity(capacity: usize) -> Self {
67 Self {
68 texts: Vec::with_capacity(capacity),
69 scores: Vec::with_capacity(capacity),
70 char_positions: Vec::with_capacity(capacity),
71 char_col_indices: Vec::with_capacity(capacity),
72 sequence_lengths: Vec::with_capacity(capacity),
73 }
74 }
75}
76
77impl Default for TextRecognitionOutput {
78 fn default() -> Self {
79 Self::empty()
80 }
81}
82
83impl TaskDefinition for TextRecognitionOutput {
84 const TASK_NAME: &'static str = "text_recognition";
85 const TASK_DOC: &'static str = "Text recognition - converting text regions to strings";
86
87 fn empty() -> Self {
88 TextRecognitionOutput::empty()
89 }
90}
91
92#[derive(Debug, Default)]
94pub struct TextRecognitionTask {
95 config: TextRecognitionConfig,
96}
97
98impl TextRecognitionTask {
99 pub fn new(config: TextRecognitionConfig) -> Self {
101 Self { config }
102 }
103}
104
105impl Task for TextRecognitionTask {
106 type Config = TextRecognitionConfig;
107 type Input = ImageTaskInput;
108 type Output = TextRecognitionOutput;
109
110 fn task_type(&self) -> TaskType {
111 TaskType::TextRecognition
112 }
113
114 fn schema(&self) -> TaskSchema {
115 TaskSchema::new(
116 TaskType::TextRecognition,
117 vec!["text_boxes".to_string()],
118 vec!["text_strings".to_string(), "scores".to_string()],
119 )
120 }
121
122 fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
123 ensure_non_empty_images(&input.images, "No images provided for text recognition")?;
124
125 Ok(())
126 }
127
128 fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
129 validate_length_match(output.texts.len(), output.scores.len(), "texts", "scores")?;
131
132 let validator = ScoreValidator::new_unit_range("score");
134 validator.validate_scores_with(&output.scores, |idx| format!("Text {}", idx))?;
135
136 for (idx, text) in output.texts.iter().enumerate() {
138 validate_max_value(
139 text.len(),
140 self.config.max_text_length,
141 "length",
142 &format!("Text {}", idx),
143 )?;
144 }
145
146 Ok(())
147 }
148
149 fn empty_output(&self) -> Self::Output {
150 TextRecognitionOutput::empty()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use image::RgbImage;
158
159 #[test]
160 fn test_text_recognition_task_creation() {
161 let task = TextRecognitionTask::default();
162 assert_eq!(task.task_type(), TaskType::TextRecognition);
163 }
164
165 #[test]
166 fn test_input_validation() {
167 let task = TextRecognitionTask::default();
168
169 let empty_input = ImageTaskInput::new(vec![]);
171 assert!(task.validate_input(&empty_input).is_err());
172
173 let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 32)]);
175 assert!(task.validate_input(&valid_input).is_ok());
176 }
177
178 #[test]
179 fn test_output_validation() {
180 let task = TextRecognitionTask::default();
181
182 let output = TextRecognitionOutput {
184 texts: vec!["Hello".to_string()],
185 scores: vec![0.95],
186 ..Default::default()
187 };
188 assert!(task.validate_output(&output).is_ok());
189
190 let bad_output = TextRecognitionOutput {
192 texts: vec![],
193 scores: vec![0.95],
194 ..Default::default()
195 };
196 assert!(task.validate_output(&bad_output).is_err());
197
198 let bad_score = TextRecognitionOutput {
200 texts: vec!["Hello".to_string()],
201 scores: vec![1.5],
202 ..Default::default()
203 };
204 assert!(task.validate_output(&bad_score).is_err());
205 }
206
207 #[test]
208 fn test_schema() {
209 let task = TextRecognitionTask::default();
210 let schema = task.schema();
211 assert_eq!(schema.task_type, TaskType::TextRecognition);
212 assert!(schema.input_types.contains(&"text_boxes".to_string()));
213 assert!(schema.output_types.contains(&"text_strings".to_string()));
214 }
215}