codeprysm_config/
loader.rs

1//! Configuration loader with inheritance support.
2//!
3//! Loads configuration from multiple sources and merges them:
4//! 1. Global config: `~/.codeprysm/config.toml`
5//! 2. Local config: `.codeprysm/config.toml` (in workspace)
6//! 3. CLI overrides
7//!
8//! Later sources override earlier ones.
9
10use crate::error::ConfigError;
11use crate::{ConfigOverrides, PrismConfig};
12use std::path::{Path, PathBuf};
13use tracing::{debug, trace};
14
15/// Configuration file name.
16const CONFIG_FILE_NAME: &str = "config.toml";
17
18/// Global configuration directory name.
19const GLOBAL_CONFIG_DIR: &str = ".codeprysm";
20
21/// Local configuration directory name.
22const LOCAL_CONFIG_DIR: &str = ".codeprysm";
23
24/// Configuration loader with caching and inheritance support.
25#[derive(Debug, Clone)]
26pub struct ConfigLoader {
27    /// Global config directory (e.g., `~/.codeprysm`)
28    global_config_dir: Option<PathBuf>,
29
30    /// Cached global config
31    global_config: Option<PrismConfig>,
32}
33
34impl Default for ConfigLoader {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl ConfigLoader {
41    /// Create a new configuration loader.
42    ///
43    /// Automatically detects the global config directory (`~/.codeprysm`).
44    pub fn new() -> Self {
45        let global_config_dir = dirs::home_dir().map(|h| h.join(GLOBAL_CONFIG_DIR));
46
47        Self {
48            global_config_dir,
49            global_config: None,
50        }
51    }
52
53    /// Create a loader with a custom global config directory.
54    ///
55    /// Useful for testing.
56    pub fn with_global_dir(global_dir: impl Into<PathBuf>) -> Self {
57        Self {
58            global_config_dir: Some(global_dir.into()),
59            global_config: None,
60        }
61    }
62
63    /// Get the global config file path.
64    pub fn global_config_path(&self) -> Option<PathBuf> {
65        self.global_config_dir
66            .as_ref()
67            .map(|d| d.join(CONFIG_FILE_NAME))
68    }
69
70    /// Get the local config file path for a workspace.
71    pub fn local_config_path(&self, workspace_root: &Path) -> PathBuf {
72        workspace_root.join(LOCAL_CONFIG_DIR).join(CONFIG_FILE_NAME)
73    }
74
75    /// Load configuration for a workspace with optional CLI overrides.
76    ///
77    /// Merges config in order: global → local → overrides.
78    pub fn load(
79        &mut self,
80        workspace_root: &Path,
81        overrides: Option<&ConfigOverrides>,
82    ) -> Result<PrismConfig, ConfigError> {
83        // Start with default config
84        let mut config = PrismConfig::default();
85
86        // Apply global config if available
87        if let Some(global_config) = self.load_global()? {
88            config = merge_configs(config, global_config);
89        }
90
91        // Apply local config if available
92        if let Some(local_config) = self.load_local(workspace_root)? {
93            config = merge_configs(config, local_config);
94        }
95
96        // Apply CLI overrides
97        if let Some(ovr) = overrides {
98            config.apply_overrides(ovr);
99        }
100
101        Ok(config)
102    }
103
104    /// Load only the global configuration.
105    pub fn load_global(&mut self) -> Result<Option<PrismConfig>, ConfigError> {
106        // Return cached global config if available
107        if let Some(ref config) = self.global_config {
108            return Ok(Some(config.clone()));
109        }
110
111        let Some(global_path) = self.global_config_path() else {
112            debug!("No home directory found, skipping global config");
113            return Ok(None);
114        };
115
116        if !global_path.exists() {
117            trace!("Global config not found at {:?}", global_path);
118            return Ok(None);
119        }
120
121        debug!("Loading global config from {:?}", global_path);
122        let config = load_config_file(&global_path)?;
123
124        // Cache the global config
125        self.global_config = Some(config.clone());
126
127        Ok(Some(config))
128    }
129
130    /// Load only the local configuration for a workspace.
131    pub fn load_local(&self, workspace_root: &Path) -> Result<Option<PrismConfig>, ConfigError> {
132        let local_path = self.local_config_path(workspace_root);
133
134        if !local_path.exists() {
135            trace!("Local config not found at {:?}", local_path);
136            return Ok(None);
137        }
138
139        debug!("Loading local config from {:?}", local_path);
140        load_config_file(&local_path).map(Some)
141    }
142
143    /// Save configuration to the global config file.
144    pub fn save_global(&self, config: &PrismConfig) -> Result<(), ConfigError> {
145        let Some(ref global_dir) = self.global_config_dir else {
146            return Err(ConfigError::NoHomeDir);
147        };
148
149        let global_path = global_dir.join(CONFIG_FILE_NAME);
150        save_config_file(&global_path, config)
151    }
152
153    /// Save configuration to the local config file for a workspace.
154    pub fn save_local(
155        &self,
156        workspace_root: &Path,
157        config: &PrismConfig,
158    ) -> Result<(), ConfigError> {
159        let local_path = self.local_config_path(workspace_root);
160        save_config_file(&local_path, config)
161    }
162
163    /// Initialize global configuration directory.
164    ///
165    /// Creates `~/.codeprysm/config.toml` with default configuration.
166    pub fn init_global(&self) -> Result<PathBuf, ConfigError> {
167        let Some(ref global_dir) = self.global_config_dir else {
168            return Err(ConfigError::NoHomeDir);
169        };
170
171        // Create directory if it doesn't exist
172        if !global_dir.exists() {
173            std::fs::create_dir_all(global_dir)
174                .map_err(|e| ConfigError::create_dir(global_dir, e))?;
175        }
176
177        let config_path = global_dir.join(CONFIG_FILE_NAME);
178        if !config_path.exists() {
179            let default_config = PrismConfig::default();
180            save_config_file(&config_path, &default_config)?;
181        }
182
183        Ok(config_path)
184    }
185
186    /// Initialize local configuration for a workspace.
187    ///
188    /// Creates `.codeprysm/config.toml` with default configuration.
189    pub fn init_local(&self, workspace_root: &Path) -> Result<PathBuf, ConfigError> {
190        let local_dir = workspace_root.join(LOCAL_CONFIG_DIR);
191
192        // Create directory if it doesn't exist
193        if !local_dir.exists() {
194            std::fs::create_dir_all(&local_dir)
195                .map_err(|e| ConfigError::create_dir(&local_dir, e))?;
196        }
197
198        let config_path = local_dir.join(CONFIG_FILE_NAME);
199        if !config_path.exists() {
200            let default_config = PrismConfig::default();
201            save_config_file(&config_path, &default_config)?;
202        }
203
204        Ok(config_path)
205    }
206
207    /// Clear cached global configuration.
208    ///
209    /// Forces reload on next `load_global()` call.
210    pub fn clear_cache(&mut self) {
211        self.global_config = None;
212    }
213}
214
215/// Load a configuration file from disk.
216fn load_config_file(path: &Path) -> Result<PrismConfig, ConfigError> {
217    let content = std::fs::read_to_string(path).map_err(|e| ConfigError::read_file(path, e))?;
218
219    toml::from_str(&content).map_err(|e| ConfigError::parse_toml(path, e))
220}
221
222/// Save a configuration file to disk.
223fn save_config_file(path: &Path, config: &PrismConfig) -> Result<(), ConfigError> {
224    // Ensure parent directory exists
225    if let Some(parent) = path.parent() {
226        if !parent.exists() {
227            std::fs::create_dir_all(parent).map_err(|e| ConfigError::create_dir(parent, e))?;
228        }
229    }
230
231    let content = toml::to_string_pretty(config)?;
232    std::fs::write(path, content).map_err(|e| ConfigError::write_file(path, e))
233}
234
235/// Merge two configurations, with `overlay` taking precedence.
236///
237/// This performs a field-by-field merge, allowing partial configs.
238fn merge_configs(base: PrismConfig, overlay: PrismConfig) -> PrismConfig {
239    PrismConfig {
240        storage: merge_storage(base.storage, overlay.storage),
241        backend: merge_backend(base.backend, overlay.backend),
242        embedding: merge_embedding(base.embedding, overlay.embedding),
243        analysis: merge_analysis(base.analysis, overlay.analysis),
244        workspace: merge_workspace(base.workspace, overlay.workspace),
245        logging: merge_logging(base.logging, overlay.logging),
246    }
247}
248
249/// Merge storage config, overlay values override base.
250fn merge_storage(
251    base: crate::StorageConfig,
252    overlay: crate::StorageConfig,
253) -> crate::StorageConfig {
254    crate::StorageConfig {
255        // Use overlay if it differs from default, otherwise keep base
256        prism_dir: if overlay.prism_dir != Path::new(".codeprysm") {
257            overlay.prism_dir
258        } else {
259            base.prism_dir
260        },
261        graph_format: overlay.graph_format, // Always use overlay
262        compression: overlay.compression,
263        max_partition_size_mb: if overlay.max_partition_size_mb != 100 {
264            overlay.max_partition_size_mb
265        } else {
266            base.max_partition_size_mb
267        },
268    }
269}
270
271/// Merge backend config.
272fn merge_backend(
273    base: crate::BackendConfig,
274    overlay: crate::BackendConfig,
275) -> crate::BackendConfig {
276    crate::BackendConfig {
277        backend_type: overlay.backend_type,
278        qdrant: merge_qdrant(base.qdrant, overlay.qdrant),
279        remote: overlay.remote.or(base.remote),
280    }
281}
282
283/// Merge Qdrant config.
284fn merge_qdrant(base: crate::QdrantConfig, overlay: crate::QdrantConfig) -> crate::QdrantConfig {
285    crate::QdrantConfig {
286        url: if overlay.url != "http://localhost:6334" {
287            overlay.url
288        } else {
289            base.url
290        },
291        api_key: overlay.api_key.or(base.api_key),
292        collection_prefix: if overlay.collection_prefix != "codeprysm" {
293            overlay.collection_prefix
294        } else {
295            base.collection_prefix
296        },
297        vector_dimension: if overlay.vector_dimension != 768 {
298            overlay.vector_dimension
299        } else {
300            base.vector_dimension
301        },
302        hnsw_enabled: overlay.hnsw_enabled,
303    }
304}
305
306/// Merge embedding config.
307fn merge_embedding(
308    base: crate::EmbeddingConfig,
309    overlay: crate::EmbeddingConfig,
310) -> crate::EmbeddingConfig {
311    crate::EmbeddingConfig {
312        // Provider type from overlay if it differs from default
313        provider: if overlay.provider != crate::EmbeddingProviderType::Local {
314            overlay.provider
315        } else {
316            base.provider
317        },
318        // Overlay azure_ml takes precedence if set
319        azure_ml: overlay.azure_ml.or(base.azure_ml),
320        // Overlay openai takes precedence if set
321        openai: overlay.openai.or(base.openai),
322    }
323}
324
325/// Merge analysis config.
326fn merge_analysis(
327    base: crate::AnalysisConfig,
328    overlay: crate::AnalysisConfig,
329) -> crate::AnalysisConfig {
330    crate::AnalysisConfig {
331        max_file_size_kb: if overlay.max_file_size_kb != 1024 {
332            overlay.max_file_size_kb
333        } else {
334            base.max_file_size_kb
335        },
336        // Merge patterns: overlay patterns extend base patterns
337        exclude_patterns: if overlay.exclude_patterns.is_empty() {
338            base.exclude_patterns
339        } else {
340            // Combine both, with overlay patterns added
341            let mut patterns = base.exclude_patterns;
342            for pattern in overlay.exclude_patterns {
343                if !patterns.contains(&pattern) {
344                    patterns.push(pattern);
345                }
346            }
347            patterns
348        },
349        include_patterns: if overlay.include_patterns.is_empty() {
350            base.include_patterns
351        } else {
352            let mut patterns = base.include_patterns;
353            for pattern in overlay.include_patterns {
354                if !patterns.contains(&pattern) {
355                    patterns.push(pattern);
356                }
357            }
358            patterns
359        },
360        detect_components: overlay.detect_components,
361        parallelism: if overlay.parallelism != 0 {
362            overlay.parallelism
363        } else {
364            base.parallelism
365        },
366        languages: {
367            let mut langs = base.languages;
368            langs.extend(overlay.languages);
369            langs
370        },
371    }
372}
373
374/// Merge workspace config.
375fn merge_workspace(
376    base: crate::WorkspaceConfig,
377    overlay: crate::WorkspaceConfig,
378) -> crate::WorkspaceConfig {
379    crate::WorkspaceConfig {
380        workspaces: {
381            let mut ws = base.workspaces;
382            ws.extend(overlay.workspaces);
383            ws
384        },
385        active: overlay.active.or(base.active),
386        cross_workspace_search: overlay.cross_workspace_search,
387    }
388}
389
390/// Merge logging config.
391fn merge_logging(
392    base: crate::LoggingConfig,
393    overlay: crate::LoggingConfig,
394) -> crate::LoggingConfig {
395    crate::LoggingConfig {
396        level: if overlay.level != "info" {
397            overlay.level
398        } else {
399            base.level
400        },
401        format: overlay.format,
402        file: overlay.file.or(base.file),
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use tempfile::TempDir;
410
411    fn create_test_config(content: &str, dir: &Path, filename: &str) -> PathBuf {
412        let config_dir = dir.join(".codeprysm");
413        std::fs::create_dir_all(&config_dir).unwrap();
414        let path = config_dir.join(filename);
415        std::fs::write(&path, content).unwrap();
416        path
417    }
418
419    #[test]
420    fn test_load_default_config() {
421        let temp = TempDir::new().unwrap();
422        let mut loader = ConfigLoader::with_global_dir(temp.path().join("global"));
423
424        let config = loader.load(temp.path(), None).unwrap();
425
426        // Should get defaults
427        assert_eq!(config.storage.prism_dir, PathBuf::from(".codeprysm"));
428        assert_eq!(config.backend.qdrant.url, "http://localhost:6334");
429    }
430
431    #[test]
432    fn test_load_local_config() {
433        let temp = TempDir::new().unwrap();
434        let mut loader = ConfigLoader::with_global_dir(temp.path().join("global"));
435
436        // Create local config
437        create_test_config(
438            r#"
439            [storage]
440            prism_dir = ".custom-prism"
441
442            [backend.qdrant]
443            url = "http://custom:6334"
444            "#,
445            temp.path(),
446            "config.toml",
447        );
448
449        let config = loader.load(temp.path(), None).unwrap();
450
451        assert_eq!(config.storage.prism_dir, PathBuf::from(".custom-prism"));
452        assert_eq!(config.backend.qdrant.url, "http://custom:6334");
453    }
454
455    #[test]
456    fn test_global_overrides_default() {
457        let temp = TempDir::new().unwrap();
458        let global_dir = temp.path().join("global");
459
460        // Create global config
461        std::fs::create_dir_all(&global_dir).unwrap();
462        std::fs::write(
463            global_dir.join("config.toml"),
464            r#"
465            [logging]
466            level = "debug"
467            "#,
468        )
469        .unwrap();
470
471        let mut loader = ConfigLoader::with_global_dir(&global_dir);
472        let config = loader.load(temp.path(), None).unwrap();
473
474        assert_eq!(config.logging.level, "debug");
475    }
476
477    #[test]
478    fn test_local_overrides_global() {
479        let temp = TempDir::new().unwrap();
480        let global_dir = temp.path().join("global");
481
482        // Create global config
483        std::fs::create_dir_all(&global_dir).unwrap();
484        std::fs::write(
485            global_dir.join("config.toml"),
486            r#"
487            [logging]
488            level = "debug"
489
490            [backend.qdrant]
491            url = "http://global:6334"
492            "#,
493        )
494        .unwrap();
495
496        // Create local config that overrides Qdrant URL but not log level
497        create_test_config(
498            r#"
499            [backend.qdrant]
500            url = "http://local:6334"
501            "#,
502            temp.path(),
503            "config.toml",
504        );
505
506        let mut loader = ConfigLoader::with_global_dir(&global_dir);
507        let config = loader.load(temp.path(), None).unwrap();
508
509        // Local override should take effect
510        assert_eq!(config.backend.qdrant.url, "http://local:6334");
511        // Global value should be preserved (since local doesn't override)
512        assert_eq!(config.logging.level, "debug");
513    }
514
515    #[test]
516    fn test_cli_overrides_all() {
517        let temp = TempDir::new().unwrap();
518
519        // Create local config
520        create_test_config(
521            r#"
522            [backend.qdrant]
523            url = "http://local:6334"
524            "#,
525            temp.path(),
526            "config.toml",
527        );
528
529        let mut loader = ConfigLoader::with_global_dir(temp.path().join("global"));
530
531        let overrides = ConfigOverrides {
532            qdrant_url: Some("http://cli:6334".to_string()),
533            log_level: Some("trace".to_string()),
534            ..Default::default()
535        };
536
537        let config = loader.load(temp.path(), Some(&overrides)).unwrap();
538
539        // CLI should override local
540        assert_eq!(config.backend.qdrant.url, "http://cli:6334");
541        assert_eq!(config.logging.level, "trace");
542    }
543
544    #[test]
545    fn test_save_and_load_config() {
546        let temp = TempDir::new().unwrap();
547        let loader = ConfigLoader::with_global_dir(temp.path().join("global"));
548
549        let mut config = PrismConfig::default();
550        config.backend.qdrant.url = "http://saved:6334".to_string();
551        config.logging.level = "warn".to_string();
552
553        // Save to local
554        loader.save_local(temp.path(), &config).unwrap();
555
556        // Load it back
557        let mut loader = ConfigLoader::with_global_dir(temp.path().join("global"));
558        let loaded = loader.load(temp.path(), None).unwrap();
559
560        assert_eq!(loaded.backend.qdrant.url, "http://saved:6334");
561        assert_eq!(loaded.logging.level, "warn");
562    }
563
564    #[test]
565    fn test_init_local_creates_config() {
566        let temp = TempDir::new().unwrap();
567        let loader = ConfigLoader::with_global_dir(temp.path().join("global"));
568
569        let config_path = loader.init_local(temp.path()).unwrap();
570
571        assert!(config_path.exists());
572        assert!(config_path.ends_with(".codeprysm/config.toml"));
573
574        // Should be valid TOML
575        let content = std::fs::read_to_string(&config_path).unwrap();
576        let _: PrismConfig = toml::from_str(&content).unwrap();
577    }
578
579    #[test]
580    fn test_exclude_patterns_merge() {
581        let base = crate::AnalysisConfig {
582            exclude_patterns: vec!["**/node_modules/**".to_string()],
583            ..Default::default()
584        };
585
586        let overlay = crate::AnalysisConfig {
587            exclude_patterns: vec!["**/custom/**".to_string()],
588            ..Default::default()
589        };
590
591        let merged = merge_analysis(base, overlay);
592
593        // Should have both patterns
594        assert!(
595            merged
596                .exclude_patterns
597                .contains(&"**/node_modules/**".to_string())
598        );
599        assert!(
600            merged
601                .exclude_patterns
602                .contains(&"**/custom/**".to_string())
603        );
604    }
605
606    #[test]
607    fn test_workspace_merge() {
608        let mut base_ws = std::collections::HashMap::new();
609        base_ws.insert("project-a".to_string(), PathBuf::from("/a"));
610
611        let mut overlay_ws = std::collections::HashMap::new();
612        overlay_ws.insert("project-b".to_string(), PathBuf::from("/b"));
613
614        let base = crate::WorkspaceConfig {
615            workspaces: base_ws,
616            active: Some("project-a".to_string()),
617            ..Default::default()
618        };
619
620        let overlay = crate::WorkspaceConfig {
621            workspaces: overlay_ws,
622            active: None, // Don't change active
623            ..Default::default()
624        };
625
626        let merged = merge_workspace(base, overlay);
627
628        // Should have both workspaces
629        assert!(merged.workspaces.contains_key("project-a"));
630        assert!(merged.workspaces.contains_key("project-b"));
631        // Active should be preserved from base
632        assert_eq!(merged.active, Some("project-a".to_string()));
633    }
634
635    #[test]
636    fn test_cache_clearing() {
637        let temp = TempDir::new().unwrap();
638        let global_dir = temp.path().join("global");
639
640        // Create global config
641        std::fs::create_dir_all(&global_dir).unwrap();
642        std::fs::write(
643            global_dir.join("config.toml"),
644            r#"
645            [logging]
646            level = "debug"
647            "#,
648        )
649        .unwrap();
650
651        let mut loader = ConfigLoader::with_global_dir(&global_dir);
652
653        // First load caches
654        let _ = loader.load_global().unwrap();
655        assert!(loader.global_config.is_some());
656
657        // Clear cache
658        loader.clear_cache();
659        assert!(loader.global_config.is_none());
660    }
661}