oar_ocr_core/core/inference/
mod.rs1use crate::core::{
7 batch::{Tensor2D, Tensor3D, Tensor4D},
8 errors::OCRError,
9};
10use ort::{session::Session, value::ValueType};
11use std::sync::Mutex;
12
13pub mod session;
14pub mod wrappers;
15
16#[path = "ort_infer_builders.rs"]
18mod ort_infer_builders;
19#[path = "ort_infer_config.rs"]
20mod ort_infer_config;
21#[path = "ort_infer_execution.rs"]
22mod ort_infer_execution;
23
24pub use session::load_session;
25pub use wrappers::{OrtInfer2D, OrtInfer3D, OrtInfer4D};
26
27pub struct OrtInfer {
29 pub(self) sessions: Vec<Mutex<Session>>,
30 pub(self) next_idx: std::sync::atomic::AtomicUsize,
31 pub(self) input_name: String,
32 pub(self) output_name: Option<String>,
33 pub(self) model_path: std::path::PathBuf,
34 pub(self) model_name: String,
35}
36
37impl std::fmt::Debug for OrtInfer {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("OrtInfer")
40 .field("sessions", &self.sessions.len())
41 .field("input_name", &self.input_name)
42 .field("output_name", &self.output_name)
43 .field("model_path", &self.model_path)
44 .field("model_name", &self.model_name)
45 .finish()
46 }
47}
48
49impl OrtInfer {
50 pub fn input_name(&self) -> &str {
52 &self.input_name
53 }
54
55 pub fn get_session(&self, idx: usize) -> Result<std::sync::MutexGuard<'_, Session>, OCRError> {
57 self.sessions[idx % self.sessions.len()]
58 .lock()
59 .map_err(|_| OCRError::ConfigError {
60 message: "Failed to acquire session lock".to_string(),
61 })
62 }
63
64 pub fn primary_input_shape(&self) -> Option<Vec<i64>> {
68 let session_mutex = self.sessions.first()?;
69 let session_guard = session_mutex.lock().ok()?;
70 let input = session_guard.inputs().first()?;
71 match input.dtype() {
72 ValueType::Tensor { shape, .. } => Some(shape.iter().copied().collect()),
73 _ => None,
74 }
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::core::config::ModelInferenceConfig;
82
83 #[test]
84 fn test_from_config_with_ort_session() {
85 let common = ModelInferenceConfig::new();
86 let result = OrtInfer::from_config(&common, "dummy_path.onnx", None);
87 assert!(result.is_err()); }
89}