oar_ocr_core/models/classification/pp_lcnet.rs
1//! PP-LCNet Classification Model
2//!
3//! This module provides a pure implementation of the PP-LCNet model for image classification.
4//! PP-LCNet is a lightweight classification network that can be used for various classification
5//! tasks such as document orientation and text line orientation.
6
7use crate::core::inference::OrtInfer;
8use crate::core::{OCRError, Tensor2D, Tensor4D};
9use crate::domain::adapters::preprocessing::rgb_to_dynamic;
10use crate::processors::{NormalizeImage, TensorLayout};
11use crate::utils::topk::Topk;
12use image::{RgbImage, imageops::FilterType};
13
14/// Configuration for PP-LCNet model preprocessing.
15#[derive(Debug, Clone)]
16pub struct PPLCNetPreprocessConfig {
17 /// Input shape (height, width)
18 pub input_shape: (u32, u32),
19 /// Resizing filter to use
20 pub resize_filter: FilterType,
21 /// When set, resize by short edge to this size (keep ratio) then center-crop to `input_shape`.
22 ///
23 /// This matches `ResizeImage(resize_short=256)` + `CropImage(size=224)` used by
24 /// most PP-LCNet classifiers (e.g. `PP-LCNet_x1_0_doc_ori`, `PP-LCNet_x1_0_table_cls`).
25 ///
26 /// Some specialized PP-LCNet classifiers (e.g. `PP-LCNet_x1_0_textline_ori`) use a direct
27 /// resize to a fixed `(height,width)` without short-edge resize/crop; set this to `None`
28 /// to match `ResizeImage(size=[w,h])`.
29 pub resize_short: Option<u32>,
30 /// Scaling factor applied before normalization (defaults to 1.0 / 255.0)
31 pub normalize_scale: f32,
32 /// Mean values for normalization
33 pub normalize_mean: Vec<f32>,
34 /// Standard deviation values for normalization
35 pub normalize_std: Vec<f32>,
36 /// Tensor data layout (CHW or HWC)
37 pub tensor_layout: TensorLayout,
38}
39
40impl Default for PPLCNetPreprocessConfig {
41 fn default() -> Self {
42 Self {
43 input_shape: (224, 224),
44 // Use cv2.INTER_LINEAR for PP-LCNet resize.
45 resize_filter: FilterType::Triangle,
46 // PP-LCNet classifiers default to resize_short=256 then center-crop.
47 resize_short: Some(256),
48 normalize_scale: 1.0 / 255.0,
49 normalize_mean: vec![0.485, 0.456, 0.406],
50 normalize_std: vec![0.229, 0.224, 0.225],
51 tensor_layout: TensorLayout::CHW,
52 }
53 }
54}
55
56/// Configuration for PP-LCNet model postprocessing.
57#[derive(Debug, Clone)]
58pub struct PPLCNetPostprocessConfig {
59 /// Class labels
60 pub labels: Vec<String>,
61 /// Number of top predictions to return
62 pub topk: usize,
63}
64
65impl Default for PPLCNetPostprocessConfig {
66 fn default() -> Self {
67 Self {
68 labels: vec![],
69 topk: 1,
70 }
71 }
72}
73
74/// Output from PP-LCNet model.
75#[derive(Debug, Clone)]
76pub struct PPLCNetModelOutput {
77 /// Predicted class IDs per image
78 pub class_ids: Vec<Vec<usize>>,
79 /// Confidence scores for each prediction
80 pub scores: Vec<Vec<f32>>,
81 /// Label names for each prediction (if labels provided)
82 pub label_names: Option<Vec<Vec<String>>>,
83}
84
85/// Pure PP-LCNet model implementation.
86///
87/// This model performs image classification using the PP-LCNet architecture.
88#[derive(Debug)]
89pub struct PPLCNetModel {
90 /// ONNX Runtime inference engine
91 inference: OrtInfer,
92 /// Image normalizer for preprocessing
93 normalizer: NormalizeImage,
94 /// Top-k processor for postprocessing
95 topk_processor: Topk,
96 /// Input shape (height, width)
97 input_shape: (u32, u32),
98 /// Resizing filter
99 resize_filter: FilterType,
100 /// Optional short-edge resize (see `PPLCNetPreprocessConfig::resize_short`)
101 resize_short: Option<u32>,
102}
103
104impl PPLCNetModel {
105 /// Creates a new PP-LCNet model.
106 pub fn new(
107 inference: OrtInfer,
108 normalizer: NormalizeImage,
109 topk_processor: Topk,
110 input_shape: (u32, u32),
111 resize_filter: FilterType,
112 resize_short: Option<u32>,
113 ) -> Self {
114 Self {
115 inference,
116 normalizer,
117 topk_processor,
118 input_shape,
119 resize_filter,
120 resize_short,
121 }
122 }
123
124 /// Preprocesses images for classification.
125 ///
126 /// # Arguments
127 ///
128 /// * `images` - Input images to preprocess
129 ///
130 /// # Returns
131 ///
132 /// Preprocessed batch tensor
133 pub fn preprocess(&self, images: Vec<RgbImage>) -> Result<Tensor4D, OCRError> {
134 let (crop_h, crop_w) = self.input_shape;
135
136 let resized_rgb: Vec<RgbImage> = if let Some(resize_short) = self.resize_short {
137 // PP-LCNet classifier preprocessing:
138 // 1) Resize by short edge (keep ratio)
139 // 2) Center crop to the model input size
140 //
141 // This matches `ResizeImage(resize_short=256)` + `CropImage(size=224)` used by
142 // models like `PP-LCNet_x1_0_doc_ori` / `PP-LCNet_x1_0_table_cls`.
143 images
144 .into_iter()
145 .filter_map(|img| {
146 let (w, h) = (img.width(), img.height());
147 if w == 0 || h == 0 {
148 return None;
149 }
150
151 let short = w.min(h) as f32;
152 let scale = (resize_short as f32) / short;
153 let new_w = ((w as f32) * scale).round().max(crop_w as f32) as u32;
154 let new_h = ((h as f32) * scale).round().max(crop_h as f32) as u32;
155
156 let resized = image::imageops::resize(&img, new_w, new_h, self.resize_filter);
157
158 // Center crop to (crop_w, crop_h)
159 let x1 = (new_w.saturating_sub(crop_w)) / 2;
160 let y1 = (new_h.saturating_sub(crop_h)) / 2;
161 let cropped =
162 image::imageops::crop_imm(&resized, x1, y1, crop_w, crop_h).to_image();
163 Some(cropped)
164 })
165 .collect()
166 } else {
167 // Direct resize to input shape (height,width) without crop.
168 // This matches `ResizeImage(size=[w,h])` used by
169 // `PP-LCNet_x1_0_textline_ori` (80x160).
170 images
171 .into_iter()
172 .filter_map(|img| {
173 let (w, h) = (img.width(), img.height());
174 if w == 0 || h == 0 {
175 return None;
176 }
177 Some(image::imageops::resize(
178 &img,
179 crop_w,
180 crop_h,
181 self.resize_filter,
182 ))
183 })
184 .collect()
185 };
186
187 // Convert to dynamic images and normalize using common helper
188 let dynamic_images = rgb_to_dynamic(resized_rgb);
189 self.normalizer.normalize_batch_to(dynamic_images)
190 }
191
192 /// Runs inference on the preprocessed batch.
193 ///
194 /// # Arguments
195 ///
196 /// * `batch_tensor` - Preprocessed batch tensor
197 ///
198 /// # Returns
199 ///
200 /// Model predictions as a 2D tensor (batch_size x num_classes)
201 pub fn infer(&self, batch_tensor: &Tensor4D) -> Result<Tensor2D, OCRError> {
202 self.inference
203 .infer_2d(batch_tensor)
204 .map_err(|e| OCRError::Inference {
205 model_name: "PP-LCNet".to_string(),
206 context: format!(
207 "failed to run inference on batch with shape {:?}",
208 batch_tensor.shape()
209 ),
210 source: Box::new(e),
211 })
212 }
213
214 /// Postprocesses model predictions to class IDs and scores.
215 ///
216 /// # Arguments
217 ///
218 /// * `predictions` - Model predictions (batch_size x num_classes)
219 /// * `config` - Postprocessing configuration
220 ///
221 /// # Returns
222 ///
223 /// PPLCNetModelOutput containing class IDs, scores, and optional label names
224 pub fn postprocess(
225 &self,
226 predictions: &Tensor2D,
227 config: &PPLCNetPostprocessConfig,
228 ) -> Result<PPLCNetModelOutput, OCRError> {
229 let predictions_vec: Vec<Vec<f32>> =
230 predictions.outer_iter().map(|row| row.to_vec()).collect();
231
232 let topk_result = self
233 .topk_processor
234 .process(&predictions_vec, config.topk)
235 .unwrap_or_else(|_| crate::utils::topk::TopkResult {
236 indexes: vec![],
237 scores: vec![],
238 class_names: None,
239 });
240
241 let class_ids = topk_result.indexes;
242 let scores = topk_result.scores;
243
244 // Map class IDs to label names if labels are provided
245 let label_names = if !config.labels.is_empty() {
246 Some(
247 class_ids
248 .iter()
249 .map(|ids| {
250 ids.iter()
251 .map(|&id| {
252 config
253 .labels
254 .get(id)
255 .cloned()
256 .unwrap_or_else(|| format!("class_{}", id))
257 })
258 .collect()
259 })
260 .collect(),
261 )
262 } else {
263 topk_result.class_names
264 };
265
266 Ok(PPLCNetModelOutput {
267 class_ids,
268 scores,
269 label_names,
270 })
271 }
272
273 /// Performs complete forward pass: preprocess -> infer -> postprocess.
274 ///
275 /// # Arguments
276 ///
277 /// * `images` - Input images to classify
278 /// * `config` - Postprocessing configuration
279 ///
280 /// # Returns
281 ///
282 /// PPLCNetModelOutput containing classification results
283 pub fn forward(
284 &self,
285 images: Vec<RgbImage>,
286 config: &PPLCNetPostprocessConfig,
287 ) -> Result<PPLCNetModelOutput, OCRError> {
288 let batch_tensor = self.preprocess(images)?;
289 let predictions = self.infer(&batch_tensor)?;
290 self.postprocess(&predictions, config)
291 }
292}
293
294/// Builder for PP-LCNet model.
295#[derive(Debug, Default)]
296pub struct PPLCNetModelBuilder {
297 /// Preprocessing configuration
298 preprocess_config: PPLCNetPreprocessConfig,
299 /// ONNX Runtime session configuration
300 ort_config: Option<crate::core::config::OrtSessionConfig>,
301}
302
303impl PPLCNetModelBuilder {
304 /// Creates a new PP-LCNet model builder.
305 pub fn new() -> Self {
306 Self {
307 preprocess_config: PPLCNetPreprocessConfig::default(),
308 ort_config: None,
309 }
310 }
311
312 /// Sets the preprocessing configuration.
313 pub fn preprocess_config(mut self, config: PPLCNetPreprocessConfig) -> Self {
314 self.preprocess_config = config;
315 self
316 }
317
318 /// Sets the input image shape.
319 pub fn input_shape(mut self, shape: (u32, u32)) -> Self {
320 self.preprocess_config.input_shape = shape;
321 self
322 }
323
324 /// Sets the resizing filter.
325 pub fn resize_filter(mut self, filter: FilterType) -> Self {
326 self.preprocess_config.resize_filter = filter;
327 self
328 }
329
330 /// Sets the ONNX Runtime session configuration.
331 pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
332 self.ort_config = Some(config);
333 self
334 }
335
336 /// Builds the PP-LCNet model.
337 ///
338 /// # Arguments
339 ///
340 /// * `model_path` - Path to the ONNX model file
341 ///
342 /// # Returns
343 ///
344 /// A configured PP-LCNet model instance
345 pub fn build(self, model_path: &std::path::Path) -> Result<PPLCNetModel, OCRError> {
346 // Create ONNX inference engine
347 let inference = if self.ort_config.is_some() {
348 use crate::core::config::ModelInferenceConfig;
349 let common_config = ModelInferenceConfig {
350 ort_session: self.ort_config,
351 ..Default::default()
352 };
353 OrtInfer::from_config(&common_config, model_path, None)?
354 } else {
355 OrtInfer::new(model_path, None)?
356 };
357
358 // Create normalizer (ImageNet normalization).
359 //
360 // PP-LCNet classifiers read images as **RGB** by default (no `DecodeImage`
361 // op in the official inference.yml), so we keep RGB order here.
362 let mean = self.preprocess_config.normalize_mean.clone();
363 let std = self.preprocess_config.normalize_std.clone();
364 let normalizer = NormalizeImage::with_color_order(
365 Some(self.preprocess_config.normalize_scale),
366 Some(mean),
367 Some(std),
368 Some(self.preprocess_config.tensor_layout),
369 Some(crate::processors::types::ColorOrder::RGB),
370 )?;
371
372 // Create top-k processor
373 let topk_processor = Topk::new(None);
374
375 Ok(PPLCNetModel::new(
376 inference,
377 normalizer,
378 topk_processor,
379 self.preprocess_config.input_shape,
380 self.preprocess_config.resize_filter,
381 self.preprocess_config.resize_short,
382 ))
383 }
384}