oar_ocr/predictor/
text_line_classifier.rs

1//! Text Line Classifier
2//!
3//! This module provides functionality for classifying the orientation of text lines in images.
4//! It can detect if a text line is rotated and by how much (0° or 180°).
5//!
6//! The classifier uses a pre-trained model to analyze images and determine their orientation.
7//! It supports batch processing for efficient handling of multiple images.
8
9use crate::common_builder_methods;
10use crate::core::ImageReader as CoreImageReader;
11use crate::core::{
12    BatchData, CommonBuilderConfig, DefaultImageReader, OCRError, OrtInfer, Tensor2D, Tensor4D,
13    config::{ConfigValidator, ConfigValidatorExt},
14    get_text_line_orientation_labels,
15};
16use crate::core::{
17    GranularImageReader as GIReader, ModularPredictor, OrtInfer2D, Postprocessor as GPostprocessor,
18    Preprocessor as GPreprocessor,
19};
20
21use crate::processors::{Crop, NormalizeImage, Topk};
22use image::{DynamicImage, RgbImage};
23use std::path::Path;
24use std::sync::Arc;
25
26use crate::impl_config_new_and_with_common;
27
28/// Results from text line classification
29///
30/// This struct contains the results of classifying text line orientations in images.
31/// For each image, it provides the predicted orientations along with confidence scores.
32#[derive(Debug, Clone)]
33pub struct TextLineClasResult {
34    /// Paths to the input images
35    pub input_path: Vec<Arc<str>>,
36    /// Indexes of the images in the batch
37    pub index: Vec<usize>,
38    /// The input images
39    pub input_img: Vec<Arc<RgbImage>>,
40    /// Predicted class IDs for each image (sorted by confidence)
41    pub class_ids: Vec<Vec<usize>>,
42    /// Confidence scores for each prediction
43    pub scores: Vec<Vec<f32>>,
44    /// Label names for each prediction (e.g., "0", "180")
45    pub label_names: Vec<Vec<Arc<str>>>,
46}
47
48/// Configuration for the text line classifier
49///
50/// This struct holds configuration parameters for the text line classifier.
51/// It includes common configuration options as well as classifier-specific parameters.
52#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
53pub struct TextLineClasPredictorConfig {
54    /// Common configuration options shared across predictors
55    pub common: CommonBuilderConfig,
56    /// Number of top predictions to return for each image
57    pub topk: Option<usize>,
58    /// Input shape for the model (width, height)
59    pub input_shape: Option<(u32, u32)>,
60}
61
62impl_config_new_and_with_common!(
63    TextLineClasPredictorConfig,
64    common_defaults: (Some("PP-LCNet_x0_25".to_string()), Some(1)),
65    fields: {
66        topk: None,
67        input_shape: Some((224, 224))
68    }
69);
70
71impl TextLineClasPredictorConfig {
72    /// Validates the text line classifier configuration
73    ///
74    /// Checks that all configuration parameters are valid and within acceptable ranges.
75    ///
76    /// # Returns
77    ///
78    /// Ok if the configuration is valid, or an error if validation fails
79    pub fn validate(&self) -> Result<(), crate::core::ConfigError> {
80        ConfigValidator::validate(self)
81    }
82}
83
84impl ConfigValidator for TextLineClasPredictorConfig {
85    /// Validates the text line classifier configuration
86    ///
87    /// Checks that all configuration parameters are valid and within acceptable ranges.
88    /// This includes validating the common configuration, topk value, and input shape.
89    ///
90    /// # Returns
91    ///
92    /// Ok if the configuration is valid, or an error if validation fails
93    fn validate(&self) -> Result<(), crate::core::ConfigError> {
94        self.common.validate()?;
95
96        if let Some(topk) = self.topk {
97            self.validate_positive_usize(topk, "topk")?;
98        }
99
100        if let Some((width, height)) = self.input_shape {
101            self.validate_image_dimensions(width, height)?;
102        }
103
104        Ok(())
105    }
106
107    /// Gets the default text line classifier configuration
108    ///
109    /// Returns a new instance of the text line classifier configuration
110    /// with default values for all parameters.
111    ///
112    /// # Returns
113    ///
114    /// A new instance of `TextLineClasPredictorConfig` with default settings
115    fn get_defaults() -> Self {
116        Self {
117            common: CommonBuilderConfig::get_defaults(),
118            topk: Some(2),
119            input_shape: Some((224, 224)),
120        }
121    }
122}
123
124impl TextLineClasResult {
125    /// Creates a new empty text line classification result
126    ///
127    /// Initializes a new instance of the text line classification result with empty vectors
128    /// for all fields.
129    ///
130    /// # Returns
131    ///
132    /// A new instance of `TextLineClasResult` with empty vectors
133    pub fn new() -> Self {
134        Self {
135            input_path: Vec::new(),
136            index: Vec::new(),
137            input_img: Vec::new(),
138            class_ids: Vec::new(),
139            scores: Vec::new(),
140            label_names: Vec::new(),
141        }
142    }
143}
144
145impl Default for TextLineClasResult {
146    /// Creates a new empty text line classification result
147    ///
148    /// This is equivalent to calling `TextLineClasResult::new()`.
149    ///
150    /// # Returns
151    ///
152    /// A new instance of `TextLineClasResult` with empty vectors
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// Text line classifier built from modular components
159///
160/// This is a type alias over `ModularPredictor` with concrete, composable components
161/// to eliminate duplicated StandardPredictor implementations.
162pub type TextLineClasPredictor =
163    ModularPredictor<TLImageReader, TLPreprocessor, OrtInfer2D, TLPostprocessor>;
164
165#[derive(Debug, Clone)]
166pub struct TextLineClasConfig;
167
168#[derive(Debug)]
169pub struct TLImageReader {
170    inner: DefaultImageReader,
171}
172impl TLImageReader {
173    /// Creates a new TLImageReader.
174    ///
175    /// Wraps the DefaultImageReader and is used in the text line classification
176    /// pipeline to load images from paths into RgbImage values expected by the
177    /// preprocessor.
178    ///
179    /// Returns a reader ready to be plugged into the modular TextLineClasPredictor.
180    pub fn new() -> Self {
181        Self {
182            inner: DefaultImageReader::new(),
183        }
184    }
185}
186impl Default for TLImageReader {
187    /// Creates a TLImageReader with default settings.
188    ///
189    /// This is equivalent to calling TLImageReader::new().
190    fn default() -> Self {
191        Self::new()
192    }
193}
194impl GIReader for TLImageReader {
195    fn read_images<'a>(
196        &self,
197        paths: impl Iterator<Item = &'a str>,
198    ) -> Result<Vec<RgbImage>, OCRError> {
199        self.inner.apply(paths)
200    }
201}
202
203#[derive(Debug)]
204pub struct TLPreprocessor {
205    pub input_shape: (u32, u32),
206    pub crop: Option<Crop>,
207    pub normalize: NormalizeImage,
208}
209impl GPreprocessor for TLPreprocessor {
210    type Config = TextLineClasConfig;
211    type Output = Tensor4D;
212    fn preprocess(
213        &self,
214        images: Vec<RgbImage>,
215        _config: Option<&Self::Config>,
216    ) -> Result<Self::Output, OCRError> {
217        use crate::utils::resize_images_batch;
218        let (width, height) = self.input_shape;
219        let mut batch_imgs = resize_images_batch(&images, width, height, None);
220        if let Some(crop_op) = &self.crop {
221            batch_imgs = crop_op.process_batch(&batch_imgs).map_err(|e| {
222                OCRError::post_processing("Crop operation failed during text classification", e)
223            })?;
224        }
225        let imgs_dynamic: Vec<DynamicImage> = batch_imgs
226            .iter()
227            .map(|img| DynamicImage::ImageRgb8(img.clone()))
228            .collect();
229        self.normalize.normalize_batch_to(imgs_dynamic)
230    }
231    fn preprocessing_info(&self) -> String {
232        format!(
233            "resize_to=({},{}) + crop? + normalize",
234            self.input_shape.0, self.input_shape.1
235        )
236    }
237}
238
239#[derive(Debug)]
240pub struct TLPostprocessor {
241    pub topk: usize,
242    pub topk_op: Topk,
243}
244impl GPostprocessor for TLPostprocessor {
245    type Config = TextLineClasConfig;
246    type InferenceOutput = Tensor2D;
247    type PreprocessOutput = Tensor4D;
248    type Result = TextLineClasResult;
249    fn postprocess(
250        &self,
251        output: Self::InferenceOutput,
252        _pre: Option<&Self::PreprocessOutput>,
253        batch_data: &BatchData,
254        raw_images: Vec<RgbImage>,
255        _config: Option<&Self::Config>,
256    ) -> crate::core::OcrResult<Self::Result> {
257        let predictions: Vec<Vec<f32>> = output.outer_iter().map(|row| row.to_vec()).collect();
258        let topk_result = self
259            .topk_op
260            .process(&predictions, self.topk)
261            .map_err(|e| OCRError::ConfigError { message: e })?;
262        Ok(TextLineClasResult {
263            input_path: batch_data.input_paths.clone(),
264            index: batch_data.indexes.clone(),
265            input_img: raw_images.into_iter().map(Arc::new).collect(),
266            class_ids: topk_result.indexes,
267            scores: topk_result.scores,
268            label_names: topk_result
269                .class_names
270                .unwrap_or_default()
271                .into_iter()
272                .map(|names| names.into_iter().map(Arc::from).collect())
273                .collect(),
274        })
275    }
276    fn empty_result(&self) -> crate::core::OcrResult<Self::Result> {
277        Ok(TextLineClasResult::new())
278    }
279}
280
281/// Builder for text line classifier
282///
283/// This struct provides a builder pattern for creating a text line classifier
284/// with custom configuration options.
285pub struct TextLineClasPredictorBuilder {
286    /// Common configuration options shared across predictors
287    common: CommonBuilderConfig,
288
289    /// Number of top predictions to return for each image
290    topk: Option<usize>,
291    /// Input shape for the model (width, height)
292    input_shape: Option<(u32, u32)>,
293}
294
295impl TextLineClasPredictorBuilder {
296    /// Creates a new text line classifier builder
297    ///
298    /// Initializes a new instance of the text line classifier builder
299    /// with default configuration options.
300    ///
301    /// # Returns
302    ///
303    /// A new instance of `TextLineClasPredictorBuilder`
304    pub fn new() -> Self {
305        Self {
306            common: CommonBuilderConfig::new(),
307            topk: None,
308            input_shape: None,
309        }
310    }
311
312    // Inject common builder methods
313    common_builder_methods!(common);
314
315    /// Sets the number of top predictions to return
316    ///
317    /// Specifies how many of the top predictions to return for each image.
318    ///
319    /// # Arguments
320    ///
321    /// * `topk` - Number of top predictions to return
322    ///
323    /// # Returns
324    ///
325    /// The updated builder instance
326    pub fn topk(mut self, topk: usize) -> Self {
327        self.topk = Some(topk);
328        self
329    }
330
331    /// Sets the input shape for the model
332    ///
333    /// Specifies the input shape (width, height) that the model expects.
334    ///
335    /// # Arguments
336    ///
337    /// * `input_shape` - Input shape as (width, height)
338    ///
339    /// # Returns
340    ///
341    /// The updated builder instance
342    pub fn input_shape(mut self, input_shape: (u32, u32)) -> Self {
343        self.input_shape = Some(input_shape);
344        self
345    }
346
347    /// Builds the text line classifier
348    ///
349    /// Creates a new instance of the text line classifier with the
350    /// configured options.
351    ///
352    /// # Arguments
353    ///
354    /// * `model_path` - Path to the ONNX model file
355    ///
356    /// # Returns
357    ///
358    /// A new instance of `TextLineClasPredictor` or an error if building fails
359    pub fn build(self, model_path: &Path) -> crate::core::OcrResult<TextLineClasPredictor> {
360        self.build_internal(model_path)
361    }
362
363    /// Internal method to build the text line classifier
364    ///
365    /// Creates a new instance of the text line classifier with the
366    /// configured options. This method handles validation of the configuration
367    /// and initialization of the classifier.
368    ///
369    /// # Arguments
370    ///
371    /// * `model_path` - Path to the ONNX model file
372    ///
373    /// # Returns
374    ///
375    /// A new instance of `TextLineClasPredictor` or an error if building fails
376    fn build_internal(
377        mut self,
378        model_path: &Path,
379    ) -> crate::core::OcrResult<TextLineClasPredictor> {
380        if self.common.model_path.is_none() {
381            self.common = self.common.model_path(model_path.to_path_buf());
382        }
383
384        let config = TextLineClasPredictorConfig {
385            common: self.common,
386            topk: self.topk,
387            input_shape: self.input_shape,
388        };
389
390        let config = config.validate_and_wrap_ocr_error()?;
391
392        let input_shape = config.input_shape.unwrap_or((224, 224));
393        let (width, height) = input_shape;
394        let crop = Some(
395            Crop::new([width, height], crate::processors::CropMode::Center).map_err(|e| {
396                OCRError::ConfigError {
397                    message: format!("Failed to create crop operation: {e}"),
398                }
399            })?,
400        );
401        let normalize = NormalizeImage::new(
402            Some(1.0 / 255.0),
403            Some(vec![0.485, 0.456, 0.406]),
404            Some(vec![0.229, 0.224, 0.225]),
405            None,
406        )?;
407        let preprocessor = TLPreprocessor {
408            input_shape,
409            crop,
410            normalize,
411        };
412        let infer_inner = OrtInfer::from_common(&config.common, model_path, None)?;
413        let inference_engine = OrtInfer2D::new(infer_inner);
414        let postprocessor = TLPostprocessor {
415            topk: config.topk.unwrap_or(2),
416            topk_op: Topk::from_class_names(get_text_line_orientation_labels()),
417        };
418        let image_reader = TLImageReader::new();
419        Ok(ModularPredictor::new(
420            image_reader,
421            preprocessor,
422            inference_engine,
423            postprocessor,
424        ))
425    }
426}
427
428impl Default for TextLineClasPredictorBuilder {
429    /// Creates a new text line classifier builder with default settings
430    ///
431    /// This is equivalent to calling `TextLineClasPredictorBuilder::new()`.
432    ///
433    fn default() -> Self {
434        Self::new()
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_text_line_clas_config_defaults_and_validate() {
444        let config = TextLineClasPredictorConfig::new();
445        // Defaults from impl_config_new_and_with_common! invocation
446        assert_eq!(config.topk, None); // overridden by get_defaults when used
447        assert_eq!(config.input_shape, Some((224, 224)));
448        assert_eq!(config.common.model_name.as_deref(), Some("PP-LCNet_x0_25"));
449        assert_eq!(config.common.batch_size, Some(1));
450        // Validate should pass
451        assert!(config.validate().is_ok());
452    }
453}