1use std::collections::HashMap;
4use std::path::PathBuf;
5
6use crate::error::ProviderError;
7use crate::models::{DefaultsConfig, ProviderConfig, ProviderSettings};
8
9pub struct ConfigurationManager {
11 config: ProviderConfig,
12}
13
14impl ConfigurationManager {
15 pub fn new() -> Self {
17 Self {
18 config: ProviderConfig {
19 defaults: DefaultsConfig {
20 provider: "openai".to_string(),
21 model: "gpt-4".to_string(),
22 per_command: HashMap::new(),
23 per_action: HashMap::new(),
24 },
25 providers: HashMap::new(),
26 },
27 }
28 }
29
30 pub fn load_with_precedence(&mut self) -> Result<(), ProviderError> {
36 let global_config_path = Self::get_global_config_path();
40 if global_config_path.exists() {
41 self.load_from_file(&global_config_path)?;
42 }
43
44 let project_config_path = Self::get_project_config_path();
46 if project_config_path.exists() {
47 self.merge_from_file(&project_config_path)?;
48 }
49
50 self.load_from_env()?;
52
53 Ok(())
54 }
55
56 pub fn get_global_config_path() -> PathBuf {
58 let home = std::env::var("HOME")
59 .or_else(|_| std::env::var("USERPROFILE"))
60 .unwrap_or_else(|_| ".".to_string());
61 PathBuf::from(home).join("Documents/.ricecoder/config.yaml")
62 }
63
64 pub fn get_project_config_path() -> PathBuf {
66 PathBuf::from("./.agent/config.yaml")
67 }
68
69 pub fn load_from_env(&mut self) -> Result<(), ProviderError> {
72 let providers_to_check = vec!["openai", "anthropic", "google", "ollama"];
75
76 for provider in providers_to_check {
77 let env_var = format!("{}_API_KEY", provider.to_uppercase());
78 if let Ok(api_key) = std::env::var(&env_var) {
79 self.config
81 .providers
82 .entry(provider.to_string())
83 .and_modify(|s| {
84 s.api_key = Some(api_key.clone());
85 })
86 .or_insert_with(|| ProviderSettings {
87 api_key: Some(api_key.clone()),
88 base_url: None,
89 timeout: None,
90 retry_count: None,
91 });
92 }
93 }
94
95 for (key, value) in std::env::vars() {
97 if key.starts_with("RICECODER_PROVIDER_") {
98 let provider_id = key
99 .strip_prefix("RICECODER_PROVIDER_")
100 .unwrap()
101 .to_lowercase();
102 self.config
103 .providers
104 .entry(provider_id)
105 .and_modify(|s| {
106 s.api_key = Some(value.clone());
107 })
108 .or_insert_with(|| ProviderSettings {
109 api_key: Some(value.clone()),
110 base_url: None,
111 timeout: None,
112 retry_count: None,
113 });
114 }
115 }
116
117 Ok(())
118 }
119
120 pub fn load_from_file(&mut self, path: &PathBuf) -> Result<(), ProviderError> {
122 if !path.exists() {
123 return Ok(()); }
125
126 let content = std::fs::read_to_string(path).map_err(|e| {
127 ProviderError::ConfigError(format!("Failed to read config file: {}", e))
128 })?;
129
130 let config: ProviderConfig = serde_yaml::from_str(&content).map_err(|e| {
131 ProviderError::ConfigError(format!("Failed to parse config file: {}", e))
132 })?;
133
134 self.config = config;
135 Ok(())
136 }
137
138 pub fn merge_from_file(&mut self, path: &PathBuf) -> Result<(), ProviderError> {
140 if !path.exists() {
141 return Ok(()); }
143
144 let content = std::fs::read_to_string(path).map_err(|e| {
145 ProviderError::ConfigError(format!("Failed to read config file: {}", e))
146 })?;
147
148 let new_config: ProviderConfig = serde_yaml::from_str(&content).map_err(|e| {
149 ProviderError::ConfigError(format!("Failed to parse config file: {}", e))
150 })?;
151
152 if !new_config.defaults.provider.is_empty() {
154 self.config.defaults.provider = new_config.defaults.provider;
155 }
156 if !new_config.defaults.model.is_empty() {
157 self.config.defaults.model = new_config.defaults.model;
158 }
159 self.config
160 .defaults
161 .per_command
162 .extend(new_config.defaults.per_command);
163 self.config
164 .defaults
165 .per_action
166 .extend(new_config.defaults.per_action);
167
168 for (provider_id, settings) in new_config.providers {
170 self.config
171 .providers
172 .entry(provider_id)
173 .and_modify(|existing| {
174 if settings.api_key.is_some() {
175 existing.api_key = settings.api_key.clone();
176 }
177 if settings.base_url.is_some() {
178 existing.base_url = settings.base_url.clone();
179 }
180 if settings.timeout.is_some() {
181 existing.timeout = settings.timeout;
182 }
183 if settings.retry_count.is_some() {
184 existing.retry_count = settings.retry_count;
185 }
186 })
187 .or_insert(settings);
188 }
189
190 Ok(())
191 }
192
193 pub fn validate(&self) -> Result<(), ProviderError> {
202 if self.config.providers.is_empty() {
204 return Err(ProviderError::ConfigError(
205 "No providers configured".to_string(),
206 ));
207 }
208
209 if !self
211 .config
212 .providers
213 .contains_key(&self.config.defaults.provider)
214 {
215 return Err(ProviderError::ConfigError(format!(
216 "Default provider '{}' not configured",
217 self.config.defaults.provider
218 )));
219 }
220
221 for (provider_id, settings) in &self.config.providers {
223 if settings.api_key.is_none() {
224 let env_var = format!("{}_API_KEY", provider_id.to_uppercase());
226 if std::env::var(&env_var).is_err() {
227 return Err(ProviderError::ConfigError(format!(
228 "API key not found for provider '{}'. Set {} environment variable or configure in config file",
229 provider_id, env_var
230 )));
231 }
232 }
233
234 if let Some(settings) = self.config.providers.get(provider_id) {
236 if let Some(timeout) = settings.timeout {
239 if timeout.as_secs() == 0 {
240 return Err(ProviderError::ConfigError(format!(
241 "Invalid timeout for provider '{}': must be greater than 0",
242 provider_id
243 )));
244 }
245 }
246
247 if let Some(retry_count) = settings.retry_count {
248 if retry_count > 10 {
249 return Err(ProviderError::ConfigError(format!(
250 "Invalid retry count for provider '{}': must be <= 10",
251 provider_id
252 )));
253 }
254 }
255 }
256 }
257
258 for command in self.config.defaults.per_command.keys() {
260 match command.as_str() {
261 "gen" | "refactor" | "review" => {} _ => {
263 return Err(ProviderError::ConfigError(format!(
264 "Invalid command in per_command defaults: '{}'. Valid commands are: gen, refactor, review",
265 command
266 )));
267 }
268 }
269 }
270
271 for action in self.config.defaults.per_action.keys() {
273 match action.as_str() {
274 "analysis" | "generation" => {} _ => {
276 return Err(ProviderError::ConfigError(format!(
277 "Invalid action in per_action defaults: '{}'. Valid actions are: analysis, generation",
278 action
279 )));
280 }
281 }
282 }
283
284 Ok(())
285 }
286
287 pub fn validate_with_registry(
294 &self,
295 registry: &crate::provider::ProviderRegistry,
296 ) -> Result<(), ProviderError> {
297 self.validate()?;
299
300 let default_provider_id = &self.config.defaults.provider;
302 let default_model_id = &self.config.defaults.model;
303
304 let provider = registry.get(default_provider_id)?;
305 let models = provider.models();
306
307 if !models.iter().any(|m| m.id == *default_model_id) {
308 return Err(ProviderError::ConfigError(format!(
309 "Default model '{}' not found in provider '{}'",
310 default_model_id, default_provider_id
311 )));
312 }
313
314 for (command, model_id) in &self.config.defaults.per_command {
316 let provider_id = default_provider_id; let provider = registry.get(provider_id)?;
318 let models = provider.models();
319
320 if !models.iter().any(|m| m.id == *model_id) {
321 return Err(ProviderError::ConfigError(format!(
322 "Model '{}' for command '{}' not found in provider '{}'",
323 model_id, command, provider_id
324 )));
325 }
326 }
327
328 for (action, model_id) in &self.config.defaults.per_action {
330 let provider_id = default_provider_id; let provider = registry.get(provider_id)?;
332 let models = provider.models();
333
334 if !models.iter().any(|m| m.id == *model_id) {
335 return Err(ProviderError::ConfigError(format!(
336 "Model '{}' for action '{}' not found in provider '{}'",
337 model_id, action, provider_id
338 )));
339 }
340 }
341
342 Ok(())
343 }
344
345 pub fn config(&self) -> &ProviderConfig {
347 &self.config
348 }
349
350 pub fn config_mut(&mut self) -> &mut ProviderConfig {
352 &mut self.config
353 }
354
355 pub fn default_provider(&self) -> &str {
357 &self.config.defaults.provider
358 }
359
360 pub fn default_model(&self) -> &str {
362 &self.config.defaults.model
363 }
364
365 pub fn get_provider_settings(&self, provider_id: &str) -> Option<&ProviderSettings> {
367 self.config.providers.get(provider_id)
368 }
369
370 pub fn get_api_key(&self, provider_id: &str) -> Result<String, ProviderError> {
372 if let Some(settings) = self.config.providers.get(provider_id) {
374 if let Some(key) = &settings.api_key {
375 return Ok(key.clone());
376 }
377 }
378
379 let env_var = format!("{}_API_KEY", provider_id.to_uppercase());
381 std::env::var(&env_var).map_err(|_| {
382 ProviderError::ConfigError(format!(
383 "API key not found for provider '{}'. Set {} environment variable",
384 provider_id, env_var
385 ))
386 })
387 }
388}
389
390impl Default for ConfigurationManager {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn test_new_configuration_manager() {
402 let manager = ConfigurationManager::new();
403 assert_eq!(manager.default_provider(), "openai");
404 assert_eq!(manager.default_model(), "gpt-4");
405 }
406
407 #[test]
408 fn test_validate_empty_config() {
409 let manager = ConfigurationManager::new();
410 assert!(manager.validate().is_err());
411 }
412
413 #[test]
414 fn test_get_default_provider() {
415 let manager = ConfigurationManager::new();
416 assert_eq!(manager.default_provider(), "openai");
417 }
418
419 #[test]
420 fn test_get_default_model() {
421 let manager = ConfigurationManager::new();
422 assert_eq!(manager.default_model(), "gpt-4");
423 }
424
425 #[test]
426 fn test_get_global_config_path() {
427 let path = ConfigurationManager::get_global_config_path();
428 assert!(path.to_string_lossy().contains(".ricecoder"));
429 assert!(path.to_string_lossy().contains("config.yaml"));
430 }
431
432 #[test]
433 fn test_get_project_config_path() {
434 let path = ConfigurationManager::get_project_config_path();
435 assert_eq!(path, PathBuf::from("./.agent/config.yaml"));
436 }
437
438 #[test]
439 fn test_merge_from_file_preserves_existing() {
440 let mut manager = ConfigurationManager::new();
441
442 manager.config_mut().providers.insert(
444 "openai".to_string(),
445 ProviderSettings {
446 api_key: Some("initial-key".to_string()),
447 base_url: None,
448 timeout: None,
449 retry_count: None,
450 },
451 );
452
453 let merged_config = ProviderConfig {
455 defaults: DefaultsConfig {
456 provider: "anthropic".to_string(),
457 model: "claude-3".to_string(),
458 per_command: HashMap::new(),
459 per_action: HashMap::new(),
460 },
461 providers: {
462 let mut map = HashMap::new();
463 map.insert(
464 "anthropic".to_string(),
465 ProviderSettings {
466 api_key: Some("anthropic-key".to_string()),
467 base_url: None,
468 timeout: None,
469 retry_count: None,
470 },
471 );
472 map
473 },
474 };
475
476 manager.config_mut().defaults = merged_config.defaults;
478 manager
479 .config_mut()
480 .providers
481 .extend(merged_config.providers);
482
483 assert_eq!(manager.default_provider(), "anthropic");
484 assert_eq!(manager.default_model(), "claude-3");
485 assert!(manager.config().providers.contains_key("openai"));
486 assert!(manager.config().providers.contains_key("anthropic"));
487 }
488
489 #[test]
490 fn test_load_from_env_sets_api_keys() {
491 let mut manager = ConfigurationManager::new();
492
493 std::env::set_var("OPENAI_API_KEY", "test-key-123");
495
496 manager.load_from_env().unwrap();
497
498 let openai_settings = manager.get_provider_settings("openai");
500 assert!(openai_settings.is_some());
501 assert_eq!(
502 openai_settings.unwrap().api_key,
503 Some("test-key-123".to_string())
504 );
505
506 std::env::remove_var("OPENAI_API_KEY");
508 }
509
510 #[test]
511 fn test_get_api_key_from_config() {
512 let mut manager = ConfigurationManager::new();
513
514 manager.config_mut().providers.insert(
515 "openai".to_string(),
516 ProviderSettings {
517 api_key: Some("config-key".to_string()),
518 base_url: None,
519 timeout: None,
520 retry_count: None,
521 },
522 );
523
524 let key = manager.get_api_key("openai");
525 assert!(key.is_ok());
526 assert_eq!(key.unwrap(), "config-key");
527 }
528
529 #[test]
530 fn test_get_api_key_from_env() {
531 let mut manager = ConfigurationManager::new();
532
533 manager.config_mut().providers.insert(
534 "anthropic".to_string(),
535 ProviderSettings {
536 api_key: None,
537 base_url: None,
538 timeout: None,
539 retry_count: None,
540 },
541 );
542
543 std::env::set_var("ANTHROPIC_API_KEY", "env-key-456");
544
545 let key = manager.get_api_key("anthropic");
546 assert!(key.is_ok());
547 assert_eq!(key.unwrap(), "env-key-456");
548
549 std::env::remove_var("ANTHROPIC_API_KEY");
550 }
551
552 #[test]
553 fn test_validate_with_valid_config() {
554 let mut manager = ConfigurationManager::new();
555
556 manager.config_mut().providers.insert(
557 "openai".to_string(),
558 ProviderSettings {
559 api_key: Some("test-key".to_string()),
560 base_url: None,
561 timeout: None,
562 retry_count: None,
563 },
564 );
565
566 assert!(manager.validate().is_ok());
567 }
568
569 #[test]
570 fn test_validate_missing_api_key() {
571 std::env::remove_var("OPENAI_API_KEY");
573 std::env::remove_var("ANTHROPIC_API_KEY");
574 std::env::remove_var("GOOGLE_API_KEY");
575 std::env::remove_var("ZEN_API_KEY");
576
577 let mut manager = ConfigurationManager::new();
578
579 manager.config_mut().providers.insert(
580 "openai".to_string(),
581 ProviderSettings {
582 api_key: None,
583 base_url: None,
584 timeout: None,
585 retry_count: None,
586 },
587 );
588
589 assert!(manager.validate().is_err());
591 }
592
593 #[test]
594 fn test_validate_invalid_timeout() {
595 let mut manager = ConfigurationManager::new();
596
597 manager.config_mut().providers.insert(
598 "openai".to_string(),
599 ProviderSettings {
600 api_key: Some("test-key".to_string()),
601 base_url: None,
602 timeout: Some(std::time::Duration::from_secs(0)),
603 retry_count: None,
604 },
605 );
606
607 assert!(manager.validate().is_err());
608 }
609
610 #[test]
611 fn test_validate_invalid_retry_count() {
612 let mut manager = ConfigurationManager::new();
613
614 manager.config_mut().providers.insert(
615 "openai".to_string(),
616 ProviderSettings {
617 api_key: Some("test-key".to_string()),
618 base_url: None,
619 timeout: None,
620 retry_count: Some(15),
621 },
622 );
623
624 assert!(manager.validate().is_err());
625 }
626
627 #[test]
628 fn test_validate_invalid_command() {
629 let mut manager = ConfigurationManager::new();
630
631 manager.config_mut().providers.insert(
632 "openai".to_string(),
633 ProviderSettings {
634 api_key: Some("test-key".to_string()),
635 base_url: None,
636 timeout: None,
637 retry_count: None,
638 },
639 );
640
641 manager
642 .config_mut()
643 .defaults
644 .per_command
645 .insert("invalid_command".to_string(), "gpt-4".to_string());
646
647 assert!(manager.validate().is_err());
648 }
649
650 #[test]
651 fn test_validate_invalid_action() {
652 let mut manager = ConfigurationManager::new();
653
654 manager.config_mut().providers.insert(
655 "openai".to_string(),
656 ProviderSettings {
657 api_key: Some("test-key".to_string()),
658 base_url: None,
659 timeout: None,
660 retry_count: None,
661 },
662 );
663
664 manager
665 .config_mut()
666 .defaults
667 .per_action
668 .insert("invalid_action".to_string(), "gpt-4".to_string());
669
670 assert!(manager.validate().is_err());
671 }
672
673 #[test]
674 fn test_validate_valid_commands() {
675 let mut manager = ConfigurationManager::new();
676
677 manager.config_mut().providers.insert(
678 "openai".to_string(),
679 ProviderSettings {
680 api_key: Some("test-key".to_string()),
681 base_url: None,
682 timeout: None,
683 retry_count: None,
684 },
685 );
686
687 manager
688 .config_mut()
689 .defaults
690 .per_command
691 .insert("gen".to_string(), "gpt-4".to_string());
692 manager
693 .config_mut()
694 .defaults
695 .per_command
696 .insert("refactor".to_string(), "gpt-4".to_string());
697 manager
698 .config_mut()
699 .defaults
700 .per_command
701 .insert("review".to_string(), "gpt-4".to_string());
702
703 assert!(manager.validate().is_ok());
704 }
705
706 #[test]
707 fn test_validate_valid_actions() {
708 let mut manager = ConfigurationManager::new();
709
710 manager.config_mut().providers.insert(
711 "openai".to_string(),
712 ProviderSettings {
713 api_key: Some("test-key".to_string()),
714 base_url: None,
715 timeout: None,
716 retry_count: None,
717 },
718 );
719
720 manager
721 .config_mut()
722 .defaults
723 .per_action
724 .insert("analysis".to_string(), "gpt-4".to_string());
725 manager
726 .config_mut()
727 .defaults
728 .per_action
729 .insert("generation".to_string(), "gpt-4".to_string());
730
731 assert!(manager.validate().is_ok());
732 }
733
734 #[test]
735 fn test_validate_with_registry_valid_model() {
736 use crate::models::ModelInfo;
737 use crate::provider::{Provider, ProviderRegistry};
738 use async_trait::async_trait;
739 use std::sync::Arc;
740
741 struct MockProvider;
743
744 #[async_trait]
745 impl Provider for MockProvider {
746 fn id(&self) -> &str {
747 "openai"
748 }
749
750 fn name(&self) -> &str {
751 "OpenAI"
752 }
753
754 fn models(&self) -> Vec<ModelInfo> {
755 vec![ModelInfo {
756 id: "gpt-4".to_string(),
757 name: "GPT-4".to_string(),
758 provider: "openai".to_string(),
759 context_window: 8192,
760 capabilities: vec![],
761 pricing: None,
762 }]
763 }
764
765 async fn chat(
766 &self,
767 _request: crate::models::ChatRequest,
768 ) -> Result<crate::models::ChatResponse, crate::error::ProviderError> {
769 Err(crate::error::ProviderError::NotFound(
770 "Not implemented".to_string(),
771 ))
772 }
773
774 async fn chat_stream(
775 &self,
776 _request: crate::models::ChatRequest,
777 ) -> Result<crate::provider::ChatStream, crate::error::ProviderError> {
778 Err(crate::error::ProviderError::NotFound(
779 "Not implemented".to_string(),
780 ))
781 }
782
783 fn count_tokens(
784 &self,
785 _content: &str,
786 _model: &str,
787 ) -> Result<usize, crate::error::ProviderError> {
788 Ok(0)
789 }
790
791 async fn health_check(&self) -> Result<bool, crate::error::ProviderError> {
792 Ok(true)
793 }
794 }
795
796 let mut manager = ConfigurationManager::new();
797 manager.config_mut().providers.insert(
798 "openai".to_string(),
799 ProviderSettings {
800 api_key: Some("test-key".to_string()),
801 base_url: None,
802 timeout: None,
803 retry_count: None,
804 },
805 );
806
807 let mut registry = ProviderRegistry::new();
808 registry.register(Arc::new(MockProvider)).unwrap();
809
810 assert!(manager.validate_with_registry(®istry).is_ok());
811 }
812
813 #[test]
814 fn test_validate_with_registry_invalid_model() {
815 use crate::models::ModelInfo;
816 use crate::provider::{Provider, ProviderRegistry};
817 use async_trait::async_trait;
818 use std::sync::Arc;
819
820 struct MockProvider;
822
823 #[async_trait]
824 impl Provider for MockProvider {
825 fn id(&self) -> &str {
826 "openai"
827 }
828
829 fn name(&self) -> &str {
830 "OpenAI"
831 }
832
833 fn models(&self) -> Vec<ModelInfo> {
834 vec![ModelInfo {
835 id: "gpt-3.5-turbo".to_string(),
836 name: "GPT-3.5 Turbo".to_string(),
837 provider: "openai".to_string(),
838 context_window: 4096,
839 capabilities: vec![],
840 pricing: None,
841 }]
842 }
843
844 async fn chat(
845 &self,
846 _request: crate::models::ChatRequest,
847 ) -> Result<crate::models::ChatResponse, crate::error::ProviderError> {
848 Err(crate::error::ProviderError::NotFound(
849 "Not implemented".to_string(),
850 ))
851 }
852
853 async fn chat_stream(
854 &self,
855 _request: crate::models::ChatRequest,
856 ) -> Result<crate::provider::ChatStream, crate::error::ProviderError> {
857 Err(crate::error::ProviderError::NotFound(
858 "Not implemented".to_string(),
859 ))
860 }
861
862 fn count_tokens(
863 &self,
864 _content: &str,
865 _model: &str,
866 ) -> Result<usize, crate::error::ProviderError> {
867 Ok(0)
868 }
869
870 async fn health_check(&self) -> Result<bool, crate::error::ProviderError> {
871 Ok(true)
872 }
873 }
874
875 let mut manager = ConfigurationManager::new();
876 manager.config_mut().providers.insert(
877 "openai".to_string(),
878 ProviderSettings {
879 api_key: Some("test-key".to_string()),
880 base_url: None,
881 timeout: None,
882 retry_count: None,
883 },
884 );
885 manager.config_mut().defaults.model = "gpt-4".to_string();
887
888 let mut registry = ProviderRegistry::new();
889 registry.register(Arc::new(MockProvider)).unwrap();
890
891 assert!(manager.validate_with_registry(®istry).is_err());
892 }
893}