1use image::{DynamicImage, GenericImageView};
6use ndarray::{Array4, ArrayD};
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::preprocess::NormalizeParams;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum OriPreprocessMode {
16 Doc,
18 Textline,
20}
21
22#[derive(Debug, Clone)]
24pub struct OrientationResult {
25 pub class_idx: usize,
27 pub angle: i32,
29 pub confidence: f32,
31 pub scores: Vec<f32>,
33}
34
35impl OrientationResult {
36 pub fn new(class_idx: usize, angle: i32, confidence: f32, scores: Vec<f32>) -> Self {
38 Self {
39 class_idx,
40 angle,
41 confidence,
42 scores,
43 }
44 }
45
46 pub fn is_valid(&self, threshold: f32) -> bool {
48 self.confidence >= threshold
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct OriOptions {
55 pub target_height: u32,
57 pub target_width: u32,
59 pub min_score: f32,
61 pub resize_shorter: u32,
63 pub preprocess_mode: OriPreprocessMode,
65 pub class_angles: Vec<i32>,
67}
68
69impl Default for OriOptions {
70 fn default() -> Self {
71 Self {
72 target_height: 224,
73 target_width: 224,
74 min_score: 0.5,
75 resize_shorter: 256,
76 preprocess_mode: OriPreprocessMode::Doc,
77 class_angles: vec![0, 90, 180, 270],
78 }
79 }
80}
81
82impl OriOptions {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn doc() -> Self {
90 Self::default()
91 }
92
93 pub fn textline() -> Self {
95 Self {
96 target_height: 48,
97 target_width: 192,
98 min_score: 0.5,
99 resize_shorter: 256,
100 preprocess_mode: OriPreprocessMode::Textline,
101 class_angles: vec![0, 180],
102 }
103 }
104
105 pub fn with_target_height(mut self, height: u32) -> Self {
107 self.target_height = height;
108 self
109 }
110
111 pub fn with_target_width(mut self, width: u32) -> Self {
113 self.target_width = width;
114 self
115 }
116
117 pub fn with_min_score(mut self, score: f32) -> Self {
119 self.min_score = score;
120 self
121 }
122
123 pub fn with_resize_shorter(mut self, size: u32) -> Self {
125 self.resize_shorter = size;
126 self
127 }
128
129 pub fn with_preprocess_mode(mut self, mode: OriPreprocessMode) -> Self {
131 self.preprocess_mode = mode;
132 self
133 }
134
135 pub fn with_class_angles(mut self, angles: Vec<i32>) -> Self {
137 self.class_angles = angles;
138 self
139 }
140}
141
142pub struct OriModel {
144 engine: InferenceEngine,
145 options: OriOptions,
146 normalize_params: NormalizeParams,
147}
148
149impl OriModel {
150 pub fn from_file(
152 model_path: impl AsRef<Path>,
153 config: Option<InferenceConfig>,
154 ) -> OcrResult<Self> {
155 let engine = InferenceEngine::from_file(model_path, config)?;
156 let options = OriOptions::default();
157 let mode = options.preprocess_mode;
158 Ok(Self {
159 engine,
160 options,
161 normalize_params: normalize_params_for_mode(mode),
162 })
163 }
164
165 pub fn from_bytes(model_bytes: &[u8], config: Option<InferenceConfig>) -> OcrResult<Self> {
167 let engine = InferenceEngine::from_buffer(model_bytes, config)?;
168 let options = OriOptions::default();
169 let mode = options.preprocess_mode;
170 Ok(Self {
171 engine,
172 options,
173 normalize_params: normalize_params_for_mode(mode),
174 })
175 }
176
177 pub fn with_options(mut self, options: OriOptions) -> Self {
179 self.options = options;
180 self.normalize_params = normalize_params_for_mode(self.options.preprocess_mode);
181 self
182 }
183
184 pub fn options(&self) -> &OriOptions {
186 &self.options
187 }
188
189 pub fn options_mut(&mut self) -> &mut OriOptions {
191 &mut self.options
192 }
193
194 pub fn classify(&self, image: &DynamicImage) -> OcrResult<OrientationResult> {
196 let input = preprocess_for_ori(
197 image,
198 self.options.target_height,
199 self.options.target_width,
200 self.options.resize_shorter,
201 self.options.preprocess_mode,
202 &self.normalize_params,
203 )?;
204
205 let output = self.engine.run_dynamic(input.view().into_dyn())?;
206 self.decode_output(&output)
207 }
208
209 fn decode_output(&self, output: &ArrayD<f32>) -> OcrResult<OrientationResult> {
210 let shape = output.shape();
211 if shape.is_empty() {
212 return Err(OcrError::PostprocessError(
213 "Orientation model output shape is empty".to_string(),
214 ));
215 }
216
217 let num_classes = *shape.last().unwrap_or(&0);
218 if num_classes == 0 {
219 return Err(OcrError::PostprocessError(
220 "Orientation model output classes is zero".to_string(),
221 ));
222 }
223
224 let output_data: Vec<f32> = output.iter().cloned().collect();
225 if output_data.is_empty() {
226 return Err(OcrError::PostprocessError(
227 "Orientation model output data is empty".to_string(),
228 ));
229 }
230
231 let scores_raw = if output_data.len() >= num_classes {
232 output_data[..num_classes].to_vec()
233 } else {
234 return Err(OcrError::PostprocessError(
235 "Orientation model output data size mismatch".to_string(),
236 ));
237 };
238
239 let scores = softmax(&scores_raw);
240 let (class_idx, &confidence) = scores
241 .iter()
242 .enumerate()
243 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
244 .ok_or_else(|| {
245 OcrError::PostprocessError(
246 "Orientation model output has no valid scores".to_string(),
247 )
248 })?;
249
250 let angle = class_to_angle(num_classes, class_idx, &self.options.class_angles);
251 Ok(OrientationResult::new(class_idx, angle, confidence, scores))
252 }
253}
254
255fn class_to_angle(num_classes: usize, class_idx: usize, class_angles: &[i32]) -> i32 {
257 if class_angles.len() == num_classes {
258 return class_angles
259 .get(class_idx)
260 .copied()
261 .unwrap_or(class_idx as i32);
262 }
263
264 match num_classes {
265 2 => {
266 if class_idx == 0 {
267 0
268 } else {
269 180
270 }
271 }
272 4 => match class_idx {
273 0 => 0,
274 1 => 90,
275 2 => 180,
276 3 => 270,
277 _ => class_idx as i32,
278 },
279 _ => class_idx as i32,
280 }
281}
282
283fn softmax(scores: &[f32]) -> Vec<f32> {
284 if scores.is_empty() {
285 return Vec::new();
286 }
287
288 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
289 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
290 let sum_exp: f32 = exp_scores.iter().sum();
291
292 if sum_exp == 0.0 {
293 return vec![0.0; scores.len()];
294 }
295
296 exp_scores.into_iter().map(|v| v / sum_exp).collect()
297}
298
299fn normalize_params_for_mode(mode: OriPreprocessMode) -> NormalizeParams {
300 match mode {
301 OriPreprocessMode::Doc => NormalizeParams::paddle_det(),
302 OriPreprocessMode::Textline => NormalizeParams::paddle_rec(),
303 }
304}
305
306fn preprocess_for_ori(
308 img: &DynamicImage,
309 target_height: u32,
310 target_width: u32,
311 resize_shorter: u32,
312 mode: OriPreprocessMode,
313 params: &NormalizeParams,
314) -> OcrResult<Array4<f32>> {
315 if target_height == 0 || target_width == 0 {
316 return Err(OcrError::PreprocessError(
317 "Target size must be greater than zero".to_string(),
318 ));
319 }
320
321 let processed = match mode {
322 OriPreprocessMode::Textline => {
323 let (w, h) = img.dimensions();
324 let ratio = w as f32 / h.max(1) as f32;
325 let mut resize_w = (target_height as f32 * ratio).round() as u32;
326 if resize_w == 0 {
327 resize_w = 1;
328 }
329 if resize_w > target_width {
330 resize_w = target_width;
331 }
332
333 img.resize_exact(
334 resize_w,
335 target_height,
336 image::imageops::FilterType::Lanczos3,
337 )
338 }
339 OriPreprocessMode::Doc => {
340 let (w, h) = img.dimensions();
341 let shorter = w.min(h).max(1) as f32;
342 let scale = resize_shorter as f32 / shorter;
343 let new_w = (w as f32 * scale).round().max(1.0) as u32;
344 let new_h = (h as f32 * scale).round().max(1.0) as u32;
345 let resized = img.resize_exact(new_w, new_h, image::imageops::FilterType::Lanczos3);
346
347 if new_w < target_width || new_h < target_height {
348 resized.resize_exact(
349 target_width,
350 target_height,
351 image::imageops::FilterType::Lanczos3,
352 )
353 } else {
354 let left = (new_w - target_width) / 2;
355 let top = (new_h - target_height) / 2;
356 resized.crop_imm(left, top, target_width, target_height)
357 }
358 }
359 };
360
361 let rgb_img = processed.to_rgb8();
362 let (proc_w, proc_h) = processed.dimensions();
363
364 let mut input = Array4::<f32>::zeros((1, 3, target_height as usize, target_width as usize));
365
366 let max_y = proc_h.min(target_height) as usize;
367 let max_x = proc_w.min(target_width) as usize;
368
369 for y in 0..max_y {
370 for x in 0..max_x {
371 let pixel = rgb_img.get_pixel(x as u32, y as u32);
372 let [r, g, b] = pixel.0;
373
374 input[[0, 0, y, x]] = (b as f32 / 255.0 - params.mean[0]) / params.std[0];
376 input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
377 input[[0, 2, y, x]] = (r as f32 / 255.0 - params.mean[2]) / params.std[2];
378 }
379 }
380
381 Ok(input)
382}
383
384impl OriModel {
386 pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
388 Ok(self.engine.run_dynamic(input)?)
389 }
390
391 pub fn input_shape(&self) -> &[usize] {
393 self.engine.input_shape()
394 }
395
396 pub fn output_shape(&self) -> &[usize] {
398 self.engine.output_shape()
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn test_ori_options_default() {
408 let opts = OriOptions::default();
409 assert_eq!(opts.target_height, 224);
410 assert_eq!(opts.target_width, 224);
411 assert_eq!(opts.min_score, 0.5);
412 assert_eq!(opts.resize_shorter, 256);
413 assert_eq!(opts.preprocess_mode, OriPreprocessMode::Doc);
414 assert_eq!(opts.class_angles, vec![0, 90, 180, 270]);
415 }
416
417 #[test]
418 fn test_ori_options_builder() {
419 let opts = OriOptions::new()
420 .with_target_height(32)
421 .with_target_width(128)
422 .with_min_score(0.7)
423 .with_resize_shorter(200)
424 .with_preprocess_mode(OriPreprocessMode::Textline)
425 .with_class_angles(vec![0, 180]);
426
427 assert_eq!(opts.target_height, 32);
428 assert_eq!(opts.target_width, 128);
429 assert_eq!(opts.min_score, 0.7);
430 assert_eq!(opts.resize_shorter, 200);
431 assert_eq!(opts.preprocess_mode, OriPreprocessMode::Textline);
432 assert_eq!(opts.class_angles, vec![0, 180]);
433 }
434
435 #[test]
436 fn test_class_to_angle_mapping() {
437 let angles_4 = vec![0, 90, 180, 270];
438 let angles_2 = vec![0, 180];
439 assert_eq!(class_to_angle(2, 0, &angles_2), 0);
440 assert_eq!(class_to_angle(2, 1, &angles_2), 180);
441 assert_eq!(class_to_angle(4, 0, &angles_4), 0);
442 assert_eq!(class_to_angle(4, 1, &angles_4), 90);
443 assert_eq!(class_to_angle(4, 2, &angles_4), 180);
444 assert_eq!(class_to_angle(4, 3, &angles_4), 270);
445 assert_eq!(class_to_angle(3, 2, &angles_2), 2);
446 }
447
448 #[test]
449 fn test_preprocess_for_ori_shape() {
450 let img = DynamicImage::new_rgb8(100, 32);
451 let params = NormalizeParams::paddle_det();
452 let tensor =
453 preprocess_for_ori(&img, 224, 224, 256, OriPreprocessMode::Doc, ¶ms).unwrap();
454 assert_eq!(tensor.shape(), &[1, 3, 224, 224]);
455 }
456}