core/
conf.rs

1//! 配置加载工具(对齐 go-zero 的 `core/conf`)。
2//!
3//! 目标:
4//! - 对齐 go-zero 的调用习惯:`conf::must_load(path, &mut cfg)`;
5//! - 支持 YAML/JSON/TOML(按文件扩展名自动识别);
6//! - 支持环境变量展开(类似 Go 的 `os.ExpandEnv`:`${VAR}` / `$VAR`);
7//! - 同时提供非 panic 的 `load_into/load` 系列 API。
8
9use serde::de::DeserializeOwned;
10use std::path::Path;
11
12/// 配置文件格式。
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Format {
15    Yaml,
16    Json,
17    Toml,
18}
19
20impl Format {
21    pub fn from_path(path: &Path) -> anyhow::Result<Self> {
22        let ext = path
23            .extension()
24            .and_then(|s| s.to_str())
25            .unwrap_or_default()
26            .to_ascii_lowercase();
27        match ext.as_str() {
28            "yaml" | "yml" => Ok(Self::Yaml),
29            "json" => Ok(Self::Json),
30            "toml" => Ok(Self::Toml),
31            _ => Err(anyhow::anyhow!("unsupported config file type: .{ext}")),
32        }
33    }
34}
35
36/// 加载选项(对齐 go-zero 的 Option 思路)。
37#[derive(Debug, Clone)]
38pub struct Options {
39    /// 是否展开环境变量(默认 true)。
40    pub expand_env: bool,
41}
42
43impl Default for Options {
44    fn default() -> Self {
45        Self { expand_env: true }
46    }
47}
48
49/// 从配置文件读取并反序列化为 `T`。
50pub fn load<T: DeserializeOwned>(path: impl AsRef<Path>) -> anyhow::Result<T> {
51    let mut v = None;
52    load_into_with(path.as_ref(), &mut v, Options::default())?;
53    Ok(v.expect("load_into_with must set value"))
54}
55
56/// 从配置文件读取并写入 `cfg`。
57pub fn load_into<T: DeserializeOwned>(path: impl AsRef<Path>, cfg: &mut T) -> anyhow::Result<()> {
58    load_into_with(path.as_ref(), cfg, Options::default())
59}
60
61/// 带 options 的 load_into。
62pub fn load_into_with<T: DeserializeOwned>(
63    path: &Path,
64    cfg: &mut T,
65    opts: Options,
66) -> anyhow::Result<()> {
67    let fmt = Format::from_path(path)?;
68    let raw = std::fs::read_to_string(path)
69        .map_err(|e| anyhow::anyhow!("read config {}: {e}", path.display()))?;
70    let raw = if opts.expand_env {
71        expand_env(&raw)
72    } else {
73        raw
74    };
75    load_from_str(fmt, &raw, cfg).map_err(|e| anyhow::anyhow!("load {}: {e}", path.display()))
76}
77
78/// 从 bytes 加载(主要对齐 go-zero 的 `LoadFromBytes`)。
79pub fn load_from_bytes_into<T: DeserializeOwned>(
80    fmt: Format,
81    bytes: &[u8],
82    cfg: &mut T,
83    opts: Options,
84) -> anyhow::Result<()> {
85    let mut s = String::from_utf8(bytes.to_vec())
86        .map_err(|e| anyhow::anyhow!("config bytes not utf-8: {e}"))?;
87    if opts.expand_env {
88        s = expand_env(&s);
89    }
90    load_from_str(fmt, &s, cfg)
91}
92
93fn load_from_str<T: DeserializeOwned>(fmt: Format, s: &str, cfg: &mut T) -> anyhow::Result<()> {
94    *cfg = match fmt {
95        Format::Yaml => {
96            serde_yaml::from_str::<T>(s).map_err(|e| anyhow::anyhow!("parse yaml: {e}"))?
97        }
98        Format::Json => {
99            serde_json::from_str::<T>(s).map_err(|e| anyhow::anyhow!("parse json: {e}"))?
100        }
101        Format::Toml => toml::from_str::<T>(s).map_err(|e| anyhow::anyhow!("parse toml: {e}"))?,
102    };
103    Ok(())
104}
105
106/// 对齐 go-zero:加载失败直接 panic(用于快速失败的启动场景)。
107pub fn must_load<T: DeserializeOwned>(path: impl AsRef<Path>, cfg: &mut T) {
108    if let Err(e) = load_into(path.as_ref(), cfg) {
109        panic!("conf.MustLoad failed: {e}");
110    }
111}
112
113/// 环境变量展开:
114/// - `${VAR}` / `$VAR` 替换为 env 值
115/// - 未定义时替换为空字符串(对齐 Go 的 `os.ExpandEnv`)
116fn expand_env(input: &str) -> String {
117    let bytes = input.as_bytes();
118    let mut out = String::with_capacity(input.len());
119    let mut i = 0;
120    while i < bytes.len() {
121        if bytes[i] != b'$' {
122            out.push(bytes[i] as char);
123            i += 1;
124            continue;
125        }
126
127        // "$$" -> "$"
128        if i + 1 < bytes.len() && bytes[i + 1] == b'$' {
129            out.push('$');
130            i += 2;
131            continue;
132        }
133
134        // ${VAR}
135        if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
136            let mut j = i + 2;
137            while j < bytes.len() && bytes[j] != b'}' {
138                j += 1;
139            }
140            if j < bytes.len() && bytes[j] == b'}' {
141                let key = &input[i + 2..j];
142                out.push_str(&std::env::var(key).unwrap_or_default());
143                i = j + 1;
144                continue;
145            }
146            // unmatched "{", treat '$' as literal
147            out.push('$');
148            i += 1;
149            continue;
150        }
151
152        // $VAR
153        let mut j = i + 1;
154        while j < bytes.len() {
155            let c = bytes[j] as char;
156            if !(c.is_ascii_alphanumeric() || c == '_') {
157                break;
158            }
159            j += 1;
160        }
161        if j == i + 1 {
162            out.push('$');
163            i += 1;
164            continue;
165        }
166        let key = &input[i + 1..j];
167        out.push_str(&std::env::var(key).unwrap_or_default());
168        i = j;
169    }
170    out
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use std::path::PathBuf;
177
178    #[derive(Debug, Default, serde::Deserialize, PartialEq, Eq)]
179    struct TestCfg {
180        name: String,
181        port: u16,
182    }
183
184    fn tmp_file(name: &str, ext: &str) -> PathBuf {
185        let mut p = std::env::temp_dir();
186        let uniq = std::time::SystemTime::now()
187            .duration_since(std::time::UNIX_EPOCH)
188            .unwrap()
189            .as_nanos();
190        p.push(format!("rz_core_conf_{name}_{uniq}.{ext}"));
191        p
192    }
193
194    #[test]
195    fn load_into_should_parse_yaml() {
196        let path = tmp_file("ok", "yaml");
197        std::fs::write(&path, "name: test\nport: 8080\n").unwrap();
198
199        let mut cfg = TestCfg::default();
200        load_into(&path, &mut cfg).unwrap();
201        assert_eq!(
202            cfg,
203            TestCfg {
204                name: "test".into(),
205                port: 8080
206            }
207        );
208
209        let _ = std::fs::remove_file(&path);
210    }
211
212    #[test]
213    fn load_into_should_reject_non_yaml() {
214        let mut cfg = TestCfg::default();
215        let err = load_into("a.unknown", &mut cfg).unwrap_err();
216        assert!(err.to_string().contains("unsupported config file type"));
217    }
218
219    #[test]
220    fn load_into_should_parse_json() {
221        let path = tmp_file("json", "json");
222        std::fs::write(&path, r#"{"name":"test","port":8080}"#).unwrap();
223        let mut cfg = TestCfg::default();
224        load_into(&path, &mut cfg).unwrap();
225        assert_eq!(
226            cfg,
227            TestCfg {
228                name: "test".into(),
229                port: 8080
230            }
231        );
232        let _ = std::fs::remove_file(&path);
233    }
234
235    #[test]
236    fn load_into_should_parse_toml() {
237        let path = tmp_file("toml", "toml");
238        std::fs::write(&path, "name = \"test\"\nport = 8080\n").unwrap();
239        let mut cfg = TestCfg::default();
240        load_into(&path, &mut cfg).unwrap();
241        assert_eq!(
242            cfg,
243            TestCfg {
244                name: "test".into(),
245                port: 8080
246            }
247        );
248        let _ = std::fs::remove_file(&path);
249    }
250
251    #[test]
252    fn expand_env_should_work() {
253        unsafe { std::env::set_var("RZ_CONF_X", "abc") };
254        assert_eq!(
255            super::expand_env("a=${RZ_CONF_X},b=$RZ_CONF_X"),
256            "a=abc,b=abc"
257        );
258        unsafe { std::env::remove_var("RZ_CONF_X") };
259        assert_eq!(super::expand_env("x=${RZ_CONF_X}"), "x=");
260    }
261
262    #[test]
263    #[should_panic]
264    fn must_load_should_panic_on_error() {
265        let mut cfg = TestCfg::default();
266        must_load("not-exist.yaml", &mut cfg);
267    }
268}