use crate::core::errors::OCRError;
use ort::{session::Session, value::ValueType};
use std::sync::Mutex;
pub mod session;
mod tensor_output;
#[path = "ort_infer_builders.rs"]
mod ort_infer_builders;
#[path = "ort_infer_config.rs"]
mod ort_infer_config;
#[path = "ort_infer_execution.rs"]
mod ort_infer_execution;
pub use ort_infer_execution::TensorInput;
pub use session::load_session;
pub use tensor_output::TensorOutput;
pub struct OrtInfer {
pub(self) sessions: Vec<Mutex<Session>>,
pub(self) next_idx: std::sync::atomic::AtomicUsize,
pub(self) input_name: String,
pub(self) model_path: std::path::PathBuf,
pub(self) model_name: String,
}
impl std::fmt::Debug for OrtInfer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OrtInfer")
.field("sessions", &self.sessions.len())
.field("input_name", &self.input_name)
.field("model_path", &self.model_path)
.field("model_name", &self.model_name)
.finish()
}
}
impl OrtInfer {
pub fn input_name(&self) -> &str {
&self.input_name
}
pub fn get_session(&self, idx: usize) -> Result<std::sync::MutexGuard<'_, Session>, OCRError> {
self.sessions[idx % self.sessions.len()]
.lock()
.map_err(|_| OCRError::ConfigError {
message: "Failed to acquire session lock".to_string(),
})
}
pub fn input_names_from_model(&self) -> Vec<String> {
let Some(session_mutex) = self.sessions.first() else {
return Vec::new();
};
let Ok(session_guard) = session_mutex.lock() else {
return Vec::new();
};
session_guard
.inputs()
.iter()
.map(|i| i.name().to_string())
.collect()
}
pub fn primary_input_shape(&self) -> Option<Vec<i64>> {
let session_mutex = self.sessions.first()?;
let session_guard = session_mutex.lock().ok()?;
let input = session_guard.inputs().first()?;
match input.dtype() {
ValueType::Tensor { shape, .. } => Some(shape.iter().copied().collect()),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::config::ModelInferenceConfig;
#[test]
fn test_from_config_with_ort_session() {
let common = ModelInferenceConfig::new();
let result = OrtInfer::from_config(&common, "dummy_path.onnx", None);
assert!(result.is_err()); }
}