paddle_ocr_rs/
base_net.rs1use ort::session::{
2 builder::{GraphOptimizationLevel, SessionBuilder},
3 Session,
4};
5
6use crate::ocr_error::OcrError;
7
8pub trait BaseNet {
9 fn new() -> Self;
10
11 fn get_session_builder(
12 &self,
13 num_thread: usize,
14 builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
15 ) -> Result<SessionBuilder, OcrError> {
16 let builder = Session::builder()?;
17 let builder = match builder_fn {
18 Some(custom) => custom(builder)?,
19 None => builder
20 .with_optimization_level(GraphOptimizationLevel::Level2)?
21 .with_intra_threads(num_thread)?
22 .with_inter_threads(num_thread)?,
23 };
24
25 Ok(builder)
26 }
27
28 fn set_input_names(&mut self, input_names: Vec<String>);
29 fn set_session(&mut self, session: Option<Session>);
30
31 fn init(&mut self, session: Session) {
32 let input_names: Vec<String> = session
33 .inputs
34 .iter()
35 .map(|input| input.name.clone())
36 .collect();
37
38 self.set_input_names(input_names);
39 self.set_session(Some(session));
40 }
41
42 fn init_model(
43 &mut self,
44 path: &str,
45 num_thread: usize,
46 builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
47 ) -> Result<(), OcrError> {
48 let session = self
49 .get_session_builder(num_thread, builder_fn)?
50 .commit_from_file(path)?;
51 self.init(session);
52
53 Ok(())
54 }
55
56 fn init_model_from_memory(
57 &mut self,
58 model_bytes: &[u8],
59 num_thread: usize,
60 builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
61 ) -> Result<(), OcrError> {
62 let session = self
63 .get_session_builder(num_thread, builder_fn)?
64 .commit_from_memory(model_bytes)?;
65
66 self.init(session);
67
68 Ok(())
69 }
70}