Skip to main content

oar_ocr_core/models/classification/
pp_lcnet.rs

1//! PP-LCNet Classification Model
2//!
3//! This module provides a pure implementation of the PP-LCNet model for image classification.
4//! PP-LCNet is a lightweight classification network that can be used for various classification
5//! tasks such as document orientation and text line orientation.
6
7use crate::core::OCRError;
8use crate::core::inference::{OrtInfer, TensorInput};
9use crate::domain::adapters::preprocessing::rgb_to_dynamic;
10use crate::processors::{NormalizeImage, TensorLayout};
11use crate::utils::topk::Topk;
12use image::{RgbImage, imageops::FilterType};
13
14/// Configuration for PP-LCNet model preprocessing.
15#[derive(Debug, Clone)]
16pub struct PPLCNetPreprocessConfig {
17    /// Input shape (height, width)
18    pub input_shape: (u32, u32),
19    /// Resizing filter to use
20    pub resize_filter: FilterType,
21    /// When set, resize by short edge to this size (keep ratio) then center-crop to `input_shape`.
22    ///
23    /// This matches `ResizeImage(resize_short=256)` + `CropImage(size=224)` used by
24    /// most PP-LCNet classifiers (e.g. `PP-LCNet_x1_0_doc_ori`, `PP-LCNet_x1_0_table_cls`).
25    ///
26    /// Some specialized PP-LCNet classifiers (e.g. `PP-LCNet_x1_0_textline_ori`) use a direct
27    /// resize to a fixed `(height,width)` without short-edge resize/crop; set this to `None`
28    /// to match `ResizeImage(size=[w,h])`.
29    pub resize_short: Option<u32>,
30    /// Scaling factor applied before normalization (defaults to 1.0 / 255.0)
31    pub normalize_scale: f32,
32    /// Mean values for normalization
33    pub normalize_mean: Vec<f32>,
34    /// Standard deviation values for normalization
35    pub normalize_std: Vec<f32>,
36    /// Tensor data layout (CHW or HWC)
37    pub tensor_layout: TensorLayout,
38}
39
40impl Default for PPLCNetPreprocessConfig {
41    fn default() -> Self {
42        Self {
43            input_shape: (224, 224),
44            // Use cv2.INTER_LINEAR for PP-LCNet resize.
45            resize_filter: FilterType::Triangle,
46            // PP-LCNet classifiers default to resize_short=256 then center-crop.
47            resize_short: Some(256),
48            normalize_scale: 1.0 / 255.0,
49            normalize_mean: vec![0.485, 0.456, 0.406],
50            normalize_std: vec![0.229, 0.224, 0.225],
51            tensor_layout: TensorLayout::CHW,
52        }
53    }
54}
55
56/// Configuration for PP-LCNet model postprocessing.
57#[derive(Debug, Clone)]
58pub struct PPLCNetPostprocessConfig {
59    /// Class labels
60    pub labels: Vec<String>,
61    /// Number of top predictions to return
62    pub topk: usize,
63}
64
65impl Default for PPLCNetPostprocessConfig {
66    fn default() -> Self {
67        Self {
68            labels: vec![],
69            topk: 1,
70        }
71    }
72}
73
74/// Output from PP-LCNet model.
75#[derive(Debug, Clone)]
76pub struct PPLCNetModelOutput {
77    /// Predicted class IDs per image
78    pub class_ids: Vec<Vec<usize>>,
79    /// Confidence scores for each prediction
80    pub scores: Vec<Vec<f32>>,
81    /// Label names for each prediction (if labels provided)
82    pub label_names: Option<Vec<Vec<String>>>,
83}
84
85/// Pure PP-LCNet model implementation.
86///
87/// This model performs image classification using the PP-LCNet architecture.
88#[derive(Debug)]
89pub struct PPLCNetModel {
90    /// ONNX Runtime inference engine
91    inference: OrtInfer,
92    /// Image normalizer for preprocessing
93    normalizer: NormalizeImage,
94    /// Top-k processor for postprocessing
95    topk_processor: Topk,
96    /// Input shape (height, width)
97    input_shape: (u32, u32),
98    /// Resizing filter
99    resize_filter: FilterType,
100    /// Optional short-edge resize (see `PPLCNetPreprocessConfig::resize_short`)
101    resize_short: Option<u32>,
102}
103
104impl PPLCNetModel {
105    /// Creates a new PP-LCNet model.
106    pub fn new(
107        inference: OrtInfer,
108        normalizer: NormalizeImage,
109        topk_processor: Topk,
110        input_shape: (u32, u32),
111        resize_filter: FilterType,
112        resize_short: Option<u32>,
113    ) -> Self {
114        Self {
115            inference,
116            normalizer,
117            topk_processor,
118            input_shape,
119            resize_filter,
120            resize_short,
121        }
122    }
123
124    /// Preprocesses images for classification.
125    ///
126    /// # Arguments
127    ///
128    /// * `images` - Input images to preprocess
129    ///
130    /// # Returns
131    ///
132    /// Preprocessed batch tensor
133    pub fn preprocess(&self, images: Vec<RgbImage>) -> Result<ndarray::Array4<f32>, OCRError> {
134        let (crop_h, crop_w) = self.input_shape;
135
136        let resized_rgb: Vec<RgbImage> = if let Some(resize_short) = self.resize_short {
137            // PP-LCNet classifier preprocessing:
138            // 1) Resize by short edge (keep ratio)
139            // 2) Center crop to the model input size
140            //
141            // This matches `ResizeImage(resize_short=256)` + `CropImage(size=224)` used by
142            // models like `PP-LCNet_x1_0_doc_ori` / `PP-LCNet_x1_0_table_cls`.
143            images
144                .into_iter()
145                .filter_map(|img| {
146                    let (w, h) = (img.width(), img.height());
147                    if w == 0 || h == 0 {
148                        return None;
149                    }
150
151                    let short = w.min(h) as f32;
152                    let scale = (resize_short as f32) / short;
153                    let new_w = ((w as f32) * scale).round().max(crop_w as f32) as u32;
154                    let new_h = ((h as f32) * scale).round().max(crop_h as f32) as u32;
155
156                    let resized = image::imageops::resize(&img, new_w, new_h, self.resize_filter);
157
158                    // Center crop to (crop_w, crop_h)
159                    let x1 = (new_w.saturating_sub(crop_w)) / 2;
160                    let y1 = (new_h.saturating_sub(crop_h)) / 2;
161                    let cropped =
162                        image::imageops::crop_imm(&resized, x1, y1, crop_w, crop_h).to_image();
163                    Some(cropped)
164                })
165                .collect()
166        } else {
167            // Direct resize to input shape (height,width) without crop.
168            // This matches `ResizeImage(size=[w,h])` used by
169            // `PP-LCNet_x1_0_textline_ori` (80x160).
170            images
171                .into_iter()
172                .filter_map(|img| {
173                    let (w, h) = (img.width(), img.height());
174                    if w == 0 || h == 0 {
175                        return None;
176                    }
177                    Some(image::imageops::resize(
178                        &img,
179                        crop_w,
180                        crop_h,
181                        self.resize_filter,
182                    ))
183                })
184                .collect()
185        };
186
187        // Convert to dynamic images and normalize using common helper
188        let dynamic_images = rgb_to_dynamic(resized_rgb);
189        self.normalizer.normalize_batch_to(dynamic_images)
190    }
191
192    /// Runs inference on the preprocessed batch.
193    ///
194    /// # Arguments
195    ///
196    /// * `batch_tensor` - Preprocessed batch tensor
197    ///
198    /// # Returns
199    ///
200    /// Model predictions as a 2D tensor (batch_size x num_classes)
201    pub fn infer(
202        &self,
203        batch_tensor: &ndarray::Array4<f32>,
204    ) -> Result<ndarray::Array2<f32>, OCRError> {
205        let input_name = self.inference.input_name();
206        let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
207
208        let outputs = self
209            .inference
210            .infer(&inputs)
211            .map_err(|e| OCRError::Inference {
212                model_name: "PP-LCNet".to_string(),
213                context: format!(
214                    "failed to run inference on batch with shape {:?}",
215                    batch_tensor.shape()
216                ),
217                source: Box::new(e),
218            })?;
219
220        let output = outputs
221            .into_iter()
222            .next()
223            .ok_or_else(|| OCRError::InvalidInput {
224                message: "PP-LCNet: no output returned from inference".to_string(),
225            })?;
226
227        output
228            .1
229            .try_into_array2_f32()
230            .map_err(|e| OCRError::Inference {
231                model_name: "PP-LCNet".to_string(),
232                context: "failed to convert output to 2D array".to_string(),
233                source: Box::new(e),
234            })
235    }
236
237    /// Postprocesses model predictions to class IDs and scores.
238    ///
239    /// # Arguments
240    ///
241    /// * `predictions` - Model predictions (batch_size x num_classes)
242    /// * `config` - Postprocessing configuration
243    ///
244    /// # Returns
245    ///
246    /// PPLCNetModelOutput containing class IDs, scores, and optional label names
247    pub fn postprocess(
248        &self,
249        predictions: &ndarray::Array2<f32>,
250        config: &PPLCNetPostprocessConfig,
251    ) -> Result<PPLCNetModelOutput, OCRError> {
252        let predictions_vec: Vec<Vec<f32>> =
253            predictions.outer_iter().map(|row| row.to_vec()).collect();
254
255        let topk_result = self
256            .topk_processor
257            .process(&predictions_vec, config.topk)
258            .unwrap_or_else(|_| crate::utils::topk::TopkResult {
259                indexes: vec![],
260                scores: vec![],
261                class_names: None,
262            });
263
264        let class_ids = topk_result.indexes;
265        let scores = topk_result.scores;
266
267        // Map class IDs to label names if labels are provided
268        let label_names = if !config.labels.is_empty() {
269            Some(
270                class_ids
271                    .iter()
272                    .map(|ids| {
273                        ids.iter()
274                            .map(|&id| {
275                                config
276                                    .labels
277                                    .get(id)
278                                    .cloned()
279                                    .unwrap_or_else(|| format!("class_{}", id))
280                            })
281                            .collect()
282                    })
283                    .collect(),
284            )
285        } else {
286            topk_result.class_names
287        };
288
289        Ok(PPLCNetModelOutput {
290            class_ids,
291            scores,
292            label_names,
293        })
294    }
295
296    /// Performs complete forward pass: preprocess -> infer -> postprocess.
297    ///
298    /// # Arguments
299    ///
300    /// * `images` - Input images to classify
301    /// * `config` - Postprocessing configuration
302    ///
303    /// # Returns
304    ///
305    /// PPLCNetModelOutput containing classification results
306    pub fn forward(
307        &self,
308        images: Vec<RgbImage>,
309        config: &PPLCNetPostprocessConfig,
310    ) -> Result<PPLCNetModelOutput, OCRError> {
311        let batch_tensor = self.preprocess(images)?;
312        let predictions = self.infer(&batch_tensor)?;
313        self.postprocess(&predictions, config)
314    }
315}
316
317/// Builder for PP-LCNet model.
318#[derive(Debug, Default)]
319pub struct PPLCNetModelBuilder {
320    /// Preprocessing configuration
321    preprocess_config: PPLCNetPreprocessConfig,
322    /// ONNX Runtime session configuration
323    ort_config: Option<crate::core::config::OrtSessionConfig>,
324}
325
326impl PPLCNetModelBuilder {
327    /// Creates a new PP-LCNet model builder.
328    pub fn new() -> Self {
329        Self {
330            preprocess_config: PPLCNetPreprocessConfig::default(),
331            ort_config: None,
332        }
333    }
334
335    /// Sets the preprocessing configuration.
336    pub fn preprocess_config(mut self, config: PPLCNetPreprocessConfig) -> Self {
337        self.preprocess_config = config;
338        self
339    }
340
341    /// Sets the input image shape.
342    pub fn input_shape(mut self, shape: (u32, u32)) -> Self {
343        self.preprocess_config.input_shape = shape;
344        self
345    }
346
347    /// Sets the resizing filter.
348    pub fn resize_filter(mut self, filter: FilterType) -> Self {
349        self.preprocess_config.resize_filter = filter;
350        self
351    }
352
353    /// Sets the ONNX Runtime session configuration.
354    pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
355        self.ort_config = Some(config);
356        self
357    }
358
359    /// Builds the PP-LCNet model.
360    ///
361    /// # Arguments
362    ///
363    /// * `model_path` - Path to the ONNX model file
364    ///
365    /// # Returns
366    ///
367    /// A configured PP-LCNet model instance
368    pub fn build(self, model_path: &std::path::Path) -> Result<PPLCNetModel, OCRError> {
369        // Create ONNX inference engine
370        let inference = if self.ort_config.is_some() {
371            use crate::core::config::ModelInferenceConfig;
372            let common_config = ModelInferenceConfig {
373                ort_session: self.ort_config,
374                ..Default::default()
375            };
376            OrtInfer::from_config(&common_config, model_path, None)?
377        } else {
378            OrtInfer::new(model_path, None)?
379        };
380
381        // Create normalizer (ImageNet normalization).
382        //
383        // PP-LCNet classifiers read images as **RGB** by default (no `DecodeImage`
384        // op in the official inference.yml), so we keep RGB order here.
385        let mean = self.preprocess_config.normalize_mean.clone();
386        let std = self.preprocess_config.normalize_std.clone();
387        let normalizer = NormalizeImage::with_color_order(
388            Some(self.preprocess_config.normalize_scale),
389            Some(mean),
390            Some(std),
391            Some(self.preprocess_config.tensor_layout),
392            Some(crate::processors::types::ColorOrder::RGB),
393        )?;
394
395        // Create top-k processor
396        let topk_processor = Topk::new(None);
397
398        Ok(PPLCNetModel::new(
399            inference,
400            normalizer,
401            topk_processor,
402            self.preprocess_config.input_shape,
403            self.preprocess_config.resize_filter,
404            self.preprocess_config.resize_short,
405        ))
406    }
407}