hehe_core/config/
loader.rs1use super::types::Config;
2use crate::error::{Error, Result};
3use std::path::Path;
4
5impl Config {
6 pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self> {
7 let content = std::fs::read_to_string(path.as_ref())?;
8 Self::from_toml(&content)
9 }
10
11 pub fn from_toml(content: &str) -> Result<Self> {
12 toml::from_str(content).map_err(|e| Error::Config(format!("Failed to parse config: {}", e)))
13 }
14
15 pub fn from_json(content: &str) -> Result<Self> {
16 serde_json::from_str(content).map_err(Error::Json)
17 }
18
19 pub fn load_default() -> Result<Self> {
20 let paths = [
21 "./hehe.toml",
22 "~/.hehe/config.toml",
23 "~/.config/hehe/config.toml",
24 "/etc/hehe/config.toml",
25 ];
26
27 for path in &paths {
28 let expanded = shellexpand::tilde(path);
29 let path = Path::new(expanded.as_ref());
30 if path.exists() {
31 return Self::load_from_file(path);
32 }
33 }
34
35 Ok(Config::default())
36 }
37
38 pub fn merge_env(mut self) -> Self {
39 if let Ok(level) = std::env::var("HEHE_LOG_LEVEL") {
40 self.general.log_level = match level.to_lowercase().as_str() {
41 "trace" => super::types::LogLevel::Trace,
42 "debug" => super::types::LogLevel::Debug,
43 "info" => super::types::LogLevel::Info,
44 "warn" => super::types::LogLevel::Warn,
45 "error" => super::types::LogLevel::Error,
46 _ => self.general.log_level,
47 };
48 }
49
50 if let Ok(dir) = std::env::var("HEHE_DATA_DIR") {
51 self.general.data_dir = dir.into();
52 }
53
54 if let Ok(provider) = std::env::var("HEHE_DEFAULT_PROVIDER") {
55 self.llm.default_provider = Some(provider);
56 }
57
58 self
59 }
60
61 pub fn to_toml(&self) -> Result<String> {
62 toml::to_string_pretty(self)
63 .map_err(|e| Error::Config(format!("Failed to serialize config: {}", e)))
64 }
65
66 pub fn to_json(&self) -> Result<String> {
67 serde_json::to_string_pretty(self).map_err(Error::Json)
68 }
69
70 pub fn save_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
71 let content = self.to_toml()?;
72 std::fs::write(path.as_ref(), content)?;
73 Ok(())
74 }
75
76 pub fn data_dir(&self) -> std::path::PathBuf {
77 let expanded = shellexpand::tilde(self.general.data_dir.as_str());
78 std::path::PathBuf::from(expanded.as_ref())
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn test_default_config() {
88 let config = Config::default();
89 assert_eq!(config.general.log_level, super::super::types::LogLevel::Info);
90 }
91
92 #[test]
93 fn test_config_from_toml() {
94 let toml = r#"
95 [general]
96 log_level = "debug"
97
98 [llm]
99 default_provider = "openai"
100
101 [llm.providers.openai]
102 provider_type = "openai"
103 model = "gpt-4"
104 "#;
105
106 let config = Config::from_toml(toml).unwrap();
107 assert_eq!(config.general.log_level, super::super::types::LogLevel::Debug);
108 assert_eq!(config.llm.default_provider, Some("openai".to_string()));
109 }
110}