Skip to main content

oar_ocr_core/models/classification/
pp_lcnet.rs

1//! PP-LCNet Classification Model
2//!
3//! This module provides a pure implementation of the PP-LCNet model for image classification.
4//! PP-LCNet is a lightweight classification network that can be used for various classification
5//! tasks such as document orientation and text line orientation.
6
7use crate::core::inference::OrtInfer;
8use crate::core::{OCRError, Tensor2D, Tensor4D};
9use crate::domain::adapters::preprocessing::rgb_to_dynamic;
10use crate::processors::{NormalizeImage, TensorLayout};
11use crate::utils::topk::Topk;
12use image::{RgbImage, imageops::FilterType};
13
14/// Configuration for PP-LCNet model preprocessing.
15#[derive(Debug, Clone)]
16pub struct PPLCNetPreprocessConfig {
17    /// Input shape (height, width)
18    pub input_shape: (u32, u32),
19    /// Resizing filter to use
20    pub resize_filter: FilterType,
21    /// When set, resize by short edge to this size (keep ratio) then center-crop to `input_shape`.
22    ///
23    /// This matches `ResizeImage(resize_short=256)` + `CropImage(size=224)` used by
24    /// most PP-LCNet classifiers (e.g. `PP-LCNet_x1_0_doc_ori`, `PP-LCNet_x1_0_table_cls`).
25    ///
26    /// Some specialized PP-LCNet classifiers (e.g. `PP-LCNet_x1_0_textline_ori`) use a direct
27    /// resize to a fixed `(height,width)` without short-edge resize/crop; set this to `None`
28    /// to match `ResizeImage(size=[w,h])`.
29    pub resize_short: Option<u32>,
30    /// Scaling factor applied before normalization (defaults to 1.0 / 255.0)
31    pub normalize_scale: f32,
32    /// Mean values for normalization
33    pub normalize_mean: Vec<f32>,
34    /// Standard deviation values for normalization
35    pub normalize_std: Vec<f32>,
36    /// Tensor data layout (CHW or HWC)
37    pub tensor_layout: TensorLayout,
38}
39
40impl Default for PPLCNetPreprocessConfig {
41    fn default() -> Self {
42        Self {
43            input_shape: (224, 224),
44            // Use cv2.INTER_LINEAR for PP-LCNet resize.
45            resize_filter: FilterType::Triangle,
46            // PP-LCNet classifiers default to resize_short=256 then center-crop.
47            resize_short: Some(256),
48            normalize_scale: 1.0 / 255.0,
49            normalize_mean: vec![0.485, 0.456, 0.406],
50            normalize_std: vec![0.229, 0.224, 0.225],
51            tensor_layout: TensorLayout::CHW,
52        }
53    }
54}
55
56/// Configuration for PP-LCNet model postprocessing.
57#[derive(Debug, Clone)]
58pub struct PPLCNetPostprocessConfig {
59    /// Class labels
60    pub labels: Vec<String>,
61    /// Number of top predictions to return
62    pub topk: usize,
63}
64
65impl Default for PPLCNetPostprocessConfig {
66    fn default() -> Self {
67        Self {
68            labels: vec![],
69            topk: 1,
70        }
71    }
72}
73
74/// Output from PP-LCNet model.
75#[derive(Debug, Clone)]
76pub struct PPLCNetModelOutput {
77    /// Predicted class IDs per image
78    pub class_ids: Vec<Vec<usize>>,
79    /// Confidence scores for each prediction
80    pub scores: Vec<Vec<f32>>,
81    /// Label names for each prediction (if labels provided)
82    pub label_names: Option<Vec<Vec<String>>>,
83}
84
85/// Pure PP-LCNet model implementation.
86///
87/// This model performs image classification using the PP-LCNet architecture.
88#[derive(Debug)]
89pub struct PPLCNetModel {
90    /// ONNX Runtime inference engine
91    inference: OrtInfer,
92    /// Image normalizer for preprocessing
93    normalizer: NormalizeImage,
94    /// Top-k processor for postprocessing
95    topk_processor: Topk,
96    /// Input shape (height, width)
97    input_shape: (u32, u32),
98    /// Resizing filter
99    resize_filter: FilterType,
100    /// Optional short-edge resize (see `PPLCNetPreprocessConfig::resize_short`)
101    resize_short: Option<u32>,
102}
103
104impl PPLCNetModel {
105    /// Creates a new PP-LCNet model.
106    pub fn new(
107        inference: OrtInfer,
108        normalizer: NormalizeImage,
109        topk_processor: Topk,
110        input_shape: (u32, u32),
111        resize_filter: FilterType,
112        resize_short: Option<u32>,
113    ) -> Self {
114        Self {
115            inference,
116            normalizer,
117            topk_processor,
118            input_shape,
119            resize_filter,
120            resize_short,
121        }
122    }
123
124    /// Preprocesses images for classification.
125    ///
126    /// # Arguments
127    ///
128    /// * `images` - Input images to preprocess
129    ///
130    /// # Returns
131    ///
132    /// Preprocessed batch tensor
133    pub fn preprocess(&self, images: Vec<RgbImage>) -> Result<Tensor4D, OCRError> {
134        let (crop_h, crop_w) = self.input_shape;
135
136        let resized_rgb: Vec<RgbImage> = if let Some(resize_short) = self.resize_short {
137            // PP-LCNet classifier preprocessing:
138            // 1) Resize by short edge (keep ratio)
139            // 2) Center crop to the model input size
140            //
141            // This matches `ResizeImage(resize_short=256)` + `CropImage(size=224)` used by
142            // models like `PP-LCNet_x1_0_doc_ori` / `PP-LCNet_x1_0_table_cls`.
143            images
144                .into_iter()
145                .filter_map(|img| {
146                    let (w, h) = (img.width(), img.height());
147                    if w == 0 || h == 0 {
148                        return None;
149                    }
150
151                    let short = w.min(h) as f32;
152                    let scale = (resize_short as f32) / short;
153                    let new_w = ((w as f32) * scale).round().max(crop_w as f32) as u32;
154                    let new_h = ((h as f32) * scale).round().max(crop_h as f32) as u32;
155
156                    let resized = image::imageops::resize(&img, new_w, new_h, self.resize_filter);
157
158                    // Center crop to (crop_w, crop_h)
159                    let x1 = (new_w.saturating_sub(crop_w)) / 2;
160                    let y1 = (new_h.saturating_sub(crop_h)) / 2;
161                    let cropped =
162                        image::imageops::crop_imm(&resized, x1, y1, crop_w, crop_h).to_image();
163                    Some(cropped)
164                })
165                .collect()
166        } else {
167            // Direct resize to input shape (height,width) without crop.
168            // This matches `ResizeImage(size=[w,h])` used by
169            // `PP-LCNet_x1_0_textline_ori` (80x160).
170            images
171                .into_iter()
172                .filter_map(|img| {
173                    let (w, h) = (img.width(), img.height());
174                    if w == 0 || h == 0 {
175                        return None;
176                    }
177                    Some(image::imageops::resize(
178                        &img,
179                        crop_w,
180                        crop_h,
181                        self.resize_filter,
182                    ))
183                })
184                .collect()
185        };
186
187        // Convert to dynamic images and normalize using common helper
188        let dynamic_images = rgb_to_dynamic(resized_rgb);
189        self.normalizer.normalize_batch_to(dynamic_images)
190    }
191
192    /// Runs inference on the preprocessed batch.
193    ///
194    /// # Arguments
195    ///
196    /// * `batch_tensor` - Preprocessed batch tensor
197    ///
198    /// # Returns
199    ///
200    /// Model predictions as a 2D tensor (batch_size x num_classes)
201    pub fn infer(&self, batch_tensor: &Tensor4D) -> Result<Tensor2D, OCRError> {
202        self.inference
203            .infer_2d(batch_tensor)
204            .map_err(|e| OCRError::Inference {
205                model_name: "PP-LCNet".to_string(),
206                context: format!(
207                    "failed to run inference on batch with shape {:?}",
208                    batch_tensor.shape()
209                ),
210                source: Box::new(e),
211            })
212    }
213
214    /// Postprocesses model predictions to class IDs and scores.
215    ///
216    /// # Arguments
217    ///
218    /// * `predictions` - Model predictions (batch_size x num_classes)
219    /// * `config` - Postprocessing configuration
220    ///
221    /// # Returns
222    ///
223    /// PPLCNetModelOutput containing class IDs, scores, and optional label names
224    pub fn postprocess(
225        &self,
226        predictions: &Tensor2D,
227        config: &PPLCNetPostprocessConfig,
228    ) -> Result<PPLCNetModelOutput, OCRError> {
229        let predictions_vec: Vec<Vec<f32>> =
230            predictions.outer_iter().map(|row| row.to_vec()).collect();
231
232        let topk_result = self
233            .topk_processor
234            .process(&predictions_vec, config.topk)
235            .unwrap_or_else(|_| crate::utils::topk::TopkResult {
236                indexes: vec![],
237                scores: vec![],
238                class_names: None,
239            });
240
241        let class_ids = topk_result.indexes;
242        let scores = topk_result.scores;
243
244        // Map class IDs to label names if labels are provided
245        let label_names = if !config.labels.is_empty() {
246            Some(
247                class_ids
248                    .iter()
249                    .map(|ids| {
250                        ids.iter()
251                            .map(|&id| {
252                                config
253                                    .labels
254                                    .get(id)
255                                    .cloned()
256                                    .unwrap_or_else(|| format!("class_{}", id))
257                            })
258                            .collect()
259                    })
260                    .collect(),
261            )
262        } else {
263            topk_result.class_names
264        };
265
266        Ok(PPLCNetModelOutput {
267            class_ids,
268            scores,
269            label_names,
270        })
271    }
272
273    /// Performs complete forward pass: preprocess -> infer -> postprocess.
274    ///
275    /// # Arguments
276    ///
277    /// * `images` - Input images to classify
278    /// * `config` - Postprocessing configuration
279    ///
280    /// # Returns
281    ///
282    /// PPLCNetModelOutput containing classification results
283    pub fn forward(
284        &self,
285        images: Vec<RgbImage>,
286        config: &PPLCNetPostprocessConfig,
287    ) -> Result<PPLCNetModelOutput, OCRError> {
288        let batch_tensor = self.preprocess(images)?;
289        let predictions = self.infer(&batch_tensor)?;
290        self.postprocess(&predictions, config)
291    }
292}
293
294/// Builder for PP-LCNet model.
295#[derive(Debug, Default)]
296pub struct PPLCNetModelBuilder {
297    /// Preprocessing configuration
298    preprocess_config: PPLCNetPreprocessConfig,
299    /// ONNX Runtime session configuration
300    ort_config: Option<crate::core::config::OrtSessionConfig>,
301}
302
303impl PPLCNetModelBuilder {
304    /// Creates a new PP-LCNet model builder.
305    pub fn new() -> Self {
306        Self {
307            preprocess_config: PPLCNetPreprocessConfig::default(),
308            ort_config: None,
309        }
310    }
311
312    /// Sets the preprocessing configuration.
313    pub fn preprocess_config(mut self, config: PPLCNetPreprocessConfig) -> Self {
314        self.preprocess_config = config;
315        self
316    }
317
318    /// Sets the input image shape.
319    pub fn input_shape(mut self, shape: (u32, u32)) -> Self {
320        self.preprocess_config.input_shape = shape;
321        self
322    }
323
324    /// Sets the resizing filter.
325    pub fn resize_filter(mut self, filter: FilterType) -> Self {
326        self.preprocess_config.resize_filter = filter;
327        self
328    }
329
330    /// Sets the ONNX Runtime session configuration.
331    pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
332        self.ort_config = Some(config);
333        self
334    }
335
336    /// Builds the PP-LCNet model.
337    ///
338    /// # Arguments
339    ///
340    /// * `model_path` - Path to the ONNX model file
341    ///
342    /// # Returns
343    ///
344    /// A configured PP-LCNet model instance
345    pub fn build(self, model_path: &std::path::Path) -> Result<PPLCNetModel, OCRError> {
346        // Create ONNX inference engine
347        let inference = if self.ort_config.is_some() {
348            use crate::core::config::ModelInferenceConfig;
349            let common_config = ModelInferenceConfig {
350                ort_session: self.ort_config,
351                ..Default::default()
352            };
353            OrtInfer::from_config(&common_config, model_path, None)?
354        } else {
355            OrtInfer::new(model_path, None)?
356        };
357
358        // Create normalizer (ImageNet normalization).
359        //
360        // PP-LCNet classifiers read images as **RGB** by default (no `DecodeImage`
361        // op in the official inference.yml), so we keep RGB order here.
362        let mean = self.preprocess_config.normalize_mean.clone();
363        let std = self.preprocess_config.normalize_std.clone();
364        let normalizer = NormalizeImage::with_color_order(
365            Some(self.preprocess_config.normalize_scale),
366            Some(mean),
367            Some(std),
368            Some(self.preprocess_config.tensor_layout),
369            Some(crate::processors::types::ColorOrder::RGB),
370        )?;
371
372        // Create top-k processor
373        let topk_processor = Topk::new(None);
374
375        Ok(PPLCNetModel::new(
376            inference,
377            normalizer,
378            topk_processor,
379            self.preprocess_config.input_shape,
380            self.preprocess_config.resize_filter,
381            self.preprocess_config.resize_short,
382        ))
383    }
384}