oar_ocr/predictor/
crnn_recognizer.rs

1//! CRNN (Convolutional Recurrent Neural Network) Text Recognizer
2//!
3//! This module implements a text recognition predictor using the CRNN model,
4//! which combines convolutional layers for feature extraction and recurrent layers
5//! for sequence modeling. It's commonly used for recognizing text in images.
6//!
7//! The main components are:
8//! - `TextRecPredictor`: The main predictor that performs text recognition
9//! - `TextRecPredictorConfig`: Configuration for the predictor
10//! - `TextRecResult`: Results from text recognition
11//! - `TextRecPredictorBuilder`: Builder for creating predictor instances
12
13use crate::core::ImageReader as CoreImageReader;
14use crate::core::{
15    BatchData, CommonBuilderConfig, ConfigValidator, ConfigValidatorExt, DefaultImageReader,
16    OCRError, OrtInfer, Tensor3D, Tensor4D,
17};
18use crate::core::{
19    GranularImageReader as GIReader, ModularPredictor, OrtInfer3D, Postprocessor as GPostprocessor,
20    Preprocessor as GPreprocessor,
21};
22use crate::impl_common_builder_methods;
23use crate::impl_config_new_and_with_common;
24use crate::processors::{CTCLabelDecode, NormalizeImage, OCRResize};
25
26use image::RgbImage;
27use std::path::Path;
28use std::sync::Arc;
29
30/// Results from text recognition
31///
32/// This struct holds the results of text recognition operations,
33/// including the recognized text, confidence scores, and associated metadata.
34#[derive(Debug, Clone)]
35pub struct TextRecResult {
36    /// Paths to the input images
37    pub input_path: Vec<Arc<str>>,
38    /// Indexes of the input images
39    pub index: Vec<usize>,
40    /// Input images
41    pub input_img: Vec<Arc<RgbImage>>,
42    /// Recognized text
43    pub rec_text: Vec<Arc<str>>,
44    /// Confidence scores for the recognized text
45    pub rec_score: Vec<f32>,
46}
47
48/// Configuration for the text recognition predictor
49///
50/// This struct holds the configuration parameters for the text recognition predictor.
51#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
52pub struct TextRecPredictorConfig {
53    /// Common configuration parameters
54    pub common: CommonBuilderConfig,
55    /// Model input shape for image resizing [channels, height, width]
56    /// When specified, images are resized to fit this shape while maintaining aspect ratio.
57    /// If None, the predictor defaults to [3, 48, 320] (DEFAULT_REC_IMAGE_SHAPE).
58    pub model_input_shape: Option<[usize; 3]>,
59    /// Character dictionary for recognition
60    pub character_dict: Option<Vec<String>>,
61    /// Score threshold for filtering recognition results
62    pub score_thresh: Option<f32>,
63}
64
65impl_config_new_and_with_common!(
66    TextRecPredictorConfig,
67    common_defaults: (Some("crnn".to_string()), Some(32)),
68    fields: {
69        model_input_shape: Some([3, 48, 320]),
70        character_dict: None,
71        score_thresh: None
72    }
73);
74
75impl ConfigValidator for TextRecPredictorConfig {
76    fn validate(&self) -> Result<(), crate::core::ConfigError> {
77        self.common.validate()?;
78
79        if let Some(shape) = self.model_input_shape
80            && (shape[0] == 0 || shape[1] == 0 || shape[2] == 0)
81        {
82            return Err(crate::core::ConfigError::InvalidConfig {
83                message: "Model input shape dimensions must be greater than 0".to_string(),
84            });
85        }
86
87        Ok(())
88    }
89
90    fn get_defaults() -> Self {
91        Self::new()
92    }
93}
94
95impl TextRecResult {
96    /// Creates a new, empty `TextRecResult`
97    pub fn new() -> Self {
98        Self {
99            input_path: Vec::new(),
100            index: Vec::new(),
101            input_img: Vec::new(),
102            rec_text: Vec::new(),
103            rec_score: Vec::new(),
104        }
105    }
106}
107
108impl Default for TextRecResult {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114/// Text recognition predictor built from modular components
115///
116/// This is a type alias over `ModularPredictor` with concrete, composable components
117/// to eliminate duplicated StandardPredictor implementations across predictors.
118pub type TextRecPredictor =
119    ModularPredictor<TRImageReader, TRPreprocessor, OrtInfer3D, TRPostprocessor>;
120
121#[derive(Debug)]
122pub struct TRImageReader {
123    inner: DefaultImageReader,
124}
125impl TRImageReader {
126    pub fn new() -> Self {
127        Self {
128            inner: DefaultImageReader::new(),
129        }
130    }
131}
132impl Default for TRImageReader {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137impl GIReader for TRImageReader {
138    fn read_images<'a>(
139        &self,
140        paths: impl Iterator<Item = &'a str>,
141    ) -> Result<Vec<RgbImage>, OCRError> {
142        self.inner.apply(paths)
143    }
144}
145
146#[derive(Debug)]
147pub struct TRPreprocessor {
148    pub resize: OCRResize,
149    pub normalize: NormalizeImage,
150}
151impl GPreprocessor for TRPreprocessor {
152    type Config = TextRecConfig;
153    type Output = Tensor4D;
154    fn preprocess(
155        &self,
156        images: Vec<RgbImage>,
157        _config: Option<&Self::Config>,
158    ) -> Result<Self::Output, OCRError> {
159        let resized_imgs = self.resize.apply_to_images(&images)?;
160        let dynamic_imgs: Vec<image::DynamicImage> = resized_imgs
161            .into_iter()
162            .map(image::DynamicImage::ImageRgb8)
163            .collect();
164        self.normalize.normalize_batch_to(dynamic_imgs)
165    }
166}
167
168#[derive(Debug)]
169pub struct TRPostprocessor {
170    pub decoder: CTCLabelDecode,
171}
172impl GPostprocessor for TRPostprocessor {
173    type Config = TextRecConfig;
174    type InferenceOutput = Tensor3D;
175    type PreprocessOutput = Tensor4D;
176    type Result = TextRecResult;
177    fn postprocess(
178        &self,
179        output: Self::InferenceOutput,
180        _pre: Option<&Self::PreprocessOutput>,
181        batch_data: &BatchData,
182        raw_images: Vec<RgbImage>,
183        _config: Option<&Self::Config>,
184    ) -> crate::core::OcrResult<Self::Result> {
185        let (texts, scores) = self.decoder.apply(&output);
186        Ok(TextRecResult {
187            input_path: batch_data.input_paths.clone(),
188            index: batch_data.indexes.clone(),
189            input_img: raw_images.into_iter().map(Arc::new).collect(),
190            rec_text: texts.into_iter().map(Arc::from).collect(),
191            rec_score: scores,
192        })
193    }
194    fn empty_result(&self) -> crate::core::OcrResult<Self::Result> {
195        Ok(TextRecResult::new())
196    }
197}
198
199/// Configuration for text recognition
200///
201/// This struct is used as a placeholder for text recognition configuration.
202#[derive(Debug, Clone)]
203pub struct TextRecConfig;
204
205/// Builder for `TextRecPredictor`
206///
207/// This struct is used to build a `TextRecPredictor` with the desired configuration.
208pub struct TextRecPredictorBuilder {
209    /// Common configuration parameters
210    common: CommonBuilderConfig,
211
212    /// Model input shape for image resizing [channels, height, width]
213    model_input_shape: Option<[usize; 3]>,
214    /// Character dictionary for recognition
215    character_dict: Option<Vec<String>>,
216    /// Score threshold for filtering recognition results
217    score_thresh: Option<f32>,
218}
219
220impl_common_builder_methods!(TextRecPredictorBuilder, common);
221
222impl TextRecPredictorBuilder {
223    /// Creates a new `TextRecPredictorBuilder`
224    ///
225    /// This function initializes a new builder with default values.
226    pub fn new() -> Self {
227        Self {
228            common: CommonBuilderConfig::new(),
229            model_input_shape: None,
230            character_dict: None,
231            score_thresh: None,
232        }
233    }
234
235    /// Sets the model input shape
236    ///
237    /// This function sets the model input shape for image resizing.
238    /// Images will be resized to fit this shape while maintaining aspect ratio.
239    pub fn model_input_shape(mut self, shape: [usize; 3]) -> Self {
240        self.model_input_shape = Some(shape);
241        self
242    }
243
244    /// Sets the character dictionary
245    ///
246    /// This function sets the character dictionary for recognition.
247    pub fn character_dict(mut self, character_dict: Vec<String>) -> Self {
248        self.character_dict = Some(character_dict);
249        self
250    }
251
252    /// Sets the score threshold for filtering recognition results
253    ///
254    /// This function sets the minimum score threshold for recognition results.
255    /// Results with scores below this threshold will be filtered out.
256    pub fn score_thresh(mut self, score_thresh: f32) -> Self {
257        self.score_thresh = Some(score_thresh);
258        self
259    }
260
261    /// Builds the `TextRecPredictor`
262    ///
263    /// This function builds the `TextRecPredictor` with the provided configuration.
264    pub fn build(self, model_path: &Path) -> crate::core::OcrResult<TextRecPredictor> {
265        self.build_internal(model_path)
266    }
267
268    /// Builds the `TextRecPredictor` internally
269    ///
270    /// This function builds the `TextRecPredictor` with the provided configuration.
271    /// It also validates the configuration and handles the model path.
272    fn build_internal(mut self, model_path: &Path) -> crate::core::OcrResult<TextRecPredictor> {
273        // Ensure model path is set first
274        if self.common.model_path.is_none() {
275            self.common = self.common.model_path(model_path.to_path_buf());
276        }
277
278        // Build the configuration
279        let config = TextRecPredictorConfig {
280            common: self.common,
281            model_input_shape: self.model_input_shape,
282            character_dict: self.character_dict,
283            score_thresh: self.score_thresh,
284        };
285
286        // Validate the configuration
287        let config = config.validate_and_wrap_ocr_error()?;
288
289        // Build modular components
290        let model_input_shape = config.model_input_shape.unwrap_or([3, 48, 320]);
291        let character_dict = config.character_dict.clone();
292
293        let image_reader = TRImageReader::new();
294        let resize = OCRResize::new(Some(model_input_shape), None);
295        let normalize = NormalizeImage::for_ocr_recognition()?;
296        let preprocessor = TRPreprocessor { resize, normalize };
297        let infer = OrtInfer::from_common(&config.common, model_path, None)?;
298        let inference_engine = OrtInfer3D::new(infer);
299        let decoder = CTCLabelDecode::from_string_list(character_dict.as_deref(), true, false);
300        let postprocessor = TRPostprocessor { decoder };
301
302        Ok(ModularPredictor::new(
303            image_reader,
304            preprocessor,
305            inference_engine,
306            postprocessor,
307        ))
308    }
309}
310
311impl Default for TextRecPredictorBuilder {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[cfg(test)]
322    mod tests_local {
323        use super::*;
324
325        #[test]
326        fn test_text_rec_config_defaults_and_validate() {
327            let config = TextRecPredictorConfig::new();
328            assert_eq!(config.model_input_shape, Some([3, 48, 320]));
329            assert_eq!(config.common.model_name.as_deref(), Some("crnn"));
330            assert_eq!(config.common.batch_size, Some(32));
331            assert!(config.validate().is_ok());
332        }
333    }
334
335    #[test]
336    fn test_text_rec_predictor_config_score_thresh() {
337        // Test default configuration
338        let config = TextRecPredictorConfig::new();
339        assert_eq!(config.score_thresh, None);
340
341        // Test configuration with score threshold
342        let mut config = TextRecPredictorConfig::new();
343        config.score_thresh = Some(0.5);
344        assert_eq!(config.score_thresh, Some(0.5));
345    }
346
347    #[test]
348    fn test_text_rec_predictor_builder_score_thresh() {
349        // Test builder with score threshold
350        let builder = TextRecPredictorBuilder::new().score_thresh(0.7);
351
352        assert_eq!(builder.score_thresh, Some(0.7));
353    }
354}