1use std::str::FromStr;
28
29use anyhow::anyhow;
30use anyhow::Result;
31use config_rs::Config;
32use config_rs::Environment as EnvironmentVariables;
33use config_rs::File;
34use config_rs::FileFormat;
35use config_rs::FileSourceFile;
36use serde::Deserialize;
37use strum::EnumString;
38
39pub fn load_config<'de, T: Deserialize<'de>>(environment: Environment) -> Result<T> {
61 let base_config_file = File::with_name("config/base").required(true);
62 let env_config_file = File::with_name(&format!("config/{}", environment)).required(true);
63
64 let custom_env_vars = EnvironmentVariables::with_prefix("app")
65 .prefix_separator("_")
66 .separator("__");
67
68 load_custom_config(base_config_file, env_config_file, custom_env_vars)
69}
70
71pub fn load_config_by_path<'de, T: Deserialize<'de>>(
93 environment: Environment,
94 path: &str,
95) -> Result<T> {
96 let base_config_file = File::with_name(&format!("{}/base", path)).required(true);
97 let env_config_file = File::with_name(&format!("{}/{}", path, environment)).required(true);
98
99 let custom_env_vars = EnvironmentVariables::with_prefix("app")
100 .prefix_separator("_")
101 .separator("__");
102
103 load_custom_config(base_config_file, env_config_file, custom_env_vars)
104}
105
106pub fn load_custom_config<'de, T: Deserialize<'de>>(
130 base_config_file: File<FileSourceFile, FileFormat>,
131 env_config_file: File<FileSourceFile, FileFormat>,
132 custom_env_vars: EnvironmentVariables,
133) -> Result<T> {
134 Config::builder()
135 .add_source(base_config_file)
136 .add_source(env_config_file)
137 .add_source(custom_env_vars)
138 .build()?
139 .try_deserialize()
140 .map_err(|err| {
141 anyhow!(
142 "Unable to deserialize into config with type {} with error: {}",
143 std::any::type_name::<T>(),
144 err
145 )
146 })
147}
148
149#[derive(PartialEq, Eq, Debug, EnumString, strum::Display)]
153pub enum Environment {
154 #[strum(serialize = "local")]
156 Local,
157
158 #[strum(serialize = "test")]
160 Test,
161
162 #[strum(serialize = "develop")]
164 Develop,
165
166 #[strum(serialize = "production")]
168 Production,
169}
170
171impl Environment {
172 pub fn from_env() -> Result<Self> {
183 Self::from_custom_env("APP_ENVIRONMENT")
184 }
185
186 pub fn from_custom_env(key: &str) -> Result<Self> {
197 std::env::var(key)
198 .map(|environment_string| {
199 Environment::from_str(&environment_string)
200 .map_err(|_| anyhow!("Unknown environment: {environment_string}"))
201 })
202 .unwrap_or_else(|_| Ok(Environment::default()))
203 }
204}
205
206impl Default for Environment {
207 fn default() -> Self {
208 if cfg!(test) {
209 Environment::Test
210 } else {
211 Environment::Local
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use serial_test::serial;
219
220 use super::*;
221
222 #[derive(Clone, Debug, Deserialize, PartialEq)]
223 struct MyConfig {
224 log_level: String,
225 db: MyDbConfig,
226 }
227
228 #[derive(Clone, Debug, Deserialize, PartialEq)]
229 struct MyDbConfig {
230 host: String,
231 user: String,
232 password: String,
233 db_name: String,
234 max_connections: u32,
235 }
236
237 #[test]
238 #[serial]
239 fn test_load_config_success() {
240 std::env::set_var("APP_DB__PASSWORD", "supersecurepassword");
241
242 let expected = MyConfig {
243 log_level: "info".to_string(),
244 db: MyDbConfig {
245 host: "localhost".to_string(),
246 user: "username".to_string(),
247 password: "supersecurepassword".to_string(),
248 db_name: "my_db".to_string(),
249 max_connections: 30,
250 },
251 };
252
253 let actual = load_custom_config::<MyConfig>(
254 File::with_name("config/base").required(true),
255 File::with_name("config/develop").required(true),
256 EnvironmentVariables::with_prefix("app")
257 .prefix_separator("_")
258 .separator("__"),
259 )
260 .unwrap();
261
262 assert_eq!(expected, actual);
263
264 let actual = load_config::<MyConfig>(Environment::Develop).unwrap();
265
266 assert_eq!(expected, actual);
267
268 let actual = load_config::<MyConfig>(Environment::Test).unwrap();
269
270 assert_eq!(expected, actual);
271
272 let actual = load_config::<MyConfig>(Environment::Production).unwrap();
273
274 assert_eq!(expected, actual);
275
276 std::env::remove_var("APP_DB__PASSWORD");
277 }
278
279 #[test]
280 #[serial]
281 fn test_load_config_by_path_success() {
282 std::env::set_var("APP_DB__PASSWORD", "supersecurepassword");
283
284 let expected = MyConfig {
285 log_level: "info".to_string(),
286 db: MyDbConfig {
287 host: "localhost".to_string(),
288 user: "username".to_string(),
289 password: "supersecurepassword".to_string(),
290 db_name: "my_db".to_string(),
291 max_connections: 30,
292 },
293 };
294
295 let actual = load_custom_config::<MyConfig>(
296 File::with_name("config-workspace/config/base").required(true),
297 File::with_name("config-workspace/config/develop").required(true),
298 EnvironmentVariables::with_prefix("app")
299 .prefix_separator("_")
300 .separator("__"),
301 )
302 .unwrap();
303
304 assert_eq!(expected, actual);
305
306 let actual =
307 load_config_by_path::<MyConfig>(Environment::Develop, "config-workspace/config")
308 .unwrap();
309
310 assert_eq!(expected, actual);
311
312 let actual =
313 load_config_by_path::<MyConfig>(Environment::Test, "config-workspace/config").unwrap();
314
315 assert_eq!(expected, actual);
316
317 let actual =
318 load_config_by_path::<MyConfig>(Environment::Production, "config-workspace/config")
319 .unwrap();
320
321 assert_eq!(expected, actual);
322
323 std::env::remove_var("APP_DB__PASSWORD");
324 }
325
326 #[test]
327 #[serial]
328 #[should_panic(expected = "configuration file \"config/staging\" not found")]
329 fn test_load_config_file_not_found() {
330 load_custom_config::<MyConfig>(
331 File::with_name("config/base").required(true),
332 File::with_name("config/staging").required(true),
333 EnvironmentVariables::with_prefix("app").separator("__"),
334 )
335 .unwrap();
336 }
337
338 #[test]
339 #[serial]
340 #[should_panic(
341 expected = "Unable to deserialize into config with type avantis_utils::config::tests::MyConfig with error: missing field"
342 )]
343 fn test_load_config_missing_fields() {
344 load_custom_config::<MyConfig>(
345 File::with_name("config/base").required(true),
346 File::with_name("config/base").required(true),
347 EnvironmentVariables::with_prefix("app").separator("__"),
348 )
349 .unwrap();
350 }
351
352 #[test]
353 #[serial]
354 fn test_environment_from_env() {
355 assert_eq!(Environment::Test, Environment::from_env().unwrap());
356
357 assert_eq!(
358 Environment::Test,
359 Environment::from_custom_env("APP_ENVIRONMENT").unwrap()
360 );
361
362 std::env::set_var("APP_ENVIRONMENT", "local");
363
364 assert_eq!(Environment::Local, Environment::from_env().unwrap());
365
366 assert_eq!(
367 Environment::Local,
368 Environment::from_custom_env("APP_ENVIRONMENT").unwrap()
369 );
370
371 std::env::remove_var("APP_ENVIRONMENT")
372 }
373
374 #[test]
375 #[serial]
376 #[should_panic(expected = "Unknown environment: staging")]
377 fn test_environment_from_unknown_env() {
378 std::env::set_var("APP_ENVIRONMENT_INVALID", "staging");
379
380 let result = Environment::from_custom_env("APP_ENVIRONMENT_INVALID");
381
382 std::env::remove_var("APP_ENVIRONMENT_INVALID");
383
384 result.unwrap();
385 }
386}