oar_ocr/core/config/builder.rs
1//! Common builder configuration types and utilities.
2
3use super::errors::{ConfigError, ConfigValidator};
4use super::onnx::OrtSessionConfig;
5use serde::{Deserialize, Serialize};
6use std::path::PathBuf;
7
8/// Common configuration for model builders.
9///
10/// This struct contains configuration options that are common across different
11/// types of model builders in the OCR pipeline.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CommonBuilderConfig {
14 /// The path to the model file (optional).
15 pub model_path: Option<PathBuf>,
16 /// The name of the model (optional).
17 pub model_name: Option<String>,
18 /// The batch size for processing (optional).
19 pub batch_size: Option<usize>,
20 /// Whether to enable logging (optional).
21 pub enable_logging: Option<bool>,
22 /// ONNX Runtime session configuration for this model (optional)
23 #[serde(default)]
24 pub ort_session: Option<OrtSessionConfig>,
25 /// Size of the pinned session pool to allow concurrent predictions (>=1)
26 /// If None, defaults to 1 (single session)
27 #[serde(default)]
28 pub session_pool_size: Option<usize>,
29}
30
31impl CommonBuilderConfig {
32 /// Creates a new CommonBuilderConfig with default values.
33 ///
34 /// # Returns
35 ///
36 /// A new CommonBuilderConfig instance.
37 pub fn new() -> Self {
38 Self {
39 model_path: None,
40 model_name: None,
41 batch_size: None,
42 enable_logging: None,
43 ort_session: None,
44 session_pool_size: Some(1),
45 }
46 }
47
48 /// Creates a new CommonBuilderConfig with default values for model name and batch size.
49 ///
50 /// # Arguments
51 ///
52 /// * `model_name` - The name of the model (optional).
53 /// * `batch_size` - The batch size for processing (optional).
54 ///
55 /// # Returns
56 ///
57 /// A new CommonBuilderConfig instance.
58 pub fn with_defaults(model_name: Option<String>, batch_size: Option<usize>) -> Self {
59 Self {
60 model_path: None,
61 model_name,
62 batch_size,
63 enable_logging: Some(true),
64 ort_session: None,
65 session_pool_size: Some(1),
66 }
67 }
68
69 /// Creates a new CommonBuilderConfig with a model path.
70 ///
71 /// # Arguments
72 ///
73 /// * `model_path` - The path to the model file.
74 ///
75 /// # Returns
76 ///
77 /// A new CommonBuilderConfig instance.
78 pub fn with_model_path(model_path: PathBuf) -> Self {
79 Self {
80 model_path: Some(model_path),
81 model_name: None,
82 batch_size: None,
83 enable_logging: Some(true),
84 ort_session: None,
85 session_pool_size: Some(1),
86 }
87 }
88
89 /// Sets the model path for the configuration.
90 ///
91 /// # Arguments
92 ///
93 /// * `model_path` - The path to the model file.
94 ///
95 /// # Returns
96 ///
97 /// The updated CommonBuilderConfig instance.
98 pub fn model_path(mut self, model_path: impl Into<PathBuf>) -> Self {
99 self.model_path = Some(model_path.into());
100 self
101 }
102
103 /// Sets the model name for the configuration.
104 ///
105 /// # Arguments
106 ///
107 /// * `model_name` - The name of the model.
108 ///
109 /// # Returns
110 ///
111 /// The updated CommonBuilderConfig instance.
112 pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
113 self.model_name = Some(model_name.into());
114 self
115 }
116
117 /// Sets the batch size for the configuration.
118 ///
119 /// # Arguments
120 ///
121 /// * `batch_size` - The batch size for processing.
122 ///
123 /// # Returns
124 ///
125 /// The updated CommonBuilderConfig instance.
126 pub fn batch_size(mut self, batch_size: usize) -> Self {
127 self.batch_size = Some(batch_size);
128 self
129 }
130
131 /// Sets whether logging is enabled for the configuration.
132 ///
133 /// # Arguments
134 ///
135 /// * `enable` - Whether to enable logging.
136 ///
137 /// # Returns
138 ///
139 /// The updated CommonBuilderConfig instance.
140 pub fn enable_logging(mut self, enable: bool) -> Self {
141 self.enable_logging = Some(enable);
142 self
143 }
144
145 /// Gets whether logging is enabled for the configuration.
146 ///
147 /// # Returns
148 ///
149 /// True if logging is enabled, false otherwise.
150 pub fn get_enable_logging(&self) -> bool {
151 self.enable_logging.unwrap_or(true)
152 }
153
154 /// Sets the ORT session configuration.
155 ///
156 /// # Arguments
157 ///
158 /// * `cfg` - The ONNX Runtime session configuration.
159 ///
160 /// # Returns
161 ///
162 /// The updated CommonBuilderConfig instance.
163 pub fn ort_session(mut self, cfg: OrtSessionConfig) -> Self {
164 self.ort_session = Some(cfg);
165 self
166 }
167
168 /// Sets the session pool size used for concurrent predictions (>=1).
169 ///
170 /// # Arguments
171 ///
172 /// * `size` - The session pool size (minimum 1).
173 ///
174 /// # Returns
175 ///
176 /// The updated CommonBuilderConfig instance.
177 pub fn session_pool_size(mut self, size: usize) -> Self {
178 self.session_pool_size = Some(size);
179 self
180 }
181
182 /// Validates the configuration.
183 ///
184 /// # Returns
185 ///
186 /// A Result indicating success or a ConfigError if validation fails.
187 pub fn validate(&self) -> Result<(), ConfigError> {
188 ConfigValidator::validate(self)
189 }
190
191 /// Merges this configuration with another configuration.
192 ///
193 /// Values from the other configuration will override values in this configuration
194 /// if they are present in the other configuration.
195 ///
196 /// # Arguments
197 ///
198 /// * `other` - The other configuration to merge with.
199 ///
200 /// # Returns
201 ///
202 /// The updated CommonBuilderConfig instance.
203 pub fn merge_with(mut self, other: &CommonBuilderConfig) -> Self {
204 if other.model_path.is_some() {
205 self.model_path = other.model_path.clone();
206 }
207 if other.model_name.is_some() {
208 self.model_name = other.model_name.clone();
209 }
210 if other.batch_size.is_some() {
211 self.batch_size = other.batch_size;
212 }
213 if other.enable_logging.is_some() {
214 self.enable_logging = other.enable_logging;
215 }
216 if other.ort_session.is_some() {
217 self.ort_session = other.ort_session.clone();
218 }
219 if other.session_pool_size.is_some() {
220 self.session_pool_size = other.session_pool_size;
221 }
222 self
223 }
224
225 /// Gets the effective batch size.
226 ///
227 /// # Returns
228 ///
229 /// The batch size, or a default value if not set.
230 pub fn get_batch_size(&self) -> usize {
231 self.batch_size.unwrap_or(1)
232 }
233
234 /// Gets the effective session pool size.
235 ///
236 /// # Returns
237 ///
238 /// The session pool size, or a default value if not set.
239 pub fn get_session_pool_size(&self) -> usize {
240 self.session_pool_size.unwrap_or(1)
241 }
242
243 /// Gets the model name.
244 ///
245 /// # Returns
246 ///
247 /// The model name, or a default value if not set.
248 pub fn get_model_name(&self) -> String {
249 self.model_name
250 .clone()
251 .unwrap_or_else(|| "unnamed_model".to_string())
252 }
253}
254
255impl ConfigValidator for CommonBuilderConfig {
256 /// Validates the configuration.
257 ///
258 /// This method checks that the batch size is valid and that the model path exists
259 /// if it is specified.
260 ///
261 /// # Returns
262 ///
263 /// A Result indicating success or a ConfigError if validation fails.
264 fn validate(&self) -> Result<(), ConfigError> {
265 if let Some(batch_size) = self.batch_size {
266 self.validate_batch_size_with_limits(batch_size, 1000)?;
267 }
268
269 if let Some(model_path) = &self.model_path {
270 self.validate_model_path(model_path)?;
271 }
272
273 if let Some(pool) = self.session_pool_size
274 && pool == 0
275 {
276 return Err(ConfigError::InvalidConfig {
277 message: "session_pool_size must be >= 1".to_string(),
278 });
279 }
280
281 Ok(())
282 }
283
284 /// Returns the default configuration.
285 ///
286 /// # Returns
287 ///
288 /// The default CommonBuilderConfig instance.
289 fn get_defaults() -> Self {
290 Self {
291 model_path: None,
292 model_name: Some("default_model".to_string()),
293 batch_size: Some(32),
294 enable_logging: Some(false),
295 ort_session: None,
296 session_pool_size: Some(1),
297 }
298 }
299}
300
301impl Default for CommonBuilderConfig {
302 /// This allows CommonBuilderConfig to be created with default values.
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_common_builder_config_builder_pattern() {
314 let config = CommonBuilderConfig::new()
315 .model_name("test_model")
316 .batch_size(16)
317 .enable_logging(true)
318 .session_pool_size(4);
319
320 assert_eq!(config.model_name, Some("test_model".to_string()));
321 assert_eq!(config.batch_size, Some(16));
322 assert_eq!(config.enable_logging, Some(true));
323 assert_eq!(config.session_pool_size, Some(4));
324 }
325
326 #[test]
327 fn test_common_builder_config_merge() {
328 let config1 = CommonBuilderConfig::new()
329 .model_name("model1")
330 .batch_size(8);
331 let config2 = CommonBuilderConfig::new()
332 .model_name("model2")
333 .enable_logging(true);
334
335 let merged = config1.merge_with(&config2);
336 assert_eq!(merged.model_name, Some("model2".to_string()));
337 assert_eq!(merged.batch_size, Some(8));
338 assert_eq!(merged.enable_logging, Some(true));
339 }
340
341 #[test]
342 fn test_common_builder_config_getters() {
343 let config = CommonBuilderConfig::new()
344 .model_name("test")
345 .batch_size(16)
346 .session_pool_size(2);
347
348 assert_eq!(config.get_model_name(), "test");
349 assert_eq!(config.get_batch_size(), 16);
350 assert_eq!(config.get_session_pool_size(), 2);
351 assert!(config.get_enable_logging()); // Default is true
352 }
353
354 #[test]
355 fn test_common_builder_config_validation() {
356 let valid_config = CommonBuilderConfig::new()
357 .batch_size(16)
358 .session_pool_size(2);
359 assert!(valid_config.validate().is_ok());
360
361 let invalid_batch_config = CommonBuilderConfig::new().batch_size(0);
362 assert!(invalid_batch_config.validate().is_err());
363
364 let invalid_pool_config = CommonBuilderConfig::new().session_pool_size(0);
365 assert!(invalid_pool_config.validate().is_err());
366 }
367}