oar_ocr_core/models/classification/
pp_lcnet.rs1use crate::core::OCRError;
8use crate::core::inference::{OrtInfer, TensorInput};
9use crate::domain::adapters::preprocessing::rgb_to_dynamic;
10use crate::processors::{NormalizeImage, TensorLayout};
11use crate::utils::topk::Topk;
12use image::{RgbImage, imageops::FilterType};
13
14#[derive(Debug, Clone)]
16pub struct PPLCNetPreprocessConfig {
17 pub input_shape: (u32, u32),
19 pub resize_filter: FilterType,
21 pub resize_short: Option<u32>,
30 pub normalize_scale: f32,
32 pub normalize_mean: Vec<f32>,
34 pub normalize_std: Vec<f32>,
36 pub tensor_layout: TensorLayout,
38}
39
40impl Default for PPLCNetPreprocessConfig {
41 fn default() -> Self {
42 Self {
43 input_shape: (224, 224),
44 resize_filter: FilterType::Triangle,
46 resize_short: Some(256),
48 normalize_scale: 1.0 / 255.0,
49 normalize_mean: vec![0.485, 0.456, 0.406],
50 normalize_std: vec![0.229, 0.224, 0.225],
51 tensor_layout: TensorLayout::CHW,
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct PPLCNetPostprocessConfig {
59 pub labels: Vec<String>,
61 pub topk: usize,
63}
64
65impl Default for PPLCNetPostprocessConfig {
66 fn default() -> Self {
67 Self {
68 labels: vec![],
69 topk: 1,
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct PPLCNetModelOutput {
77 pub class_ids: Vec<Vec<usize>>,
79 pub scores: Vec<Vec<f32>>,
81 pub label_names: Option<Vec<Vec<String>>>,
83}
84
85#[derive(Debug)]
89pub struct PPLCNetModel {
90 inference: OrtInfer,
92 normalizer: NormalizeImage,
94 topk_processor: Topk,
96 input_shape: (u32, u32),
98 resize_filter: FilterType,
100 resize_short: Option<u32>,
102}
103
104impl PPLCNetModel {
105 pub fn new(
107 inference: OrtInfer,
108 normalizer: NormalizeImage,
109 topk_processor: Topk,
110 input_shape: (u32, u32),
111 resize_filter: FilterType,
112 resize_short: Option<u32>,
113 ) -> Self {
114 Self {
115 inference,
116 normalizer,
117 topk_processor,
118 input_shape,
119 resize_filter,
120 resize_short,
121 }
122 }
123
124 pub fn preprocess(&self, images: Vec<RgbImage>) -> Result<ndarray::Array4<f32>, OCRError> {
134 let (crop_h, crop_w) = self.input_shape;
135
136 let resized_rgb: Vec<RgbImage> = if let Some(resize_short) = self.resize_short {
137 images
144 .into_iter()
145 .filter_map(|img| {
146 let (w, h) = (img.width(), img.height());
147 if w == 0 || h == 0 {
148 return None;
149 }
150
151 let short = w.min(h) as f32;
152 let scale = (resize_short as f32) / short;
153 let new_w = ((w as f32) * scale).round().max(crop_w as f32) as u32;
154 let new_h = ((h as f32) * scale).round().max(crop_h as f32) as u32;
155
156 let resized = image::imageops::resize(&img, new_w, new_h, self.resize_filter);
157
158 let x1 = (new_w.saturating_sub(crop_w)) / 2;
160 let y1 = (new_h.saturating_sub(crop_h)) / 2;
161 let cropped =
162 image::imageops::crop_imm(&resized, x1, y1, crop_w, crop_h).to_image();
163 Some(cropped)
164 })
165 .collect()
166 } else {
167 images
171 .into_iter()
172 .filter_map(|img| {
173 let (w, h) = (img.width(), img.height());
174 if w == 0 || h == 0 {
175 return None;
176 }
177 Some(image::imageops::resize(
178 &img,
179 crop_w,
180 crop_h,
181 self.resize_filter,
182 ))
183 })
184 .collect()
185 };
186
187 let dynamic_images = rgb_to_dynamic(resized_rgb);
189 self.normalizer.normalize_batch_to(dynamic_images)
190 }
191
192 pub fn infer(
202 &self,
203 batch_tensor: &ndarray::Array4<f32>,
204 ) -> Result<ndarray::Array2<f32>, OCRError> {
205 let input_name = self.inference.input_name();
206 let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
207
208 let outputs = self
209 .inference
210 .infer(&inputs)
211 .map_err(|e| OCRError::Inference {
212 model_name: "PP-LCNet".to_string(),
213 context: format!(
214 "failed to run inference on batch with shape {:?}",
215 batch_tensor.shape()
216 ),
217 source: Box::new(e),
218 })?;
219
220 let output = outputs
221 .into_iter()
222 .next()
223 .ok_or_else(|| OCRError::InvalidInput {
224 message: "PP-LCNet: no output returned from inference".to_string(),
225 })?;
226
227 output
228 .1
229 .try_into_array2_f32()
230 .map_err(|e| OCRError::Inference {
231 model_name: "PP-LCNet".to_string(),
232 context: "failed to convert output to 2D array".to_string(),
233 source: Box::new(e),
234 })
235 }
236
237 pub fn postprocess(
248 &self,
249 predictions: &ndarray::Array2<f32>,
250 config: &PPLCNetPostprocessConfig,
251 ) -> Result<PPLCNetModelOutput, OCRError> {
252 let predictions_vec: Vec<Vec<f32>> =
253 predictions.outer_iter().map(|row| row.to_vec()).collect();
254
255 let topk_result = self
256 .topk_processor
257 .process(&predictions_vec, config.topk)
258 .unwrap_or_else(|_| crate::utils::topk::TopkResult {
259 indexes: vec![],
260 scores: vec![],
261 class_names: None,
262 });
263
264 let class_ids = topk_result.indexes;
265 let scores = topk_result.scores;
266
267 let label_names = if !config.labels.is_empty() {
269 Some(
270 class_ids
271 .iter()
272 .map(|ids| {
273 ids.iter()
274 .map(|&id| {
275 config
276 .labels
277 .get(id)
278 .cloned()
279 .unwrap_or_else(|| format!("class_{}", id))
280 })
281 .collect()
282 })
283 .collect(),
284 )
285 } else {
286 topk_result.class_names
287 };
288
289 Ok(PPLCNetModelOutput {
290 class_ids,
291 scores,
292 label_names,
293 })
294 }
295
296 pub fn forward(
307 &self,
308 images: Vec<RgbImage>,
309 config: &PPLCNetPostprocessConfig,
310 ) -> Result<PPLCNetModelOutput, OCRError> {
311 let batch_tensor = self.preprocess(images)?;
312 let predictions = self.infer(&batch_tensor)?;
313 self.postprocess(&predictions, config)
314 }
315}
316
317#[derive(Debug, Default)]
319pub struct PPLCNetModelBuilder {
320 preprocess_config: PPLCNetPreprocessConfig,
322 ort_config: Option<crate::core::config::OrtSessionConfig>,
324}
325
326impl PPLCNetModelBuilder {
327 pub fn new() -> Self {
329 Self {
330 preprocess_config: PPLCNetPreprocessConfig::default(),
331 ort_config: None,
332 }
333 }
334
335 pub fn preprocess_config(mut self, config: PPLCNetPreprocessConfig) -> Self {
337 self.preprocess_config = config;
338 self
339 }
340
341 pub fn input_shape(mut self, shape: (u32, u32)) -> Self {
343 self.preprocess_config.input_shape = shape;
344 self
345 }
346
347 pub fn resize_filter(mut self, filter: FilterType) -> Self {
349 self.preprocess_config.resize_filter = filter;
350 self
351 }
352
353 pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
355 self.ort_config = Some(config);
356 self
357 }
358
359 pub fn build(self, model_path: &std::path::Path) -> Result<PPLCNetModel, OCRError> {
369 let inference = if self.ort_config.is_some() {
371 use crate::core::config::ModelInferenceConfig;
372 let common_config = ModelInferenceConfig {
373 ort_session: self.ort_config,
374 ..Default::default()
375 };
376 OrtInfer::from_config(&common_config, model_path, None)?
377 } else {
378 OrtInfer::new(model_path, None)?
379 };
380
381 let mean = self.preprocess_config.normalize_mean.clone();
386 let std = self.preprocess_config.normalize_std.clone();
387 let normalizer = NormalizeImage::with_color_order(
388 Some(self.preprocess_config.normalize_scale),
389 Some(mean),
390 Some(std),
391 Some(self.preprocess_config.tensor_layout),
392 Some(crate::processors::types::ColorOrder::RGB),
393 )?;
394
395 let topk_processor = Topk::new(None);
397
398 Ok(PPLCNetModel::new(
399 inference,
400 normalizer,
401 topk_processor,
402 self.preprocess_config.input_shape,
403 self.preprocess_config.resize_filter,
404 self.preprocess_config.resize_short,
405 ))
406 }
407}