oxify_connect_vision/
config.rs

1//! Configuration management for vision/OCR system.
2//!
3//! This module provides:
4//! - YAML/TOML configuration file support
5//! - Environment variable overrides
6//! - Configuration validation
7//! - Hot-reload capability
8//! - Default configurations for common scenarios
9
10use crate::errors::{Result, VisionError};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14
15/// Main vision configuration.
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
17pub struct VisionConfig {
18    /// Provider configuration
19    #[serde(default)]
20    pub provider: ProviderConfig,
21
22    /// Cache configuration
23    #[serde(default)]
24    pub cache: CacheConfig,
25
26    /// Preprocessing configuration
27    #[serde(default)]
28    pub preprocessing: PreprocessingConfig,
29
30    /// Batch processing configuration
31    #[serde(default)]
32    pub batch: BatchConfig,
33
34    /// Model download configuration
35    #[serde(default)]
36    pub downloader: DownloaderConfig,
37}
38
39impl VisionConfig {
40    /// Load configuration from a YAML file.
41    pub fn from_yaml_file<P: AsRef<Path>>(path: P) -> Result<Self> {
42        let content = std::fs::read_to_string(path.as_ref())
43            .map_err(|e| VisionError::config(format!("Failed to read config file: {}", e)))?;
44
45        Self::from_yaml_str(&content)
46    }
47
48    /// Load configuration from a YAML string.
49    pub fn from_yaml_str(content: &str) -> Result<Self> {
50        serde_yaml::from_str(content)
51            .map_err(|e| VisionError::config(format!("Failed to parse YAML config: {}", e)))
52    }
53
54    /// Load configuration from a TOML file.
55    pub fn from_toml_file<P: AsRef<Path>>(path: P) -> Result<Self> {
56        let content = std::fs::read_to_string(path.as_ref())
57            .map_err(|e| VisionError::config(format!("Failed to read config file: {}", e)))?;
58
59        Self::from_toml_str(&content)
60    }
61
62    /// Load configuration from a TOML string.
63    pub fn from_toml_str(content: &str) -> Result<Self> {
64        toml::from_str(content)
65            .map_err(|e| VisionError::config(format!("Failed to parse TOML config: {}", e)))
66    }
67
68    /// Load configuration with environment variable overrides.
69    pub fn with_env_overrides(mut self) -> Self {
70        // Provider overrides
71        if let Ok(provider) = std::env::var("OXIFY_VISION_PROVIDER") {
72            self.provider.name = provider;
73        }
74        if let Ok(model_path) = std::env::var("OXIFY_VISION_MODEL_PATH") {
75            self.provider.model_path = Some(PathBuf::from(model_path));
76        }
77        if let Ok(use_gpu) = std::env::var("OXIFY_VISION_USE_GPU") {
78            self.provider.use_gpu = use_gpu.parse().unwrap_or(false);
79        }
80
81        // Cache overrides
82        if let Ok(enabled) = std::env::var("OXIFY_VISION_CACHE_ENABLED") {
83            self.cache.enabled = enabled.parse().unwrap_or(true);
84        }
85        if let Ok(max_entries) = std::env::var("OXIFY_VISION_CACHE_MAX_ENTRIES") {
86            if let Ok(n) = max_entries.parse() {
87                self.cache.max_entries = n;
88            }
89        }
90
91        // Preprocessing overrides
92        if let Ok(enabled) = std::env::var("OXIFY_VISION_PREPROCESSING_ENABLED") {
93            self.preprocessing.enabled = enabled.parse().unwrap_or(false);
94        }
95
96        self
97    }
98
99    /// Validate configuration.
100    pub fn validate(&self) -> Result<()> {
101        // Validate provider
102        self.provider.validate()?;
103
104        // Validate cache
105        self.cache.validate()?;
106
107        // Validate batch
108        self.batch.validate()?;
109
110        Ok(())
111    }
112
113    /// Save configuration to YAML file.
114    pub fn save_yaml<P: AsRef<Path>>(&self, path: P) -> Result<()> {
115        let content = serde_yaml::to_string(self)
116            .map_err(|e| VisionError::config(format!("Failed to serialize config: {}", e)))?;
117
118        std::fs::write(path.as_ref(), content)
119            .map_err(|e| VisionError::config(format!("Failed to write config file: {}", e)))
120    }
121
122    /// Save configuration to TOML file.
123    pub fn save_toml<P: AsRef<Path>>(&self, path: P) -> Result<()> {
124        let content = toml::to_string_pretty(self)
125            .map_err(|e| VisionError::config(format!("Failed to serialize config: {}", e)))?;
126
127        std::fs::write(path.as_ref(), content)
128            .map_err(|e| VisionError::config(format!("Failed to write config file: {}", e)))
129    }
130}
131
132/// Provider configuration.
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ProviderConfig {
135    /// Provider name (mock, tesseract, surya, paddle)
136    pub name: String,
137
138    /// Model path for ONNX providers
139    pub model_path: Option<PathBuf>,
140
141    /// Language code (e.g., "en", "ja", "zh")
142    pub language: Option<String>,
143
144    /// Use GPU acceleration
145    #[serde(default)]
146    pub use_gpu: bool,
147
148    /// GPU device ID
149    #[serde(default)]
150    pub gpu_device_id: u32,
151}
152
153impl Default for ProviderConfig {
154    fn default() -> Self {
155        Self {
156            name: "mock".to_string(),
157            model_path: None,
158            language: None,
159            use_gpu: false,
160            gpu_device_id: 0,
161        }
162    }
163}
164
165impl ProviderConfig {
166    /// Validate provider configuration.
167    pub fn validate(&self) -> Result<()> {
168        let valid_providers = ["mock", "tesseract", "surya", "paddle"];
169        if !valid_providers.contains(&self.name.as_str()) {
170            return Err(VisionError::config(format!(
171                "Invalid provider: {}. Must be one of: {}",
172                self.name,
173                valid_providers.join(", ")
174            )));
175        }
176
177        // ONNX providers require model_path
178        if matches!(self.name.as_str(), "surya" | "paddle") && self.model_path.is_none() {
179            return Err(VisionError::config(format!(
180                "Provider '{}' requires model_path to be set",
181                self.name
182            )));
183        }
184
185        Ok(())
186    }
187}
188
189/// Cache configuration.
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct CacheConfig {
192    /// Enable caching
193    #[serde(default = "default_true")]
194    pub enabled: bool,
195
196    /// Cache type (memory, redis, sqlite)
197    #[serde(default = "default_cache_type")]
198    pub cache_type: String,
199
200    /// Maximum cache entries (for memory cache)
201    #[serde(default = "default_cache_max_entries")]
202    pub max_entries: usize,
203
204    /// Cache TTL in seconds
205    #[serde(default = "default_cache_ttl_secs")]
206    pub ttl_seconds: u64,
207
208    /// Redis URL (if using redis cache)
209    pub redis_url: Option<String>,
210
211    /// SQLite database path (if using sqlite cache)
212    pub sqlite_path: Option<PathBuf>,
213}
214
215fn default_true() -> bool {
216    true
217}
218
219fn default_cache_type() -> String {
220    "memory".to_string()
221}
222
223fn default_cache_max_entries() -> usize {
224    1000
225}
226
227fn default_cache_ttl_secs() -> u64 {
228    3600
229}
230
231impl Default for CacheConfig {
232    fn default() -> Self {
233        Self {
234            enabled: true,
235            cache_type: "memory".to_string(),
236            max_entries: 1000,
237            ttl_seconds: 3600,
238            redis_url: None,
239            sqlite_path: None,
240        }
241    }
242}
243
244impl CacheConfig {
245    /// Validate cache configuration.
246    pub fn validate(&self) -> Result<()> {
247        let valid_types = ["memory", "redis", "sqlite"];
248        if !valid_types.contains(&self.cache_type.as_str()) {
249            return Err(VisionError::config(format!(
250                "Invalid cache type: {}. Must be one of: {}",
251                self.cache_type,
252                valid_types.join(", ")
253            )));
254        }
255
256        if self.cache_type == "redis" && self.redis_url.is_none() {
257            return Err(VisionError::config(
258                "Redis cache requires redis_url to be set",
259            ));
260        }
261
262        if self.cache_type == "sqlite" && self.sqlite_path.is_none() {
263            return Err(VisionError::config(
264                "SQLite cache requires sqlite_path to be set",
265            ));
266        }
267
268        Ok(())
269    }
270
271    /// Get cache TTL as Duration.
272    pub fn ttl(&self) -> Duration {
273        Duration::from_secs(self.ttl_seconds)
274    }
275}
276
277/// Preprocessing configuration.
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct PreprocessingConfig {
280    /// Enable preprocessing
281    #[serde(default)]
282    pub enabled: bool,
283
284    /// Maximum image dimension
285    pub max_dimension: Option<u32>,
286
287    /// Apply noise reduction
288    #[serde(default)]
289    pub denoise: bool,
290
291    /// Enhance contrast
292    #[serde(default)]
293    pub enhance_contrast: bool,
294
295    /// Apply deskewing
296    #[serde(default)]
297    pub deskew: bool,
298
299    /// Remove borders
300    #[serde(default)]
301    pub remove_borders: bool,
302
303    /// Convert to grayscale
304    #[serde(default)]
305    pub grayscale: bool,
306}
307
308impl Default for PreprocessingConfig {
309    fn default() -> Self {
310        Self {
311            enabled: false,
312            max_dimension: Some(4096),
313            denoise: false,
314            enhance_contrast: false,
315            deskew: false,
316            remove_borders: false,
317            grayscale: false,
318        }
319    }
320}
321
322impl PreprocessingConfig {
323    /// Create a high-quality preprocessing config.
324    pub fn high_quality() -> Self {
325        Self {
326            enabled: true,
327            max_dimension: Some(4096),
328            denoise: true,
329            enhance_contrast: true,
330            deskew: true,
331            remove_borders: true,
332            grayscale: true,
333        }
334    }
335}
336
337/// Batch processing configuration.
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct BatchConfig {
340    /// Maximum concurrent operations
341    #[serde(default = "default_batch_concurrency")]
342    pub max_concurrency: usize,
343
344    /// Continue on error
345    #[serde(default = "default_true")]
346    pub continue_on_error: bool,
347
348    /// Report progress
349    #[serde(default)]
350    pub report_progress: bool,
351}
352
353fn default_batch_concurrency() -> usize {
354    num_cpus::get()
355}
356
357impl Default for BatchConfig {
358    fn default() -> Self {
359        Self {
360            max_concurrency: num_cpus::get(),
361            continue_on_error: true,
362            report_progress: false,
363        }
364    }
365}
366
367impl BatchConfig {
368    /// Validate batch configuration.
369    pub fn validate(&self) -> Result<()> {
370        if self.max_concurrency == 0 {
371            return Err(VisionError::config(
372                "Batch max_concurrency must be at least 1",
373            ));
374        }
375        Ok(())
376    }
377}
378
379/// Model downloader configuration.
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct DownloaderConfig {
382    /// Cache directory for models
383    pub cache_dir: Option<PathBuf>,
384
385    /// Verify checksums
386    #[serde(default = "default_true")]
387    pub verify_checksums: bool,
388
389    /// Download timeout in seconds
390    #[serde(default = "default_download_timeout")]
391    pub timeout_seconds: u64,
392
393    /// Report progress
394    #[serde(default = "default_true")]
395    pub report_progress: bool,
396}
397
398fn default_download_timeout() -> u64 {
399    600
400}
401
402impl Default for DownloaderConfig {
403    fn default() -> Self {
404        Self {
405            cache_dir: None,
406            verify_checksums: true,
407            timeout_seconds: 600,
408            report_progress: true,
409        }
410    }
411}
412
413/// Configuration file watcher for hot-reload.
414pub struct ConfigWatcher {
415    config_path: PathBuf,
416    last_modified: Option<std::time::SystemTime>,
417}
418
419impl ConfigWatcher {
420    /// Create a new configuration watcher.
421    pub fn new<P: AsRef<Path>>(config_path: P) -> Self {
422        Self {
423            config_path: config_path.as_ref().to_path_buf(),
424            last_modified: None,
425        }
426    }
427
428    /// Check if configuration file has been modified.
429    pub fn has_changed(&mut self) -> Result<bool> {
430        let metadata = std::fs::metadata(&self.config_path)
431            .map_err(|e| VisionError::config(format!("Failed to read config metadata: {}", e)))?;
432
433        let modified = metadata
434            .modified()
435            .map_err(|e| VisionError::config(format!("Failed to get modification time: {}", e)))?;
436
437        if let Some(last_mod) = self.last_modified {
438            if modified > last_mod {
439                self.last_modified = Some(modified);
440                return Ok(true);
441            }
442        } else {
443            self.last_modified = Some(modified);
444        }
445
446        Ok(false)
447    }
448
449    /// Reload configuration if changed.
450    pub fn reload_if_changed(&mut self) -> Result<Option<VisionConfig>> {
451        if self.has_changed()? {
452            let ext = self
453                .config_path
454                .extension()
455                .and_then(|s| s.to_str())
456                .unwrap_or("");
457
458            let config = match ext {
459                "yaml" | "yml" => VisionConfig::from_yaml_file(&self.config_path)?,
460                "toml" => VisionConfig::from_toml_file(&self.config_path)?,
461                _ => {
462                    return Err(VisionError::config(format!(
463                        "Unsupported config file extension: {}",
464                        ext
465                    )))
466                }
467            };
468
469            config.validate()?;
470            Ok(Some(config))
471        } else {
472            Ok(None)
473        }
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    #[test]
482    fn test_default_config() {
483        let config = VisionConfig::default();
484        assert_eq!(config.provider.name, "mock");
485        assert!(config.cache.enabled);
486        assert!(!config.preprocessing.enabled);
487    }
488
489    #[test]
490    fn test_yaml_serialization() {
491        let config = VisionConfig::default();
492        let yaml = serde_yaml::to_string(&config).unwrap();
493        assert!(yaml.contains("provider"));
494        assert!(yaml.contains("cache"));
495
496        let parsed: VisionConfig = serde_yaml::from_str(&yaml).unwrap();
497        assert_eq!(parsed.provider.name, "mock");
498    }
499
500    #[test]
501    fn test_toml_serialization() {
502        let config = VisionConfig::default();
503        let toml_str = toml::to_string(&config).unwrap();
504        assert!(toml_str.contains("provider"));
505        assert!(toml_str.contains("cache"));
506
507        let parsed: VisionConfig = toml::from_str(&toml_str).unwrap();
508        assert_eq!(parsed.provider.name, "mock");
509    }
510
511    #[test]
512    fn test_provider_validation() {
513        let mut config = ProviderConfig::default();
514        assert!(config.validate().is_ok());
515
516        config.name = "invalid".to_string();
517        assert!(config.validate().is_err());
518
519        config.name = "surya".to_string();
520        assert!(config.validate().is_err()); // Missing model_path
521
522        config.model_path = Some(PathBuf::from("/path/to/models"));
523        assert!(config.validate().is_ok());
524    }
525
526    #[test]
527    fn test_cache_validation() {
528        let mut config = CacheConfig::default();
529        assert!(config.validate().is_ok());
530
531        config.cache_type = "invalid".to_string();
532        assert!(config.validate().is_err());
533
534        config.cache_type = "redis".to_string();
535        assert!(config.validate().is_err()); // Missing redis_url
536
537        config.redis_url = Some("redis://localhost".to_string());
538        assert!(config.validate().is_ok());
539    }
540
541    #[test]
542    fn test_batch_validation() {
543        let mut config = BatchConfig::default();
544        assert!(config.validate().is_ok());
545
546        config.max_concurrency = 0;
547        assert!(config.validate().is_err());
548    }
549
550    #[test]
551    fn test_cache_ttl() {
552        let config = CacheConfig::default();
553        let ttl = config.ttl();
554        assert_eq!(ttl.as_secs(), 3600);
555    }
556
557    #[test]
558    fn test_preprocessing_high_quality() {
559        let config = PreprocessingConfig::high_quality();
560        assert!(config.enabled);
561        assert!(config.denoise);
562        assert!(config.enhance_contrast);
563        assert!(config.deskew);
564    }
565
566    #[test]
567    fn test_env_overrides() {
568        std::env::set_var("OXIFY_VISION_PROVIDER", "tesseract");
569        std::env::set_var("OXIFY_VISION_USE_GPU", "true");
570
571        let config = VisionConfig::default().with_env_overrides();
572        assert_eq!(config.provider.name, "tesseract");
573        assert!(config.provider.use_gpu);
574
575        std::env::remove_var("OXIFY_VISION_PROVIDER");
576        std::env::remove_var("OXIFY_VISION_USE_GPU");
577    }
578
579    #[test]
580    fn test_config_watcher() {
581        let mut watcher = ConfigWatcher::new("/tmp/nonexistent.yaml");
582        assert!(watcher.has_changed().is_err());
583    }
584}