paddle_ocr_rs/
angle_net.rs

1use crate::{base_net::BaseNet, ocr_error::OcrError, ocr_result::Angle, ocr_utils::OcrUtils};
2
3use ort::{
4    inputs,
5    session::{Session, SessionOutputs},
6    value::Tensor,
7};
8
9const MEAN_VALUES: [f32; 3] = [127.5, 127.5, 127.5];
10const NORM_VALUES: [f32; 3] = [1.0 / 127.5, 1.0 / 127.5, 1.0 / 127.5];
11const ANGLE_DST_WIDTH: u32 = 192;
12const ANGLE_DST_HEIGHT: u32 = 48;
13const ANGLE_COLS: usize = 2;
14
15#[derive(Debug)]
16pub struct AngleNet {
17    session: Option<Session>,
18    input_names: Vec<String>,
19}
20
21impl BaseNet for AngleNet {
22    fn new() -> Self {
23        Self {
24            session: None,
25            input_names: Vec::new(),
26        }
27    }
28
29    fn set_input_names(&mut self, input_names: Vec<String>) {
30        self.input_names = input_names;
31    }
32
33    fn set_session(&mut self, session: Option<Session>) {
34        self.session = session;
35    }
36}
37
38impl AngleNet {
39    pub fn get_angles(
40        &mut self,
41        part_imgs: &[image::RgbImage],
42        do_angle: bool,
43        most_angle: bool,
44    ) -> Result<Vec<Angle>, OcrError> {
45        let mut angles = Vec::new();
46
47        if do_angle {
48            for img in part_imgs {
49                let angle = self.get_angle(img)?;
50                angles.push(angle);
51            }
52        } else {
53            angles.extend(part_imgs.iter().map(|_| Angle::default()));
54        }
55
56        if do_angle && most_angle {
57            let sum: i32 = angles.iter().map(|x| x.index).sum();
58            let half_percent = angles.len() as f32 / 2.0;
59            let most_angle_index = if (sum as f32) < half_percent { 0 } else { 1 };
60
61            for angle in angles.iter_mut() {
62                angle.index = most_angle_index;
63            }
64        }
65
66        Ok(angles)
67    }
68
69    fn get_angle(&mut self, img_src: &image::RgbImage) -> Result<Angle, OcrError> {
70        let Some(session) = &mut self.session else {
71            return Err(OcrError::SessionNotInitialized);
72        };
73
74        let angle_img = image::imageops::resize(
75            img_src,
76            ANGLE_DST_WIDTH,
77            ANGLE_DST_HEIGHT,
78            image::imageops::FilterType::Triangle,
79        );
80
81        let input_tensors =
82            OcrUtils::substract_mean_normalize(&angle_img, &MEAN_VALUES, &NORM_VALUES);
83
84        let input_tensors = Tensor::from_array(input_tensors)?;
85
86        let outputs = session.run(inputs![self.input_names[0].clone() => input_tensors])?;
87
88        let angle = Self::score_to_angle(&outputs, ANGLE_COLS)?;
89
90        Ok(angle)
91    }
92
93    fn score_to_angle(
94        output_tensor: &SessionOutputs,
95        angle_cols: usize,
96    ) -> Result<Angle, OcrError> {
97        let (_, red_data) = output_tensor.iter().next().unwrap();
98
99        let src_data: Vec<f32> = red_data.try_extract_tensor::<f32>()?.1.to_vec();
100
101        let mut angle = Angle::default();
102        let mut max_value = f32::MIN;
103        let mut angle_index = 0;
104
105        for (i, value) in src_data.iter().take(angle_cols).enumerate() {
106            if i == 0 || value > &max_value {
107                max_value = *value;
108                angle_index = i as i32;
109            }
110        }
111
112        angle.index = angle_index;
113        angle.score = max_value;
114        Ok(angle)
115    }
116}