oar_ocr_core/core/config/builder.rs
1//! Model inference configuration types and utilities.
2
3use super::errors::{ConfigError, ConfigValidator};
4use super::onnx::OrtSessionConfig;
5use serde::{Deserialize, Serialize};
6use std::path::PathBuf;
7
8/// Configuration for model inference and ONNX Runtime session setup.
9///
10/// This struct provides configuration options for building ONNX inference engines,
11/// including runtime settings and model metadata.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ModelInferenceConfig {
14 /// The path to the model file (optional).
15 pub model_path: Option<PathBuf>,
16 /// The name of the model (optional).
17 pub model_name: Option<String>,
18 /// The batch size for processing (optional).
19 pub batch_size: Option<usize>,
20 /// Whether to enable logging (optional).
21 pub enable_logging: Option<bool>,
22 /// ONNX Runtime session configuration for this model (optional).
23 #[serde(default)]
24 pub ort_session: Option<OrtSessionConfig>,
25}
26
27impl ModelInferenceConfig {
28 /// Creates a new ModelInferenceConfig with default values.
29 ///
30 /// # Returns
31 ///
32 /// A new ModelInferenceConfig instance.
33 pub fn new() -> Self {
34 Self {
35 model_path: None,
36 model_name: None,
37 batch_size: None,
38 enable_logging: None,
39 ort_session: None,
40 }
41 }
42
43 /// Creates a new ModelInferenceConfig with default values for model name and batch size.
44 ///
45 /// # Arguments
46 ///
47 /// * `model_name` - The name of the model (optional).
48 /// * `batch_size` - The batch size for processing (optional).
49 ///
50 /// # Returns
51 ///
52 /// A new ModelInferenceConfig instance.
53 pub fn with_defaults(model_name: Option<String>, batch_size: Option<usize>) -> Self {
54 Self {
55 model_path: None,
56 model_name,
57 batch_size,
58 enable_logging: Some(true),
59 ort_session: None,
60 }
61 }
62
63 /// Creates a new ModelInferenceConfig with a model path.
64 ///
65 /// # Arguments
66 ///
67 /// * `model_path` - The path to the model file.
68 ///
69 /// # Returns
70 ///
71 /// A new ModelInferenceConfig instance.
72 pub fn with_model_path(model_path: PathBuf) -> Self {
73 Self {
74 model_path: Some(model_path),
75 model_name: None,
76 batch_size: None,
77 enable_logging: Some(true),
78 ort_session: None,
79 }
80 }
81
82 /// Sets the model path for the configuration.
83 ///
84 /// # Arguments
85 ///
86 /// * `model_path` - The path to the model file.
87 ///
88 /// # Returns
89 ///
90 /// The updated ModelInferenceConfig instance.
91 pub fn model_path(mut self, model_path: impl Into<PathBuf>) -> Self {
92 self.model_path = Some(model_path.into());
93 self
94 }
95
96 /// Sets the model name for the configuration.
97 ///
98 /// # Arguments
99 ///
100 /// * `model_name` - The name of the model.
101 ///
102 /// # Returns
103 ///
104 /// The updated ModelInferenceConfig instance.
105 pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
106 self.model_name = Some(model_name.into());
107 self
108 }
109
110 /// Sets the batch size for the configuration.
111 ///
112 /// # Arguments
113 ///
114 /// * `batch_size` - The batch size for processing.
115 ///
116 /// # Returns
117 ///
118 /// The updated ModelInferenceConfig instance.
119 pub fn batch_size(mut self, batch_size: usize) -> Self {
120 self.batch_size = Some(batch_size);
121 self
122 }
123
124 /// Sets whether logging is enabled for the configuration.
125 ///
126 /// # Arguments
127 ///
128 /// * `enable` - Whether to enable logging.
129 ///
130 /// # Returns
131 ///
132 /// The updated ModelInferenceConfig instance.
133 pub fn enable_logging(mut self, enable: bool) -> Self {
134 self.enable_logging = Some(enable);
135 self
136 }
137
138 /// Gets whether logging is enabled for the configuration.
139 ///
140 /// # Returns
141 ///
142 /// True if logging is enabled, false otherwise.
143 pub fn get_enable_logging(&self) -> bool {
144 self.enable_logging.unwrap_or(true)
145 }
146
147 /// Sets the ORT session configuration.
148 ///
149 /// # Arguments
150 ///
151 /// * `cfg` - The ONNX Runtime session configuration.
152 ///
153 /// # Returns
154 ///
155 /// The updated ModelInferenceConfig instance.
156 pub fn ort_session(mut self, cfg: OrtSessionConfig) -> Self {
157 self.ort_session = Some(cfg);
158 self
159 }
160
161 /// Validates the configuration.
162 ///
163 /// # Returns
164 ///
165 /// A Result indicating success or a ConfigError if validation fails.
166 pub fn validate(&self) -> Result<(), ConfigError> {
167 ConfigValidator::validate(self)
168 }
169
170 /// Merges this configuration with another configuration.
171 ///
172 /// Values from the other configuration will override values in this configuration
173 /// if they are present in the other configuration.
174 ///
175 /// # Arguments
176 ///
177 /// * `other` - The other configuration to merge with.
178 ///
179 /// # Returns
180 ///
181 /// The updated ModelInferenceConfig instance.
182 pub fn merge_with(mut self, other: &ModelInferenceConfig) -> Self {
183 if other.model_path.is_some() {
184 self.model_path = other.model_path.clone();
185 }
186 if other.model_name.is_some() {
187 self.model_name = other.model_name.clone();
188 }
189 if other.batch_size.is_some() {
190 self.batch_size = other.batch_size;
191 }
192 if other.enable_logging.is_some() {
193 self.enable_logging = other.enable_logging;
194 }
195 if other.ort_session.is_some() {
196 self.ort_session = other.ort_session.clone();
197 }
198 self
199 }
200
201 /// Gets the effective batch size.
202 ///
203 /// # Returns
204 ///
205 /// The batch size, or a default value if not set.
206 pub fn get_batch_size(&self) -> usize {
207 self.batch_size.unwrap_or(1)
208 }
209
210 /// Gets the model name.
211 ///
212 /// # Returns
213 ///
214 /// The model name, or a default value if not set.
215 pub fn get_model_name(&self) -> String {
216 self.model_name
217 .clone()
218 .unwrap_or_else(|| "unnamed_model".to_string())
219 }
220}
221
222impl ConfigValidator for ModelInferenceConfig {
223 /// Validates the configuration.
224 ///
225 /// This method checks that the batch size is valid and that the model path exists
226 /// if it is specified.
227 ///
228 /// # Returns
229 ///
230 /// A Result indicating success or a ConfigError if validation fails.
231 fn validate(&self) -> Result<(), ConfigError> {
232 if let Some(batch_size) = self.batch_size {
233 self.validate_batch_size_with_limits(batch_size, 1000)?;
234 }
235
236 if let Some(model_path) = &self.model_path {
237 self.validate_model_path(model_path)?;
238 }
239
240 Ok(())
241 }
242
243 /// Returns the default configuration.
244 ///
245 /// # Returns
246 ///
247 /// The default ModelInferenceConfig instance.
248 fn get_defaults() -> Self {
249 Self {
250 model_path: None,
251 model_name: Some("default_model".to_string()),
252 batch_size: Some(32),
253 enable_logging: Some(false),
254 ort_session: None,
255 }
256 }
257}
258
259impl Default for ModelInferenceConfig {
260 /// This allows ModelInferenceConfig to be created with default values.
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_common_builder_config_builder_pattern() {
272 let ort_cfg = OrtSessionConfig::default();
273 let config = ModelInferenceConfig::new()
274 .model_name("test_model")
275 .batch_size(16)
276 .enable_logging(true)
277 .ort_session(ort_cfg);
278
279 assert_eq!(config.model_name, Some("test_model".to_string()));
280 assert_eq!(config.batch_size, Some(16));
281 assert_eq!(config.enable_logging, Some(true));
282 assert!(config.ort_session.is_some());
283 }
284
285 #[test]
286 fn test_common_builder_config_merge() {
287 let config1 = ModelInferenceConfig::new()
288 .model_name("model1")
289 .batch_size(8);
290 let config2 = ModelInferenceConfig::new()
291 .model_name("model2")
292 .enable_logging(true);
293
294 let merged = config1.merge_with(&config2);
295 assert_eq!(merged.model_name, Some("model2".to_string()));
296 assert_eq!(merged.batch_size, Some(8));
297 assert_eq!(merged.enable_logging, Some(true));
298 }
299
300 #[test]
301 fn test_common_builder_config_getters() {
302 let ort_cfg = OrtSessionConfig::default();
303 let config = ModelInferenceConfig::new()
304 .model_name("test")
305 .batch_size(16)
306 .ort_session(ort_cfg);
307
308 assert_eq!(config.get_model_name(), "test");
309 assert_eq!(config.get_batch_size(), 16);
310 assert!(config.get_enable_logging()); // Default is true
311 }
312
313 #[test]
314 fn test_common_builder_config_validation() {
315 let ort_cfg = OrtSessionConfig::default();
316 let valid_config = ModelInferenceConfig::new()
317 .batch_size(16)
318 .ort_session(ort_cfg);
319 assert!(valid_config.validate().is_ok());
320
321 let invalid_batch_config = ModelInferenceConfig::new().batch_size(0);
322 assert!(invalid_batch_config.validate().is_err());
323 }
324}