1mod error;
11mod loader;
12
13pub use error::ConfigError;
14pub use loader::ConfigLoader;
15
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::path::PathBuf;
19
20#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24#[serde(default)]
25pub struct PrismConfig {
26 pub storage: StorageConfig,
28
29 pub backend: BackendConfig,
31
32 pub embedding: EmbeddingConfig,
34
35 pub analysis: AnalysisConfig,
37
38 pub workspace: WorkspaceConfig,
40
41 pub logging: LoggingConfig,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, Default)]
66#[serde(default)]
67pub struct EmbeddingConfig {
68 pub provider: EmbeddingProviderType,
70
71 pub azure_ml: Option<AzureMLSettings>,
73
74 pub openai: Option<OpenAISettings>,
76}
77
78impl EmbeddingConfig {
79 pub fn validate(&self) -> Result<(), ConfigError> {
81 match self.provider {
82 EmbeddingProviderType::Local => Ok(()),
83 EmbeddingProviderType::AzureMl => {
84 if self.azure_ml.is_none() {
85 return Err(ConfigError::ValidationError(
86 "embedding.provider is 'azure-ml' but [embedding.azure_ml] section is missing".to_string()
87 ));
88 }
89 let settings = self.azure_ml.as_ref().unwrap();
90 if settings.semantic_endpoint.is_empty() {
91 return Err(ConfigError::ValidationError(
92 "embedding.azure_ml.semantic_endpoint is required".to_string(),
93 ));
94 }
95 if settings.code_endpoint.is_empty() {
96 return Err(ConfigError::ValidationError(
97 "embedding.azure_ml.code_endpoint is required".to_string(),
98 ));
99 }
100 Ok(())
101 }
102 EmbeddingProviderType::Openai => {
103 if self.openai.is_none() {
104 return Err(ConfigError::ValidationError(
105 "embedding.provider is 'openai' but [embedding.openai] section is missing"
106 .to_string(),
107 ));
108 }
109 let settings = self.openai.as_ref().unwrap();
110 if settings.url.is_empty() {
111 return Err(ConfigError::ValidationError(
112 "embedding.openai.url is required".to_string(),
113 ));
114 }
115 if settings.semantic_model.is_empty() {
116 return Err(ConfigError::ValidationError(
117 "embedding.openai.semantic_model is required".to_string(),
118 ));
119 }
120 Ok(())
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
128#[serde(rename_all = "kebab-case")]
129pub enum EmbeddingProviderType {
130 #[default]
132 Local,
133 AzureMl,
135 Openai,
137}
138
139impl std::fmt::Display for EmbeddingProviderType {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 match self {
142 Self::Local => write!(f, "local"),
143 Self::AzureMl => write!(f, "azure-ml"),
144 Self::Openai => write!(f, "openai"),
145 }
146 }
147}
148
149impl std::str::FromStr for EmbeddingProviderType {
150 type Err = ConfigError;
151
152 fn from_str(s: &str) -> Result<Self, Self::Err> {
153 match s.to_lowercase().as_str() {
154 "local" => Ok(Self::Local),
155 "azure-ml" | "azureml" | "azure_ml" => Ok(Self::AzureMl),
156 "openai" => Ok(Self::Openai),
157 _ => Err(ConfigError::ValidationError(format!(
158 "Unknown embedding provider: '{}'. Valid values: local, azure-ml, openai",
159 s
160 ))),
161 }
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167#[serde(default)]
168pub struct AzureMLSettings {
169 pub semantic_endpoint: String,
171
172 pub code_endpoint: String,
174
175 pub semantic_auth_key_env: Option<String>,
178
179 pub code_auth_key_env: Option<String>,
182
183 #[serde(skip_serializing_if = "Option::is_none")]
186 pub auth_key_env: Option<String>,
187
188 pub timeout_secs: u64,
190
191 pub max_retries: u32,
193}
194
195impl Default for AzureMLSettings {
196 fn default() -> Self {
197 Self {
198 semantic_endpoint: String::new(),
199 code_endpoint: String::new(),
200 semantic_auth_key_env: Some("CODEPRYSM_AZURE_ML_SEMANTIC_API_KEY".to_string()),
201 code_auth_key_env: Some("CODEPRYSM_AZURE_ML_CODE_API_KEY".to_string()),
202 auth_key_env: None,
203 timeout_secs: 30,
204 max_retries: 3,
205 }
206 }
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211#[serde(default)]
212pub struct OpenAISettings {
213 pub url: String,
215
216 pub api_key_env: Option<String>,
218
219 pub semantic_model: String,
221
222 pub code_model: Option<String>,
224
225 pub timeout_secs: u64,
227
228 pub max_retries: u32,
230
231 pub azure_mode: bool,
233}
234
235impl Default for OpenAISettings {
236 fn default() -> Self {
237 Self {
238 url: "https://api.openai.com/v1".to_string(),
239 api_key_env: Some("OPENAI_API_KEY".to_string()),
240 semantic_model: "text-embedding-3-small".to_string(),
241 code_model: None,
242 timeout_secs: 30,
243 max_retries: 3,
244 azure_mode: false,
245 }
246 }
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
251#[serde(default)]
252pub struct StorageConfig {
253 pub prism_dir: PathBuf,
255
256 pub graph_format: GraphFormat,
258
259 pub compression: bool,
261
262 pub max_partition_size_mb: u32,
264}
265
266impl Default for StorageConfig {
267 fn default() -> Self {
268 Self {
269 prism_dir: PathBuf::from(".codeprysm"),
270 graph_format: GraphFormat::default(),
271 compression: false,
272 max_partition_size_mb: 100,
273 }
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
279#[serde(rename_all = "lowercase")]
280pub enum GraphFormat {
281 #[default]
283 Sqlite,
284 Json,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize, Default)]
290#[serde(default)]
291pub struct BackendConfig {
292 pub backend_type: BackendType,
294
295 pub qdrant: QdrantConfig,
297
298 pub remote: Option<RemoteConfig>,
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
304#[serde(rename_all = "lowercase")]
305pub enum BackendType {
306 #[default]
308 Local,
309 Remote,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
315#[serde(default)]
316pub struct QdrantConfig {
317 pub url: String,
319
320 pub api_key: Option<String>,
322
323 pub collection_prefix: String,
325
326 pub vector_dimension: u32,
328
329 pub hnsw_enabled: bool,
331}
332
333impl Default for QdrantConfig {
334 fn default() -> Self {
335 Self {
336 url: "http://localhost:6334".to_string(),
337 api_key: None,
338 collection_prefix: "codeprysm".to_string(),
339 vector_dimension: 768, hnsw_enabled: true,
341 }
342 }
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct RemoteConfig {
348 pub url: String,
350
351 pub api_key: Option<String>,
353
354 pub timeout_secs: u64,
356}
357
358impl Default for RemoteConfig {
359 fn default() -> Self {
360 Self {
361 url: "http://localhost:8080".to_string(),
362 api_key: None,
363 timeout_secs: 30,
364 }
365 }
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
370#[serde(default)]
371pub struct AnalysisConfig {
372 pub max_file_size_kb: u64,
374
375 pub exclude_patterns: Vec<String>,
377
378 pub include_patterns: Vec<String>,
380
381 pub detect_components: bool,
383
384 pub parallelism: usize,
386
387 pub languages: HashMap<String, LanguageConfig>,
389}
390
391impl Default for AnalysisConfig {
392 fn default() -> Self {
393 Self {
394 max_file_size_kb: 1024, exclude_patterns: vec![
396 "**/node_modules/**".to_string(),
397 "**/target/**".to_string(),
398 "**/.git/**".to_string(),
399 "**/vendor/**".to_string(),
400 "**/__pycache__/**".to_string(),
401 "**/dist/**".to_string(),
402 "**/build/**".to_string(),
403 ],
404 include_patterns: Vec::new(),
405 detect_components: true,
406 parallelism: 0, languages: HashMap::new(),
408 }
409 }
410}
411
412#[derive(Debug, Clone, Serialize, Deserialize, Default)]
414#[serde(default)]
415pub struct LanguageConfig {
416 pub enabled: bool,
418
419 pub query_file: Option<PathBuf>,
421}
422
423#[derive(Debug, Clone, Serialize, Deserialize)]
425#[serde(default)]
426pub struct WorkspaceConfig {
427 pub workspaces: HashMap<String, PathBuf>,
429
430 pub active: Option<String>,
432
433 pub cross_workspace_search: bool,
435}
436
437impl Default for WorkspaceConfig {
438 fn default() -> Self {
439 Self {
440 workspaces: HashMap::new(),
441 active: None,
442 cross_workspace_search: true,
443 }
444 }
445}
446
447#[derive(Debug, Clone, Serialize, Deserialize)]
449#[serde(default)]
450pub struct LoggingConfig {
451 pub level: String,
453
454 pub format: LogFormat,
456
457 pub file: Option<PathBuf>,
459}
460
461impl Default for LoggingConfig {
462 fn default() -> Self {
463 Self {
464 level: "info".to_string(),
465 format: LogFormat::default(),
466 file: None,
467 }
468 }
469}
470
471#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
473#[serde(rename_all = "lowercase")]
474pub enum LogFormat {
475 #[default]
477 Text,
478 Json,
480}
481
482#[derive(Debug, Clone, Default)]
486pub struct ConfigOverrides {
487 pub workspace_root: Option<PathBuf>,
489
490 pub prism_dir: Option<PathBuf>,
492
493 pub qdrant_url: Option<String>,
495
496 pub backend_type: Option<BackendType>,
498
499 pub embedding_provider: Option<EmbeddingProviderType>,
501
502 pub log_level: Option<String>,
504
505 pub parallelism: Option<usize>,
507}
508
509impl PrismConfig {
510 pub fn apply_overrides(&mut self, overrides: &ConfigOverrides) {
512 if let Some(ref dir) = overrides.prism_dir {
513 self.storage.prism_dir = dir.clone();
514 }
515
516 if let Some(ref url) = overrides.qdrant_url {
517 self.backend.qdrant.url = url.clone();
518 }
519
520 if let Some(ref backend_type) = overrides.backend_type {
521 self.backend.backend_type = backend_type.clone();
522 }
523
524 if let Some(embedding_provider) = overrides.embedding_provider {
525 self.embedding.provider = embedding_provider;
526 }
527
528 if let Some(ref level) = overrides.log_level {
529 self.logging.level = level.clone();
530 }
531
532 if let Some(parallelism) = overrides.parallelism {
533 self.analysis.parallelism = parallelism;
534 }
535 }
536
537 pub fn validate(&self) -> Result<(), ConfigError> {
541 self.embedding.validate()?;
542 Ok(())
543 }
544
545 pub fn prism_dir(&self, workspace_root: &std::path::Path) -> PathBuf {
547 if self.storage.prism_dir.is_absolute() {
548 self.storage.prism_dir.clone()
549 } else {
550 workspace_root.join(&self.storage.prism_dir)
551 }
552 }
553
554 pub fn graph_path(&self, workspace_root: &std::path::Path) -> PathBuf {
556 let prism_dir = self.prism_dir(workspace_root);
557 match self.storage.graph_format {
558 GraphFormat::Sqlite => prism_dir.join("manifest.json"),
559 GraphFormat::Json => prism_dir.join("graph.json"),
560 }
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567
568 #[test]
569 fn test_default_config() {
570 let config = PrismConfig::default();
571 assert_eq!(config.storage.prism_dir, PathBuf::from(".codeprysm"));
572 assert_eq!(config.storage.graph_format, GraphFormat::Sqlite);
573 assert_eq!(config.backend.backend_type, BackendType::Local);
574 assert_eq!(config.backend.qdrant.url, "http://localhost:6334");
575 assert!(config.analysis.detect_components);
576 }
577
578 #[test]
579 fn test_apply_overrides() {
580 let mut config = PrismConfig::default();
581 let overrides = ConfigOverrides {
582 prism_dir: Some(PathBuf::from("/custom/prism")),
583 qdrant_url: Some("http://remote:6334".to_string()),
584 log_level: Some("debug".to_string()),
585 ..Default::default()
586 };
587
588 config.apply_overrides(&overrides);
589
590 assert_eq!(config.storage.prism_dir, PathBuf::from("/custom/prism"));
591 assert_eq!(config.backend.qdrant.url, "http://remote:6334");
592 assert_eq!(config.logging.level, "debug");
593 }
594
595 #[test]
596 fn test_prism_dir_resolution() {
597 let config = PrismConfig::default();
598 let workspace = PathBuf::from("/home/user/project");
599
600 let prism_dir = config.prism_dir(&workspace);
601 assert_eq!(prism_dir, PathBuf::from("/home/user/project/.codeprysm"));
602 }
603
604 #[test]
605 fn test_prism_dir_absolute() {
606 let mut config = PrismConfig::default();
607 config.storage.prism_dir = PathBuf::from("/absolute/path/.codeprysm");
608 let workspace = PathBuf::from("/home/user/project");
609
610 let prism_dir = config.prism_dir(&workspace);
611 assert_eq!(prism_dir, PathBuf::from("/absolute/path/.codeprysm"));
612 }
613
614 #[test]
615 fn test_graph_path_sqlite() {
616 let config = PrismConfig::default();
617 let workspace = PathBuf::from("/project");
618
619 let path = config.graph_path(&workspace);
620 assert_eq!(path, PathBuf::from("/project/.codeprysm/manifest.json"));
621 }
622
623 #[test]
624 fn test_graph_path_json() {
625 let mut config = PrismConfig::default();
626 config.storage.graph_format = GraphFormat::Json;
627 let workspace = PathBuf::from("/project");
628
629 let path = config.graph_path(&workspace);
630 assert_eq!(path, PathBuf::from("/project/.codeprysm/graph.json"));
631 }
632
633 #[test]
634 fn test_embedding_config_default() {
635 let config = EmbeddingConfig::default();
636 assert_eq!(config.provider, EmbeddingProviderType::Local);
637 assert!(config.azure_ml.is_none());
638 assert!(config.openai.is_none());
639 }
640
641 #[test]
642 fn test_embedding_provider_type_display() {
643 assert_eq!(EmbeddingProviderType::Local.to_string(), "local");
644 assert_eq!(EmbeddingProviderType::AzureMl.to_string(), "azure-ml");
645 assert_eq!(EmbeddingProviderType::Openai.to_string(), "openai");
646 }
647
648 #[test]
649 fn test_embedding_provider_type_from_str() {
650 assert_eq!(
651 "local".parse::<EmbeddingProviderType>().unwrap(),
652 EmbeddingProviderType::Local
653 );
654 assert_eq!(
655 "azure-ml".parse::<EmbeddingProviderType>().unwrap(),
656 EmbeddingProviderType::AzureMl
657 );
658 assert_eq!(
659 "azureml".parse::<EmbeddingProviderType>().unwrap(),
660 EmbeddingProviderType::AzureMl
661 );
662 assert_eq!(
663 "azure_ml".parse::<EmbeddingProviderType>().unwrap(),
664 EmbeddingProviderType::AzureMl
665 );
666 assert_eq!(
667 "openai".parse::<EmbeddingProviderType>().unwrap(),
668 EmbeddingProviderType::Openai
669 );
670 assert!("unknown".parse::<EmbeddingProviderType>().is_err());
671 }
672
673 #[test]
674 fn test_embedding_config_validate_local() {
675 let config = EmbeddingConfig::default();
676 assert!(config.validate().is_ok());
677 }
678
679 #[test]
680 fn test_embedding_config_validate_azure_ml_missing() {
681 let config = EmbeddingConfig {
682 provider: EmbeddingProviderType::AzureMl,
683 azure_ml: None,
684 openai: None,
685 };
686 let err = config.validate().unwrap_err();
687 assert!(err.to_string().contains("azure_ml"));
688 }
689
690 #[test]
691 fn test_embedding_config_validate_azure_ml_valid() {
692 let config = EmbeddingConfig {
693 provider: EmbeddingProviderType::AzureMl,
694 azure_ml: Some(AzureMLSettings {
695 semantic_endpoint: "https://semantic.example.com/score".to_string(),
696 code_endpoint: "https://code.example.com/score".to_string(),
697 ..Default::default()
698 }),
699 openai: None,
700 };
701 assert!(config.validate().is_ok());
702 }
703
704 #[test]
705 fn test_embedding_config_validate_openai_missing() {
706 let config = EmbeddingConfig {
707 provider: EmbeddingProviderType::Openai,
708 azure_ml: None,
709 openai: None,
710 };
711 let err = config.validate().unwrap_err();
712 assert!(err.to_string().contains("openai"));
713 }
714
715 #[test]
716 fn test_embedding_config_validate_openai_valid() {
717 let config = EmbeddingConfig {
718 provider: EmbeddingProviderType::Openai,
719 azure_ml: None,
720 openai: Some(OpenAISettings {
721 url: "https://api.openai.com/v1".to_string(),
722 semantic_model: "text-embedding-3-small".to_string(),
723 ..Default::default()
724 }),
725 };
726 assert!(config.validate().is_ok());
727 }
728
729 #[test]
730 fn test_apply_embedding_provider_override() {
731 let mut config = PrismConfig::default();
732 assert_eq!(config.embedding.provider, EmbeddingProviderType::Local);
733
734 let overrides = ConfigOverrides {
735 embedding_provider: Some(EmbeddingProviderType::AzureMl),
736 ..Default::default()
737 };
738 config.apply_overrides(&overrides);
739
740 assert_eq!(config.embedding.provider, EmbeddingProviderType::AzureMl);
741 }
742
743 #[test]
744 fn test_embedding_config_toml_roundtrip() {
745 let config = EmbeddingConfig {
746 provider: EmbeddingProviderType::AzureMl,
747 azure_ml: Some(AzureMLSettings {
748 semantic_endpoint: "https://semantic.example.com/score".to_string(),
749 code_endpoint: "https://code.example.com/score".to_string(),
750 semantic_auth_key_env: Some("MY_SEMANTIC_KEY".to_string()),
751 code_auth_key_env: Some("MY_CODE_KEY".to_string()),
752 auth_key_env: None,
753 timeout_secs: 60,
754 max_retries: 5,
755 }),
756 openai: None,
757 };
758
759 let toml_str = toml::to_string(&config).unwrap();
760 let parsed: EmbeddingConfig = toml::from_str(&toml_str).unwrap();
761
762 assert_eq!(parsed.provider, EmbeddingProviderType::AzureMl);
763 assert!(parsed.azure_ml.is_some());
764 let azure_ml = parsed.azure_ml.unwrap();
765 assert_eq!(
766 azure_ml.semantic_endpoint,
767 "https://semantic.example.com/score"
768 );
769 assert_eq!(
770 azure_ml.semantic_auth_key_env,
771 Some("MY_SEMANTIC_KEY".to_string())
772 );
773 assert_eq!(azure_ml.code_auth_key_env, Some("MY_CODE_KEY".to_string()));
774 assert_eq!(azure_ml.timeout_secs, 60);
775 }
776}