1use crate::error::ConfigError;
11use crate::{ConfigOverrides, PrismConfig};
12use std::path::{Path, PathBuf};
13use tracing::{debug, trace};
14
15const CONFIG_FILE_NAME: &str = "config.toml";
17
18const GLOBAL_CONFIG_DIR: &str = ".codeprysm";
20
21const LOCAL_CONFIG_DIR: &str = ".codeprysm";
23
24#[derive(Debug, Clone)]
26pub struct ConfigLoader {
27 global_config_dir: Option<PathBuf>,
29
30 global_config: Option<PrismConfig>,
32}
33
34impl Default for ConfigLoader {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl ConfigLoader {
41 pub fn new() -> Self {
45 let global_config_dir = dirs::home_dir().map(|h| h.join(GLOBAL_CONFIG_DIR));
46
47 Self {
48 global_config_dir,
49 global_config: None,
50 }
51 }
52
53 pub fn with_global_dir(global_dir: impl Into<PathBuf>) -> Self {
57 Self {
58 global_config_dir: Some(global_dir.into()),
59 global_config: None,
60 }
61 }
62
63 pub fn global_config_path(&self) -> Option<PathBuf> {
65 self.global_config_dir
66 .as_ref()
67 .map(|d| d.join(CONFIG_FILE_NAME))
68 }
69
70 pub fn local_config_path(&self, workspace_root: &Path) -> PathBuf {
72 workspace_root.join(LOCAL_CONFIG_DIR).join(CONFIG_FILE_NAME)
73 }
74
75 pub fn load(
79 &mut self,
80 workspace_root: &Path,
81 overrides: Option<&ConfigOverrides>,
82 ) -> Result<PrismConfig, ConfigError> {
83 let mut config = PrismConfig::default();
85
86 if let Some(global_config) = self.load_global()? {
88 config = merge_configs(config, global_config);
89 }
90
91 if let Some(local_config) = self.load_local(workspace_root)? {
93 config = merge_configs(config, local_config);
94 }
95
96 if let Some(ovr) = overrides {
98 config.apply_overrides(ovr);
99 }
100
101 Ok(config)
102 }
103
104 pub fn load_global(&mut self) -> Result<Option<PrismConfig>, ConfigError> {
106 if let Some(ref config) = self.global_config {
108 return Ok(Some(config.clone()));
109 }
110
111 let Some(global_path) = self.global_config_path() else {
112 debug!("No home directory found, skipping global config");
113 return Ok(None);
114 };
115
116 if !global_path.exists() {
117 trace!("Global config not found at {:?}", global_path);
118 return Ok(None);
119 }
120
121 debug!("Loading global config from {:?}", global_path);
122 let config = load_config_file(&global_path)?;
123
124 self.global_config = Some(config.clone());
126
127 Ok(Some(config))
128 }
129
130 pub fn load_local(&self, workspace_root: &Path) -> Result<Option<PrismConfig>, ConfigError> {
132 let local_path = self.local_config_path(workspace_root);
133
134 if !local_path.exists() {
135 trace!("Local config not found at {:?}", local_path);
136 return Ok(None);
137 }
138
139 debug!("Loading local config from {:?}", local_path);
140 load_config_file(&local_path).map(Some)
141 }
142
143 pub fn save_global(&self, config: &PrismConfig) -> Result<(), ConfigError> {
145 let Some(ref global_dir) = self.global_config_dir else {
146 return Err(ConfigError::NoHomeDir);
147 };
148
149 let global_path = global_dir.join(CONFIG_FILE_NAME);
150 save_config_file(&global_path, config)
151 }
152
153 pub fn save_local(
155 &self,
156 workspace_root: &Path,
157 config: &PrismConfig,
158 ) -> Result<(), ConfigError> {
159 let local_path = self.local_config_path(workspace_root);
160 save_config_file(&local_path, config)
161 }
162
163 pub fn init_global(&self) -> Result<PathBuf, ConfigError> {
167 let Some(ref global_dir) = self.global_config_dir else {
168 return Err(ConfigError::NoHomeDir);
169 };
170
171 if !global_dir.exists() {
173 std::fs::create_dir_all(global_dir)
174 .map_err(|e| ConfigError::create_dir(global_dir, e))?;
175 }
176
177 let config_path = global_dir.join(CONFIG_FILE_NAME);
178 if !config_path.exists() {
179 let default_config = PrismConfig::default();
180 save_config_file(&config_path, &default_config)?;
181 }
182
183 Ok(config_path)
184 }
185
186 pub fn init_local(&self, workspace_root: &Path) -> Result<PathBuf, ConfigError> {
190 let local_dir = workspace_root.join(LOCAL_CONFIG_DIR);
191
192 if !local_dir.exists() {
194 std::fs::create_dir_all(&local_dir)
195 .map_err(|e| ConfigError::create_dir(&local_dir, e))?;
196 }
197
198 let config_path = local_dir.join(CONFIG_FILE_NAME);
199 if !config_path.exists() {
200 let default_config = PrismConfig::default();
201 save_config_file(&config_path, &default_config)?;
202 }
203
204 Ok(config_path)
205 }
206
207 pub fn clear_cache(&mut self) {
211 self.global_config = None;
212 }
213}
214
215fn load_config_file(path: &Path) -> Result<PrismConfig, ConfigError> {
217 let content = std::fs::read_to_string(path).map_err(|e| ConfigError::read_file(path, e))?;
218
219 toml::from_str(&content).map_err(|e| ConfigError::parse_toml(path, e))
220}
221
222fn save_config_file(path: &Path, config: &PrismConfig) -> Result<(), ConfigError> {
224 if let Some(parent) = path.parent() {
226 if !parent.exists() {
227 std::fs::create_dir_all(parent).map_err(|e| ConfigError::create_dir(parent, e))?;
228 }
229 }
230
231 let content = toml::to_string_pretty(config)?;
232 std::fs::write(path, content).map_err(|e| ConfigError::write_file(path, e))
233}
234
235fn merge_configs(base: PrismConfig, overlay: PrismConfig) -> PrismConfig {
239 PrismConfig {
240 storage: merge_storage(base.storage, overlay.storage),
241 backend: merge_backend(base.backend, overlay.backend),
242 embedding: merge_embedding(base.embedding, overlay.embedding),
243 analysis: merge_analysis(base.analysis, overlay.analysis),
244 workspace: merge_workspace(base.workspace, overlay.workspace),
245 logging: merge_logging(base.logging, overlay.logging),
246 }
247}
248
249fn merge_storage(
251 base: crate::StorageConfig,
252 overlay: crate::StorageConfig,
253) -> crate::StorageConfig {
254 crate::StorageConfig {
255 prism_dir: if overlay.prism_dir != Path::new(".codeprysm") {
257 overlay.prism_dir
258 } else {
259 base.prism_dir
260 },
261 graph_format: overlay.graph_format, compression: overlay.compression,
263 max_partition_size_mb: if overlay.max_partition_size_mb != 100 {
264 overlay.max_partition_size_mb
265 } else {
266 base.max_partition_size_mb
267 },
268 }
269}
270
271fn merge_backend(
273 base: crate::BackendConfig,
274 overlay: crate::BackendConfig,
275) -> crate::BackendConfig {
276 crate::BackendConfig {
277 backend_type: overlay.backend_type,
278 qdrant: merge_qdrant(base.qdrant, overlay.qdrant),
279 remote: overlay.remote.or(base.remote),
280 }
281}
282
283fn merge_qdrant(base: crate::QdrantConfig, overlay: crate::QdrantConfig) -> crate::QdrantConfig {
285 crate::QdrantConfig {
286 url: if overlay.url != "http://localhost:6334" {
287 overlay.url
288 } else {
289 base.url
290 },
291 api_key: overlay.api_key.or(base.api_key),
292 collection_prefix: if overlay.collection_prefix != "codeprysm" {
293 overlay.collection_prefix
294 } else {
295 base.collection_prefix
296 },
297 vector_dimension: if overlay.vector_dimension != 768 {
298 overlay.vector_dimension
299 } else {
300 base.vector_dimension
301 },
302 hnsw_enabled: overlay.hnsw_enabled,
303 }
304}
305
306fn merge_embedding(
308 base: crate::EmbeddingConfig,
309 overlay: crate::EmbeddingConfig,
310) -> crate::EmbeddingConfig {
311 crate::EmbeddingConfig {
312 provider: if overlay.provider != crate::EmbeddingProviderType::Local {
314 overlay.provider
315 } else {
316 base.provider
317 },
318 azure_ml: overlay.azure_ml.or(base.azure_ml),
320 openai: overlay.openai.or(base.openai),
322 }
323}
324
325fn merge_analysis(
327 base: crate::AnalysisConfig,
328 overlay: crate::AnalysisConfig,
329) -> crate::AnalysisConfig {
330 crate::AnalysisConfig {
331 max_file_size_kb: if overlay.max_file_size_kb != 1024 {
332 overlay.max_file_size_kb
333 } else {
334 base.max_file_size_kb
335 },
336 exclude_patterns: if overlay.exclude_patterns.is_empty() {
338 base.exclude_patterns
339 } else {
340 let mut patterns = base.exclude_patterns;
342 for pattern in overlay.exclude_patterns {
343 if !patterns.contains(&pattern) {
344 patterns.push(pattern);
345 }
346 }
347 patterns
348 },
349 include_patterns: if overlay.include_patterns.is_empty() {
350 base.include_patterns
351 } else {
352 let mut patterns = base.include_patterns;
353 for pattern in overlay.include_patterns {
354 if !patterns.contains(&pattern) {
355 patterns.push(pattern);
356 }
357 }
358 patterns
359 },
360 detect_components: overlay.detect_components,
361 parallelism: if overlay.parallelism != 0 {
362 overlay.parallelism
363 } else {
364 base.parallelism
365 },
366 languages: {
367 let mut langs = base.languages;
368 langs.extend(overlay.languages);
369 langs
370 },
371 }
372}
373
374fn merge_workspace(
376 base: crate::WorkspaceConfig,
377 overlay: crate::WorkspaceConfig,
378) -> crate::WorkspaceConfig {
379 crate::WorkspaceConfig {
380 workspaces: {
381 let mut ws = base.workspaces;
382 ws.extend(overlay.workspaces);
383 ws
384 },
385 active: overlay.active.or(base.active),
386 cross_workspace_search: overlay.cross_workspace_search,
387 }
388}
389
390fn merge_logging(
392 base: crate::LoggingConfig,
393 overlay: crate::LoggingConfig,
394) -> crate::LoggingConfig {
395 crate::LoggingConfig {
396 level: if overlay.level != "info" {
397 overlay.level
398 } else {
399 base.level
400 },
401 format: overlay.format,
402 file: overlay.file.or(base.file),
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use tempfile::TempDir;
410
411 fn create_test_config(content: &str, dir: &Path, filename: &str) -> PathBuf {
412 let config_dir = dir.join(".codeprysm");
413 std::fs::create_dir_all(&config_dir).unwrap();
414 let path = config_dir.join(filename);
415 std::fs::write(&path, content).unwrap();
416 path
417 }
418
419 #[test]
420 fn test_load_default_config() {
421 let temp = TempDir::new().unwrap();
422 let mut loader = ConfigLoader::with_global_dir(temp.path().join("global"));
423
424 let config = loader.load(temp.path(), None).unwrap();
425
426 assert_eq!(config.storage.prism_dir, PathBuf::from(".codeprysm"));
428 assert_eq!(config.backend.qdrant.url, "http://localhost:6334");
429 }
430
431 #[test]
432 fn test_load_local_config() {
433 let temp = TempDir::new().unwrap();
434 let mut loader = ConfigLoader::with_global_dir(temp.path().join("global"));
435
436 create_test_config(
438 r#"
439 [storage]
440 prism_dir = ".custom-prism"
441
442 [backend.qdrant]
443 url = "http://custom:6334"
444 "#,
445 temp.path(),
446 "config.toml",
447 );
448
449 let config = loader.load(temp.path(), None).unwrap();
450
451 assert_eq!(config.storage.prism_dir, PathBuf::from(".custom-prism"));
452 assert_eq!(config.backend.qdrant.url, "http://custom:6334");
453 }
454
455 #[test]
456 fn test_global_overrides_default() {
457 let temp = TempDir::new().unwrap();
458 let global_dir = temp.path().join("global");
459
460 std::fs::create_dir_all(&global_dir).unwrap();
462 std::fs::write(
463 global_dir.join("config.toml"),
464 r#"
465 [logging]
466 level = "debug"
467 "#,
468 )
469 .unwrap();
470
471 let mut loader = ConfigLoader::with_global_dir(&global_dir);
472 let config = loader.load(temp.path(), None).unwrap();
473
474 assert_eq!(config.logging.level, "debug");
475 }
476
477 #[test]
478 fn test_local_overrides_global() {
479 let temp = TempDir::new().unwrap();
480 let global_dir = temp.path().join("global");
481
482 std::fs::create_dir_all(&global_dir).unwrap();
484 std::fs::write(
485 global_dir.join("config.toml"),
486 r#"
487 [logging]
488 level = "debug"
489
490 [backend.qdrant]
491 url = "http://global:6334"
492 "#,
493 )
494 .unwrap();
495
496 create_test_config(
498 r#"
499 [backend.qdrant]
500 url = "http://local:6334"
501 "#,
502 temp.path(),
503 "config.toml",
504 );
505
506 let mut loader = ConfigLoader::with_global_dir(&global_dir);
507 let config = loader.load(temp.path(), None).unwrap();
508
509 assert_eq!(config.backend.qdrant.url, "http://local:6334");
511 assert_eq!(config.logging.level, "debug");
513 }
514
515 #[test]
516 fn test_cli_overrides_all() {
517 let temp = TempDir::new().unwrap();
518
519 create_test_config(
521 r#"
522 [backend.qdrant]
523 url = "http://local:6334"
524 "#,
525 temp.path(),
526 "config.toml",
527 );
528
529 let mut loader = ConfigLoader::with_global_dir(temp.path().join("global"));
530
531 let overrides = ConfigOverrides {
532 qdrant_url: Some("http://cli:6334".to_string()),
533 log_level: Some("trace".to_string()),
534 ..Default::default()
535 };
536
537 let config = loader.load(temp.path(), Some(&overrides)).unwrap();
538
539 assert_eq!(config.backend.qdrant.url, "http://cli:6334");
541 assert_eq!(config.logging.level, "trace");
542 }
543
544 #[test]
545 fn test_save_and_load_config() {
546 let temp = TempDir::new().unwrap();
547 let loader = ConfigLoader::with_global_dir(temp.path().join("global"));
548
549 let mut config = PrismConfig::default();
550 config.backend.qdrant.url = "http://saved:6334".to_string();
551 config.logging.level = "warn".to_string();
552
553 loader.save_local(temp.path(), &config).unwrap();
555
556 let mut loader = ConfigLoader::with_global_dir(temp.path().join("global"));
558 let loaded = loader.load(temp.path(), None).unwrap();
559
560 assert_eq!(loaded.backend.qdrant.url, "http://saved:6334");
561 assert_eq!(loaded.logging.level, "warn");
562 }
563
564 #[test]
565 fn test_init_local_creates_config() {
566 let temp = TempDir::new().unwrap();
567 let loader = ConfigLoader::with_global_dir(temp.path().join("global"));
568
569 let config_path = loader.init_local(temp.path()).unwrap();
570
571 assert!(config_path.exists());
572 assert!(config_path.ends_with(".codeprysm/config.toml"));
573
574 let content = std::fs::read_to_string(&config_path).unwrap();
576 let _: PrismConfig = toml::from_str(&content).unwrap();
577 }
578
579 #[test]
580 fn test_exclude_patterns_merge() {
581 let base = crate::AnalysisConfig {
582 exclude_patterns: vec!["**/node_modules/**".to_string()],
583 ..Default::default()
584 };
585
586 let overlay = crate::AnalysisConfig {
587 exclude_patterns: vec!["**/custom/**".to_string()],
588 ..Default::default()
589 };
590
591 let merged = merge_analysis(base, overlay);
592
593 assert!(
595 merged
596 .exclude_patterns
597 .contains(&"**/node_modules/**".to_string())
598 );
599 assert!(
600 merged
601 .exclude_patterns
602 .contains(&"**/custom/**".to_string())
603 );
604 }
605
606 #[test]
607 fn test_workspace_merge() {
608 let mut base_ws = std::collections::HashMap::new();
609 base_ws.insert("project-a".to_string(), PathBuf::from("/a"));
610
611 let mut overlay_ws = std::collections::HashMap::new();
612 overlay_ws.insert("project-b".to_string(), PathBuf::from("/b"));
613
614 let base = crate::WorkspaceConfig {
615 workspaces: base_ws,
616 active: Some("project-a".to_string()),
617 ..Default::default()
618 };
619
620 let overlay = crate::WorkspaceConfig {
621 workspaces: overlay_ws,
622 active: None, ..Default::default()
624 };
625
626 let merged = merge_workspace(base, overlay);
627
628 assert!(merged.workspaces.contains_key("project-a"));
630 assert!(merged.workspaces.contains_key("project-b"));
631 assert_eq!(merged.active, Some("project-a".to_string()));
633 }
634
635 #[test]
636 fn test_cache_clearing() {
637 let temp = TempDir::new().unwrap();
638 let global_dir = temp.path().join("global");
639
640 std::fs::create_dir_all(&global_dir).unwrap();
642 std::fs::write(
643 global_dir.join("config.toml"),
644 r#"
645 [logging]
646 level = "debug"
647 "#,
648 )
649 .unwrap();
650
651 let mut loader = ConfigLoader::with_global_dir(&global_dir);
652
653 let _ = loader.load_global().unwrap();
655 assert!(loader.global_config.is_some());
656
657 loader.clear_cache();
659 assert!(loader.global_config.is_none());
660 }
661}