Skip to main content

ocr_rs/
ori.rs

1//! Textline orientation classification model
2//!
3//! Provides textline orientation classification based on PP-LCNet_x1_0_textline_ori
4
5use image::{DynamicImage, GenericImageView};
6use ndarray::{Array4, ArrayD};
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::preprocess::NormalizeParams;
12
13/// Orientation preprocessing mode
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum OriPreprocessMode {
16    /// Document orientation (PP-LCNet_x1_0_doc_ori)
17    Doc,
18    /// Textline orientation (PP-LCNet_x1_0_textline_ori)
19    Textline,
20}
21
22/// Orientation classification result
23#[derive(Debug, Clone)]
24pub struct OrientationResult {
25    /// Predicted class index
26    pub class_idx: usize,
27    /// Predicted angle in degrees (best effort mapping)
28    pub angle: i32,
29    /// Confidence score (softmax probability)
30    pub confidence: f32,
31    /// Scores for each class (softmax probabilities)
32    pub scores: Vec<f32>,
33}
34
35impl OrientationResult {
36    /// Create new orientation result
37    pub fn new(class_idx: usize, angle: i32, confidence: f32, scores: Vec<f32>) -> Self {
38        Self {
39            class_idx,
40            angle,
41            confidence,
42            scores,
43        }
44    }
45
46    /// Check if result is valid by confidence threshold
47    pub fn is_valid(&self, threshold: f32) -> bool {
48        self.confidence >= threshold
49    }
50}
51
52/// Orientation model options
53#[derive(Debug, Clone)]
54pub struct OriOptions {
55    /// Target input height
56    pub target_height: u32,
57    /// Target input width
58    pub target_width: u32,
59    /// Minimum confidence threshold (for caller-side filtering)
60    pub min_score: f32,
61    /// Shorter side resize for document mode
62    pub resize_shorter: u32,
63    /// Preprocess mode
64    pub preprocess_mode: OriPreprocessMode,
65    /// Class index to angle mapping
66    pub class_angles: Vec<i32>,
67}
68
69impl Default for OriOptions {
70    fn default() -> Self {
71        Self {
72            target_height: 224,
73            target_width: 224,
74            min_score: 0.5,
75            resize_shorter: 256,
76            preprocess_mode: OriPreprocessMode::Doc,
77            class_angles: vec![0, 90, 180, 270],
78        }
79    }
80}
81
82impl OriOptions {
83    /// Create new options
84    pub fn new() -> Self {
85        Self::default()
86    }
87
88    /// Preset for document orientation models
89    pub fn doc() -> Self {
90        Self::default()
91    }
92
93    /// Preset for textline orientation models
94    pub fn textline() -> Self {
95        Self {
96            target_height: 48,
97            target_width: 192,
98            min_score: 0.5,
99            resize_shorter: 256,
100            preprocess_mode: OriPreprocessMode::Textline,
101            class_angles: vec![0, 180],
102        }
103    }
104
105    /// Set target height
106    pub fn with_target_height(mut self, height: u32) -> Self {
107        self.target_height = height;
108        self
109    }
110
111    /// Set target width
112    pub fn with_target_width(mut self, width: u32) -> Self {
113        self.target_width = width;
114        self
115    }
116
117    /// Set minimum confidence threshold
118    pub fn with_min_score(mut self, score: f32) -> Self {
119        self.min_score = score;
120        self
121    }
122
123    /// Set resize shorter side (document mode)
124    pub fn with_resize_shorter(mut self, size: u32) -> Self {
125        self.resize_shorter = size;
126        self
127    }
128
129    /// Set preprocess mode
130    pub fn with_preprocess_mode(mut self, mode: OriPreprocessMode) -> Self {
131        self.preprocess_mode = mode;
132        self
133    }
134
135    /// Set class index to angle mapping
136    pub fn with_class_angles(mut self, angles: Vec<i32>) -> Self {
137        self.class_angles = angles;
138        self
139    }
140}
141
142/// Textline orientation classification model
143pub struct OriModel {
144    engine: InferenceEngine,
145    options: OriOptions,
146    normalize_params: NormalizeParams,
147}
148
149impl OriModel {
150    /// Create orientation classifier from model file
151    pub fn from_file(
152        model_path: impl AsRef<Path>,
153        config: Option<InferenceConfig>,
154    ) -> OcrResult<Self> {
155        let engine = InferenceEngine::from_file(model_path, config)?;
156        let options = OriOptions::default();
157        let mode = options.preprocess_mode;
158        Ok(Self {
159            engine,
160            options,
161            normalize_params: normalize_params_for_mode(mode),
162        })
163    }
164
165    /// Create orientation classifier from model bytes
166    pub fn from_bytes(model_bytes: &[u8], config: Option<InferenceConfig>) -> OcrResult<Self> {
167        let engine = InferenceEngine::from_buffer(model_bytes, config)?;
168        let options = OriOptions::default();
169        let mode = options.preprocess_mode;
170        Ok(Self {
171            engine,
172            options,
173            normalize_params: normalize_params_for_mode(mode),
174        })
175    }
176
177    /// Set classifier options
178    pub fn with_options(mut self, options: OriOptions) -> Self {
179        self.options = options;
180        self.normalize_params = normalize_params_for_mode(self.options.preprocess_mode);
181        self
182    }
183
184    /// Get current options
185    pub fn options(&self) -> &OriOptions {
186        &self.options
187    }
188
189    /// Modify options
190    pub fn options_mut(&mut self) -> &mut OriOptions {
191        &mut self.options
192    }
193
194    /// Classify a single text line image
195    pub fn classify(&self, image: &DynamicImage) -> OcrResult<OrientationResult> {
196        let input = preprocess_for_ori(
197            image,
198            self.options.target_height,
199            self.options.target_width,
200            self.options.resize_shorter,
201            self.options.preprocess_mode,
202            &self.normalize_params,
203        )?;
204
205        let output = self.engine.run_dynamic(input.view().into_dyn())?;
206        self.decode_output(&output)
207    }
208
209    fn decode_output(&self, output: &ArrayD<f32>) -> OcrResult<OrientationResult> {
210        let shape = output.shape();
211        if shape.is_empty() {
212            return Err(OcrError::PostprocessError(
213                "Orientation model output shape is empty".to_string(),
214            ));
215        }
216
217        let num_classes = *shape.last().unwrap_or(&0);
218        if num_classes == 0 {
219            return Err(OcrError::PostprocessError(
220                "Orientation model output classes is zero".to_string(),
221            ));
222        }
223
224        let output_data: Vec<f32> = output.iter().cloned().collect();
225        if output_data.is_empty() {
226            return Err(OcrError::PostprocessError(
227                "Orientation model output data is empty".to_string(),
228            ));
229        }
230
231        let scores_raw = if output_data.len() >= num_classes {
232            output_data[..num_classes].to_vec()
233        } else {
234            return Err(OcrError::PostprocessError(
235                "Orientation model output data size mismatch".to_string(),
236            ));
237        };
238
239        let scores = softmax(&scores_raw);
240        let (class_idx, &confidence) = scores
241            .iter()
242            .enumerate()
243            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
244            .ok_or_else(|| {
245                OcrError::PostprocessError(
246                    "Orientation model output has no valid scores".to_string(),
247                )
248            })?;
249
250        let angle = class_to_angle(num_classes, class_idx, &self.options.class_angles);
251        Ok(OrientationResult::new(class_idx, angle, confidence, scores))
252    }
253}
254
255/// Convert class index to angle in degrees (best effort mapping)
256fn class_to_angle(num_classes: usize, class_idx: usize, class_angles: &[i32]) -> i32 {
257    if class_angles.len() == num_classes {
258        return class_angles.get(class_idx).copied().unwrap_or(class_idx as i32);
259    }
260
261    match num_classes {
262        2 => {
263            if class_idx == 0 {
264                0
265            } else {
266                180
267            }
268        }
269        4 => match class_idx {
270            0 => 0,
271            1 => 90,
272            2 => 180,
273            3 => 270,
274            _ => class_idx as i32,
275        },
276        _ => class_idx as i32,
277    }
278}
279
280fn softmax(scores: &[f32]) -> Vec<f32> {
281    if scores.is_empty() {
282        return Vec::new();
283    }
284
285    let max_score = scores
286        .iter()
287        .cloned()
288        .fold(f32::NEG_INFINITY, f32::max);
289    let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
290    let sum_exp: f32 = exp_scores.iter().sum();
291
292    if sum_exp == 0.0 {
293        return vec![0.0; scores.len()];
294    }
295
296    exp_scores.into_iter().map(|v| v / sum_exp).collect()
297}
298
299fn normalize_params_for_mode(mode: OriPreprocessMode) -> NormalizeParams {
300    match mode {
301        OriPreprocessMode::Doc => NormalizeParams::paddle_det(),
302        OriPreprocessMode::Textline => NormalizeParams::paddle_rec(),
303    }
304}
305
306/// Preprocess image for orientation classification
307fn preprocess_for_ori(
308    img: &DynamicImage,
309    target_height: u32,
310    target_width: u32,
311    resize_shorter: u32,
312    mode: OriPreprocessMode,
313    params: &NormalizeParams,
314) -> OcrResult<Array4<f32>> {
315    if target_height == 0 || target_width == 0 {
316        return Err(OcrError::PreprocessError(
317            "Target size must be greater than zero".to_string(),
318        ));
319    }
320
321    let processed = match mode {
322        OriPreprocessMode::Textline => {
323            let (w, h) = img.dimensions();
324            let ratio = w as f32 / h.max(1) as f32;
325            let mut resize_w = (target_height as f32 * ratio).round() as u32;
326            if resize_w == 0 {
327                resize_w = 1;
328            }
329            if resize_w > target_width {
330                resize_w = target_width;
331            }
332
333            img.resize_exact(
334                resize_w,
335                target_height,
336                image::imageops::FilterType::Lanczos3,
337            )
338        }
339        OriPreprocessMode::Doc => {
340            let (w, h) = img.dimensions();
341            let shorter = w.min(h).max(1) as f32;
342            let scale = resize_shorter as f32 / shorter;
343            let new_w = (w as f32 * scale).round().max(1.0) as u32;
344            let new_h = (h as f32 * scale).round().max(1.0) as u32;
345            let resized = img.resize_exact(
346                new_w,
347                new_h,
348                image::imageops::FilterType::Lanczos3,
349            );
350
351            if new_w < target_width || new_h < target_height {
352                resized.resize_exact(
353                    target_width,
354                    target_height,
355                    image::imageops::FilterType::Lanczos3,
356                )
357            } else {
358                let left = (new_w - target_width) / 2;
359                let top = (new_h - target_height) / 2;
360                resized.crop_imm(left, top, target_width, target_height)
361            }
362        }
363    };
364
365    let rgb_img = processed.to_rgb8();
366    let (proc_w, proc_h) = processed.dimensions();
367
368    let mut input = Array4::<f32>::zeros((
369        1,
370        3,
371        target_height as usize,
372        target_width as usize,
373    ));
374
375    let max_y = proc_h.min(target_height) as usize;
376    let max_x = proc_w.min(target_width) as usize;
377
378    for y in 0..max_y {
379        for x in 0..max_x {
380            let pixel = rgb_img.get_pixel(x as u32, y as u32);
381            let [r, g, b] = pixel.0;
382
383            // Paddle models use BGR channel order in most preprocessing pipelines.
384            input[[0, 0, y, x]] = (b as f32 / 255.0 - params.mean[0]) / params.std[0];
385            input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
386            input[[0, 2, y, x]] = (r as f32 / 255.0 - params.mean[2]) / params.std[2];
387        }
388    }
389
390    Ok(input)
391}
392
393/// Low-level orientation API
394impl OriModel {
395    /// Raw inference interface
396    pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
397        Ok(self.engine.run_dynamic(input)?)
398    }
399
400    /// Get model input shape
401    pub fn input_shape(&self) -> &[usize] {
402        self.engine.input_shape()
403    }
404
405    /// Get model output shape
406    pub fn output_shape(&self) -> &[usize] {
407        self.engine.output_shape()
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_ori_options_default() {
417        let opts = OriOptions::default();
418        assert_eq!(opts.target_height, 224);
419        assert_eq!(opts.target_width, 224);
420        assert_eq!(opts.min_score, 0.5);
421        assert_eq!(opts.resize_shorter, 256);
422        assert_eq!(opts.preprocess_mode, OriPreprocessMode::Doc);
423        assert_eq!(opts.class_angles, vec![0, 90, 180, 270]);
424    }
425
426    #[test]
427    fn test_ori_options_builder() {
428        let opts = OriOptions::new()
429            .with_target_height(32)
430            .with_target_width(128)
431            .with_min_score(0.7)
432            .with_resize_shorter(200)
433            .with_preprocess_mode(OriPreprocessMode::Textline)
434            .with_class_angles(vec![0, 180]);
435
436        assert_eq!(opts.target_height, 32);
437        assert_eq!(opts.target_width, 128);
438        assert_eq!(opts.min_score, 0.7);
439        assert_eq!(opts.resize_shorter, 200);
440        assert_eq!(opts.preprocess_mode, OriPreprocessMode::Textline);
441        assert_eq!(opts.class_angles, vec![0, 180]);
442    }
443
444    #[test]
445    fn test_class_to_angle_mapping() {
446        let angles_4 = vec![0, 90, 180, 270];
447        let angles_2 = vec![0, 180];
448        assert_eq!(class_to_angle(2, 0, &angles_2), 0);
449        assert_eq!(class_to_angle(2, 1, &angles_2), 180);
450        assert_eq!(class_to_angle(4, 0, &angles_4), 0);
451        assert_eq!(class_to_angle(4, 1, &angles_4), 90);
452        assert_eq!(class_to_angle(4, 2, &angles_4), 180);
453        assert_eq!(class_to_angle(4, 3, &angles_4), 270);
454        assert_eq!(class_to_angle(3, 2, &angles_2), 2);
455    }
456
457    #[test]
458    fn test_preprocess_for_ori_shape() {
459        let img = DynamicImage::new_rgb8(100, 32);
460        let params = NormalizeParams::paddle_det();
461        let tensor = preprocess_for_ori(
462            &img,
463            224,
464            224,
465            256,
466            OriPreprocessMode::Doc,
467            &params,
468        )
469        .unwrap();
470        assert_eq!(tensor.shape(), &[1, 3, 224, 224]);
471    }
472}