Skip to main content

oar_ocr_core/core/config/
builder.rs

1//! Model inference configuration types and utilities.
2
3use super::errors::{ConfigError, ConfigValidator};
4use super::onnx::OrtSessionConfig;
5use serde::{Deserialize, Serialize};
6use std::path::PathBuf;
7
8/// Configuration for model inference and ONNX Runtime session setup.
9///
10/// This struct provides configuration options for building ONNX inference engines,
11/// including runtime settings and model metadata.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ModelInferenceConfig {
14    /// The path to the model file (optional).
15    pub model_path: Option<PathBuf>,
16    /// The name of the model (optional).
17    pub model_name: Option<String>,
18    /// The batch size for processing (optional).
19    pub batch_size: Option<usize>,
20    /// Whether to enable logging (optional).
21    pub enable_logging: Option<bool>,
22    /// ONNX Runtime session configuration for this model (optional).
23    #[serde(default)]
24    pub ort_session: Option<OrtSessionConfig>,
25}
26
27impl ModelInferenceConfig {
28    /// Creates a new ModelInferenceConfig with default values.
29    ///
30    /// # Returns
31    ///
32    /// A new ModelInferenceConfig instance.
33    pub fn new() -> Self {
34        Self {
35            model_path: None,
36            model_name: None,
37            batch_size: None,
38            enable_logging: None,
39            ort_session: None,
40        }
41    }
42
43    /// Creates a new ModelInferenceConfig with default values for model name and batch size.
44    ///
45    /// # Arguments
46    ///
47    /// * `model_name` - The name of the model (optional).
48    /// * `batch_size` - The batch size for processing (optional).
49    ///
50    /// # Returns
51    ///
52    /// A new ModelInferenceConfig instance.
53    pub fn with_defaults(model_name: Option<String>, batch_size: Option<usize>) -> Self {
54        Self {
55            model_path: None,
56            model_name,
57            batch_size,
58            enable_logging: Some(true),
59            ort_session: None,
60        }
61    }
62
63    /// Creates a new ModelInferenceConfig with a model path.
64    ///
65    /// # Arguments
66    ///
67    /// * `model_path` - The path to the model file.
68    ///
69    /// # Returns
70    ///
71    /// A new ModelInferenceConfig instance.
72    pub fn with_model_path(model_path: PathBuf) -> Self {
73        Self {
74            model_path: Some(model_path),
75            model_name: None,
76            batch_size: None,
77            enable_logging: Some(true),
78            ort_session: None,
79        }
80    }
81
82    /// Sets the model path for the configuration.
83    ///
84    /// # Arguments
85    ///
86    /// * `model_path` - The path to the model file.
87    ///
88    /// # Returns
89    ///
90    /// The updated ModelInferenceConfig instance.
91    pub fn model_path(mut self, model_path: impl Into<PathBuf>) -> Self {
92        self.model_path = Some(model_path.into());
93        self
94    }
95
96    /// Sets the model name for the configuration.
97    ///
98    /// # Arguments
99    ///
100    /// * `model_name` - The name of the model.
101    ///
102    /// # Returns
103    ///
104    /// The updated ModelInferenceConfig instance.
105    pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
106        self.model_name = Some(model_name.into());
107        self
108    }
109
110    /// Sets the batch size for the configuration.
111    ///
112    /// # Arguments
113    ///
114    /// * `batch_size` - The batch size for processing.
115    ///
116    /// # Returns
117    ///
118    /// The updated ModelInferenceConfig instance.
119    pub fn batch_size(mut self, batch_size: usize) -> Self {
120        self.batch_size = Some(batch_size);
121        self
122    }
123
124    /// Sets whether logging is enabled for the configuration.
125    ///
126    /// # Arguments
127    ///
128    /// * `enable` - Whether to enable logging.
129    ///
130    /// # Returns
131    ///
132    /// The updated ModelInferenceConfig instance.
133    pub fn enable_logging(mut self, enable: bool) -> Self {
134        self.enable_logging = Some(enable);
135        self
136    }
137
138    /// Gets whether logging is enabled for the configuration.
139    ///
140    /// # Returns
141    ///
142    /// True if logging is enabled, false otherwise.
143    pub fn get_enable_logging(&self) -> bool {
144        self.enable_logging.unwrap_or(true)
145    }
146
147    /// Sets the ORT session configuration.
148    ///
149    /// # Arguments
150    ///
151    /// * `cfg` - The ONNX Runtime session configuration.
152    ///
153    /// # Returns
154    ///
155    /// The updated ModelInferenceConfig instance.
156    pub fn ort_session(mut self, cfg: OrtSessionConfig) -> Self {
157        self.ort_session = Some(cfg);
158        self
159    }
160
161    /// Validates the configuration.
162    ///
163    /// # Returns
164    ///
165    /// A Result indicating success or a ConfigError if validation fails.
166    pub fn validate(&self) -> Result<(), ConfigError> {
167        ConfigValidator::validate(self)
168    }
169
170    /// Merges this configuration with another configuration.
171    ///
172    /// Values from the other configuration will override values in this configuration
173    /// if they are present in the other configuration.
174    ///
175    /// # Arguments
176    ///
177    /// * `other` - The other configuration to merge with.
178    ///
179    /// # Returns
180    ///
181    /// The updated ModelInferenceConfig instance.
182    pub fn merge_with(mut self, other: &ModelInferenceConfig) -> Self {
183        if other.model_path.is_some() {
184            self.model_path = other.model_path.clone();
185        }
186        if other.model_name.is_some() {
187            self.model_name = other.model_name.clone();
188        }
189        if other.batch_size.is_some() {
190            self.batch_size = other.batch_size;
191        }
192        if other.enable_logging.is_some() {
193            self.enable_logging = other.enable_logging;
194        }
195        if other.ort_session.is_some() {
196            self.ort_session = other.ort_session.clone();
197        }
198        self
199    }
200
201    /// Gets the effective batch size.
202    ///
203    /// # Returns
204    ///
205    /// The batch size, or a default value if not set.
206    pub fn get_batch_size(&self) -> usize {
207        self.batch_size.unwrap_or(1)
208    }
209
210    /// Gets the model name.
211    ///
212    /// # Returns
213    ///
214    /// The model name, or a default value if not set.
215    pub fn get_model_name(&self) -> String {
216        self.model_name
217            .clone()
218            .unwrap_or_else(|| "unnamed_model".to_string())
219    }
220}
221
222impl ConfigValidator for ModelInferenceConfig {
223    /// Validates the configuration.
224    ///
225    /// This method checks that the batch size is valid and that the model path exists
226    /// if it is specified.
227    ///
228    /// # Returns
229    ///
230    /// A Result indicating success or a ConfigError if validation fails.
231    fn validate(&self) -> Result<(), ConfigError> {
232        if let Some(batch_size) = self.batch_size {
233            self.validate_batch_size_with_limits(batch_size, 1000)?;
234        }
235
236        if let Some(model_path) = &self.model_path {
237            self.validate_model_path(model_path)?;
238        }
239
240        Ok(())
241    }
242
243    /// Returns the default configuration.
244    ///
245    /// # Returns
246    ///
247    /// The default ModelInferenceConfig instance.
248    fn get_defaults() -> Self {
249        Self {
250            model_path: None,
251            model_name: Some("default_model".to_string()),
252            batch_size: Some(32),
253            enable_logging: Some(false),
254            ort_session: None,
255        }
256    }
257}
258
259impl Default for ModelInferenceConfig {
260    /// This allows ModelInferenceConfig to be created with default values.
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_common_builder_config_builder_pattern() {
272        let ort_cfg = OrtSessionConfig::default();
273        let config = ModelInferenceConfig::new()
274            .model_name("test_model")
275            .batch_size(16)
276            .enable_logging(true)
277            .ort_session(ort_cfg);
278
279        assert_eq!(config.model_name, Some("test_model".to_string()));
280        assert_eq!(config.batch_size, Some(16));
281        assert_eq!(config.enable_logging, Some(true));
282        assert!(config.ort_session.is_some());
283    }
284
285    #[test]
286    fn test_common_builder_config_merge() {
287        let config1 = ModelInferenceConfig::new()
288            .model_name("model1")
289            .batch_size(8);
290        let config2 = ModelInferenceConfig::new()
291            .model_name("model2")
292            .enable_logging(true);
293
294        let merged = config1.merge_with(&config2);
295        assert_eq!(merged.model_name, Some("model2".to_string()));
296        assert_eq!(merged.batch_size, Some(8));
297        assert_eq!(merged.enable_logging, Some(true));
298    }
299
300    #[test]
301    fn test_common_builder_config_getters() {
302        let ort_cfg = OrtSessionConfig::default();
303        let config = ModelInferenceConfig::new()
304            .model_name("test")
305            .batch_size(16)
306            .ort_session(ort_cfg);
307
308        assert_eq!(config.get_model_name(), "test");
309        assert_eq!(config.get_batch_size(), 16);
310        assert!(config.get_enable_logging()); // Default is true
311    }
312
313    #[test]
314    fn test_common_builder_config_validation() {
315        let ort_cfg = OrtSessionConfig::default();
316        let valid_config = ModelInferenceConfig::new()
317            .batch_size(16)
318            .ort_session(ort_cfg);
319        assert!(valid_config.validate().is_ok());
320
321        let invalid_batch_config = ModelInferenceConfig::new().batch_size(0);
322        assert!(invalid_batch_config.validate().is_err());
323    }
324}