1use std::str::FromStr;
28
29use anyhow::anyhow;
30use anyhow::Result;
31use config::Config;
32use config::Environment as EnvironmentVariables;
33use config::File;
34use config::FileFormat;
35use config::FileSourceFile;
36use serde::Deserialize;
37use strum::EnumString;
38
39pub fn load_config<'de, T: Deserialize<'de>>(environment: Environment) -> Result<T> {
61 let base_config_file = File::with_name("config/base").required(true);
62 let env_config_file =
63 File::with_name(&format!("config/{}", environment.to_string())).required(true);
64
65 let custom_env_vars = EnvironmentVariables::with_prefix("app").separator("__");
66
67 load_custom_config(base_config_file, env_config_file, custom_env_vars)
68}
69
70pub fn load_custom_config<'de, T: Deserialize<'de>>(
94 base_config_file: File<FileSourceFile, FileFormat>,
95 env_config_file: File<FileSourceFile, FileFormat>,
96 custom_env_vars: EnvironmentVariables,
97) -> Result<T> {
98 Ok(Config::builder()
99 .add_source(base_config_file)
100 .add_source(env_config_file)
101 .add_source(custom_env_vars)
102 .build()?
103 .try_deserialize()
104 .map_err(|err| {
105 anyhow!(
106 "Unable to deserialize into config with type {} with error: {}",
107 std::any::type_name::<T>(),
108 err
109 )
110 })?)
111}
112
113#[derive(PartialEq, Debug, EnumString, strum::Display)]
117pub enum Environment {
118 #[strum(serialize = "local")]
120 Local,
121
122 #[strum(serialize = "test")]
124 Test,
125
126 #[strum(serialize = "development")]
128 Development,
129
130 #[strum(serialize = "production")]
132 Production,
133}
134
135impl Environment {
136 pub fn from_env() -> Result<Self> {
147 Self::from_custom_env("APP_ENVIRONMENT")
148 }
149
150 pub fn from_custom_env(key: &str) -> Result<Self> {
161 std::env::var(key)
162 .map(|environment_string| {
163 Environment::from_str(&environment_string)
164 .map_err(|_| anyhow!("Unknown environment: {environment_string}"))
165 })
166 .unwrap_or(Ok(Environment::default()))
167 }
168}
169
170impl Default for Environment {
171 fn default() -> Self {
172 if cfg!(test) {
173 Environment::Test
174 } else {
175 Environment::Local
176 }
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[derive(Clone, Debug, Deserialize, PartialEq)]
185 struct MyConfig {
186 log_level: String,
187 db: MyDbConfig,
188 }
189
190 #[derive(Clone, Debug, Deserialize, PartialEq)]
191 struct MyDbConfig {
192 host: String,
193 user: String,
194 password: String,
195 }
196
197 #[test]
198 fn test_load_config_success() {
199 std::env::set_var("APP__DB__PASSWORD", "supersecurepassword");
200
201 let expected = MyConfig {
202 log_level: "info".to_string(),
203 db: MyDbConfig {
204 host: "localhost".to_string(),
205 user: "username".to_string(),
206 password: "supersecurepassword".to_string(),
207 },
208 };
209
210 let actual = load_custom_config::<MyConfig>(
211 File::with_name("config/base").required(true),
212 File::with_name("config/development").required(true),
213 EnvironmentVariables::with_prefix("app").separator("__"),
214 )
215 .unwrap();
216
217 assert_eq!(expected, actual);
218
219 let actual = load_config::<MyConfig>(Environment::Development).unwrap();
220
221 assert_eq!(expected, actual);
222
223 let actual = load_config::<MyConfig>(Environment::Test).unwrap();
224
225 assert_eq!(expected, actual);
226
227 let actual = load_config::<MyConfig>(Environment::Production).unwrap();
228
229 assert_eq!(expected, actual);
230
231 std::env::remove_var("APP__DB__PASSWORD");
232 }
233
234 #[test]
235 #[should_panic(expected = "configuration file \"config/staging\" not found")]
236 fn test_load_config_file_not_found() {
237 load_custom_config::<MyConfig>(
238 File::with_name("config/base").required(true),
239 File::with_name("config/staging").required(true),
240 EnvironmentVariables::with_prefix("app").separator("__"),
241 )
242 .unwrap();
243 }
244
245 #[test]
246 #[should_panic(
247 expected = "Unable to deserialize into config with type avantis_utils::config::tests::MyConfig with error: missing field"
248 )]
249 fn test_load_config_missing_fields() {
250 load_custom_config::<MyConfig>(
251 File::with_name("config/base").required(true),
252 File::with_name("config/base").required(true),
253 EnvironmentVariables::with_prefix("app").separator("__"),
254 )
255 .unwrap();
256 }
257
258 #[test]
259 fn test_environment_from_env() {
260 assert_eq!(Environment::Test, Environment::from_env().unwrap());
261
262 assert_eq!(
263 Environment::Test,
264 Environment::from_custom_env("APP_ENVIRONMENT").unwrap()
265 );
266
267 std::env::set_var("APP_ENVIRONMENT", "local");
268
269 assert_eq!(Environment::Local, Environment::from_env().unwrap());
270
271 assert_eq!(
272 Environment::Local,
273 Environment::from_custom_env("APP_ENVIRONMENT").unwrap()
274 );
275
276 std::env::remove_var("APP_ENVIRONMENT")
277 }
278
279 #[test]
280 #[should_panic(expected = "Unknown environment: staging")]
281 fn test_environment_from_unknown_env() {
282 std::env::set_var("APP_ENVIRONMENT_INVALID", "staging");
283
284 let result = Environment::from_custom_env("APP_ENVIRONMENT_INVALID");
285
286 std::env::remove_var("APP_ENVIRONMENT_INVALID");
287
288 result.unwrap();
289 }
290}