hehe_core/config/
loader.rs

1use super::types::Config;
2use crate::error::{Error, Result};
3use std::path::Path;
4
5impl Config {
6    pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self> {
7        let content = std::fs::read_to_string(path.as_ref())?;
8        Self::from_toml(&content)
9    }
10
11    pub fn from_toml(content: &str) -> Result<Self> {
12        toml::from_str(content).map_err(|e| Error::Config(format!("Failed to parse config: {}", e)))
13    }
14
15    pub fn from_json(content: &str) -> Result<Self> {
16        serde_json::from_str(content).map_err(Error::Json)
17    }
18
19    pub fn load_default() -> Result<Self> {
20        let paths = [
21            "./hehe.toml",
22            "~/.hehe/config.toml",
23            "~/.config/hehe/config.toml",
24            "/etc/hehe/config.toml",
25        ];
26
27        for path in &paths {
28            let expanded = shellexpand::tilde(path);
29            let path = Path::new(expanded.as_ref());
30            if path.exists() {
31                return Self::load_from_file(path);
32            }
33        }
34
35        Ok(Config::default())
36    }
37
38    pub fn merge_env(mut self) -> Self {
39        if let Ok(level) = std::env::var("HEHE_LOG_LEVEL") {
40            self.general.log_level = match level.to_lowercase().as_str() {
41                "trace" => super::types::LogLevel::Trace,
42                "debug" => super::types::LogLevel::Debug,
43                "info" => super::types::LogLevel::Info,
44                "warn" => super::types::LogLevel::Warn,
45                "error" => super::types::LogLevel::Error,
46                _ => self.general.log_level,
47            };
48        }
49
50        if let Ok(dir) = std::env::var("HEHE_DATA_DIR") {
51            self.general.data_dir = dir.into();
52        }
53
54        if let Ok(provider) = std::env::var("HEHE_DEFAULT_PROVIDER") {
55            self.llm.default_provider = Some(provider);
56        }
57
58        self
59    }
60
61    pub fn to_toml(&self) -> Result<String> {
62        toml::to_string_pretty(self)
63            .map_err(|e| Error::Config(format!("Failed to serialize config: {}", e)))
64    }
65
66    pub fn to_json(&self) -> Result<String> {
67        serde_json::to_string_pretty(self).map_err(Error::Json)
68    }
69
70    pub fn save_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
71        let content = self.to_toml()?;
72        std::fs::write(path.as_ref(), content)?;
73        Ok(())
74    }
75
76    pub fn data_dir(&self) -> std::path::PathBuf {
77        let expanded = shellexpand::tilde(self.general.data_dir.as_str());
78        std::path::PathBuf::from(expanded.as_ref())
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn test_default_config() {
88        let config = Config::default();
89        assert_eq!(config.general.log_level, super::super::types::LogLevel::Info);
90    }
91
92    #[test]
93    fn test_config_from_toml() {
94        let toml = r#"
95            [general]
96            log_level = "debug"
97            
98            [llm]
99            default_provider = "openai"
100            
101            [llm.providers.openai]
102            provider_type = "openai"
103            model = "gpt-4"
104        "#;
105
106        let config = Config::from_toml(toml).unwrap();
107        assert_eq!(config.general.log_level, super::super::types::LogLevel::Debug);
108        assert_eq!(config.llm.default_provider, Some("openai".to_string()));
109    }
110}