paddle_ocr_rs/
angle_net.rs1use crate::{base_net::BaseNet, ocr_error::OcrError, ocr_result::Angle, ocr_utils::OcrUtils};
2
3use ort::{
4 inputs,
5 session::{Session, SessionOutputs},
6 value::Tensor,
7};
8
9const MEAN_VALUES: [f32; 3] = [127.5, 127.5, 127.5];
10const NORM_VALUES: [f32; 3] = [1.0 / 127.5, 1.0 / 127.5, 1.0 / 127.5];
11const ANGLE_DST_WIDTH: u32 = 192;
12const ANGLE_DST_HEIGHT: u32 = 48;
13const ANGLE_COLS: usize = 2;
14
15#[derive(Debug)]
16pub struct AngleNet {
17 session: Option<Session>,
18 input_names: Vec<String>,
19}
20
21impl BaseNet for AngleNet {
22 fn new() -> Self {
23 Self {
24 session: None,
25 input_names: Vec::new(),
26 }
27 }
28
29 fn set_input_names(&mut self, input_names: Vec<String>) {
30 self.input_names = input_names;
31 }
32
33 fn set_session(&mut self, session: Option<Session>) {
34 self.session = session;
35 }
36}
37
38impl AngleNet {
39 pub fn get_angles(
40 &mut self,
41 part_imgs: &[image::RgbImage],
42 do_angle: bool,
43 most_angle: bool,
44 ) -> Result<Vec<Angle>, OcrError> {
45 let mut angles = Vec::new();
46
47 if do_angle {
48 for img in part_imgs {
49 let angle = self.get_angle(img)?;
50 angles.push(angle);
51 }
52 } else {
53 angles.extend(part_imgs.iter().map(|_| Angle::default()));
54 }
55
56 if do_angle && most_angle {
57 let sum: i32 = angles.iter().map(|x| x.index).sum();
58 let half_percent = angles.len() as f32 / 2.0;
59 let most_angle_index = if (sum as f32) < half_percent { 0 } else { 1 };
60
61 for angle in angles.iter_mut() {
62 angle.index = most_angle_index;
63 }
64 }
65
66 Ok(angles)
67 }
68
69 fn get_angle(&mut self, img_src: &image::RgbImage) -> Result<Angle, OcrError> {
70 let Some(session) = &mut self.session else {
71 return Err(OcrError::SessionNotInitialized);
72 };
73
74 let angle_img = image::imageops::resize(
75 img_src,
76 ANGLE_DST_WIDTH,
77 ANGLE_DST_HEIGHT,
78 image::imageops::FilterType::Triangle,
79 );
80
81 let input_tensors =
82 OcrUtils::substract_mean_normalize(&angle_img, &MEAN_VALUES, &NORM_VALUES);
83
84 let input_tensors = Tensor::from_array(input_tensors)?;
85
86 let outputs = session.run(inputs![self.input_names[0].clone() => input_tensors])?;
87
88 let angle = Self::score_to_angle(&outputs, ANGLE_COLS)?;
89
90 Ok(angle)
91 }
92
93 fn score_to_angle(
94 output_tensor: &SessionOutputs,
95 angle_cols: usize,
96 ) -> Result<Angle, OcrError> {
97 let (_, red_data) = output_tensor.iter().next().unwrap();
98
99 let src_data: Vec<f32> = red_data.try_extract_tensor::<f32>()?.1.to_vec();
100
101 let mut angle = Angle::default();
102 let mut max_value = f32::MIN;
103 let mut angle_index = 0;
104
105 for (i, value) in src_data.iter().take(angle_cols).enumerate() {
106 if i == 0 || value > &max_value {
107 max_value = *value;
108 angle_index = i as i32;
109 }
110 }
111
112 angle.index = angle_index;
113 angle.score = max_value;
114 Ok(angle)
115 }
116}