oar_ocr_core/domain/tasks/
text_detection.rs1use super::validation::ensure_non_empty_images;
6use crate::ConfigValidator;
7use crate::core::OCRError;
8use crate::core::traits::TaskDefinition;
9use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
10use crate::processors::{BoundingBox, LimitType};
11use crate::utils::ScoreValidator;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone)]
16pub struct Detection {
17 pub bbox: BoundingBox,
19 pub score: f32,
21}
22
23impl Detection {
24 pub fn new(bbox: BoundingBox, score: f32) -> Self {
26 Self { bbox, score }
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
34pub struct TextDetectionConfig {
35 #[validate(range(0.0, 1.0))]
37 pub score_threshold: f32,
38 #[validate(range(0.0, 1.0))]
40 pub box_threshold: f32,
41 #[validate(min(0.0))]
43 pub unclip_ratio: f32,
44 #[validate(min(1))]
46 pub max_candidates: usize,
47 pub limit_side_len: Option<u32>,
49 pub limit_type: Option<LimitType>,
51 pub max_side_len: Option<u32>,
53}
54
55impl Default for TextDetectionConfig {
56 fn default() -> Self {
57 Self {
58 score_threshold: 0.3,
59 box_threshold: 0.6,
60 unclip_ratio: 1.5,
61 max_candidates: 1000,
62 limit_side_len: None,
63 limit_type: None,
64 max_side_len: None,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct TextDetectionOutput {
72 pub detections: Vec<Vec<Detection>>,
74}
75
76impl TextDetectionOutput {
77 pub fn empty() -> Self {
79 Self {
80 detections: Vec::new(),
81 }
82 }
83
84 pub fn with_capacity(capacity: usize) -> Self {
86 Self {
87 detections: Vec::with_capacity(capacity),
88 }
89 }
90}
91
92impl TaskDefinition for TextDetectionOutput {
93 const TASK_NAME: &'static str = "text_detection";
94 const TASK_DOC: &'static str = "Text detection - locating text regions in images";
95
96 fn empty() -> Self {
97 TextDetectionOutput::empty()
98 }
99}
100
101#[derive(Debug, Default)]
103pub struct TextDetectionTask {
104 _config: TextDetectionConfig,
105}
106
107impl TextDetectionTask {
108 pub fn new(config: TextDetectionConfig) -> Self {
110 Self { _config: config }
111 }
112}
113
114impl Task for TextDetectionTask {
115 type Config = TextDetectionConfig;
116 type Input = ImageTaskInput;
117 type Output = TextDetectionOutput;
118
119 fn task_type(&self) -> TaskType {
120 TaskType::TextDetection
121 }
122
123 fn schema(&self) -> TaskSchema {
124 TaskSchema::new(
125 TaskType::TextDetection,
126 vec!["image".to_string()],
127 vec!["text_boxes".to_string(), "scores".to_string()],
128 )
129 }
130
131 fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
132 ensure_non_empty_images(&input.images, "No images provided for text detection")?;
133
134 Ok(())
135 }
136
137 fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
138 let validator = ScoreValidator::new_unit_range("score");
139
140 for (idx, detections) in output.detections.iter().enumerate() {
142 let scores: Vec<f32> = detections.iter().map(|d| d.score).collect();
143 validator.validate_scores_with(&scores, |det_idx| {
144 format!("Image {}, detection {}", idx, det_idx)
145 })?;
146 }
147
148 Ok(())
149 }
150
151 fn empty_output(&self) -> Self::Output {
152 TextDetectionOutput::empty()
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::processors::Point;
160 use image::RgbImage;
161
162 #[test]
163 fn test_text_detection_task_creation() {
164 let task = TextDetectionTask::default();
165 assert_eq!(task.task_type(), TaskType::TextDetection);
166 }
167
168 #[test]
169 fn test_input_validation() {
170 let task = TextDetectionTask::default();
171
172 let empty_input = ImageTaskInput::new(vec![]);
174 assert!(task.validate_input(&empty_input).is_err());
175
176 let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
178 assert!(task.validate_input(&valid_input).is_ok());
179 }
180
181 #[test]
182 fn test_output_validation() {
183 let task = TextDetectionTask::default();
184
185 let box1 = BoundingBox::new(vec![
187 Point::new(0.0, 0.0),
188 Point::new(10.0, 0.0),
189 Point::new(10.0, 10.0),
190 Point::new(0.0, 10.0),
191 ]);
192 let detection1 = Detection::new(box1, 0.95);
193 let output = TextDetectionOutput {
194 detections: vec![vec![detection1]],
195 };
196 assert!(task.validate_output(&output).is_ok());
197
198 let box2 = BoundingBox::new(vec![
200 Point::new(0.0, 0.0),
201 Point::new(10.0, 0.0),
202 Point::new(10.0, 10.0),
203 Point::new(0.0, 10.0),
204 ]);
205 let detection2 = Detection::new(box2, 1.5); let bad_output = TextDetectionOutput {
207 detections: vec![vec![detection2]],
208 };
209 assert!(task.validate_output(&bad_output).is_err());
210 }
211
212 #[test]
213 fn test_schema() {
214 let task = TextDetectionTask::default();
215 let schema = task.schema();
216 assert_eq!(schema.task_type, TaskType::TextDetection);
217 assert!(schema.input_types.contains(&"image".to_string()));
218 assert!(schema.output_types.contains(&"text_boxes".to_string()));
219 }
220}