Skip to main content

oar_ocr_core/core/inference/
mod.rs

1//! Structures and helpers for ONNX Runtime inference.
2//!
3//! This module centralizes the low level inference engine along with thin wrappers
4//! that adapt it to the `InferenceEngine` trait used across the pipeline.
5
6use crate::core::{
7    batch::{Tensor2D, Tensor3D, Tensor4D},
8    errors::OCRError,
9};
10use ort::{session::Session, value::ValueType};
11use std::sync::Mutex;
12
13pub mod session;
14pub mod wrappers;
15
16// OrtInfer implementation modules
17#[path = "ort_infer_builders.rs"]
18mod ort_infer_builders;
19#[path = "ort_infer_config.rs"]
20mod ort_infer_config;
21#[path = "ort_infer_execution.rs"]
22mod ort_infer_execution;
23
24pub use session::load_session;
25pub use wrappers::{OrtInfer2D, OrtInfer3D, OrtInfer4D};
26
27/// Core ONNX Runtime inference engine with support for pooling and configurable sessions.
28pub struct OrtInfer {
29    pub(self) sessions: Vec<Mutex<Session>>,
30    pub(self) next_idx: std::sync::atomic::AtomicUsize,
31    pub(self) input_name: String,
32    pub(self) output_name: Option<String>,
33    pub(self) model_path: std::path::PathBuf,
34    pub(self) model_name: String,
35}
36
37impl std::fmt::Debug for OrtInfer {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("OrtInfer")
40            .field("sessions", &self.sessions.len())
41            .field("input_name", &self.input_name)
42            .field("output_name", &self.output_name)
43            .field("model_path", &self.model_path)
44            .field("model_name", &self.model_name)
45            .finish()
46    }
47}
48
49impl OrtInfer {
50    /// Returns the input tensor name.
51    pub fn input_name(&self) -> &str {
52        &self.input_name
53    }
54
55    /// Gets a session from the pool.
56    pub fn get_session(&self, idx: usize) -> Result<std::sync::MutexGuard<'_, Session>, OCRError> {
57        self.sessions[idx % self.sessions.len()]
58            .lock()
59            .map_err(|_| OCRError::ConfigError {
60                message: "Failed to acquire session lock".to_string(),
61            })
62    }
63
64    /// Attempts to retrieve the primary input tensor shape from the first session.
65    ///
66    /// Returns a vector of dimensions if available. Dynamic dimensions (e.g., -1) are returned as-is.
67    pub fn primary_input_shape(&self) -> Option<Vec<i64>> {
68        let session_mutex = self.sessions.first()?;
69        let session_guard = session_mutex.lock().ok()?;
70        let input = session_guard.inputs().first()?;
71        match input.dtype() {
72            ValueType::Tensor { shape, .. } => Some(shape.iter().copied().collect()),
73            _ => None,
74        }
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::core::config::ModelInferenceConfig;
82
83    #[test]
84    fn test_from_config_with_ort_session() {
85        let common = ModelInferenceConfig::new();
86        let result = OrtInfer::from_config(&common, "dummy_path.onnx", None);
87        assert!(result.is_err()); // File doesn't exist
88    }
89}