1use crate::errors::{Result, VisionError};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
17pub struct VisionConfig {
18 #[serde(default)]
20 pub provider: ProviderConfig,
21
22 #[serde(default)]
24 pub cache: CacheConfig,
25
26 #[serde(default)]
28 pub preprocessing: PreprocessingConfig,
29
30 #[serde(default)]
32 pub batch: BatchConfig,
33
34 #[serde(default)]
36 pub downloader: DownloaderConfig,
37}
38
39impl VisionConfig {
40 pub fn from_yaml_file<P: AsRef<Path>>(path: P) -> Result<Self> {
42 let content = std::fs::read_to_string(path.as_ref())
43 .map_err(|e| VisionError::config(format!("Failed to read config file: {}", e)))?;
44
45 Self::from_yaml_str(&content)
46 }
47
48 pub fn from_yaml_str(content: &str) -> Result<Self> {
50 serde_yaml::from_str(content)
51 .map_err(|e| VisionError::config(format!("Failed to parse YAML config: {}", e)))
52 }
53
54 pub fn from_toml_file<P: AsRef<Path>>(path: P) -> Result<Self> {
56 let content = std::fs::read_to_string(path.as_ref())
57 .map_err(|e| VisionError::config(format!("Failed to read config file: {}", e)))?;
58
59 Self::from_toml_str(&content)
60 }
61
62 pub fn from_toml_str(content: &str) -> Result<Self> {
64 toml::from_str(content)
65 .map_err(|e| VisionError::config(format!("Failed to parse TOML config: {}", e)))
66 }
67
68 pub fn with_env_overrides(mut self) -> Self {
70 if let Ok(provider) = std::env::var("OXIFY_VISION_PROVIDER") {
72 self.provider.name = provider;
73 }
74 if let Ok(model_path) = std::env::var("OXIFY_VISION_MODEL_PATH") {
75 self.provider.model_path = Some(PathBuf::from(model_path));
76 }
77 if let Ok(use_gpu) = std::env::var("OXIFY_VISION_USE_GPU") {
78 self.provider.use_gpu = use_gpu.parse().unwrap_or(false);
79 }
80
81 if let Ok(enabled) = std::env::var("OXIFY_VISION_CACHE_ENABLED") {
83 self.cache.enabled = enabled.parse().unwrap_or(true);
84 }
85 if let Ok(max_entries) = std::env::var("OXIFY_VISION_CACHE_MAX_ENTRIES") {
86 if let Ok(n) = max_entries.parse() {
87 self.cache.max_entries = n;
88 }
89 }
90
91 if let Ok(enabled) = std::env::var("OXIFY_VISION_PREPROCESSING_ENABLED") {
93 self.preprocessing.enabled = enabled.parse().unwrap_or(false);
94 }
95
96 self
97 }
98
99 pub fn validate(&self) -> Result<()> {
101 self.provider.validate()?;
103
104 self.cache.validate()?;
106
107 self.batch.validate()?;
109
110 Ok(())
111 }
112
113 pub fn save_yaml<P: AsRef<Path>>(&self, path: P) -> Result<()> {
115 let content = serde_yaml::to_string(self)
116 .map_err(|e| VisionError::config(format!("Failed to serialize config: {}", e)))?;
117
118 std::fs::write(path.as_ref(), content)
119 .map_err(|e| VisionError::config(format!("Failed to write config file: {}", e)))
120 }
121
122 pub fn save_toml<P: AsRef<Path>>(&self, path: P) -> Result<()> {
124 let content = toml::to_string_pretty(self)
125 .map_err(|e| VisionError::config(format!("Failed to serialize config: {}", e)))?;
126
127 std::fs::write(path.as_ref(), content)
128 .map_err(|e| VisionError::config(format!("Failed to write config file: {}", e)))
129 }
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ProviderConfig {
135 pub name: String,
137
138 pub model_path: Option<PathBuf>,
140
141 pub language: Option<String>,
143
144 #[serde(default)]
146 pub use_gpu: bool,
147
148 #[serde(default)]
150 pub gpu_device_id: u32,
151}
152
153impl Default for ProviderConfig {
154 fn default() -> Self {
155 Self {
156 name: "mock".to_string(),
157 model_path: None,
158 language: None,
159 use_gpu: false,
160 gpu_device_id: 0,
161 }
162 }
163}
164
165impl ProviderConfig {
166 pub fn validate(&self) -> Result<()> {
168 let valid_providers = ["mock", "tesseract", "surya", "paddle"];
169 if !valid_providers.contains(&self.name.as_str()) {
170 return Err(VisionError::config(format!(
171 "Invalid provider: {}. Must be one of: {}",
172 self.name,
173 valid_providers.join(", ")
174 )));
175 }
176
177 if matches!(self.name.as_str(), "surya" | "paddle") && self.model_path.is_none() {
179 return Err(VisionError::config(format!(
180 "Provider '{}' requires model_path to be set",
181 self.name
182 )));
183 }
184
185 Ok(())
186 }
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct CacheConfig {
192 #[serde(default = "default_true")]
194 pub enabled: bool,
195
196 #[serde(default = "default_cache_type")]
198 pub cache_type: String,
199
200 #[serde(default = "default_cache_max_entries")]
202 pub max_entries: usize,
203
204 #[serde(default = "default_cache_ttl_secs")]
206 pub ttl_seconds: u64,
207
208 pub redis_url: Option<String>,
210
211 pub sqlite_path: Option<PathBuf>,
213}
214
215fn default_true() -> bool {
216 true
217}
218
219fn default_cache_type() -> String {
220 "memory".to_string()
221}
222
223fn default_cache_max_entries() -> usize {
224 1000
225}
226
227fn default_cache_ttl_secs() -> u64 {
228 3600
229}
230
231impl Default for CacheConfig {
232 fn default() -> Self {
233 Self {
234 enabled: true,
235 cache_type: "memory".to_string(),
236 max_entries: 1000,
237 ttl_seconds: 3600,
238 redis_url: None,
239 sqlite_path: None,
240 }
241 }
242}
243
244impl CacheConfig {
245 pub fn validate(&self) -> Result<()> {
247 let valid_types = ["memory", "redis", "sqlite"];
248 if !valid_types.contains(&self.cache_type.as_str()) {
249 return Err(VisionError::config(format!(
250 "Invalid cache type: {}. Must be one of: {}",
251 self.cache_type,
252 valid_types.join(", ")
253 )));
254 }
255
256 if self.cache_type == "redis" && self.redis_url.is_none() {
257 return Err(VisionError::config(
258 "Redis cache requires redis_url to be set",
259 ));
260 }
261
262 if self.cache_type == "sqlite" && self.sqlite_path.is_none() {
263 return Err(VisionError::config(
264 "SQLite cache requires sqlite_path to be set",
265 ));
266 }
267
268 Ok(())
269 }
270
271 pub fn ttl(&self) -> Duration {
273 Duration::from_secs(self.ttl_seconds)
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct PreprocessingConfig {
280 #[serde(default)]
282 pub enabled: bool,
283
284 pub max_dimension: Option<u32>,
286
287 #[serde(default)]
289 pub denoise: bool,
290
291 #[serde(default)]
293 pub enhance_contrast: bool,
294
295 #[serde(default)]
297 pub deskew: bool,
298
299 #[serde(default)]
301 pub remove_borders: bool,
302
303 #[serde(default)]
305 pub grayscale: bool,
306}
307
308impl Default for PreprocessingConfig {
309 fn default() -> Self {
310 Self {
311 enabled: false,
312 max_dimension: Some(4096),
313 denoise: false,
314 enhance_contrast: false,
315 deskew: false,
316 remove_borders: false,
317 grayscale: false,
318 }
319 }
320}
321
322impl PreprocessingConfig {
323 pub fn high_quality() -> Self {
325 Self {
326 enabled: true,
327 max_dimension: Some(4096),
328 denoise: true,
329 enhance_contrast: true,
330 deskew: true,
331 remove_borders: true,
332 grayscale: true,
333 }
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct BatchConfig {
340 #[serde(default = "default_batch_concurrency")]
342 pub max_concurrency: usize,
343
344 #[serde(default = "default_true")]
346 pub continue_on_error: bool,
347
348 #[serde(default)]
350 pub report_progress: bool,
351}
352
353fn default_batch_concurrency() -> usize {
354 num_cpus::get()
355}
356
357impl Default for BatchConfig {
358 fn default() -> Self {
359 Self {
360 max_concurrency: num_cpus::get(),
361 continue_on_error: true,
362 report_progress: false,
363 }
364 }
365}
366
367impl BatchConfig {
368 pub fn validate(&self) -> Result<()> {
370 if self.max_concurrency == 0 {
371 return Err(VisionError::config(
372 "Batch max_concurrency must be at least 1",
373 ));
374 }
375 Ok(())
376 }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct DownloaderConfig {
382 pub cache_dir: Option<PathBuf>,
384
385 #[serde(default = "default_true")]
387 pub verify_checksums: bool,
388
389 #[serde(default = "default_download_timeout")]
391 pub timeout_seconds: u64,
392
393 #[serde(default = "default_true")]
395 pub report_progress: bool,
396}
397
398fn default_download_timeout() -> u64 {
399 600
400}
401
402impl Default for DownloaderConfig {
403 fn default() -> Self {
404 Self {
405 cache_dir: None,
406 verify_checksums: true,
407 timeout_seconds: 600,
408 report_progress: true,
409 }
410 }
411}
412
413pub struct ConfigWatcher {
415 config_path: PathBuf,
416 last_modified: Option<std::time::SystemTime>,
417}
418
419impl ConfigWatcher {
420 pub fn new<P: AsRef<Path>>(config_path: P) -> Self {
422 Self {
423 config_path: config_path.as_ref().to_path_buf(),
424 last_modified: None,
425 }
426 }
427
428 pub fn has_changed(&mut self) -> Result<bool> {
430 let metadata = std::fs::metadata(&self.config_path)
431 .map_err(|e| VisionError::config(format!("Failed to read config metadata: {}", e)))?;
432
433 let modified = metadata
434 .modified()
435 .map_err(|e| VisionError::config(format!("Failed to get modification time: {}", e)))?;
436
437 if let Some(last_mod) = self.last_modified {
438 if modified > last_mod {
439 self.last_modified = Some(modified);
440 return Ok(true);
441 }
442 } else {
443 self.last_modified = Some(modified);
444 }
445
446 Ok(false)
447 }
448
449 pub fn reload_if_changed(&mut self) -> Result<Option<VisionConfig>> {
451 if self.has_changed()? {
452 let ext = self
453 .config_path
454 .extension()
455 .and_then(|s| s.to_str())
456 .unwrap_or("");
457
458 let config = match ext {
459 "yaml" | "yml" => VisionConfig::from_yaml_file(&self.config_path)?,
460 "toml" => VisionConfig::from_toml_file(&self.config_path)?,
461 _ => {
462 return Err(VisionError::config(format!(
463 "Unsupported config file extension: {}",
464 ext
465 )))
466 }
467 };
468
469 config.validate()?;
470 Ok(Some(config))
471 } else {
472 Ok(None)
473 }
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn test_default_config() {
483 let config = VisionConfig::default();
484 assert_eq!(config.provider.name, "mock");
485 assert!(config.cache.enabled);
486 assert!(!config.preprocessing.enabled);
487 }
488
489 #[test]
490 fn test_yaml_serialization() {
491 let config = VisionConfig::default();
492 let yaml = serde_yaml::to_string(&config).unwrap();
493 assert!(yaml.contains("provider"));
494 assert!(yaml.contains("cache"));
495
496 let parsed: VisionConfig = serde_yaml::from_str(&yaml).unwrap();
497 assert_eq!(parsed.provider.name, "mock");
498 }
499
500 #[test]
501 fn test_toml_serialization() {
502 let config = VisionConfig::default();
503 let toml_str = toml::to_string(&config).unwrap();
504 assert!(toml_str.contains("provider"));
505 assert!(toml_str.contains("cache"));
506
507 let parsed: VisionConfig = toml::from_str(&toml_str).unwrap();
508 assert_eq!(parsed.provider.name, "mock");
509 }
510
511 #[test]
512 fn test_provider_validation() {
513 let mut config = ProviderConfig::default();
514 assert!(config.validate().is_ok());
515
516 config.name = "invalid".to_string();
517 assert!(config.validate().is_err());
518
519 config.name = "surya".to_string();
520 assert!(config.validate().is_err()); config.model_path = Some(PathBuf::from("/path/to/models"));
523 assert!(config.validate().is_ok());
524 }
525
526 #[test]
527 fn test_cache_validation() {
528 let mut config = CacheConfig::default();
529 assert!(config.validate().is_ok());
530
531 config.cache_type = "invalid".to_string();
532 assert!(config.validate().is_err());
533
534 config.cache_type = "redis".to_string();
535 assert!(config.validate().is_err()); config.redis_url = Some("redis://localhost".to_string());
538 assert!(config.validate().is_ok());
539 }
540
541 #[test]
542 fn test_batch_validation() {
543 let mut config = BatchConfig::default();
544 assert!(config.validate().is_ok());
545
546 config.max_concurrency = 0;
547 assert!(config.validate().is_err());
548 }
549
550 #[test]
551 fn test_cache_ttl() {
552 let config = CacheConfig::default();
553 let ttl = config.ttl();
554 assert_eq!(ttl.as_secs(), 3600);
555 }
556
557 #[test]
558 fn test_preprocessing_high_quality() {
559 let config = PreprocessingConfig::high_quality();
560 assert!(config.enabled);
561 assert!(config.denoise);
562 assert!(config.enhance_contrast);
563 assert!(config.deskew);
564 }
565
566 #[test]
567 fn test_env_overrides() {
568 std::env::set_var("OXIFY_VISION_PROVIDER", "tesseract");
569 std::env::set_var("OXIFY_VISION_USE_GPU", "true");
570
571 let config = VisionConfig::default().with_env_overrides();
572 assert_eq!(config.provider.name, "tesseract");
573 assert!(config.provider.use_gpu);
574
575 std::env::remove_var("OXIFY_VISION_PROVIDER");
576 std::env::remove_var("OXIFY_VISION_USE_GPU");
577 }
578
579 #[test]
580 fn test_config_watcher() {
581 let mut watcher = ConfigWatcher::new("/tmp/nonexistent.yaml");
582 assert!(watcher.has_changed().is_err());
583 }
584}