oar_ocr/predictor/
doc_orientation_classifier.rs

1//! Document Orientation Classifier
2//!
3//! This module provides functionality for classifying the orientation of documents in images.
4//! It can detect if a document is rotated and by how much (0°, 90°, 180°, or 270°).
5//!
6//! The classifier uses a pre-trained model to analyze images and determine their orientation.
7//! It supports batch processing for efficient handling of multiple images.
8
9use crate::core::traits::ImageReader as CoreImageReader;
10use crate::core::{
11    BatchData, CommonBuilderConfig, DefaultImageReader, OCRError, OrtInfer, Tensor2D, Tensor4D,
12    config::{ConfigValidator, ConfigValidatorExt},
13};
14
15use crate::impl_config_new_and_with_common;
16
17use crate::core::get_document_orientation_labels;
18use crate::processors::{NormalizeImage, Topk};
19use image::RgbImage;
20use std::path::Path;
21use std::sync::Arc;
22
23/// Results from document orientation classification
24///
25/// This struct contains the results of classifying document orientations in images.
26/// For each image, it provides the predicted orientations along with confidence scores.
27#[derive(Debug, Clone)]
28pub struct DocOrientationResult {
29    /// Paths to the input images
30    pub input_path: Vec<Arc<str>>,
31    /// Indexes of the images in the batch
32    pub index: Vec<usize>,
33    /// The input images
34    pub input_img: Vec<Arc<RgbImage>>,
35    /// Predicted class IDs for each image (sorted by confidence)
36    pub class_ids: Vec<Vec<usize>>,
37    /// Confidence scores for each prediction
38    pub scores: Vec<Vec<f32>>,
39    /// Label names for each prediction (e.g., "0", "90", "180", "270")
40    pub label_names: Vec<Vec<Arc<str>>>,
41}
42
43/// Configuration for the document orientation classifier
44///
45/// This struct holds configuration parameters for the document orientation classifier.
46/// It includes common configuration options as well as classifier-specific parameters.
47#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
48pub struct DocOrientationClassifierConfig {
49    /// Common configuration options shared across predictors
50    pub common: CommonBuilderConfig,
51    /// Number of top predictions to return for each image
52    pub topk: Option<usize>,
53    /// Input shape for the model (width, height)
54    pub input_shape: Option<(u32, u32)>,
55}
56
57impl_config_new_and_with_common!(
58    DocOrientationClassifierConfig,
59    common_defaults: (Some("doc_orientation_classifier".to_string()), Some(1)),
60    fields: {
61        topk: Some(4),
62        input_shape: Some((224, 224))
63    }
64);
65
66impl DocOrientationClassifierConfig {
67    /// Validates the document orientation classifier configuration
68    ///
69    /// Checks that all configuration parameters are valid and within acceptable ranges.
70    ///
71    /// # Returns
72    ///
73    /// Ok if the configuration is valid, or an error if validation fails
74    pub fn validate(&self) -> Result<(), crate::core::ConfigError> {
75        ConfigValidator::validate(self)
76    }
77}
78
79impl ConfigValidator for DocOrientationClassifierConfig {
80    /// Validates the document orientation classifier configuration
81    ///
82    /// Checks that all configuration parameters are valid and within acceptable ranges.
83    /// This includes validating the common configuration, topk value, and input shape.
84    ///
85    /// # Returns
86    ///
87    /// Ok if the configuration is valid, or an error if validation fails
88    fn validate(&self) -> Result<(), crate::core::ConfigError> {
89        self.common.validate()?;
90
91        if let Some(topk) = self.topk {
92            self.validate_positive_usize(topk, "topk")?;
93        }
94
95        if let Some((width, height)) = self.input_shape {
96            self.validate_image_dimensions(width, height)?;
97        }
98
99        Ok(())
100    }
101
102    /// Gets the default document orientation classifier configuration
103    ///
104    /// Returns a new instance of the document orientation classifier configuration
105    /// with default values for all parameters.
106    ///
107    /// # Returns
108    ///
109    /// A new instance of `DocOrientationClassifierConfig` with default settings
110    fn get_defaults() -> Self {
111        Self {
112            common: CommonBuilderConfig::get_defaults(),
113            topk: Some(4),
114            input_shape: Some((224, 224)),
115        }
116    }
117}
118
119impl DocOrientationResult {
120    /// Creates a new empty document orientation result
121    ///
122    /// Initializes a new instance of the document orientation result with empty vectors
123    /// for all fields.
124    ///
125    /// # Returns
126    ///
127    /// A new instance of `DocOrientationResult` with empty vectors
128    pub fn new() -> Self {
129        Self {
130            input_path: Vec::new(),
131            index: Vec::new(),
132            input_img: Vec::new(),
133            class_ids: Vec::new(),
134            scores: Vec::new(),
135            label_names: Vec::new(),
136        }
137    }
138}
139
140impl Default for DocOrientationResult {
141    /// Creates a new empty document orientation result
142    ///
143    /// This is equivalent to calling `DocOrientationResult::new()`.
144    ///
145    /// # Returns
146    ///
147    /// A new instance of `DocOrientationResult` with empty vectors
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153/// Document orientation classifier built from modular components
154///
155/// This is a type alias over `ModularPredictor` with concrete, composable components
156/// to eliminate duplicated StandardPredictor implementations across predictors.
157pub type DocOrientationClassifier =
158    ModularPredictor<DocOrImageReader, DocOrPreprocessor, OrtInfer2D, DocOrPostprocessor>;
159
160// Granular trait adapters for the document orientation classifier
161use crate::core::{
162    GranularImageReader as GIReader, ModularPredictor, OrtInfer2D, Postprocessor as GPostprocessor,
163    Preprocessor as GPreprocessor,
164};
165use image::DynamicImage;
166
167#[derive(Debug, Clone)]
168pub struct DocOrientationConfig;
169
170#[derive(Debug)]
171pub struct DocOrImageReader {
172    inner: DefaultImageReader,
173}
174
175impl DocOrImageReader {
176    pub fn new() -> Self {
177        Self {
178            inner: DefaultImageReader::new(),
179        }
180    }
181}
182
183impl Default for DocOrImageReader {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189impl GIReader for DocOrImageReader {
190    fn read_images<'a>(
191        &self,
192        paths: impl Iterator<Item = &'a str>,
193    ) -> Result<Vec<RgbImage>, OCRError> {
194        self.inner.apply(paths)
195    }
196}
197
198#[derive(Debug)]
199pub struct DocOrPreprocessor {
200    pub input_shape: (u32, u32),
201    pub normalize: NormalizeImage,
202}
203
204impl GPreprocessor for DocOrPreprocessor {
205    type Config = DocOrientationConfig;
206    type Output = Tensor4D;
207
208    fn preprocess(
209        &self,
210        images: Vec<RgbImage>,
211        _config: Option<&Self::Config>,
212    ) -> Result<Self::Output, OCRError> {
213        use crate::utils::resize_images_batch_to_dynamic;
214        let dynamic_images: Vec<DynamicImage> =
215            resize_images_batch_to_dynamic(&images, self.input_shape.0, self.input_shape.1, None);
216        self.normalize.normalize_batch_to(dynamic_images)
217    }
218
219    fn preprocessing_info(&self) -> String {
220        format!(
221            "resize_to=({},{}) + normalize",
222            self.input_shape.0, self.input_shape.1
223        )
224    }
225}
226
227#[derive(Debug)]
228pub struct DocOrPostprocessor {
229    pub topk: usize,
230    pub topk_op: Topk,
231}
232
233impl GPostprocessor for DocOrPostprocessor {
234    type Config = DocOrientationConfig;
235    type InferenceOutput = Tensor2D;
236    type PreprocessOutput = Tensor4D;
237    type Result = DocOrientationResult;
238
239    fn postprocess(
240        &self,
241        output: Self::InferenceOutput,
242        _preprocess_output: Option<&Self::PreprocessOutput>,
243        batch_data: &BatchData,
244        raw_images: Vec<RgbImage>,
245        _config: Option<&Self::Config>,
246    ) -> crate::core::OcrResult<Self::Result> {
247        // Convert ndarray output to Vec<Vec<f32>> format expected by Topk
248        let predictions: Vec<Vec<f32>> = output.outer_iter().map(|row| row.to_vec()).collect();
249        let topk_result = self
250            .topk_op
251            .process(&predictions, self.topk)
252            .map_err(|e| OCRError::ConfigError { message: e })?;
253
254        Ok(DocOrientationResult {
255            input_path: batch_data.input_paths.clone(),
256            index: batch_data.indexes.clone(),
257            input_img: raw_images.into_iter().map(Arc::new).collect(),
258            class_ids: topk_result.indexes,
259            scores: topk_result.scores,
260            label_names: topk_result
261                .class_names
262                .unwrap_or_default()
263                .into_iter()
264                .map(|names| names.into_iter().map(Arc::from).collect())
265                .collect(),
266        })
267    }
268
269    fn empty_result(&self) -> Result<Self::Result, OCRError> {
270        Ok(DocOrientationResult::new())
271    }
272}
273
274/// Builder for document orientation classifier
275///
276/// This struct provides a builder pattern for creating a document orientation classifier
277/// with custom configuration options.
278pub struct DocOrientationClassifierBuilder {
279    /// Common configuration options shared across predictors
280    common: CommonBuilderConfig,
281
282    /// Number of top predictions to return for each image
283    topk: Option<usize>,
284    /// Input shape for the model (width, height)
285    input_shape: Option<(u32, u32)>,
286}
287
288impl DocOrientationClassifierBuilder {
289    /// Creates a new document orientation classifier builder
290    ///
291    /// Initializes a new instance of the document orientation classifier builder
292    /// with default configuration options.
293    ///
294    /// # Returns
295    ///
296    /// A new instance of `DocOrientationClassifierBuilder`
297    pub fn new() -> Self {
298        Self {
299            common: CommonBuilderConfig::new(),
300            topk: None,
301            input_shape: None,
302        }
303    }
304
305    /// Sets the model path for the classifier
306    ///
307    /// Specifies the path to the ONNX model file that will be used for inference.
308    ///
309    /// # Arguments
310    ///
311    /// * `model_path` - Path to the ONNX model file
312    ///
313    /// # Returns
314    ///
315    /// The updated builder instance
316    pub fn model_path(mut self, model_path: impl Into<std::path::PathBuf>) -> Self {
317        self.common = self.common.model_path(model_path);
318        self
319    }
320
321    /// Sets the model name for the classifier
322    ///
323    /// Specifies the name of the model being used.
324    ///
325    /// # Arguments
326    ///
327    /// * `model_name` - Name of the model
328    ///
329    /// # Returns
330    ///
331    /// The updated builder instance
332    pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
333        self.common = self.common.model_name(model_name);
334        self
335    }
336
337    /// Sets the batch size for the classifier
338    ///
339    /// Specifies the number of images to process in each batch.
340    ///
341    /// # Arguments
342    ///
343    /// * `batch_size` - Number of images to process in each batch
344    ///
345    /// # Returns
346    ///
347    /// The updated builder instance
348    pub fn batch_size(mut self, batch_size: usize) -> Self {
349        self.common = self.common.batch_size(batch_size);
350        self
351    }
352
353    /// Enables or disables logging for the classifier
354    ///
355    /// Controls whether logging is enabled during classification.
356    ///
357    /// # Arguments
358    ///
359    /// * `enable` - Whether to enable logging
360    ///
361    /// # Returns
362    ///
363    /// The updated builder instance
364    pub fn enable_logging(mut self, enable: bool) -> Self {
365        self.common = self.common.enable_logging(enable);
366        self
367    }
368
369    /// Sets the ONNX Runtime session configuration
370    ///
371    /// This function sets the ONNX Runtime session configuration for the predictor.
372    pub fn ort_session(mut self, config: crate::core::config::onnx::OrtSessionConfig) -> Self {
373        self.common = self.common.ort_session(config);
374        self
375    }
376
377    /// Sets the session pool size for concurrent predictions
378    ///
379    /// This function sets the size of the session pool used for concurrent predictions.
380    /// The pool size must be >= 1.
381    ///
382    /// # Arguments
383    ///
384    /// * `size` - The session pool size (minimum 1)
385    ///
386    /// # Returns
387    ///
388    /// The updated builder instance
389    pub fn session_pool_size(mut self, size: usize) -> Self {
390        self.common = self.common.session_pool_size(size);
391        self
392    }
393
394    /// Sets the number of top predictions to return
395    ///
396    /// Specifies how many of the top predictions to return for each image.
397    ///
398    /// # Arguments
399    ///
400    /// * `topk` - Number of top predictions to return
401    ///
402    /// # Returns
403    ///
404    /// The updated builder instance
405    pub fn topk(mut self, topk: usize) -> Self {
406        self.topk = Some(topk);
407        self
408    }
409
410    /// Sets the input shape for the model
411    ///
412    /// Specifies the input shape (width, height) that the model expects.
413    ///
414    /// # Arguments
415    ///
416    /// * `input_shape` - Input shape as (width, height)
417    ///
418    /// # Returns
419    ///
420    /// The updated builder instance
421    pub fn input_shape(mut self, input_shape: (u32, u32)) -> Self {
422        self.input_shape = Some(input_shape);
423        self
424    }
425
426    /// Builds the document orientation classifier
427    ///
428    /// Creates a new instance of the document orientation classifier with the
429    /// configured options.
430    ///
431    /// # Arguments
432    ///
433    /// * `model_path` - Path to the ONNX model file
434    ///
435    /// # Returns
436    ///
437    /// A new instance of `DocOrientationClassifier` or an error if building fails
438    pub fn build(self, model_path: &Path) -> Result<DocOrientationClassifier, OCRError> {
439        self.build_internal(model_path)
440    }
441
442    /// Internal method to build the document orientation classifier
443    ///
444    /// Creates a new instance of the document orientation classifier with the
445    /// configured options. This method handles validation of the configuration
446    /// and initialization of the classifier.
447    ///
448    /// # Arguments
449    ///
450    /// * `model_path` - Path to the ONNX model file
451    ///
452    /// # Returns
453    ///
454    /// A new instance of `DocOrientationClassifier` or an error if building fails
455    fn build_internal(mut self, model_path: &Path) -> Result<DocOrientationClassifier, OCRError> {
456        if self.common.model_path.is_none() {
457            self.common = self.common.model_path(model_path.to_path_buf());
458        }
459
460        let config = DocOrientationClassifierConfig {
461            common: self.common,
462            topk: self.topk,
463            input_shape: self.input_shape,
464        };
465
466        let config = config.validate_and_wrap_ocr_error()?;
467
468        // Build modular components
469        let input_shape = config.input_shape.unwrap_or((224, 224));
470        let image_reader = DocOrImageReader::new();
471        let normalize = NormalizeImage::new(
472            Some(1.0 / 255.0),
473            Some(vec![0.485, 0.456, 0.406]),
474            Some(vec![0.229, 0.224, 0.225]),
475            None,
476        )?;
477        let preprocessor = DocOrPreprocessor {
478            input_shape,
479            normalize,
480        };
481        let infer_inner = OrtInfer::from_common(&config.common, model_path, None)?;
482        let inference_engine = OrtInfer2D::new(infer_inner);
483        let postprocessor = DocOrPostprocessor {
484            topk: config.topk.unwrap_or(4),
485            topk_op: Topk::from_class_names(get_document_orientation_labels()),
486        };
487
488        Ok(ModularPredictor::new(
489            image_reader,
490            preprocessor,
491            inference_engine,
492            postprocessor,
493        ))
494    }
495}
496
497impl Default for DocOrientationClassifierBuilder {
498    /// Creates a new document orientation classifier builder with default settings
499    ///
500    /// This is equivalent to calling `DocOrientationClassifierBuilder::new()`.
501    ///
502    /// # Returns
503    ///
504    /// A new instance of `DocOrientationClassifierBuilder` with default settings
505    fn default() -> Self {
506        Self::new()
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_doc_orientation_config_defaults_and_validate() {
516        let config = DocOrientationClassifierConfig::new();
517        assert_eq!(config.topk, Some(4));
518        assert_eq!(config.input_shape, Some((224, 224)));
519        assert_eq!(
520            config.common.model_name.as_deref(),
521            Some("doc_orientation_classifier")
522        );
523        assert_eq!(config.common.batch_size, Some(1));
524        assert!(config.validate().is_ok());
525    }
526}