Skip to main content

oar_ocr_core/domain/tasks/
text_recognition.rs

1//! Concrete task implementations for text recognition.
2//!
3//! This module provides the text recognition task that converts text regions to strings.
4
5use super::validation::ensure_non_empty_images;
6use crate::ConfigValidator;
7use crate::core::OCRError;
8use crate::core::traits::TaskDefinition;
9use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
10use crate::utils::{ScoreValidator, validate_length_match, validate_max_value};
11use serde::{Deserialize, Serialize};
12
13/// Configuration for text recognition task.
14///
15/// Default values are aligned with PP-StructureV3.
16#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
17pub struct TextRecognitionConfig {
18    /// Score threshold for recognition (default: 0.0, no filtering)
19    #[validate(range(min = 0.0, max = 1.0))]
20    pub score_threshold: f32,
21    /// Maximum text length (default: 25)
22    #[validate(min = 1)]
23    pub max_text_length: usize,
24}
25
26impl Default for TextRecognitionConfig {
27    fn default() -> Self {
28        Self {
29            score_threshold: 0.0,
30            max_text_length: 25,
31        }
32    }
33}
34
35/// Output from text recognition task.
36#[derive(Debug, Clone)]
37pub struct TextRecognitionOutput {
38    /// Recognized text strings
39    pub texts: Vec<String>,
40    /// Confidence scores for each text
41    pub scores: Vec<f32>,
42    /// Character/word positions within each text line (optional)
43    /// Each inner vector contains normalized x-positions (0.0-1.0) for characters
44    /// Only populated when word box detection is enabled
45    pub char_positions: Vec<Vec<f32>>,
46    /// Column indices for each character in the CTC output
47    /// Used for accurate word box generation with compatible approach
48    pub char_col_indices: Vec<Vec<usize>>,
49    /// Total number of columns (sequence length) in the CTC output for each text line
50    pub sequence_lengths: Vec<usize>,
51}
52
53impl TextRecognitionOutput {
54    /// Creates an empty text recognition output.
55    pub fn empty() -> Self {
56        Self {
57            texts: Vec::new(),
58            scores: Vec::new(),
59            char_positions: Vec::new(),
60            char_col_indices: Vec::new(),
61            sequence_lengths: Vec::new(),
62        }
63    }
64
65    /// Creates a text recognition output with the given capacity.
66    pub fn with_capacity(capacity: usize) -> Self {
67        Self {
68            texts: Vec::with_capacity(capacity),
69            scores: Vec::with_capacity(capacity),
70            char_positions: Vec::with_capacity(capacity),
71            char_col_indices: Vec::with_capacity(capacity),
72            sequence_lengths: Vec::with_capacity(capacity),
73        }
74    }
75}
76
77impl Default for TextRecognitionOutput {
78    fn default() -> Self {
79        Self::empty()
80    }
81}
82
83impl TaskDefinition for TextRecognitionOutput {
84    const TASK_NAME: &'static str = "text_recognition";
85    const TASK_DOC: &'static str = "Text recognition - converting text regions to strings";
86
87    fn empty() -> Self {
88        TextRecognitionOutput::empty()
89    }
90}
91
92/// Text recognition task implementation.
93#[derive(Debug, Default)]
94pub struct TextRecognitionTask {
95    config: TextRecognitionConfig,
96}
97
98impl TextRecognitionTask {
99    /// Creates a new text recognition task.
100    pub fn new(config: TextRecognitionConfig) -> Self {
101        Self { config }
102    }
103}
104
105impl Task for TextRecognitionTask {
106    type Config = TextRecognitionConfig;
107    type Input = ImageTaskInput;
108    type Output = TextRecognitionOutput;
109
110    fn task_type(&self) -> TaskType {
111        TaskType::TextRecognition
112    }
113
114    fn schema(&self) -> TaskSchema {
115        TaskSchema::new(
116            TaskType::TextRecognition,
117            vec!["text_boxes".to_string()],
118            vec!["text_strings".to_string(), "scores".to_string()],
119        )
120    }
121
122    fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
123        ensure_non_empty_images(&input.images, "No images provided for text recognition")?;
124
125        Ok(())
126    }
127
128    fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
129        // Validate that texts and scores have matching lengths
130        validate_length_match(output.texts.len(), output.scores.len(), "texts", "scores")?;
131
132        // Validate score ranges
133        let validator = ScoreValidator::new_unit_range("score");
134        validator.validate_scores_with(&output.scores, |idx| format!("Text {}", idx))?;
135
136        // Validate text lengths
137        for (idx, text) in output.texts.iter().enumerate() {
138            validate_max_value(
139                text.len(),
140                self.config.max_text_length,
141                "length",
142                &format!("Text {}", idx),
143            )?;
144        }
145
146        Ok(())
147    }
148
149    fn empty_output(&self) -> Self::Output {
150        TextRecognitionOutput::empty()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use image::RgbImage;
158
159    #[test]
160    fn test_text_recognition_task_creation() {
161        let task = TextRecognitionTask::default();
162        assert_eq!(task.task_type(), TaskType::TextRecognition);
163    }
164
165    #[test]
166    fn test_input_validation() {
167        let task = TextRecognitionTask::default();
168
169        // Empty images should fail
170        let empty_input = ImageTaskInput::new(vec![]);
171        assert!(task.validate_input(&empty_input).is_err());
172
173        // Valid images should pass
174        let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 32)]);
175        assert!(task.validate_input(&valid_input).is_ok());
176    }
177
178    #[test]
179    fn test_output_validation() {
180        let task = TextRecognitionTask::default();
181
182        // Matching texts and scores should pass
183        let output = TextRecognitionOutput {
184            texts: vec!["Hello".to_string()],
185            scores: vec![0.95],
186            ..Default::default()
187        };
188        assert!(task.validate_output(&output).is_ok());
189
190        // Mismatched lengths should fail
191        let bad_output = TextRecognitionOutput {
192            texts: vec![],
193            scores: vec![0.95],
194            ..Default::default()
195        };
196        assert!(task.validate_output(&bad_output).is_err());
197
198        // Invalid score should fail
199        let bad_score = TextRecognitionOutput {
200            texts: vec!["Hello".to_string()],
201            scores: vec![1.5],
202            ..Default::default()
203        };
204        assert!(task.validate_output(&bad_score).is_err());
205    }
206
207    #[test]
208    fn test_schema() {
209        let task = TextRecognitionTask::default();
210        let schema = task.schema();
211        assert_eq!(schema.task_type, TaskType::TextRecognition);
212        assert!(schema.input_types.contains(&"text_boxes".to_string()));
213        assert!(schema.output_types.contains(&"text_strings".to_string()));
214    }
215}