oar-ocr-core 0.6.3

Core types and predictors for oar-ocr
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
//! PP-LCNet Classification Model
//!
//! This module provides a pure implementation of the PP-LCNet model for image classification.
//! PP-LCNet is a lightweight classification network that can be used for various classification
//! tasks such as document orientation and text line orientation.

use crate::core::OCRError;
use crate::core::inference::{OrtInfer, TensorInput};
use crate::domain::adapters::preprocessing::rgb_to_dynamic;
use crate::processors::{NormalizeImage, TensorLayout};
use crate::utils::topk::Topk;
use image::{RgbImage, imageops::FilterType};

/// Configuration for PP-LCNet model preprocessing.
#[derive(Debug, Clone)]
pub struct PPLCNetPreprocessConfig {
    /// Input shape (height, width)
    pub input_shape: (u32, u32),
    /// Resizing filter to use
    pub resize_filter: FilterType,
    /// When set, resize by short edge to this size (keep ratio) then center-crop to `input_shape`.
    ///
    /// This matches `ResizeImage(resize_short=256)` + `CropImage(size=224)` used by
    /// most PP-LCNet classifiers (e.g. `PP-LCNet_x1_0_doc_ori`, `PP-LCNet_x1_0_table_cls`).
    ///
    /// Some specialized PP-LCNet classifiers (e.g. `PP-LCNet_x1_0_textline_ori`) use a direct
    /// resize to a fixed `(height,width)` without short-edge resize/crop; set this to `None`
    /// to match `ResizeImage(size=[w,h])`.
    pub resize_short: Option<u32>,
    /// Scaling factor applied before normalization (defaults to 1.0 / 255.0)
    pub normalize_scale: f32,
    /// Mean values for normalization
    pub normalize_mean: Vec<f32>,
    /// Standard deviation values for normalization
    pub normalize_std: Vec<f32>,
    /// Tensor data layout (CHW or HWC)
    pub tensor_layout: TensorLayout,
}

impl Default for PPLCNetPreprocessConfig {
    fn default() -> Self {
        Self {
            input_shape: (224, 224),
            // Use cv2.INTER_LINEAR for PP-LCNet resize.
            resize_filter: FilterType::Triangle,
            // PP-LCNet classifiers default to resize_short=256 then center-crop.
            resize_short: Some(256),
            normalize_scale: 1.0 / 255.0,
            normalize_mean: vec![0.485, 0.456, 0.406],
            normalize_std: vec![0.229, 0.224, 0.225],
            tensor_layout: TensorLayout::CHW,
        }
    }
}

/// Configuration for PP-LCNet model postprocessing.
#[derive(Debug, Clone)]
pub struct PPLCNetPostprocessConfig {
    /// Class labels
    pub labels: Vec<String>,
    /// Number of top predictions to return
    pub topk: usize,
}

impl Default for PPLCNetPostprocessConfig {
    fn default() -> Self {
        Self {
            labels: vec![],
            topk: 1,
        }
    }
}

/// Output from PP-LCNet model.
#[derive(Debug, Clone)]
pub struct PPLCNetModelOutput {
    /// Predicted class IDs per image
    pub class_ids: Vec<Vec<usize>>,
    /// Confidence scores for each prediction
    pub scores: Vec<Vec<f32>>,
    /// Label names for each prediction (if labels provided)
    pub label_names: Option<Vec<Vec<String>>>,
}

/// Pure PP-LCNet model implementation.
///
/// This model performs image classification using the PP-LCNet architecture.
#[derive(Debug)]
pub struct PPLCNetModel {
    /// ONNX Runtime inference engine
    inference: OrtInfer,
    /// Image normalizer for preprocessing
    normalizer: NormalizeImage,
    /// Top-k processor for postprocessing
    topk_processor: Topk,
    /// Input shape (height, width)
    input_shape: (u32, u32),
    /// Resizing filter
    resize_filter: FilterType,
    /// Optional short-edge resize (see `PPLCNetPreprocessConfig::resize_short`)
    resize_short: Option<u32>,
}

impl PPLCNetModel {
    /// Creates a new PP-LCNet model.
    pub fn new(
        inference: OrtInfer,
        normalizer: NormalizeImage,
        topk_processor: Topk,
        input_shape: (u32, u32),
        resize_filter: FilterType,
        resize_short: Option<u32>,
    ) -> Self {
        Self {
            inference,
            normalizer,
            topk_processor,
            input_shape,
            resize_filter,
            resize_short,
        }
    }

    /// Preprocesses images for classification.
    ///
    /// # Arguments
    ///
    /// * `images` - Input images to preprocess
    ///
    /// # Returns
    ///
    /// Preprocessed batch tensor
    pub fn preprocess(&self, images: Vec<RgbImage>) -> Result<ndarray::Array4<f32>, OCRError> {
        let (crop_h, crop_w) = self.input_shape;

        let resized_rgb: Vec<RgbImage> = if let Some(resize_short) = self.resize_short {
            // PP-LCNet classifier preprocessing:
            // 1) Resize by short edge (keep ratio)
            // 2) Center crop to the model input size
            //
            // This matches `ResizeImage(resize_short=256)` + `CropImage(size=224)` used by
            // models like `PP-LCNet_x1_0_doc_ori` / `PP-LCNet_x1_0_table_cls`.
            images
                .into_iter()
                .filter_map(|img| {
                    let (w, h) = (img.width(), img.height());
                    if w == 0 || h == 0 {
                        return None;
                    }

                    let short = w.min(h) as f32;
                    let scale = (resize_short as f32) / short;
                    let new_w = ((w as f32) * scale).round().max(crop_w as f32) as u32;
                    let new_h = ((h as f32) * scale).round().max(crop_h as f32) as u32;

                    let resized = image::imageops::resize(&img, new_w, new_h, self.resize_filter);

                    // Center crop to (crop_w, crop_h)
                    let x1 = (new_w.saturating_sub(crop_w)) / 2;
                    let y1 = (new_h.saturating_sub(crop_h)) / 2;
                    let cropped =
                        image::imageops::crop_imm(&resized, x1, y1, crop_w, crop_h).to_image();
                    Some(cropped)
                })
                .collect()
        } else {
            // Direct resize to input shape (height,width) without crop.
            // This matches `ResizeImage(size=[w,h])` used by
            // `PP-LCNet_x1_0_textline_ori` (80x160).
            images
                .into_iter()
                .filter_map(|img| {
                    let (w, h) = (img.width(), img.height());
                    if w == 0 || h == 0 {
                        return None;
                    }
                    Some(image::imageops::resize(
                        &img,
                        crop_w,
                        crop_h,
                        self.resize_filter,
                    ))
                })
                .collect()
        };

        // Convert to dynamic images and normalize using common helper
        let dynamic_images = rgb_to_dynamic(resized_rgb);
        self.normalizer.normalize_batch_to(dynamic_images)
    }

    /// Runs inference on the preprocessed batch.
    ///
    /// # Arguments
    ///
    /// * `batch_tensor` - Preprocessed batch tensor
    ///
    /// # Returns
    ///
    /// Model predictions as a 2D tensor (batch_size x num_classes)
    pub fn infer(
        &self,
        batch_tensor: &ndarray::Array4<f32>,
    ) -> Result<ndarray::Array2<f32>, OCRError> {
        let input_name = self.inference.input_name();
        let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];

        let outputs = self
            .inference
            .infer(&inputs)
            .map_err(|e| OCRError::Inference {
                model_name: "PP-LCNet".to_string(),
                context: format!(
                    "failed to run inference on batch with shape {:?}",
                    batch_tensor.shape()
                ),
                source: Box::new(e),
            })?;

        let output = outputs
            .into_iter()
            .next()
            .ok_or_else(|| OCRError::InvalidInput {
                message: "PP-LCNet: no output returned from inference".to_string(),
            })?;

        output
            .1
            .try_into_array2_f32()
            .map_err(|e| OCRError::Inference {
                model_name: "PP-LCNet".to_string(),
                context: "failed to convert output to 2D array".to_string(),
                source: Box::new(e),
            })
    }

    /// Postprocesses model predictions to class IDs and scores.
    ///
    /// # Arguments
    ///
    /// * `predictions` - Model predictions (batch_size x num_classes)
    /// * `config` - Postprocessing configuration
    ///
    /// # Returns
    ///
    /// PPLCNetModelOutput containing class IDs, scores, and optional label names
    pub fn postprocess(
        &self,
        predictions: &ndarray::Array2<f32>,
        config: &PPLCNetPostprocessConfig,
    ) -> Result<PPLCNetModelOutput, OCRError> {
        let predictions_vec: Vec<Vec<f32>> =
            predictions.outer_iter().map(|row| row.to_vec()).collect();

        let topk_result = self
            .topk_processor
            .process(&predictions_vec, config.topk)
            .unwrap_or_else(|_| crate::utils::topk::TopkResult {
                indexes: vec![],
                scores: vec![],
                class_names: None,
            });

        let class_ids = topk_result.indexes;
        let scores = topk_result.scores;

        // Map class IDs to label names if labels are provided
        let label_names = if !config.labels.is_empty() {
            Some(
                class_ids
                    .iter()
                    .map(|ids| {
                        ids.iter()
                            .map(|&id| {
                                config
                                    .labels
                                    .get(id)
                                    .cloned()
                                    .unwrap_or_else(|| format!("class_{}", id))
                            })
                            .collect()
                    })
                    .collect(),
            )
        } else {
            topk_result.class_names
        };

        Ok(PPLCNetModelOutput {
            class_ids,
            scores,
            label_names,
        })
    }

    /// Performs complete forward pass: preprocess -> infer -> postprocess.
    ///
    /// # Arguments
    ///
    /// * `images` - Input images to classify
    /// * `config` - Postprocessing configuration
    ///
    /// # Returns
    ///
    /// PPLCNetModelOutput containing classification results
    pub fn forward(
        &self,
        images: Vec<RgbImage>,
        config: &PPLCNetPostprocessConfig,
    ) -> Result<PPLCNetModelOutput, OCRError> {
        let batch_tensor = self.preprocess(images)?;
        let predictions = self.infer(&batch_tensor)?;
        self.postprocess(&predictions, config)
    }
}

/// Builder for PP-LCNet model.
#[derive(Debug, Default)]
pub struct PPLCNetModelBuilder {
    /// Preprocessing configuration
    preprocess_config: PPLCNetPreprocessConfig,
    /// ONNX Runtime session configuration
    ort_config: Option<crate::core::config::OrtSessionConfig>,
}

impl PPLCNetModelBuilder {
    /// Creates a new PP-LCNet model builder.
    pub fn new() -> Self {
        Self {
            preprocess_config: PPLCNetPreprocessConfig::default(),
            ort_config: None,
        }
    }

    /// Sets the preprocessing configuration.
    pub fn preprocess_config(mut self, config: PPLCNetPreprocessConfig) -> Self {
        self.preprocess_config = config;
        self
    }

    /// Sets the input image shape.
    pub fn input_shape(mut self, shape: (u32, u32)) -> Self {
        self.preprocess_config.input_shape = shape;
        self
    }

    /// Sets the resizing filter.
    pub fn resize_filter(mut self, filter: FilterType) -> Self {
        self.preprocess_config.resize_filter = filter;
        self
    }

    /// Sets the ONNX Runtime session configuration.
    pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
        self.ort_config = Some(config);
        self
    }

    /// Builds the PP-LCNet model.
    ///
    /// # Arguments
    ///
    /// * `model_path` - Path to the ONNX model file
    ///
    /// # Returns
    ///
    /// A configured PP-LCNet model instance
    pub fn build(self, model_path: &std::path::Path) -> Result<PPLCNetModel, OCRError> {
        // Create ONNX inference engine
        let inference = if self.ort_config.is_some() {
            use crate::core::config::ModelInferenceConfig;
            let common_config = ModelInferenceConfig {
                ort_session: self.ort_config,
                ..Default::default()
            };
            OrtInfer::from_config(&common_config, model_path, None)?
        } else {
            OrtInfer::new(model_path, None)?
        };

        // Create normalizer (ImageNet normalization).
        //
        // PP-LCNet classifiers read images as **RGB** by default (no `DecodeImage`
        // op in the official inference.yml), so we keep RGB order here.
        let mean = self.preprocess_config.normalize_mean.clone();
        let std = self.preprocess_config.normalize_std.clone();
        let normalizer = NormalizeImage::with_color_order(
            Some(self.preprocess_config.normalize_scale),
            Some(mean),
            Some(std),
            Some(self.preprocess_config.tensor_layout),
            Some(crate::processors::types::ColorOrder::RGB),
        )?;

        // Create top-k processor
        let topk_processor = Topk::new(None);

        Ok(PPLCNetModel::new(
            inference,
            normalizer,
            topk_processor,
            self.preprocess_config.input_shape,
            self.preprocess_config.resize_filter,
            self.preprocess_config.resize_short,
        ))
    }
}