1use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::fs;
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13use std::time::SystemTime;
14
15pub(crate) type ConfigReloadCallback = Box<dyn Fn(&HashMap<String, Value>) + Send + Sync>;
17
18pub(crate) type ConfigReloadCallbackList = Arc<RwLock<Vec<ConfigReloadCallback>>>;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
23#[serde(rename_all = "camelCase")]
24pub enum ConfigSource {
25 Default,
27 UserSettings,
29 ProjectSettings,
31 LocalSettings,
33 EnvSettings,
35 FlagSettings,
37 PolicySettings,
39}
40
41impl ConfigSource {
42 pub fn priority(&self) -> u8 {
44 match self {
45 ConfigSource::Default => 0,
46 ConfigSource::UserSettings => 1,
47 ConfigSource::ProjectSettings => 2,
48 ConfigSource::LocalSettings => 3,
49 ConfigSource::EnvSettings => 4,
50 ConfigSource::FlagSettings => 5,
51 ConfigSource::PolicySettings => 6,
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct ConfigSourceInfo {
59 pub source: ConfigSource,
61 pub path: Option<PathBuf>,
63 pub priority: u8,
65 pub exists: bool,
67 pub loaded_at: Option<SystemTime>,
69}
70
71#[derive(Debug, Clone)]
73pub struct ConfigKeySource {
74 pub key: String,
76 pub value: Value,
78 pub source: ConfigSource,
80 pub source_path: Option<PathBuf>,
82 pub overridden_by: Vec<ConfigSource>,
84}
85
86#[derive(Debug, Clone, Default, Serialize, Deserialize)]
88pub struct EnterprisePolicyConfig {
89 #[serde(default)]
91 pub enforced: HashMap<String, Value>,
92 #[serde(default)]
94 pub defaults: HashMap<String, Value>,
95 #[serde(default)]
97 pub disabled_features: Vec<String>,
98 #[serde(default)]
100 pub allowed_tools: Vec<String>,
101 #[serde(default)]
103 pub denied_tools: Vec<String>,
104 #[serde(default)]
106 pub metadata: PolicyMetadata,
107}
108
109#[derive(Debug, Clone, Default, Serialize, Deserialize)]
111pub struct PolicyMetadata {
112 pub version: Option<String>,
113 pub last_updated: Option<String>,
114 pub organization_id: Option<String>,
115 pub policy_name: Option<String>,
116}
117
118#[derive(Debug, Clone, Default)]
120pub struct ConfigManagerOptions {
121 pub flag_settings_path: Option<PathBuf>,
123 pub working_directory: Option<PathBuf>,
125 pub debug_mode: bool,
127 pub cli_flags: HashMap<String, Value>,
129}
130
131pub struct ConfigManager {
133 global_config_dir: PathBuf,
135 user_config_file: PathBuf,
137 project_config_file: PathBuf,
139 local_config_file: PathBuf,
141 policy_config_file: PathBuf,
143 flag_config_file: Option<PathBuf>,
145
146 merged_config: RwLock<HashMap<String, Value>>,
148 config_sources: RwLock<HashMap<String, ConfigSource>>,
150 config_source_paths: RwLock<HashMap<String, PathBuf>>,
152 config_history: RwLock<HashMap<String, Vec<ConfigKeySource>>>,
154 loaded_sources: RwLock<Vec<ConfigSourceInfo>>,
156 enterprise_policy: RwLock<Option<EnterprisePolicyConfig>>,
158 watcher: RwLock<Option<RecommendedWatcher>>,
160 reload_callbacks: ConfigReloadCallbackList,
162 cli_flags: HashMap<String, Value>,
164 debug_mode: bool,
166}
167
168impl ConfigManager {
169 pub fn new(options: ConfigManagerOptions) -> Self {
171 let working_dir = options
172 .working_directory
173 .unwrap_or_else(|| std::env::current_dir().unwrap_or_default());
174
175 let global_config_dir = std::env::var("ASTER_CONFIG_DIR")
177 .map(PathBuf::from)
178 .unwrap_or_else(|_| dirs::home_dir().unwrap_or_default().join(".aster"));
179
180 let user_config_file = global_config_dir.join("settings.yaml");
182
183 let managed_settings = global_config_dir.join("managed_settings.yaml");
185 let policy_json = global_config_dir.join("policy.yaml");
186 let policy_config_file = if managed_settings.exists() {
187 managed_settings
188 } else {
189 policy_json
190 };
191
192 let project_config_file = working_dir.join(".aster").join("settings.yaml");
194
195 let local_config_file = working_dir.join(".aster").join("settings.local.yaml");
197
198 let debug_mode = options.debug_mode
199 || std::env::var("ASTER_DEBUG")
200 .map(|v| v == "true")
201 .unwrap_or(false);
202
203 let mut manager = Self {
204 global_config_dir,
205 user_config_file,
206 project_config_file,
207 local_config_file,
208 policy_config_file,
209 flag_config_file: options.flag_settings_path,
210 merged_config: RwLock::new(HashMap::new()),
211 config_sources: RwLock::new(HashMap::new()),
212 config_source_paths: RwLock::new(HashMap::new()),
213 config_history: RwLock::new(HashMap::new()),
214 loaded_sources: RwLock::new(Vec::new()),
215 enterprise_policy: RwLock::new(None),
216 watcher: RwLock::new(None),
217 reload_callbacks: Arc::new(RwLock::new(Vec::new())),
218 cli_flags: options.cli_flags,
219 debug_mode,
220 };
221
222 manager.load_and_merge_config();
223 manager
224 }
225
226 fn load_and_merge_config(&mut self) {
237 self.config_sources.write().clear();
238 self.config_source_paths.write().clear();
239 self.config_history.write().clear();
240 self.loaded_sources.write().clear();
241
242 let load_time = SystemTime::now();
243 let mut config: HashMap<String, Value> = HashMap::new();
244
245 let defaults = self.get_default_config();
247 self.track_config_source(&defaults, ConfigSource::Default, None);
248 config.extend(defaults);
249 self.loaded_sources.write().push(ConfigSourceInfo {
250 source: ConfigSource::Default,
251 path: None,
252 priority: ConfigSource::Default.priority(),
253 exists: true,
254 loaded_at: Some(load_time),
255 });
256
257 if let Some(policy) = self.load_enterprise_policy() {
259 if !policy.defaults.is_empty() {
260 self.merge_config(
261 &mut config,
262 &policy.defaults,
263 ConfigSource::PolicySettings,
264 Some(&self.policy_config_file.clone()),
265 );
266 self.debug_log("加载企业策略默认值");
267 }
268 *self.enterprise_policy.write() = Some(policy);
269 }
270
271 let user_exists = self.user_config_file.exists();
273 self.loaded_sources.write().push(ConfigSourceInfo {
274 source: ConfigSource::UserSettings,
275 path: Some(self.user_config_file.clone()),
276 priority: ConfigSource::UserSettings.priority(),
277 exists: user_exists,
278 loaded_at: Some(load_time),
279 });
280 if user_exists {
281 if let Some(user_config) = self.load_config_file(&self.user_config_file) {
282 self.merge_config(
283 &mut config,
284 &user_config,
285 ConfigSource::UserSettings,
286 Some(&self.user_config_file.clone()),
287 );
288 self.debug_log(&format!("加载用户配置: {:?}", self.user_config_file));
289 }
290 }
291
292 let project_exists = self.project_config_file.exists();
294 self.loaded_sources.write().push(ConfigSourceInfo {
295 source: ConfigSource::ProjectSettings,
296 path: Some(self.project_config_file.clone()),
297 priority: ConfigSource::ProjectSettings.priority(),
298 exists: project_exists,
299 loaded_at: Some(load_time),
300 });
301 if project_exists {
302 if let Some(project_config) = self.load_config_file(&self.project_config_file) {
303 self.merge_config(
304 &mut config,
305 &project_config,
306 ConfigSource::ProjectSettings,
307 Some(&self.project_config_file.clone()),
308 );
309 self.debug_log(&format!("加载项目配置: {:?}", self.project_config_file));
310 }
311 }
312
313 let local_exists = self.local_config_file.exists();
315 self.loaded_sources.write().push(ConfigSourceInfo {
316 source: ConfigSource::LocalSettings,
317 path: Some(self.local_config_file.clone()),
318 priority: ConfigSource::LocalSettings.priority(),
319 exists: local_exists,
320 loaded_at: Some(load_time),
321 });
322 if local_exists {
323 if let Some(local_config) = self.load_config_file(&self.local_config_file) {
324 self.merge_config(
325 &mut config,
326 &local_config,
327 ConfigSource::LocalSettings,
328 Some(&self.local_config_file.clone()),
329 );
330 self.debug_log(&format!("加载本地配置: {:?}", self.local_config_file));
331 }
332 }
333
334 let env_config = self.get_env_config();
336 if !env_config.is_empty() {
337 self.merge_config(&mut config, &env_config, ConfigSource::EnvSettings, None);
338 self.loaded_sources.write().push(ConfigSourceInfo {
339 source: ConfigSource::EnvSettings,
340 path: None,
341 priority: ConfigSource::EnvSettings.priority(),
342 exists: true,
343 loaded_at: Some(load_time),
344 });
345 self.debug_log(&format!("加载 {} 个环境变量配置", env_config.len()));
346 }
347
348 if let Some(ref flag_file) = self.flag_config_file {
350 let flag_exists = flag_file.exists();
351 self.loaded_sources.write().push(ConfigSourceInfo {
352 source: ConfigSource::FlagSettings,
353 path: Some(flag_file.clone()),
354 priority: ConfigSource::FlagSettings.priority(),
355 exists: flag_exists,
356 loaded_at: Some(load_time),
357 });
358 if flag_exists {
359 if let Some(flag_config) = self.load_config_file(flag_file) {
360 self.merge_config(
361 &mut config,
362 &flag_config,
363 ConfigSource::FlagSettings,
364 Some(flag_file),
365 );
366 self.debug_log(&format!("加载标志配置: {:?}", flag_file));
367 }
368 }
369 }
370
371 if !self.cli_flags.is_empty() {
373 self.merge_config(
374 &mut config,
375 &self.cli_flags,
376 ConfigSource::FlagSettings,
377 None,
378 );
379 self.debug_log(&format!("应用 {} 个 CLI 标志", self.cli_flags.len()));
380 }
381
382 if let Some(ref policy) = *self.enterprise_policy.read() {
384 if !policy.enforced.is_empty() {
385 self.merge_config(
386 &mut config,
387 &policy.enforced,
388 ConfigSource::PolicySettings,
389 Some(&self.policy_config_file.clone()),
390 );
391 self.loaded_sources.write().push(ConfigSourceInfo {
392 source: ConfigSource::PolicySettings,
393 path: Some(self.policy_config_file.clone()),
394 priority: ConfigSource::PolicySettings.priority(),
395 exists: true,
396 loaded_at: Some(load_time),
397 });
398 self.debug_log("应用企业策略强制设置");
399 }
400 }
401
402 *self.merged_config.write() = config;
403
404 if self.debug_mode {
405 self.print_debug_info();
406 }
407 }
408
409 fn get_default_config(&self) -> HashMap<String, Value> {
411 let mut defaults = HashMap::new();
412 defaults.insert(
413 "model".to_string(),
414 Value::String("claude-3-5-sonnet".to_string()),
415 );
416 defaults.insert("max_tokens".to_string(), Value::Number(4096.into()));
417 defaults.insert(
418 "temperature".to_string(),
419 Value::Number(serde_json::Number::from_f64(0.7).unwrap()),
420 );
421 defaults.insert("enable_telemetry".to_string(), Value::Bool(false));
422 defaults.insert("theme".to_string(), Value::String("auto".to_string()));
423 defaults
424 }
425
426 fn get_env_config(&self) -> HashMap<String, Value> {
428 let mut config = HashMap::new();
429 let env_mappings = [
430 ("ASTER_API_KEY", "api_key"),
431 ("ASTER_MODEL", "model"),
432 ("ASTER_MAX_TOKENS", "max_tokens"),
433 ("ASTER_PROVIDER", "api_provider"),
434 ("ASTER_ENABLE_TELEMETRY", "enable_telemetry"),
435 ];
436
437 for (env_key, config_key) in env_mappings {
438 if let Ok(val) = std::env::var(env_key) {
439 if let Some(parsed) = self.parse_env_value(&val) {
440 config.insert(config_key.to_string(), parsed);
441 }
442 }
443 }
444 config
445 }
446
447 fn parse_env_value(&self, val: &str) -> Option<Value> {
449 if let Ok(json_value) = serde_json::from_str(val) {
451 return Some(json_value);
452 }
453
454 let trimmed = val.trim();
455
456 match trimmed.to_lowercase().as_str() {
458 "true" => return Some(Value::Bool(true)),
459 "false" => return Some(Value::Bool(false)),
460 _ => {}
461 }
462
463 if let Ok(int_val) = trimmed.parse::<i64>() {
465 return Some(Value::Number(int_val.into()));
466 }
467
468 if let Ok(float_val) = trimmed.parse::<f64>() {
470 if let Some(num) = serde_json::Number::from_f64(float_val) {
471 return Some(Value::Number(num));
472 }
473 }
474
475 Some(Value::String(val.to_string()))
477 }
478
479 fn load_config_file(&self, path: &Path) -> Option<HashMap<String, Value>> {
481 if !path.exists() {
482 return None;
483 }
484
485 match fs::read_to_string(path) {
486 Ok(content) => {
487 if let Ok(yaml_value) = serde_yaml::from_str::<serde_yaml::Value>(&content) {
489 if let Ok(Value::Object(map)) = serde_json::to_value(yaml_value) {
490 return Some(map.into_iter().collect());
491 }
492 }
493 if let Ok(Value::Object(map)) = serde_json::from_str::<Value>(&content) {
495 return Some(map.into_iter().collect());
496 }
497 tracing::warn!("无法解析配置文件: {:?}", path);
498 None
499 }
500 Err(e) => {
501 tracing::warn!("读取配置文件失败: {:?}, 错误: {}", path, e);
502 None
503 }
504 }
505 }
506
507 fn load_enterprise_policy(&self) -> Option<EnterprisePolicyConfig> {
509 if !self.policy_config_file.exists() {
510 return None;
511 }
512
513 match fs::read_to_string(&self.policy_config_file) {
514 Ok(content) => {
515 if let Ok(policy) = serde_yaml::from_str(&content) {
517 self.debug_log(&format!("加载企业策略: {:?}", self.policy_config_file));
518 return Some(policy);
519 }
520 if let Ok(policy) = serde_json::from_str(&content) {
522 self.debug_log(&format!("加载企业策略: {:?}", self.policy_config_file));
523 return Some(policy);
524 }
525 tracing::warn!("无法解析企业策略文件");
526 None
527 }
528 Err(e) => {
529 tracing::warn!("读取企业策略失败: {}", e);
530 None
531 }
532 }
533 }
534
535 fn merge_config(
537 &self,
538 base: &mut HashMap<String, Value>,
539 override_config: &HashMap<String, Value>,
540 source: ConfigSource,
541 source_path: Option<&PathBuf>,
542 ) {
543 for (key, value) in override_config {
544 if let Some(prev_source) = self.config_sources.read().get(key) {
546 if *prev_source != source {
547 let mut history = self.config_history.write();
548 let entry = history.entry(key.clone()).or_default();
549 entry.push(ConfigKeySource {
550 key: key.clone(),
551 value: value.clone(),
552 source,
553 source_path: source_path.cloned(),
554 overridden_by: vec![*prev_source],
555 });
556 }
557 }
558
559 self.config_sources.write().insert(key.clone(), source);
561 if let Some(path) = source_path {
562 self.config_source_paths
563 .write()
564 .insert(key.clone(), path.clone());
565 }
566
567 base.insert(key.clone(), self.deep_merge(base.get(key), value));
569 }
570 }
571
572 fn deep_merge(&self, base: Option<&Value>, override_val: &Value) -> Value {
574 match (base, override_val) {
575 (Some(Value::Object(base_map)), Value::Object(override_map)) => {
576 let mut result = base_map.clone();
577 for (k, v) in override_map {
578 let merged = self.deep_merge(base_map.get(k), v);
579 result.insert(k.clone(), merged);
580 }
581 Value::Object(result)
582 }
583 _ => override_val.clone(),
584 }
585 }
586
587 fn track_config_source(
589 &self,
590 config: &HashMap<String, Value>,
591 source: ConfigSource,
592 source_path: Option<&PathBuf>,
593 ) {
594 for key in config.keys() {
595 self.config_sources.write().insert(key.clone(), source);
596 if let Some(path) = source_path {
597 self.config_source_paths
598 .write()
599 .insert(key.clone(), path.clone());
600 }
601 }
602 }
603
604 fn debug_log(&self, message: &str) {
606 if self.debug_mode {
607 tracing::debug!("[Config] {}", message);
608 }
609 }
610
611 fn print_debug_info(&self) {
613 tracing::debug!("\n=== 配置调试信息 ===");
614 tracing::debug!("已加载的配置源:");
615 for source in self.loaded_sources.read().iter() {
616 let status = if source.exists { "OK" } else { "未找到" };
617 let path_info = source
618 .path
619 .as_ref()
620 .map(|p| format!(" ({:?})", p))
621 .unwrap_or_default();
622 tracing::debug!(
623 " [{}] {:?}{}: {}",
624 source.priority,
625 source.source,
626 path_info,
627 status
628 );
629 }
630
631 tracing::debug!("\n配置项来源:");
632 for (key, source) in self.config_sources.read().iter() {
633 let path_info = self
634 .config_source_paths
635 .read()
636 .get(key)
637 .map(|p| format!(" ({:?})", p))
638 .unwrap_or_default();
639 tracing::debug!(" {}: {:?}{}", key, source, path_info);
640 }
641 tracing::debug!("================================\n");
642 }
643
644 pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
648 self.merged_config
649 .read()
650 .get(key)
651 .and_then(|v| serde_json::from_value(v.clone()).ok())
652 }
653
654 pub fn get_or<T: for<'de> Deserialize<'de>>(&self, key: &str, default: T) -> T {
656 self.get(key).unwrap_or(default)
657 }
658
659 pub fn get_value(&self, key: &str) -> Option<Value> {
661 self.merged_config.read().get(key).cloned()
662 }
663
664 pub fn set<T: Serialize>(&self, key: &str, value: T) {
666 if let Ok(json_value) = serde_json::to_value(value) {
667 self.merged_config
668 .write()
669 .insert(key.to_string(), json_value);
670 }
671 }
672
673 pub fn get_all(&self) -> HashMap<String, Value> {
675 self.merged_config.read().clone()
676 }
677
678 pub fn get_with_source<T: for<'de> Deserialize<'de>>(
680 &self,
681 key: &str,
682 ) -> Option<(T, ConfigSource, Option<PathBuf>)> {
683 let value = self.get::<T>(key)?;
684 let source = self
685 .config_sources
686 .read()
687 .get(key)
688 .copied()
689 .unwrap_or(ConfigSource::Default);
690 let path = self.config_source_paths.read().get(key).cloned();
691 Some((value, source, path))
692 }
693
694 pub fn get_config_source(&self, key: &str) -> Option<ConfigSource> {
696 self.config_sources.read().get(key).copied()
697 }
698
699 pub fn get_all_config_sources(&self) -> HashMap<String, ConfigSource> {
701 self.config_sources.read().clone()
702 }
703
704 pub fn get_config_source_info(&self) -> Vec<ConfigSourceInfo> {
706 self.loaded_sources.read().clone()
707 }
708
709 pub fn get_config_history(&self, key: &str) -> Vec<ConfigKeySource> {
711 self.config_history
712 .read()
713 .get(key)
714 .cloned()
715 .unwrap_or_default()
716 }
717
718 pub fn is_enforced_by_policy(&self, key: &str) -> bool {
720 self.enterprise_policy
721 .read()
722 .as_ref()
723 .map(|p| p.enforced.contains_key(key))
724 .unwrap_or(false)
725 }
726
727 pub fn get_enterprise_policy(&self) -> Option<EnterprisePolicyConfig> {
729 self.enterprise_policy.read().clone()
730 }
731
732 pub fn is_feature_disabled(&self, feature: &str) -> bool {
734 self.enterprise_policy
735 .read()
736 .as_ref()
737 .map(|p| p.disabled_features.contains(&feature.to_string()))
738 .unwrap_or(false)
739 }
740
741 pub fn get_config_paths(&self) -> HashMap<String, PathBuf> {
743 let mut paths = HashMap::new();
744 paths.insert("user_settings".to_string(), self.user_config_file.clone());
745 paths.insert(
746 "project_settings".to_string(),
747 self.project_config_file.clone(),
748 );
749 paths.insert("local_settings".to_string(), self.local_config_file.clone());
750 paths.insert(
751 "policy_settings".to_string(),
752 self.policy_config_file.clone(),
753 );
754 paths.insert(
755 "global_config_dir".to_string(),
756 self.global_config_dir.clone(),
757 );
758 if let Some(ref flag_file) = self.flag_config_file {
759 paths.insert("flag_settings".to_string(), flag_file.clone());
760 }
761 paths
762 }
763
764 pub fn save(&self, config: Option<&HashMap<String, Value>>) -> Result<(), std::io::Error> {
768 if let Some(cfg) = config {
769 self.merged_config.write().extend(cfg.clone());
770 }
771
772 if let Some(parent) = self.user_config_file.parent() {
773 fs::create_dir_all(parent)?;
774 }
775
776 self.backup_config(&self.user_config_file)?;
778
779 let yaml = serde_yaml::to_string(&*self.merged_config.read())
780 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
781 fs::write(&self.user_config_file, yaml)
782 }
783
784 pub fn save_local(&self, config: &HashMap<String, Value>) -> Result<(), std::io::Error> {
786 let mut filtered_config = config.clone();
788 if let Some(ref policy) = *self.enterprise_policy.read() {
789 for key in policy.enforced.keys() {
790 if filtered_config.contains_key(key) {
791 tracing::warn!("配置项 {} 被企业策略强制,无法本地覆盖", key);
792 filtered_config.remove(key);
793 }
794 }
795 }
796
797 if let Some(parent) = self.local_config_file.parent() {
798 fs::create_dir_all(parent)?;
799 }
800
801 let mut local_config = self
803 .load_config_file(&self.local_config_file)
804 .unwrap_or_default();
805 local_config.extend(filtered_config);
806
807 self.backup_config(&self.local_config_file)?;
808
809 let yaml = serde_yaml::to_string(&local_config)
810 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
811 fs::write(&self.local_config_file, yaml)
812 }
813
814 pub fn save_project(&self, config: &HashMap<String, Value>) -> Result<(), std::io::Error> {
816 if let Some(parent) = self.project_config_file.parent() {
817 fs::create_dir_all(parent)?;
818 }
819
820 let mut project_config = self
821 .load_config_file(&self.project_config_file)
822 .unwrap_or_default();
823 project_config.extend(config.clone());
824
825 let yaml = serde_yaml::to_string(&project_config)
826 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
827 fs::write(&self.project_config_file, yaml)
828 }
829
830 pub fn reload(&mut self) {
832 self.load_and_merge_config();
833 let config = self.merged_config.read().clone();
834 for callback in self.reload_callbacks.read().iter() {
835 callback(&config);
836 }
837 }
838
839 pub fn watch<F>(&self, callback: F) -> Result<(), notify::Error>
841 where
842 F: Fn(&HashMap<String, Value>) + Send + Sync + 'static,
843 {
844 self.reload_callbacks.write().push(Box::new(callback));
845
846 let mut watcher_guard = self.watcher.write();
847 if watcher_guard.is_some() {
848 return Ok(());
849 }
850
851 let callbacks = self.reload_callbacks.clone();
852 let user_file = self.user_config_file.clone();
853 let project_file = self.project_config_file.clone();
854 let local_file = self.local_config_file.clone();
855
856 let watcher = notify::recommended_watcher(move |res: Result<Event, _>| {
857 if let Ok(event) = res {
858 if event.kind.is_modify() {
859 let cbs = callbacks.read();
861 for cb in cbs.iter() {
862 cb(&HashMap::new()); }
864 }
865 }
866 })?;
867
868 let mut w = watcher;
870 if user_file.exists() {
871 let _ = w.watch(&user_file, RecursiveMode::NonRecursive);
872 }
873 if project_file.exists() {
874 let _ = w.watch(&project_file, RecursiveMode::NonRecursive);
875 }
876 if local_file.exists() {
877 let _ = w.watch(&local_file, RecursiveMode::NonRecursive);
878 }
879
880 *watcher_guard = Some(w);
881 Ok(())
882 }
883
884 fn backup_config(&self, file_path: &Path) -> Result<(), std::io::Error> {
888 if !file_path.exists() {
889 return Ok(());
890 }
891
892 let backup_dir = file_path
893 .parent()
894 .map(|p| p.join(".backups"))
895 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "无效路径"))?;
896
897 fs::create_dir_all(&backup_dir)?;
898
899 let timestamp = chrono::Utc::now().format("%Y-%m-%dT%H-%M-%S");
900 let filename = file_path
901 .file_stem()
902 .and_then(|s| s.to_str())
903 .unwrap_or("config");
904 let backup_path = backup_dir.join(format!("{}.{}.yaml", filename, timestamp));
905
906 fs::copy(file_path, &backup_path)?;
907 self.clean_old_backups(&backup_dir, filename)?;
908 Ok(())
909 }
910
911 fn clean_old_backups(&self, backup_dir: &Path, filename: &str) -> Result<(), std::io::Error> {
913 let mut backups: Vec<_> = fs::read_dir(backup_dir)?
914 .filter_map(|e| e.ok())
915 .filter(|e| e.file_name().to_string_lossy().starts_with(filename))
916 .collect();
917
918 backups.sort_by_key(|e| std::cmp::Reverse(e.metadata().and_then(|m| m.modified()).ok()));
919
920 for backup in backups.into_iter().skip(10) {
921 let _ = fs::remove_file(backup.path());
922 }
923 Ok(())
924 }
925
926 pub fn list_backups(&self, config_type: &str) -> Vec<String> {
928 let config_file = match config_type {
929 "user" => &self.user_config_file,
930 "project" => &self.project_config_file,
931 "local" => &self.local_config_file,
932 _ => return Vec::new(),
933 };
934
935 let backup_dir = match config_file.parent() {
936 Some(p) => p.join(".backups"),
937 None => return Vec::new(),
938 };
939
940 if !backup_dir.exists() {
941 return Vec::new();
942 }
943
944 let filename = config_file
945 .file_stem()
946 .and_then(|s| s.to_str())
947 .unwrap_or("settings");
948
949 fs::read_dir(&backup_dir)
950 .ok()
951 .map(|entries| {
952 let mut backups: Vec<_> = entries
953 .filter_map(|e| e.ok())
954 .filter(|e| e.file_name().to_string_lossy().starts_with(filename))
955 .map(|e| e.file_name().to_string_lossy().to_string())
956 .collect();
957 backups.sort();
958 backups.reverse();
959 backups
960 })
961 .unwrap_or_default()
962 }
963
964 pub fn restore_from_backup(
966 &mut self,
967 backup_filename: &str,
968 config_type: &str,
969 ) -> Result<(), std::io::Error> {
970 let config_file = match config_type {
971 "user" => &self.user_config_file,
972 "project" => &self.project_config_file,
973 "local" => &self.local_config_file,
974 _ => {
975 return Err(std::io::Error::new(
976 std::io::ErrorKind::InvalidInput,
977 "无效的配置类型",
978 ))
979 }
980 };
981
982 let backup_dir = config_file
983 .parent()
984 .map(|p| p.join(".backups"))
985 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "无效路径"))?;
986
987 let backup_path = backup_dir.join(backup_filename);
988 if !backup_path.exists() {
989 return Err(std::io::Error::new(
990 std::io::ErrorKind::NotFound,
991 "备份文件不存在",
992 ));
993 }
994
995 self.backup_config(config_file)?;
997
998 fs::copy(&backup_path, config_file)?;
1000
1001 self.reload();
1003 Ok(())
1004 }
1005
1006 pub fn reset(&mut self) {
1008 *self.merged_config.write() = self.get_default_config();
1009 let _ = self.save(None);
1010 }
1011
1012 pub fn export(&self, mask_secrets: bool) -> String {
1016 let config = self.merged_config.read().clone();
1017
1018 if mask_secrets {
1019 let masked = self.mask_sensitive_fields(&config);
1020 serde_json::to_string_pretty(&masked).unwrap_or_default()
1021 } else {
1022 serde_json::to_string_pretty(&config).unwrap_or_default()
1023 }
1024 }
1025
1026 fn mask_sensitive_fields(&self, config: &HashMap<String, Value>) -> HashMap<String, Value> {
1028 let sensitive_keys = ["api_key", "secret", "password", "token", "credential"];
1029 let mut masked = config.clone();
1030
1031 for (key, value) in masked.iter_mut() {
1032 let key_lower = key.to_lowercase();
1033 if sensitive_keys.iter().any(|s| key_lower.contains(s)) {
1034 if let Value::String(s) = value {
1035 if s.len() > 8 {
1036 *value = Value::String(format!(
1037 "{}...{}",
1038 s.get(..4).unwrap_or(""),
1039 s.get(s.len().saturating_sub(4)..).unwrap_or("")
1040 ));
1041 } else {
1042 *value = Value::String("****".to_string());
1043 }
1044 }
1045 }
1046 }
1047 masked
1048 }
1049
1050 pub fn import(&mut self, config_json: &str) -> Result<(), String> {
1052 let config: HashMap<String, Value> =
1053 serde_json::from_str(config_json).map_err(|e| format!("JSON 解析失败: {}", e))?;
1054
1055 *self.merged_config.write() = config;
1056 self.save(None).map_err(|e| format!("保存失败: {}", e))?;
1057 Ok(())
1058 }
1059}
1060
1061impl Default for ConfigManager {
1062 fn default() -> Self {
1063 Self::new(ConfigManagerOptions::default())
1064 }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070 #[allow(unused_imports)]
1071 use tempfile::TempDir;
1072
1073 #[test]
1074 fn test_config_manager_default() {
1075 let manager = ConfigManager::default();
1076 assert!(manager.get::<String>("model").is_some());
1077 }
1078
1079 #[test]
1080 fn test_get_default_config() {
1081 let manager = ConfigManager::default();
1082 let model: String = manager.get_or("model", "default".to_string());
1083 assert_eq!(model, "claude-3-5-sonnet");
1084 }
1085
1086 #[test]
1087 fn test_set_and_get() {
1088 let manager = ConfigManager::default();
1089 manager.set("test_key", "test_value");
1090 let value: Option<String> = manager.get("test_key");
1091 assert_eq!(value, Some("test_value".to_string()));
1092 }
1093
1094 #[test]
1095 fn test_config_source_priority() {
1096 assert!(ConfigSource::PolicySettings.priority() > ConfigSource::FlagSettings.priority());
1097 assert!(ConfigSource::FlagSettings.priority() > ConfigSource::EnvSettings.priority());
1098 assert!(ConfigSource::EnvSettings.priority() > ConfigSource::LocalSettings.priority());
1099 }
1100
1101 #[test]
1102 fn test_parse_env_value() {
1103 let manager = ConfigManager::default();
1104
1105 assert_eq!(manager.parse_env_value("true"), Some(Value::Bool(true)));
1106 assert_eq!(manager.parse_env_value("false"), Some(Value::Bool(false)));
1107 assert_eq!(
1108 manager.parse_env_value("42"),
1109 Some(Value::Number(42.into()))
1110 );
1111 assert_eq!(
1112 manager.parse_env_value("hello"),
1113 Some(Value::String("hello".to_string()))
1114 );
1115 }
1116
1117 #[test]
1118 fn test_mask_sensitive_fields() {
1119 let manager = ConfigManager::default();
1120 let mut config = HashMap::new();
1121 config.insert(
1122 "api_key".to_string(),
1123 Value::String("sk-1234567890abcdef".to_string()),
1124 );
1125 config.insert("model".to_string(), Value::String("claude-3".to_string()));
1126
1127 let masked = manager.mask_sensitive_fields(&config);
1128 assert!(masked
1129 .get("api_key")
1130 .unwrap()
1131 .as_str()
1132 .unwrap()
1133 .contains("..."));
1134 assert_eq!(masked.get("model").unwrap().as_str().unwrap(), "claude-3");
1135 }
1136}