1use image::{DynamicImage, GenericImageView};
6use ndarray::{Array4, ArrayD};
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::preprocess::NormalizeParams;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum OriPreprocessMode {
16 Doc,
18 Textline,
20}
21
22#[derive(Debug, Clone)]
24pub struct OrientationResult {
25 pub class_idx: usize,
27 pub angle: i32,
29 pub confidence: f32,
31 pub scores: Vec<f32>,
33}
34
35impl OrientationResult {
36 pub fn new(class_idx: usize, angle: i32, confidence: f32, scores: Vec<f32>) -> Self {
38 Self {
39 class_idx,
40 angle,
41 confidence,
42 scores,
43 }
44 }
45
46 pub fn is_valid(&self, threshold: f32) -> bool {
48 self.confidence >= threshold
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct OriOptions {
55 pub target_height: u32,
57 pub target_width: u32,
59 pub min_score: f32,
61 pub resize_shorter: u32,
63 pub preprocess_mode: OriPreprocessMode,
65 pub class_angles: Vec<i32>,
67}
68
69impl Default for OriOptions {
70 fn default() -> Self {
71 Self {
72 target_height: 224,
73 target_width: 224,
74 min_score: 0.5,
75 resize_shorter: 256,
76 preprocess_mode: OriPreprocessMode::Doc,
77 class_angles: vec![0, 90, 180, 270],
78 }
79 }
80}
81
82impl OriOptions {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn doc() -> Self {
90 Self::default()
91 }
92
93 pub fn textline() -> Self {
95 Self {
96 target_height: 48,
97 target_width: 192,
98 min_score: 0.5,
99 resize_shorter: 256,
100 preprocess_mode: OriPreprocessMode::Textline,
101 class_angles: vec![0, 180],
102 }
103 }
104
105 pub fn with_target_height(mut self, height: u32) -> Self {
107 self.target_height = height;
108 self
109 }
110
111 pub fn with_target_width(mut self, width: u32) -> Self {
113 self.target_width = width;
114 self
115 }
116
117 pub fn with_min_score(mut self, score: f32) -> Self {
119 self.min_score = score;
120 self
121 }
122
123 pub fn with_resize_shorter(mut self, size: u32) -> Self {
125 self.resize_shorter = size;
126 self
127 }
128
129 pub fn with_preprocess_mode(mut self, mode: OriPreprocessMode) -> Self {
131 self.preprocess_mode = mode;
132 self
133 }
134
135 pub fn with_class_angles(mut self, angles: Vec<i32>) -> Self {
137 self.class_angles = angles;
138 self
139 }
140}
141
142pub struct OriModel {
144 engine: InferenceEngine,
145 options: OriOptions,
146 normalize_params: NormalizeParams,
147}
148
149impl OriModel {
150 pub fn from_file(
152 model_path: impl AsRef<Path>,
153 config: Option<InferenceConfig>,
154 ) -> OcrResult<Self> {
155 let engine = InferenceEngine::from_file(model_path, config)?;
156 let options = OriOptions::default();
157 let mode = options.preprocess_mode;
158 Ok(Self {
159 engine,
160 options,
161 normalize_params: normalize_params_for_mode(mode),
162 })
163 }
164
165 pub fn from_bytes(model_bytes: &[u8], config: Option<InferenceConfig>) -> OcrResult<Self> {
167 let engine = InferenceEngine::from_buffer(model_bytes, config)?;
168 let options = OriOptions::default();
169 let mode = options.preprocess_mode;
170 Ok(Self {
171 engine,
172 options,
173 normalize_params: normalize_params_for_mode(mode),
174 })
175 }
176
177 pub fn with_options(mut self, options: OriOptions) -> Self {
179 self.options = options;
180 self.normalize_params = normalize_params_for_mode(self.options.preprocess_mode);
181 self
182 }
183
184 pub fn options(&self) -> &OriOptions {
186 &self.options
187 }
188
189 pub fn options_mut(&mut self) -> &mut OriOptions {
191 &mut self.options
192 }
193
194 pub fn classify(&self, image: &DynamicImage) -> OcrResult<OrientationResult> {
196 let input = preprocess_for_ori(
197 image,
198 self.options.target_height,
199 self.options.target_width,
200 self.options.resize_shorter,
201 self.options.preprocess_mode,
202 &self.normalize_params,
203 )?;
204
205 let output = self.engine.run_dynamic(input.view().into_dyn())?;
206 self.decode_output(&output)
207 }
208
209 fn decode_output(&self, output: &ArrayD<f32>) -> OcrResult<OrientationResult> {
210 let shape = output.shape();
211 if shape.is_empty() {
212 return Err(OcrError::PostprocessError(
213 "Orientation model output shape is empty".to_string(),
214 ));
215 }
216
217 let num_classes = *shape.last().unwrap_or(&0);
218 if num_classes == 0 {
219 return Err(OcrError::PostprocessError(
220 "Orientation model output classes is zero".to_string(),
221 ));
222 }
223
224 let output_data: Vec<f32> = output.iter().cloned().collect();
225 if output_data.is_empty() {
226 return Err(OcrError::PostprocessError(
227 "Orientation model output data is empty".to_string(),
228 ));
229 }
230
231 let scores_raw = if output_data.len() >= num_classes {
232 output_data[..num_classes].to_vec()
233 } else {
234 return Err(OcrError::PostprocessError(
235 "Orientation model output data size mismatch".to_string(),
236 ));
237 };
238
239 let scores = softmax(&scores_raw);
240 let (class_idx, &confidence) = scores
241 .iter()
242 .enumerate()
243 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
244 .ok_or_else(|| {
245 OcrError::PostprocessError(
246 "Orientation model output has no valid scores".to_string(),
247 )
248 })?;
249
250 let angle = class_to_angle(num_classes, class_idx, &self.options.class_angles);
251 Ok(OrientationResult::new(class_idx, angle, confidence, scores))
252 }
253}
254
255fn class_to_angle(num_classes: usize, class_idx: usize, class_angles: &[i32]) -> i32 {
257 if class_angles.len() == num_classes {
258 return class_angles.get(class_idx).copied().unwrap_or(class_idx as i32);
259 }
260
261 match num_classes {
262 2 => {
263 if class_idx == 0 {
264 0
265 } else {
266 180
267 }
268 }
269 4 => match class_idx {
270 0 => 0,
271 1 => 90,
272 2 => 180,
273 3 => 270,
274 _ => class_idx as i32,
275 },
276 _ => class_idx as i32,
277 }
278}
279
280fn softmax(scores: &[f32]) -> Vec<f32> {
281 if scores.is_empty() {
282 return Vec::new();
283 }
284
285 let max_score = scores
286 .iter()
287 .cloned()
288 .fold(f32::NEG_INFINITY, f32::max);
289 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
290 let sum_exp: f32 = exp_scores.iter().sum();
291
292 if sum_exp == 0.0 {
293 return vec![0.0; scores.len()];
294 }
295
296 exp_scores.into_iter().map(|v| v / sum_exp).collect()
297}
298
299fn normalize_params_for_mode(mode: OriPreprocessMode) -> NormalizeParams {
300 match mode {
301 OriPreprocessMode::Doc => NormalizeParams::paddle_det(),
302 OriPreprocessMode::Textline => NormalizeParams::paddle_rec(),
303 }
304}
305
306fn preprocess_for_ori(
308 img: &DynamicImage,
309 target_height: u32,
310 target_width: u32,
311 resize_shorter: u32,
312 mode: OriPreprocessMode,
313 params: &NormalizeParams,
314) -> OcrResult<Array4<f32>> {
315 if target_height == 0 || target_width == 0 {
316 return Err(OcrError::PreprocessError(
317 "Target size must be greater than zero".to_string(),
318 ));
319 }
320
321 let processed = match mode {
322 OriPreprocessMode::Textline => {
323 let (w, h) = img.dimensions();
324 let ratio = w as f32 / h.max(1) as f32;
325 let mut resize_w = (target_height as f32 * ratio).round() as u32;
326 if resize_w == 0 {
327 resize_w = 1;
328 }
329 if resize_w > target_width {
330 resize_w = target_width;
331 }
332
333 img.resize_exact(
334 resize_w,
335 target_height,
336 image::imageops::FilterType::Lanczos3,
337 )
338 }
339 OriPreprocessMode::Doc => {
340 let (w, h) = img.dimensions();
341 let shorter = w.min(h).max(1) as f32;
342 let scale = resize_shorter as f32 / shorter;
343 let new_w = (w as f32 * scale).round().max(1.0) as u32;
344 let new_h = (h as f32 * scale).round().max(1.0) as u32;
345 let resized = img.resize_exact(
346 new_w,
347 new_h,
348 image::imageops::FilterType::Lanczos3,
349 );
350
351 if new_w < target_width || new_h < target_height {
352 resized.resize_exact(
353 target_width,
354 target_height,
355 image::imageops::FilterType::Lanczos3,
356 )
357 } else {
358 let left = (new_w - target_width) / 2;
359 let top = (new_h - target_height) / 2;
360 resized.crop_imm(left, top, target_width, target_height)
361 }
362 }
363 };
364
365 let rgb_img = processed.to_rgb8();
366 let (proc_w, proc_h) = processed.dimensions();
367
368 let mut input = Array4::<f32>::zeros((
369 1,
370 3,
371 target_height as usize,
372 target_width as usize,
373 ));
374
375 let max_y = proc_h.min(target_height) as usize;
376 let max_x = proc_w.min(target_width) as usize;
377
378 for y in 0..max_y {
379 for x in 0..max_x {
380 let pixel = rgb_img.get_pixel(x as u32, y as u32);
381 let [r, g, b] = pixel.0;
382
383 input[[0, 0, y, x]] = (b as f32 / 255.0 - params.mean[0]) / params.std[0];
385 input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
386 input[[0, 2, y, x]] = (r as f32 / 255.0 - params.mean[2]) / params.std[2];
387 }
388 }
389
390 Ok(input)
391}
392
393impl OriModel {
395 pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
397 Ok(self.engine.run_dynamic(input)?)
398 }
399
400 pub fn input_shape(&self) -> &[usize] {
402 self.engine.input_shape()
403 }
404
405 pub fn output_shape(&self) -> &[usize] {
407 self.engine.output_shape()
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_ori_options_default() {
417 let opts = OriOptions::default();
418 assert_eq!(opts.target_height, 224);
419 assert_eq!(opts.target_width, 224);
420 assert_eq!(opts.min_score, 0.5);
421 assert_eq!(opts.resize_shorter, 256);
422 assert_eq!(opts.preprocess_mode, OriPreprocessMode::Doc);
423 assert_eq!(opts.class_angles, vec![0, 90, 180, 270]);
424 }
425
426 #[test]
427 fn test_ori_options_builder() {
428 let opts = OriOptions::new()
429 .with_target_height(32)
430 .with_target_width(128)
431 .with_min_score(0.7)
432 .with_resize_shorter(200)
433 .with_preprocess_mode(OriPreprocessMode::Textline)
434 .with_class_angles(vec![0, 180]);
435
436 assert_eq!(opts.target_height, 32);
437 assert_eq!(opts.target_width, 128);
438 assert_eq!(opts.min_score, 0.7);
439 assert_eq!(opts.resize_shorter, 200);
440 assert_eq!(opts.preprocess_mode, OriPreprocessMode::Textline);
441 assert_eq!(opts.class_angles, vec![0, 180]);
442 }
443
444 #[test]
445 fn test_class_to_angle_mapping() {
446 let angles_4 = vec![0, 90, 180, 270];
447 let angles_2 = vec![0, 180];
448 assert_eq!(class_to_angle(2, 0, &angles_2), 0);
449 assert_eq!(class_to_angle(2, 1, &angles_2), 180);
450 assert_eq!(class_to_angle(4, 0, &angles_4), 0);
451 assert_eq!(class_to_angle(4, 1, &angles_4), 90);
452 assert_eq!(class_to_angle(4, 2, &angles_4), 180);
453 assert_eq!(class_to_angle(4, 3, &angles_4), 270);
454 assert_eq!(class_to_angle(3, 2, &angles_2), 2);
455 }
456
457 #[test]
458 fn test_preprocess_for_ori_shape() {
459 let img = DynamicImage::new_rgb8(100, 32);
460 let params = NormalizeParams::paddle_det();
461 let tensor = preprocess_for_ori(
462 &img,
463 224,
464 224,
465 256,
466 OriPreprocessMode::Doc,
467 ¶ms,
468 )
469 .unwrap();
470 assert_eq!(tensor.shape(), &[1, 3, 224, 224]);
471 }
472}