Skip to main content

oar_ocr_core/domain/tasks/
text_detection.rs

1//! Concrete task implementations for text detection.
2//!
3//! This module provides the text detection task that locates text regions in images.
4
5use super::validation::ensure_non_empty_images;
6use crate::ConfigValidator;
7use crate::core::OCRError;
8use crate::core::traits::TaskDefinition;
9use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
10use crate::processors::{BoundingBox, LimitType};
11use crate::utils::ScoreValidator;
12use serde::{Deserialize, Serialize};
13
14/// A single text detection result with bounding box and confidence score.
15#[derive(Debug, Clone)]
16pub struct Detection {
17    /// The bounding box polygon coordinates
18    pub bbox: BoundingBox,
19    /// Confidence score for this detection (0.0 to 1.0)
20    pub score: f32,
21}
22
23impl Detection {
24    /// Creates a new detection.
25    pub fn new(bbox: BoundingBox, score: f32) -> Self {
26        Self { bbox, score }
27    }
28}
29
30/// Configuration for text detection task.
31///
32/// Default values are aligned with PP-StructureV3.
33#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
34pub struct TextDetectionConfig {
35    /// Score threshold for detection (default: 0.3)
36    #[validate(range(min = 0.0, max = 1.0))]
37    pub score_threshold: f32,
38    /// Box threshold for filtering (default: 0.6)
39    #[validate(range(min = 0.0, max = 1.0))]
40    pub box_threshold: f32,
41    /// Unclip ratio for expanding detected regions (default: 1.5)
42    #[validate(min = 0.0)]
43    pub unclip_ratio: f32,
44    /// Maximum candidates to consider (default: 1000)
45    #[validate(min = 1)]
46    pub max_candidates: usize,
47    /// Target side length for image resizing (optional)
48    pub limit_side_len: Option<u32>,
49    /// Limit type for resizing (optional)
50    pub limit_type: Option<LimitType>,
51    /// Maximum side length to prevent OOM (optional)
52    pub max_side_len: Option<u32>,
53}
54
55impl Default for TextDetectionConfig {
56    fn default() -> Self {
57        Self {
58            score_threshold: 0.3,
59            box_threshold: 0.6,
60            unclip_ratio: 1.5,
61            max_candidates: 1000,
62            limit_side_len: None,
63            limit_type: None,
64            max_side_len: None,
65        }
66    }
67}
68
69/// Output from text detection task.
70#[derive(Debug, Clone)]
71pub struct TextDetectionOutput {
72    /// Detected text regions per image
73    pub detections: Vec<Vec<Detection>>,
74}
75
76impl TextDetectionOutput {
77    /// Creates an empty text detection output.
78    pub fn empty() -> Self {
79        Self {
80            detections: Vec::new(),
81        }
82    }
83
84    /// Creates a text detection output with the given capacity.
85    pub fn with_capacity(capacity: usize) -> Self {
86        Self {
87            detections: Vec::with_capacity(capacity),
88        }
89    }
90}
91
92impl TaskDefinition for TextDetectionOutput {
93    const TASK_NAME: &'static str = "text_detection";
94    const TASK_DOC: &'static str = "Text detection - locating text regions in images";
95
96    fn empty() -> Self {
97        TextDetectionOutput::empty()
98    }
99}
100
101/// Text detection task implementation.
102#[derive(Debug, Default)]
103pub struct TextDetectionTask {
104    _config: TextDetectionConfig,
105}
106
107impl TextDetectionTask {
108    /// Creates a new text detection task.
109    pub fn new(config: TextDetectionConfig) -> Self {
110        Self { _config: config }
111    }
112}
113
114impl Task for TextDetectionTask {
115    type Config = TextDetectionConfig;
116    type Input = ImageTaskInput;
117    type Output = TextDetectionOutput;
118
119    fn task_type(&self) -> TaskType {
120        TaskType::TextDetection
121    }
122
123    fn schema(&self) -> TaskSchema {
124        TaskSchema::new(
125            TaskType::TextDetection,
126            vec!["image".to_string()],
127            vec!["text_boxes".to_string(), "scores".to_string()],
128        )
129    }
130
131    fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
132        ensure_non_empty_images(&input.images, "No images provided for text detection")?;
133
134        Ok(())
135    }
136
137    fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
138        let validator = ScoreValidator::new_unit_range("score");
139
140        // Validate each image's detections
141        for (idx, detections) in output.detections.iter().enumerate() {
142            let scores: Vec<f32> = detections.iter().map(|d| d.score).collect();
143            validator.validate_scores_with(&scores, |det_idx| {
144                format!("Image {}, detection {}", idx, det_idx)
145            })?;
146        }
147
148        Ok(())
149    }
150
151    fn empty_output(&self) -> Self::Output {
152        TextDetectionOutput::empty()
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::processors::Point;
160    use image::RgbImage;
161
162    #[test]
163    fn test_text_detection_task_creation() {
164        let task = TextDetectionTask::default();
165        assert_eq!(task.task_type(), TaskType::TextDetection);
166    }
167
168    #[test]
169    fn test_input_validation() {
170        let task = TextDetectionTask::default();
171
172        // Empty images should fail
173        let empty_input = ImageTaskInput::new(vec![]);
174        assert!(task.validate_input(&empty_input).is_err());
175
176        // Valid images should pass
177        let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
178        assert!(task.validate_input(&valid_input).is_ok());
179    }
180
181    #[test]
182    fn test_output_validation() {
183        let task = TextDetectionTask::default();
184
185        // Valid detection should pass
186        let box1 = BoundingBox::new(vec![
187            Point::new(0.0, 0.0),
188            Point::new(10.0, 0.0),
189            Point::new(10.0, 10.0),
190            Point::new(0.0, 10.0),
191        ]);
192        let detection1 = Detection::new(box1, 0.95);
193        let output = TextDetectionOutput {
194            detections: vec![vec![detection1]],
195        };
196        assert!(task.validate_output(&output).is_ok());
197
198        // Invalid score should fail
199        let box2 = BoundingBox::new(vec![
200            Point::new(0.0, 0.0),
201            Point::new(10.0, 0.0),
202            Point::new(10.0, 10.0),
203            Point::new(0.0, 10.0),
204        ]);
205        let detection2 = Detection::new(box2, 1.5); // Invalid score > 1.0
206        let bad_output = TextDetectionOutput {
207            detections: vec![vec![detection2]],
208        };
209        assert!(task.validate_output(&bad_output).is_err());
210    }
211
212    #[test]
213    fn test_schema() {
214        let task = TextDetectionTask::default();
215        let schema = task.schema();
216        assert_eq!(schema.task_type, TaskType::TextDetection);
217        assert!(schema.input_types.contains(&"image".to_string()));
218        assert!(schema.output_types.contains(&"text_boxes".to_string()));
219    }
220}