paddle_ocr_rs/
base_net.rs

1use ort::session::{
2    builder::{GraphOptimizationLevel, SessionBuilder},
3    Session,
4};
5
6use crate::ocr_error::OcrError;
7
8pub trait BaseNet {
9    fn new() -> Self;
10
11    fn get_session_builder(
12        &self,
13        num_thread: usize,
14        builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
15    ) -> Result<SessionBuilder, OcrError> {
16        let builder = Session::builder()?;
17        let builder = match builder_fn {
18            Some(custom) => custom(builder)?,
19            None => builder
20                .with_optimization_level(GraphOptimizationLevel::Level2)?
21                .with_intra_threads(num_thread)?
22                .with_inter_threads(num_thread)?,
23        };
24
25        Ok(builder)
26    }
27
28    fn set_input_names(&mut self, input_names: Vec<String>);
29    fn set_session(&mut self, session: Option<Session>);
30
31    fn init(&mut self, session: Session) {
32        let input_names: Vec<String> = session
33            .inputs
34            .iter()
35            .map(|input| input.name.clone())
36            .collect();
37
38        self.set_input_names(input_names);
39        self.set_session(Some(session));
40    }
41
42    fn init_model(
43        &mut self,
44        path: &str,
45        num_thread: usize,
46        builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
47    ) -> Result<(), OcrError> {
48        let session = self
49            .get_session_builder(num_thread, builder_fn)?
50            .commit_from_file(path)?;
51        self.init(session);
52
53        Ok(())
54    }
55
56    fn init_model_from_memory(
57        &mut self,
58        model_bytes: &[u8],
59        num_thread: usize,
60        builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
61    ) -> Result<(), OcrError> {
62        let session = self
63            .get_session_builder(num_thread, builder_fn)?
64            .commit_from_memory(model_bytes)?;
65
66        self.init(session);
67
68        Ok(())
69    }
70}