1use serde::de::DeserializeOwned;
10use std::path::Path;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Format {
15 Yaml,
16 Json,
17 Toml,
18}
19
20impl Format {
21 pub fn from_path(path: &Path) -> anyhow::Result<Self> {
22 let ext = path
23 .extension()
24 .and_then(|s| s.to_str())
25 .unwrap_or_default()
26 .to_ascii_lowercase();
27 match ext.as_str() {
28 "yaml" | "yml" => Ok(Self::Yaml),
29 "json" => Ok(Self::Json),
30 "toml" => Ok(Self::Toml),
31 _ => Err(anyhow::anyhow!("unsupported config file type: .{ext}")),
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct Options {
39 pub expand_env: bool,
41}
42
43impl Default for Options {
44 fn default() -> Self {
45 Self { expand_env: true }
46 }
47}
48
49pub fn load<T: DeserializeOwned>(path: impl AsRef<Path>) -> anyhow::Result<T> {
51 let mut v = None;
52 load_into_with(path.as_ref(), &mut v, Options::default())?;
53 Ok(v.expect("load_into_with must set value"))
54}
55
56pub fn load_into<T: DeserializeOwned>(path: impl AsRef<Path>, cfg: &mut T) -> anyhow::Result<()> {
58 load_into_with(path.as_ref(), cfg, Options::default())
59}
60
61pub fn load_into_with<T: DeserializeOwned>(
63 path: &Path,
64 cfg: &mut T,
65 opts: Options,
66) -> anyhow::Result<()> {
67 let fmt = Format::from_path(path)?;
68 let raw = std::fs::read_to_string(path)
69 .map_err(|e| anyhow::anyhow!("read config {}: {e}", path.display()))?;
70 let raw = if opts.expand_env {
71 expand_env(&raw)
72 } else {
73 raw
74 };
75 load_from_str(fmt, &raw, cfg).map_err(|e| anyhow::anyhow!("load {}: {e}", path.display()))
76}
77
78pub fn load_from_bytes_into<T: DeserializeOwned>(
80 fmt: Format,
81 bytes: &[u8],
82 cfg: &mut T,
83 opts: Options,
84) -> anyhow::Result<()> {
85 let mut s = String::from_utf8(bytes.to_vec())
86 .map_err(|e| anyhow::anyhow!("config bytes not utf-8: {e}"))?;
87 if opts.expand_env {
88 s = expand_env(&s);
89 }
90 load_from_str(fmt, &s, cfg)
91}
92
93fn load_from_str<T: DeserializeOwned>(fmt: Format, s: &str, cfg: &mut T) -> anyhow::Result<()> {
94 *cfg = match fmt {
95 Format::Yaml => {
96 serde_yaml::from_str::<T>(s).map_err(|e| anyhow::anyhow!("parse yaml: {e}"))?
97 }
98 Format::Json => {
99 serde_json::from_str::<T>(s).map_err(|e| anyhow::anyhow!("parse json: {e}"))?
100 }
101 Format::Toml => toml::from_str::<T>(s).map_err(|e| anyhow::anyhow!("parse toml: {e}"))?,
102 };
103 Ok(())
104}
105
106pub fn must_load<T: DeserializeOwned>(path: impl AsRef<Path>, cfg: &mut T) {
108 if let Err(e) = load_into(path.as_ref(), cfg) {
109 panic!("conf.MustLoad failed: {e}");
110 }
111}
112
113fn expand_env(input: &str) -> String {
117 let bytes = input.as_bytes();
118 let mut out = String::with_capacity(input.len());
119 let mut i = 0;
120 while i < bytes.len() {
121 if bytes[i] != b'$' {
122 out.push(bytes[i] as char);
123 i += 1;
124 continue;
125 }
126
127 if i + 1 < bytes.len() && bytes[i + 1] == b'$' {
129 out.push('$');
130 i += 2;
131 continue;
132 }
133
134 if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
136 let mut j = i + 2;
137 while j < bytes.len() && bytes[j] != b'}' {
138 j += 1;
139 }
140 if j < bytes.len() && bytes[j] == b'}' {
141 let key = &input[i + 2..j];
142 out.push_str(&std::env::var(key).unwrap_or_default());
143 i = j + 1;
144 continue;
145 }
146 out.push('$');
148 i += 1;
149 continue;
150 }
151
152 let mut j = i + 1;
154 while j < bytes.len() {
155 let c = bytes[j] as char;
156 if !(c.is_ascii_alphanumeric() || c == '_') {
157 break;
158 }
159 j += 1;
160 }
161 if j == i + 1 {
162 out.push('$');
163 i += 1;
164 continue;
165 }
166 let key = &input[i + 1..j];
167 out.push_str(&std::env::var(key).unwrap_or_default());
168 i = j;
169 }
170 out
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use std::path::PathBuf;
177
178 #[derive(Debug, Default, serde::Deserialize, PartialEq, Eq)]
179 struct TestCfg {
180 name: String,
181 port: u16,
182 }
183
184 fn tmp_file(name: &str, ext: &str) -> PathBuf {
185 let mut p = std::env::temp_dir();
186 let uniq = std::time::SystemTime::now()
187 .duration_since(std::time::UNIX_EPOCH)
188 .unwrap()
189 .as_nanos();
190 p.push(format!("rz_core_conf_{name}_{uniq}.{ext}"));
191 p
192 }
193
194 #[test]
195 fn load_into_should_parse_yaml() {
196 let path = tmp_file("ok", "yaml");
197 std::fs::write(&path, "name: test\nport: 8080\n").unwrap();
198
199 let mut cfg = TestCfg::default();
200 load_into(&path, &mut cfg).unwrap();
201 assert_eq!(
202 cfg,
203 TestCfg {
204 name: "test".into(),
205 port: 8080
206 }
207 );
208
209 let _ = std::fs::remove_file(&path);
210 }
211
212 #[test]
213 fn load_into_should_reject_non_yaml() {
214 let mut cfg = TestCfg::default();
215 let err = load_into("a.unknown", &mut cfg).unwrap_err();
216 assert!(err.to_string().contains("unsupported config file type"));
217 }
218
219 #[test]
220 fn load_into_should_parse_json() {
221 let path = tmp_file("json", "json");
222 std::fs::write(&path, r#"{"name":"test","port":8080}"#).unwrap();
223 let mut cfg = TestCfg::default();
224 load_into(&path, &mut cfg).unwrap();
225 assert_eq!(
226 cfg,
227 TestCfg {
228 name: "test".into(),
229 port: 8080
230 }
231 );
232 let _ = std::fs::remove_file(&path);
233 }
234
235 #[test]
236 fn load_into_should_parse_toml() {
237 let path = tmp_file("toml", "toml");
238 std::fs::write(&path, "name = \"test\"\nport = 8080\n").unwrap();
239 let mut cfg = TestCfg::default();
240 load_into(&path, &mut cfg).unwrap();
241 assert_eq!(
242 cfg,
243 TestCfg {
244 name: "test".into(),
245 port: 8080
246 }
247 );
248 let _ = std::fs::remove_file(&path);
249 }
250
251 #[test]
252 fn expand_env_should_work() {
253 unsafe { std::env::set_var("RZ_CONF_X", "abc") };
254 assert_eq!(
255 super::expand_env("a=${RZ_CONF_X},b=$RZ_CONF_X"),
256 "a=abc,b=abc"
257 );
258 unsafe { std::env::remove_var("RZ_CONF_X") };
259 assert_eq!(super::expand_env("x=${RZ_CONF_X}"), "x=");
260 }
261
262 #[test]
263 #[should_panic]
264 fn must_load_should_panic_on_error() {
265 let mut cfg = TestCfg::default();
266 must_load("not-exist.yaml", &mut cfg);
267 }
268}