oar_ocr/core/config/
builder.rs

1//! Common builder configuration types and utilities.
2
3use super::errors::{ConfigError, ConfigValidator};
4use super::onnx::OrtSessionConfig;
5use serde::{Deserialize, Serialize};
6use std::path::PathBuf;
7
8/// Common configuration for model builders.
9///
10/// This struct contains configuration options that are common across different
11/// types of model builders in the OCR pipeline.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CommonBuilderConfig {
14    /// The path to the model file (optional).
15    pub model_path: Option<PathBuf>,
16    /// The name of the model (optional).
17    pub model_name: Option<String>,
18    /// The batch size for processing (optional).
19    pub batch_size: Option<usize>,
20    /// Whether to enable logging (optional).
21    pub enable_logging: Option<bool>,
22    /// ONNX Runtime session configuration for this model (optional)
23    #[serde(default)]
24    pub ort_session: Option<OrtSessionConfig>,
25    /// Size of the pinned session pool to allow concurrent predictions (>=1)
26    /// If None, defaults to 1 (single session)
27    #[serde(default)]
28    pub session_pool_size: Option<usize>,
29}
30
31impl CommonBuilderConfig {
32    /// Creates a new CommonBuilderConfig with default values.
33    ///
34    /// # Returns
35    ///
36    /// A new CommonBuilderConfig instance.
37    pub fn new() -> Self {
38        Self {
39            model_path: None,
40            model_name: None,
41            batch_size: None,
42            enable_logging: None,
43            ort_session: None,
44            session_pool_size: Some(1),
45        }
46    }
47
48    /// Creates a new CommonBuilderConfig with default values for model name and batch size.
49    ///
50    /// # Arguments
51    ///
52    /// * `model_name` - The name of the model (optional).
53    /// * `batch_size` - The batch size for processing (optional).
54    ///
55    /// # Returns
56    ///
57    /// A new CommonBuilderConfig instance.
58    pub fn with_defaults(model_name: Option<String>, batch_size: Option<usize>) -> Self {
59        Self {
60            model_path: None,
61            model_name,
62            batch_size,
63            enable_logging: Some(true),
64            ort_session: None,
65            session_pool_size: Some(1),
66        }
67    }
68
69    /// Creates a new CommonBuilderConfig with a model path.
70    ///
71    /// # Arguments
72    ///
73    /// * `model_path` - The path to the model file.
74    ///
75    /// # Returns
76    ///
77    /// A new CommonBuilderConfig instance.
78    pub fn with_model_path(model_path: PathBuf) -> Self {
79        Self {
80            model_path: Some(model_path),
81            model_name: None,
82            batch_size: None,
83            enable_logging: Some(true),
84            ort_session: None,
85            session_pool_size: Some(1),
86        }
87    }
88
89    /// Sets the model path for the configuration.
90    ///
91    /// # Arguments
92    ///
93    /// * `model_path` - The path to the model file.
94    ///
95    /// # Returns
96    ///
97    /// The updated CommonBuilderConfig instance.
98    pub fn model_path(mut self, model_path: impl Into<PathBuf>) -> Self {
99        self.model_path = Some(model_path.into());
100        self
101    }
102
103    /// Sets the model name for the configuration.
104    ///
105    /// # Arguments
106    ///
107    /// * `model_name` - The name of the model.
108    ///
109    /// # Returns
110    ///
111    /// The updated CommonBuilderConfig instance.
112    pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
113        self.model_name = Some(model_name.into());
114        self
115    }
116
117    /// Sets the batch size for the configuration.
118    ///
119    /// # Arguments
120    ///
121    /// * `batch_size` - The batch size for processing.
122    ///
123    /// # Returns
124    ///
125    /// The updated CommonBuilderConfig instance.
126    pub fn batch_size(mut self, batch_size: usize) -> Self {
127        self.batch_size = Some(batch_size);
128        self
129    }
130
131    /// Sets whether logging is enabled for the configuration.
132    ///
133    /// # Arguments
134    ///
135    /// * `enable` - Whether to enable logging.
136    ///
137    /// # Returns
138    ///
139    /// The updated CommonBuilderConfig instance.
140    pub fn enable_logging(mut self, enable: bool) -> Self {
141        self.enable_logging = Some(enable);
142        self
143    }
144
145    /// Gets whether logging is enabled for the configuration.
146    ///
147    /// # Returns
148    ///
149    /// True if logging is enabled, false otherwise.
150    pub fn get_enable_logging(&self) -> bool {
151        self.enable_logging.unwrap_or(true)
152    }
153
154    /// Sets the ORT session configuration.
155    ///
156    /// # Arguments
157    ///
158    /// * `cfg` - The ONNX Runtime session configuration.
159    ///
160    /// # Returns
161    ///
162    /// The updated CommonBuilderConfig instance.
163    pub fn ort_session(mut self, cfg: OrtSessionConfig) -> Self {
164        self.ort_session = Some(cfg);
165        self
166    }
167
168    /// Sets the session pool size used for concurrent predictions (>=1).
169    ///
170    /// # Arguments
171    ///
172    /// * `size` - The session pool size (minimum 1).
173    ///
174    /// # Returns
175    ///
176    /// The updated CommonBuilderConfig instance.
177    pub fn session_pool_size(mut self, size: usize) -> Self {
178        self.session_pool_size = Some(size);
179        self
180    }
181
182    /// Validates the configuration.
183    ///
184    /// # Returns
185    ///
186    /// A Result indicating success or a ConfigError if validation fails.
187    pub fn validate(&self) -> Result<(), ConfigError> {
188        ConfigValidator::validate(self)
189    }
190
191    /// Merges this configuration with another configuration.
192    ///
193    /// Values from the other configuration will override values in this configuration
194    /// if they are present in the other configuration.
195    ///
196    /// # Arguments
197    ///
198    /// * `other` - The other configuration to merge with.
199    ///
200    /// # Returns
201    ///
202    /// The updated CommonBuilderConfig instance.
203    pub fn merge_with(mut self, other: &CommonBuilderConfig) -> Self {
204        if other.model_path.is_some() {
205            self.model_path = other.model_path.clone();
206        }
207        if other.model_name.is_some() {
208            self.model_name = other.model_name.clone();
209        }
210        if other.batch_size.is_some() {
211            self.batch_size = other.batch_size;
212        }
213        if other.enable_logging.is_some() {
214            self.enable_logging = other.enable_logging;
215        }
216        if other.ort_session.is_some() {
217            self.ort_session = other.ort_session.clone();
218        }
219        if other.session_pool_size.is_some() {
220            self.session_pool_size = other.session_pool_size;
221        }
222        self
223    }
224
225    /// Gets the effective batch size.
226    ///
227    /// # Returns
228    ///
229    /// The batch size, or a default value if not set.
230    pub fn get_batch_size(&self) -> usize {
231        self.batch_size.unwrap_or(1)
232    }
233
234    /// Gets the effective session pool size.
235    ///
236    /// # Returns
237    ///
238    /// The session pool size, or a default value if not set.
239    pub fn get_session_pool_size(&self) -> usize {
240        self.session_pool_size.unwrap_or(1)
241    }
242
243    /// Gets the model name.
244    ///
245    /// # Returns
246    ///
247    /// The model name, or a default value if not set.
248    pub fn get_model_name(&self) -> String {
249        self.model_name
250            .clone()
251            .unwrap_or_else(|| "unnamed_model".to_string())
252    }
253}
254
255impl ConfigValidator for CommonBuilderConfig {
256    /// Validates the configuration.
257    ///
258    /// This method checks that the batch size is valid and that the model path exists
259    /// if it is specified.
260    ///
261    /// # Returns
262    ///
263    /// A Result indicating success or a ConfigError if validation fails.
264    fn validate(&self) -> Result<(), ConfigError> {
265        if let Some(batch_size) = self.batch_size {
266            self.validate_batch_size_with_limits(batch_size, 1000)?;
267        }
268
269        if let Some(model_path) = &self.model_path {
270            self.validate_model_path(model_path)?;
271        }
272
273        if let Some(pool) = self.session_pool_size
274            && pool == 0
275        {
276            return Err(ConfigError::InvalidConfig {
277                message: "session_pool_size must be >= 1".to_string(),
278            });
279        }
280
281        Ok(())
282    }
283
284    /// Returns the default configuration.
285    ///
286    /// # Returns
287    ///
288    /// The default CommonBuilderConfig instance.
289    fn get_defaults() -> Self {
290        Self {
291            model_path: None,
292            model_name: Some("default_model".to_string()),
293            batch_size: Some(32),
294            enable_logging: Some(false),
295            ort_session: None,
296            session_pool_size: Some(1),
297        }
298    }
299}
300
301impl Default for CommonBuilderConfig {
302    /// This allows CommonBuilderConfig to be created with default values.
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_common_builder_config_builder_pattern() {
314        let config = CommonBuilderConfig::new()
315            .model_name("test_model")
316            .batch_size(16)
317            .enable_logging(true)
318            .session_pool_size(4);
319
320        assert_eq!(config.model_name, Some("test_model".to_string()));
321        assert_eq!(config.batch_size, Some(16));
322        assert_eq!(config.enable_logging, Some(true));
323        assert_eq!(config.session_pool_size, Some(4));
324    }
325
326    #[test]
327    fn test_common_builder_config_merge() {
328        let config1 = CommonBuilderConfig::new()
329            .model_name("model1")
330            .batch_size(8);
331        let config2 = CommonBuilderConfig::new()
332            .model_name("model2")
333            .enable_logging(true);
334
335        let merged = config1.merge_with(&config2);
336        assert_eq!(merged.model_name, Some("model2".to_string()));
337        assert_eq!(merged.batch_size, Some(8));
338        assert_eq!(merged.enable_logging, Some(true));
339    }
340
341    #[test]
342    fn test_common_builder_config_getters() {
343        let config = CommonBuilderConfig::new()
344            .model_name("test")
345            .batch_size(16)
346            .session_pool_size(2);
347
348        assert_eq!(config.get_model_name(), "test");
349        assert_eq!(config.get_batch_size(), 16);
350        assert_eq!(config.get_session_pool_size(), 2);
351        assert!(config.get_enable_logging()); // Default is true
352    }
353
354    #[test]
355    fn test_common_builder_config_validation() {
356        let valid_config = CommonBuilderConfig::new()
357            .batch_size(16)
358            .session_pool_size(2);
359        assert!(valid_config.validate().is_ok());
360
361        let invalid_batch_config = CommonBuilderConfig::new().batch_size(0);
362        assert!(invalid_batch_config.validate().is_err());
363
364        let invalid_pool_config = CommonBuilderConfig::new().session_pool_size(0);
365        assert!(invalid_pool_config.validate().is_err());
366    }
367}