oar_ocr/core/inference/
wrappers.rs

1//! Small helpers that wrap OrtInfer into concrete dimensional InferenceEngine implementations.
2
3use super::ort_infer::OrtInfer;
4use crate::core::{InferenceEngine as GInferenceEngine, OCRError, Tensor2D, Tensor3D, Tensor4D};
5
6#[derive(Debug)]
7pub struct OrtInfer2D(OrtInfer);
8
9impl OrtInfer2D {
10    /// Creates a new OrtInfer2D wrapper around an OrtInfer instance.
11    pub fn new(inner: OrtInfer) -> Self {
12        Self(inner)
13    }
14
15    /// Returns a reference to the inner OrtInfer instance.
16    pub fn inner(&self) -> &OrtInfer {
17        &self.0
18    }
19
20    /// Returns a mutable reference to the inner OrtInfer instance.
21    pub fn inner_mut(&mut self) -> &mut OrtInfer {
22        &mut self.0
23    }
24
25    /// Consumes the wrapper and returns the inner OrtInfer instance.
26    pub fn into_inner(self) -> OrtInfer {
27        self.0
28    }
29}
30
31impl From<OrtInfer> for OrtInfer2D {
32    fn from(inner: OrtInfer) -> Self {
33        Self::new(inner)
34    }
35}
36
37impl GInferenceEngine for OrtInfer2D {
38    type Input = Tensor4D;
39    type Output = Tensor2D;
40    fn infer(&self, input: &Self::Input) -> Result<Self::Output, OCRError> {
41        // Performance improvement: Pass reference instead of cloning the tensor
42        self.0.infer_2d(input)
43    }
44    fn engine_info(&self) -> String {
45        "ONNXRuntime-2D".to_string()
46    }
47}
48
49#[derive(Debug)]
50pub struct OrtInfer3D(OrtInfer);
51
52impl OrtInfer3D {
53    /// Creates a new OrtInfer3D wrapper around an OrtInfer instance.
54    pub fn new(inner: OrtInfer) -> Self {
55        Self(inner)
56    }
57
58    /// Returns a reference to the inner OrtInfer instance.
59    pub fn inner(&self) -> &OrtInfer {
60        &self.0
61    }
62
63    /// Returns a mutable reference to the inner OrtInfer instance.
64    pub fn inner_mut(&mut self) -> &mut OrtInfer {
65        &mut self.0
66    }
67
68    /// Consumes the wrapper and returns the inner OrtInfer instance.
69    pub fn into_inner(self) -> OrtInfer {
70        self.0
71    }
72}
73
74impl From<OrtInfer> for OrtInfer3D {
75    fn from(inner: OrtInfer) -> Self {
76        Self::new(inner)
77    }
78}
79
80impl GInferenceEngine for OrtInfer3D {
81    type Input = Tensor4D;
82    type Output = Tensor3D;
83    fn infer(&self, input: &Self::Input) -> Result<Self::Output, OCRError> {
84        // Performance improvement: Pass reference instead of cloning the tensor
85        self.0.infer_3d(input)
86    }
87    fn engine_info(&self) -> String {
88        "ONNXRuntime-3D".to_string()
89    }
90}
91
92#[derive(Debug)]
93pub struct OrtInfer4D(OrtInfer);
94
95impl OrtInfer4D {
96    /// Creates a new OrtInfer4D wrapper around an OrtInfer instance.
97    pub fn new(inner: OrtInfer) -> Self {
98        Self(inner)
99    }
100
101    /// Returns a reference to the inner OrtInfer instance.
102    pub fn inner(&self) -> &OrtInfer {
103        &self.0
104    }
105
106    /// Returns a mutable reference to the inner OrtInfer instance.
107    pub fn inner_mut(&mut self) -> &mut OrtInfer {
108        &mut self.0
109    }
110
111    /// Consumes the wrapper and returns the inner OrtInfer instance.
112    pub fn into_inner(self) -> OrtInfer {
113        self.0
114    }
115}
116
117impl From<OrtInfer> for OrtInfer4D {
118    fn from(inner: OrtInfer) -> Self {
119        Self::new(inner)
120    }
121}
122
123impl GInferenceEngine for OrtInfer4D {
124    type Input = Tensor4D;
125    type Output = Tensor4D;
126    fn infer(&self, input: &Self::Input) -> Result<Self::Output, OCRError> {
127        // Performance improvement: Pass reference instead of cloning the tensor
128        self.0.infer_4d(input)
129    }
130    fn engine_info(&self) -> String {
131        "ONNXRuntime-4D".to_string()
132    }
133}