Skip to main content

ocr_rs/
ori.rs

1//! Textline orientation classification model
2//!
3//! Provides textline orientation classification based on PP-LCNet_x1_0_textline_ori
4
5use image::{DynamicImage, GenericImageView};
6use ndarray::{Array4, ArrayD};
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::preprocess::NormalizeParams;
12
13/// Orientation preprocessing mode
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum OriPreprocessMode {
16    /// Document orientation (PP-LCNet_x1_0_doc_ori)
17    Doc,
18    /// Textline orientation (PP-LCNet_x1_0_textline_ori)
19    Textline,
20}
21
22/// Orientation classification result
23#[derive(Debug, Clone)]
24pub struct OrientationResult {
25    /// Predicted class index
26    pub class_idx: usize,
27    /// Predicted angle in degrees (best effort mapping)
28    pub angle: i32,
29    /// Confidence score (softmax probability)
30    pub confidence: f32,
31    /// Scores for each class (softmax probabilities)
32    pub scores: Vec<f32>,
33}
34
35impl OrientationResult {
36    /// Create new orientation result
37    pub fn new(class_idx: usize, angle: i32, confidence: f32, scores: Vec<f32>) -> Self {
38        Self {
39            class_idx,
40            angle,
41            confidence,
42            scores,
43        }
44    }
45
46    /// Check if result is valid by confidence threshold
47    pub fn is_valid(&self, threshold: f32) -> bool {
48        self.confidence >= threshold
49    }
50}
51
52/// Orientation model options
53#[derive(Debug, Clone)]
54pub struct OriOptions {
55    /// Target input height
56    pub target_height: u32,
57    /// Target input width
58    pub target_width: u32,
59    /// Minimum confidence threshold (for caller-side filtering)
60    pub min_score: f32,
61    /// Shorter side resize for document mode
62    pub resize_shorter: u32,
63    /// Preprocess mode
64    pub preprocess_mode: OriPreprocessMode,
65    /// Class index to angle mapping
66    pub class_angles: Vec<i32>,
67}
68
69impl Default for OriOptions {
70    fn default() -> Self {
71        Self {
72            target_height: 224,
73            target_width: 224,
74            min_score: 0.5,
75            resize_shorter: 256,
76            preprocess_mode: OriPreprocessMode::Doc,
77            class_angles: vec![0, 90, 180, 270],
78        }
79    }
80}
81
82impl OriOptions {
83    /// Create new options
84    pub fn new() -> Self {
85        Self::default()
86    }
87
88    /// Preset for document orientation models
89    pub fn doc() -> Self {
90        Self::default()
91    }
92
93    /// Preset for textline orientation models
94    pub fn textline() -> Self {
95        Self {
96            target_height: 48,
97            target_width: 192,
98            min_score: 0.5,
99            resize_shorter: 256,
100            preprocess_mode: OriPreprocessMode::Textline,
101            class_angles: vec![0, 180],
102        }
103    }
104
105    /// Set target height
106    pub fn with_target_height(mut self, height: u32) -> Self {
107        self.target_height = height;
108        self
109    }
110
111    /// Set target width
112    pub fn with_target_width(mut self, width: u32) -> Self {
113        self.target_width = width;
114        self
115    }
116
117    /// Set minimum confidence threshold
118    pub fn with_min_score(mut self, score: f32) -> Self {
119        self.min_score = score;
120        self
121    }
122
123    /// Set resize shorter side (document mode)
124    pub fn with_resize_shorter(mut self, size: u32) -> Self {
125        self.resize_shorter = size;
126        self
127    }
128
129    /// Set preprocess mode
130    pub fn with_preprocess_mode(mut self, mode: OriPreprocessMode) -> Self {
131        self.preprocess_mode = mode;
132        self
133    }
134
135    /// Set class index to angle mapping
136    pub fn with_class_angles(mut self, angles: Vec<i32>) -> Self {
137        self.class_angles = angles;
138        self
139    }
140}
141
142/// Textline orientation classification model
143pub struct OriModel {
144    engine: InferenceEngine,
145    options: OriOptions,
146    normalize_params: NormalizeParams,
147}
148
149impl OriModel {
150    /// Create orientation classifier from model file
151    pub fn from_file(
152        model_path: impl AsRef<Path>,
153        config: Option<InferenceConfig>,
154    ) -> OcrResult<Self> {
155        let engine = InferenceEngine::from_file(model_path, config)?;
156        let options = OriOptions::default();
157        let mode = options.preprocess_mode;
158        Ok(Self {
159            engine,
160            options,
161            normalize_params: normalize_params_for_mode(mode),
162        })
163    }
164
165    /// Create orientation classifier from model bytes
166    pub fn from_bytes(model_bytes: &[u8], config: Option<InferenceConfig>) -> OcrResult<Self> {
167        let engine = InferenceEngine::from_buffer(model_bytes, config)?;
168        let options = OriOptions::default();
169        let mode = options.preprocess_mode;
170        Ok(Self {
171            engine,
172            options,
173            normalize_params: normalize_params_for_mode(mode),
174        })
175    }
176
177    /// Set classifier options
178    pub fn with_options(mut self, options: OriOptions) -> Self {
179        self.options = options;
180        self.normalize_params = normalize_params_for_mode(self.options.preprocess_mode);
181        self
182    }
183
184    /// Get current options
185    pub fn options(&self) -> &OriOptions {
186        &self.options
187    }
188
189    /// Modify options
190    pub fn options_mut(&mut self) -> &mut OriOptions {
191        &mut self.options
192    }
193
194    /// Classify a single text line image
195    pub fn classify(&self, image: &DynamicImage) -> OcrResult<OrientationResult> {
196        let input = preprocess_for_ori(
197            image,
198            self.options.target_height,
199            self.options.target_width,
200            self.options.resize_shorter,
201            self.options.preprocess_mode,
202            &self.normalize_params,
203        )?;
204
205        let output = self.engine.run_dynamic(input.view().into_dyn())?;
206        self.decode_output(&output)
207    }
208
209    fn decode_output(&self, output: &ArrayD<f32>) -> OcrResult<OrientationResult> {
210        let shape = output.shape();
211        if shape.is_empty() {
212            return Err(OcrError::PostprocessError(
213                "Orientation model output shape is empty".to_string(),
214            ));
215        }
216
217        let num_classes = *shape.last().unwrap_or(&0);
218        if num_classes == 0 {
219            return Err(OcrError::PostprocessError(
220                "Orientation model output classes is zero".to_string(),
221            ));
222        }
223
224        let output_data: Vec<f32> = output.iter().cloned().collect();
225        if output_data.is_empty() {
226            return Err(OcrError::PostprocessError(
227                "Orientation model output data is empty".to_string(),
228            ));
229        }
230
231        let scores_raw = if output_data.len() >= num_classes {
232            output_data[..num_classes].to_vec()
233        } else {
234            return Err(OcrError::PostprocessError(
235                "Orientation model output data size mismatch".to_string(),
236            ));
237        };
238
239        let scores = softmax(&scores_raw);
240        let (class_idx, &confidence) = scores
241            .iter()
242            .enumerate()
243            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
244            .ok_or_else(|| {
245                OcrError::PostprocessError(
246                    "Orientation model output has no valid scores".to_string(),
247                )
248            })?;
249
250        let angle = class_to_angle(num_classes, class_idx, &self.options.class_angles);
251        Ok(OrientationResult::new(class_idx, angle, confidence, scores))
252    }
253}
254
255/// Convert class index to angle in degrees (best effort mapping)
256fn class_to_angle(num_classes: usize, class_idx: usize, class_angles: &[i32]) -> i32 {
257    if class_angles.len() == num_classes {
258        return class_angles
259            .get(class_idx)
260            .copied()
261            .unwrap_or(class_idx as i32);
262    }
263
264    match num_classes {
265        2 => {
266            if class_idx == 0 {
267                0
268            } else {
269                180
270            }
271        }
272        4 => match class_idx {
273            0 => 0,
274            1 => 90,
275            2 => 180,
276            3 => 270,
277            _ => class_idx as i32,
278        },
279        _ => class_idx as i32,
280    }
281}
282
283fn softmax(scores: &[f32]) -> Vec<f32> {
284    if scores.is_empty() {
285        return Vec::new();
286    }
287
288    let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
289    let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
290    let sum_exp: f32 = exp_scores.iter().sum();
291
292    if sum_exp == 0.0 {
293        return vec![0.0; scores.len()];
294    }
295
296    exp_scores.into_iter().map(|v| v / sum_exp).collect()
297}
298
299fn normalize_params_for_mode(mode: OriPreprocessMode) -> NormalizeParams {
300    match mode {
301        OriPreprocessMode::Doc => NormalizeParams::paddle_det(),
302        OriPreprocessMode::Textline => NormalizeParams::paddle_rec(),
303    }
304}
305
306/// Preprocess image for orientation classification
307fn preprocess_for_ori(
308    img: &DynamicImage,
309    target_height: u32,
310    target_width: u32,
311    resize_shorter: u32,
312    mode: OriPreprocessMode,
313    params: &NormalizeParams,
314) -> OcrResult<Array4<f32>> {
315    if target_height == 0 || target_width == 0 {
316        return Err(OcrError::PreprocessError(
317            "Target size must be greater than zero".to_string(),
318        ));
319    }
320
321    let processed = match mode {
322        OriPreprocessMode::Textline => {
323            let (w, h) = img.dimensions();
324            let ratio = w as f32 / h.max(1) as f32;
325            let mut resize_w = (target_height as f32 * ratio).round() as u32;
326            if resize_w == 0 {
327                resize_w = 1;
328            }
329            if resize_w > target_width {
330                resize_w = target_width;
331            }
332
333            img.resize_exact(
334                resize_w,
335                target_height,
336                image::imageops::FilterType::Lanczos3,
337            )
338        }
339        OriPreprocessMode::Doc => {
340            let (w, h) = img.dimensions();
341            let shorter = w.min(h).max(1) as f32;
342            let scale = resize_shorter as f32 / shorter;
343            let new_w = (w as f32 * scale).round().max(1.0) as u32;
344            let new_h = (h as f32 * scale).round().max(1.0) as u32;
345            let resized = img.resize_exact(new_w, new_h, image::imageops::FilterType::Lanczos3);
346
347            if new_w < target_width || new_h < target_height {
348                resized.resize_exact(
349                    target_width,
350                    target_height,
351                    image::imageops::FilterType::Lanczos3,
352                )
353            } else {
354                let left = (new_w - target_width) / 2;
355                let top = (new_h - target_height) / 2;
356                resized.crop_imm(left, top, target_width, target_height)
357            }
358        }
359    };
360
361    let rgb_img = processed.to_rgb8();
362    let (proc_w, proc_h) = processed.dimensions();
363
364    let mut input = Array4::<f32>::zeros((1, 3, target_height as usize, target_width as usize));
365
366    let max_y = proc_h.min(target_height) as usize;
367    let max_x = proc_w.min(target_width) as usize;
368
369    for y in 0..max_y {
370        for x in 0..max_x {
371            let pixel = rgb_img.get_pixel(x as u32, y as u32);
372            let [r, g, b] = pixel.0;
373
374            // Paddle models use BGR channel order in most preprocessing pipelines.
375            input[[0, 0, y, x]] = (b as f32 / 255.0 - params.mean[0]) / params.std[0];
376            input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
377            input[[0, 2, y, x]] = (r as f32 / 255.0 - params.mean[2]) / params.std[2];
378        }
379    }
380
381    Ok(input)
382}
383
384/// Low-level orientation API
385impl OriModel {
386    /// Raw inference interface
387    pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
388        Ok(self.engine.run_dynamic(input)?)
389    }
390
391    /// Get model input shape
392    pub fn input_shape(&self) -> &[usize] {
393        self.engine.input_shape()
394    }
395
396    /// Get model output shape
397    pub fn output_shape(&self) -> &[usize] {
398        self.engine.output_shape()
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn test_ori_options_default() {
408        let opts = OriOptions::default();
409        assert_eq!(opts.target_height, 224);
410        assert_eq!(opts.target_width, 224);
411        assert_eq!(opts.min_score, 0.5);
412        assert_eq!(opts.resize_shorter, 256);
413        assert_eq!(opts.preprocess_mode, OriPreprocessMode::Doc);
414        assert_eq!(opts.class_angles, vec![0, 90, 180, 270]);
415    }
416
417    #[test]
418    fn test_ori_options_builder() {
419        let opts = OriOptions::new()
420            .with_target_height(32)
421            .with_target_width(128)
422            .with_min_score(0.7)
423            .with_resize_shorter(200)
424            .with_preprocess_mode(OriPreprocessMode::Textline)
425            .with_class_angles(vec![0, 180]);
426
427        assert_eq!(opts.target_height, 32);
428        assert_eq!(opts.target_width, 128);
429        assert_eq!(opts.min_score, 0.7);
430        assert_eq!(opts.resize_shorter, 200);
431        assert_eq!(opts.preprocess_mode, OriPreprocessMode::Textline);
432        assert_eq!(opts.class_angles, vec![0, 180]);
433    }
434
435    #[test]
436    fn test_class_to_angle_mapping() {
437        let angles_4 = vec![0, 90, 180, 270];
438        let angles_2 = vec![0, 180];
439        assert_eq!(class_to_angle(2, 0, &angles_2), 0);
440        assert_eq!(class_to_angle(2, 1, &angles_2), 180);
441        assert_eq!(class_to_angle(4, 0, &angles_4), 0);
442        assert_eq!(class_to_angle(4, 1, &angles_4), 90);
443        assert_eq!(class_to_angle(4, 2, &angles_4), 180);
444        assert_eq!(class_to_angle(4, 3, &angles_4), 270);
445        assert_eq!(class_to_angle(3, 2, &angles_2), 2);
446    }
447
448    #[test]
449    fn test_preprocess_for_ori_shape() {
450        let img = DynamicImage::new_rgb8(100, 32);
451        let params = NormalizeParams::paddle_det();
452        let tensor =
453            preprocess_for_ori(&img, 224, 224, 256, OriPreprocessMode::Doc, &params).unwrap();
454        assert_eq!(tensor.shape(), &[1, 3, 224, 224]);
455    }
456}