oar_ocr/predictor/doc_orientation_classifier.rs
1//! Document Orientation Classifier
2//!
3//! This module provides functionality for classifying the orientation of documents in images.
4//! It can detect if a document is rotated and by how much (0°, 90°, 180°, or 270°).
5//!
6//! The classifier uses a pre-trained model to analyze images and determine their orientation.
7//! It supports batch processing for efficient handling of multiple images.
8
9use crate::core::traits::ImageReader as CoreImageReader;
10use crate::core::{
11 BatchData, CommonBuilderConfig, DefaultImageReader, OCRError, OrtInfer, Tensor2D, Tensor4D,
12 config::{ConfigValidator, ConfigValidatorExt},
13};
14
15use crate::impl_config_new_and_with_common;
16
17use crate::core::get_document_orientation_labels;
18use crate::processors::{NormalizeImage, Topk};
19use image::RgbImage;
20use std::path::Path;
21use std::sync::Arc;
22
23/// Results from document orientation classification
24///
25/// This struct contains the results of classifying document orientations in images.
26/// For each image, it provides the predicted orientations along with confidence scores.
27#[derive(Debug, Clone)]
28pub struct DocOrientationResult {
29 /// Paths to the input images
30 pub input_path: Vec<Arc<str>>,
31 /// Indexes of the images in the batch
32 pub index: Vec<usize>,
33 /// The input images
34 pub input_img: Vec<Arc<RgbImage>>,
35 /// Predicted class IDs for each image (sorted by confidence)
36 pub class_ids: Vec<Vec<usize>>,
37 /// Confidence scores for each prediction
38 pub scores: Vec<Vec<f32>>,
39 /// Label names for each prediction (e.g., "0", "90", "180", "270")
40 pub label_names: Vec<Vec<Arc<str>>>,
41}
42
43/// Configuration for the document orientation classifier
44///
45/// This struct holds configuration parameters for the document orientation classifier.
46/// It includes common configuration options as well as classifier-specific parameters.
47#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
48pub struct DocOrientationClassifierConfig {
49 /// Common configuration options shared across predictors
50 pub common: CommonBuilderConfig,
51 /// Number of top predictions to return for each image
52 pub topk: Option<usize>,
53 /// Input shape for the model (width, height)
54 pub input_shape: Option<(u32, u32)>,
55}
56
57impl_config_new_and_with_common!(
58 DocOrientationClassifierConfig,
59 common_defaults: (Some("doc_orientation_classifier".to_string()), Some(1)),
60 fields: {
61 topk: Some(4),
62 input_shape: Some((224, 224))
63 }
64);
65
66impl DocOrientationClassifierConfig {
67 /// Validates the document orientation classifier configuration
68 ///
69 /// Checks that all configuration parameters are valid and within acceptable ranges.
70 ///
71 /// # Returns
72 ///
73 /// Ok if the configuration is valid, or an error if validation fails
74 pub fn validate(&self) -> Result<(), crate::core::ConfigError> {
75 ConfigValidator::validate(self)
76 }
77}
78
79impl ConfigValidator for DocOrientationClassifierConfig {
80 /// Validates the document orientation classifier configuration
81 ///
82 /// Checks that all configuration parameters are valid and within acceptable ranges.
83 /// This includes validating the common configuration, topk value, and input shape.
84 ///
85 /// # Returns
86 ///
87 /// Ok if the configuration is valid, or an error if validation fails
88 fn validate(&self) -> Result<(), crate::core::ConfigError> {
89 self.common.validate()?;
90
91 if let Some(topk) = self.topk {
92 self.validate_positive_usize(topk, "topk")?;
93 }
94
95 if let Some((width, height)) = self.input_shape {
96 self.validate_image_dimensions(width, height)?;
97 }
98
99 Ok(())
100 }
101
102 /// Gets the default document orientation classifier configuration
103 ///
104 /// Returns a new instance of the document orientation classifier configuration
105 /// with default values for all parameters.
106 ///
107 /// # Returns
108 ///
109 /// A new instance of `DocOrientationClassifierConfig` with default settings
110 fn get_defaults() -> Self {
111 Self {
112 common: CommonBuilderConfig::get_defaults(),
113 topk: Some(4),
114 input_shape: Some((224, 224)),
115 }
116 }
117}
118
119impl DocOrientationResult {
120 /// Creates a new empty document orientation result
121 ///
122 /// Initializes a new instance of the document orientation result with empty vectors
123 /// for all fields.
124 ///
125 /// # Returns
126 ///
127 /// A new instance of `DocOrientationResult` with empty vectors
128 pub fn new() -> Self {
129 Self {
130 input_path: Vec::new(),
131 index: Vec::new(),
132 input_img: Vec::new(),
133 class_ids: Vec::new(),
134 scores: Vec::new(),
135 label_names: Vec::new(),
136 }
137 }
138}
139
140impl Default for DocOrientationResult {
141 /// Creates a new empty document orientation result
142 ///
143 /// This is equivalent to calling `DocOrientationResult::new()`.
144 ///
145 /// # Returns
146 ///
147 /// A new instance of `DocOrientationResult` with empty vectors
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153/// Document orientation classifier built from modular components
154///
155/// This is a type alias over `ModularPredictor` with concrete, composable components
156/// to eliminate duplicated StandardPredictor implementations across predictors.
157pub type DocOrientationClassifier =
158 ModularPredictor<DocOrImageReader, DocOrPreprocessor, OrtInfer2D, DocOrPostprocessor>;
159
160// Granular trait adapters for the document orientation classifier
161use crate::core::{
162 GranularImageReader as GIReader, ModularPredictor, OrtInfer2D, Postprocessor as GPostprocessor,
163 Preprocessor as GPreprocessor,
164};
165use image::DynamicImage;
166
167#[derive(Debug, Clone)]
168pub struct DocOrientationConfig;
169
170#[derive(Debug)]
171pub struct DocOrImageReader {
172 inner: DefaultImageReader,
173}
174
175impl DocOrImageReader {
176 pub fn new() -> Self {
177 Self {
178 inner: DefaultImageReader::new(),
179 }
180 }
181}
182
183impl Default for DocOrImageReader {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189impl GIReader for DocOrImageReader {
190 fn read_images<'a>(
191 &self,
192 paths: impl Iterator<Item = &'a str>,
193 ) -> Result<Vec<RgbImage>, OCRError> {
194 self.inner.apply(paths)
195 }
196}
197
198#[derive(Debug)]
199pub struct DocOrPreprocessor {
200 pub input_shape: (u32, u32),
201 pub normalize: NormalizeImage,
202}
203
204impl GPreprocessor for DocOrPreprocessor {
205 type Config = DocOrientationConfig;
206 type Output = Tensor4D;
207
208 fn preprocess(
209 &self,
210 images: Vec<RgbImage>,
211 _config: Option<&Self::Config>,
212 ) -> Result<Self::Output, OCRError> {
213 use crate::utils::resize_images_batch_to_dynamic;
214 let dynamic_images: Vec<DynamicImage> =
215 resize_images_batch_to_dynamic(&images, self.input_shape.0, self.input_shape.1, None);
216 self.normalize.normalize_batch_to(dynamic_images)
217 }
218
219 fn preprocessing_info(&self) -> String {
220 format!(
221 "resize_to=({},{}) + normalize",
222 self.input_shape.0, self.input_shape.1
223 )
224 }
225}
226
227#[derive(Debug)]
228pub struct DocOrPostprocessor {
229 pub topk: usize,
230 pub topk_op: Topk,
231}
232
233impl GPostprocessor for DocOrPostprocessor {
234 type Config = DocOrientationConfig;
235 type InferenceOutput = Tensor2D;
236 type PreprocessOutput = Tensor4D;
237 type Result = DocOrientationResult;
238
239 fn postprocess(
240 &self,
241 output: Self::InferenceOutput,
242 _preprocess_output: Option<&Self::PreprocessOutput>,
243 batch_data: &BatchData,
244 raw_images: Vec<RgbImage>,
245 _config: Option<&Self::Config>,
246 ) -> crate::core::OcrResult<Self::Result> {
247 // Convert ndarray output to Vec<Vec<f32>> format expected by Topk
248 let predictions: Vec<Vec<f32>> = output.outer_iter().map(|row| row.to_vec()).collect();
249 let topk_result = self
250 .topk_op
251 .process(&predictions, self.topk)
252 .map_err(|e| OCRError::ConfigError { message: e })?;
253
254 Ok(DocOrientationResult {
255 input_path: batch_data.input_paths.clone(),
256 index: batch_data.indexes.clone(),
257 input_img: raw_images.into_iter().map(Arc::new).collect(),
258 class_ids: topk_result.indexes,
259 scores: topk_result.scores,
260 label_names: topk_result
261 .class_names
262 .unwrap_or_default()
263 .into_iter()
264 .map(|names| names.into_iter().map(Arc::from).collect())
265 .collect(),
266 })
267 }
268
269 fn empty_result(&self) -> Result<Self::Result, OCRError> {
270 Ok(DocOrientationResult::new())
271 }
272}
273
274/// Builder for document orientation classifier
275///
276/// This struct provides a builder pattern for creating a document orientation classifier
277/// with custom configuration options.
278pub struct DocOrientationClassifierBuilder {
279 /// Common configuration options shared across predictors
280 common: CommonBuilderConfig,
281
282 /// Number of top predictions to return for each image
283 topk: Option<usize>,
284 /// Input shape for the model (width, height)
285 input_shape: Option<(u32, u32)>,
286}
287
288impl DocOrientationClassifierBuilder {
289 /// Creates a new document orientation classifier builder
290 ///
291 /// Initializes a new instance of the document orientation classifier builder
292 /// with default configuration options.
293 ///
294 /// # Returns
295 ///
296 /// A new instance of `DocOrientationClassifierBuilder`
297 pub fn new() -> Self {
298 Self {
299 common: CommonBuilderConfig::new(),
300 topk: None,
301 input_shape: None,
302 }
303 }
304
305 /// Sets the model path for the classifier
306 ///
307 /// Specifies the path to the ONNX model file that will be used for inference.
308 ///
309 /// # Arguments
310 ///
311 /// * `model_path` - Path to the ONNX model file
312 ///
313 /// # Returns
314 ///
315 /// The updated builder instance
316 pub fn model_path(mut self, model_path: impl Into<std::path::PathBuf>) -> Self {
317 self.common = self.common.model_path(model_path);
318 self
319 }
320
321 /// Sets the model name for the classifier
322 ///
323 /// Specifies the name of the model being used.
324 ///
325 /// # Arguments
326 ///
327 /// * `model_name` - Name of the model
328 ///
329 /// # Returns
330 ///
331 /// The updated builder instance
332 pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
333 self.common = self.common.model_name(model_name);
334 self
335 }
336
337 /// Sets the batch size for the classifier
338 ///
339 /// Specifies the number of images to process in each batch.
340 ///
341 /// # Arguments
342 ///
343 /// * `batch_size` - Number of images to process in each batch
344 ///
345 /// # Returns
346 ///
347 /// The updated builder instance
348 pub fn batch_size(mut self, batch_size: usize) -> Self {
349 self.common = self.common.batch_size(batch_size);
350 self
351 }
352
353 /// Enables or disables logging for the classifier
354 ///
355 /// Controls whether logging is enabled during classification.
356 ///
357 /// # Arguments
358 ///
359 /// * `enable` - Whether to enable logging
360 ///
361 /// # Returns
362 ///
363 /// The updated builder instance
364 pub fn enable_logging(mut self, enable: bool) -> Self {
365 self.common = self.common.enable_logging(enable);
366 self
367 }
368
369 /// Sets the ONNX Runtime session configuration
370 ///
371 /// This function sets the ONNX Runtime session configuration for the predictor.
372 pub fn ort_session(mut self, config: crate::core::config::onnx::OrtSessionConfig) -> Self {
373 self.common = self.common.ort_session(config);
374 self
375 }
376
377 /// Sets the session pool size for concurrent predictions
378 ///
379 /// This function sets the size of the session pool used for concurrent predictions.
380 /// The pool size must be >= 1.
381 ///
382 /// # Arguments
383 ///
384 /// * `size` - The session pool size (minimum 1)
385 ///
386 /// # Returns
387 ///
388 /// The updated builder instance
389 pub fn session_pool_size(mut self, size: usize) -> Self {
390 self.common = self.common.session_pool_size(size);
391 self
392 }
393
394 /// Sets the number of top predictions to return
395 ///
396 /// Specifies how many of the top predictions to return for each image.
397 ///
398 /// # Arguments
399 ///
400 /// * `topk` - Number of top predictions to return
401 ///
402 /// # Returns
403 ///
404 /// The updated builder instance
405 pub fn topk(mut self, topk: usize) -> Self {
406 self.topk = Some(topk);
407 self
408 }
409
410 /// Sets the input shape for the model
411 ///
412 /// Specifies the input shape (width, height) that the model expects.
413 ///
414 /// # Arguments
415 ///
416 /// * `input_shape` - Input shape as (width, height)
417 ///
418 /// # Returns
419 ///
420 /// The updated builder instance
421 pub fn input_shape(mut self, input_shape: (u32, u32)) -> Self {
422 self.input_shape = Some(input_shape);
423 self
424 }
425
426 /// Builds the document orientation classifier
427 ///
428 /// Creates a new instance of the document orientation classifier with the
429 /// configured options.
430 ///
431 /// # Arguments
432 ///
433 /// * `model_path` - Path to the ONNX model file
434 ///
435 /// # Returns
436 ///
437 /// A new instance of `DocOrientationClassifier` or an error if building fails
438 pub fn build(self, model_path: &Path) -> Result<DocOrientationClassifier, OCRError> {
439 self.build_internal(model_path)
440 }
441
442 /// Internal method to build the document orientation classifier
443 ///
444 /// Creates a new instance of the document orientation classifier with the
445 /// configured options. This method handles validation of the configuration
446 /// and initialization of the classifier.
447 ///
448 /// # Arguments
449 ///
450 /// * `model_path` - Path to the ONNX model file
451 ///
452 /// # Returns
453 ///
454 /// A new instance of `DocOrientationClassifier` or an error if building fails
455 fn build_internal(mut self, model_path: &Path) -> Result<DocOrientationClassifier, OCRError> {
456 if self.common.model_path.is_none() {
457 self.common = self.common.model_path(model_path.to_path_buf());
458 }
459
460 let config = DocOrientationClassifierConfig {
461 common: self.common,
462 topk: self.topk,
463 input_shape: self.input_shape,
464 };
465
466 let config = config.validate_and_wrap_ocr_error()?;
467
468 // Build modular components
469 let input_shape = config.input_shape.unwrap_or((224, 224));
470 let image_reader = DocOrImageReader::new();
471 let normalize = NormalizeImage::new(
472 Some(1.0 / 255.0),
473 Some(vec![0.485, 0.456, 0.406]),
474 Some(vec![0.229, 0.224, 0.225]),
475 None,
476 )?;
477 let preprocessor = DocOrPreprocessor {
478 input_shape,
479 normalize,
480 };
481 let infer_inner = OrtInfer::from_common(&config.common, model_path, None)?;
482 let inference_engine = OrtInfer2D::new(infer_inner);
483 let postprocessor = DocOrPostprocessor {
484 topk: config.topk.unwrap_or(4),
485 topk_op: Topk::from_class_names(get_document_orientation_labels()),
486 };
487
488 Ok(ModularPredictor::new(
489 image_reader,
490 preprocessor,
491 inference_engine,
492 postprocessor,
493 ))
494 }
495}
496
497impl Default for DocOrientationClassifierBuilder {
498 /// Creates a new document orientation classifier builder with default settings
499 ///
500 /// This is equivalent to calling `DocOrientationClassifierBuilder::new()`.
501 ///
502 /// # Returns
503 ///
504 /// A new instance of `DocOrientationClassifierBuilder` with default settings
505 fn default() -> Self {
506 Self::new()
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_doc_orientation_config_defaults_and_validate() {
516 let config = DocOrientationClassifierConfig::new();
517 assert_eq!(config.topk, Some(4));
518 assert_eq!(config.input_shape, Some((224, 224)));
519 assert_eq!(
520 config.common.model_name.as_deref(),
521 Some("doc_orientation_classifier")
522 );
523 assert_eq!(config.common.batch_size, Some(1));
524 assert!(config.validate().is_ok());
525 }
526}