Skip to main content

nestforge_config/
lib.rs

1use std::collections::HashMap;
2use std::env;
3use std::path::Path;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum ConfigError {
8    #[error("Failed to read env file `{path}`: {source}")]
9    ReadEnvFile {
10        path: String,
11        #[source]
12        source: dotenvy::Error,
13    },
14    #[error("Missing config key: {key}")]
15    MissingKey { key: String },
16}
17
18#[derive(Clone, Debug, Default)]
19pub struct EnvSchema {
20    requirements: Vec<String>,
21}
22
23impl EnvSchema {
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    pub fn required(&mut self, key: &str) -> &mut Self {
29        self.requirements.push(key.to_string());
30        self
31    }
32}
33
34#[derive(Clone, Debug)]
35pub struct EnvStore {
36    values: HashMap<String, String>,
37}
38
39impl EnvStore {
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    pub fn get(&self, key: &str) -> Option<&str> {
45        self.values.get(key).map(String::as_str)
46    }
47}
48
49impl Default for EnvStore {
50    fn default() -> Self {
51        Self {
52            values: env::vars().collect(),
53        }
54    }
55}
56
57impl From<ConfigService> for EnvStore {
58    fn from(config: ConfigService) -> Self {
59        Self {
60            values: config.values,
61        }
62    }
63}
64
65#[derive(Clone, Debug)]
66pub struct EnvValidationIssue {
67    pub key: String,
68    pub message: String,
69}
70
71pub trait FromEnv: Sized {
72    fn from_env(env: &EnvStore) -> Result<Self, ConfigError>;
73}
74
75#[derive(Clone, Debug, Default)]
76pub struct ConfigService {
77    values: HashMap<String, String>,
78}
79
80impl ConfigService {
81    pub fn new() -> Self {
82        Self::default()
83    }
84
85    pub fn load() -> Result<Self, ConfigError> {
86        Self::load_with_options(&ConfigOptions::default())
87    }
88
89    pub fn load_with_options(options: &ConfigOptions) -> Result<Self, ConfigError> {
90        let path_ref = Path::new(&options.env_file_path);
91        let mut values = if options.include_process_env {
92            env::vars().collect::<HashMap<_, _>>()
93        } else {
94            HashMap::new()
95        };
96
97        if path_ref.exists() {
98            dotenvy::from_path_iter(path_ref)
99                .map_err(|source| ConfigError::ReadEnvFile {
100                    path: path_ref.display().to_string(),
101                    source,
102                })?
103                .for_each(|result| {
104                    if let Ok((key, value)) = result {
105                        values.insert(key, value);
106                    }
107                });
108        }
109
110        Ok(Self { values })
111    }
112
113    pub fn get(&self, key: &str) -> Option<&str> {
114        self.values.get(key).map(String::as_str)
115    }
116
117    pub fn get_string(&self, key: &str) -> String {
118        self.get(key).map(|v| v.to_string()).unwrap_or_default()
119    }
120
121    pub fn get_string_or(&self, key: &str, default: &str) -> String {
122        self.get(key)
123            .map(|v| v.to_string())
124            .unwrap_or_else(|| default.to_string())
125    }
126
127    pub fn get_i32(&self, key: &str) -> i32 {
128        self.get(key).and_then(|v| v.parse().ok()).unwrap_or(0)
129    }
130
131    pub fn get_i32_or(&self, key: &str, default: i32) -> i32 {
132        self.get(key)
133            .and_then(|v| v.parse().ok())
134            .unwrap_or(default)
135    }
136
137    pub fn get_u16(&self, key: &str) -> u16 {
138        self.get(key).and_then(|v| v.parse().ok()).unwrap_or(0)
139    }
140
141    pub fn get_u16_or(&self, key: &str, default: u16) -> u16 {
142        self.get(key)
143            .and_then(|v| v.parse().ok())
144            .unwrap_or(default)
145    }
146
147    pub fn get_u32(&self, key: &str) -> u32 {
148        self.get(key).and_then(|v| v.parse().ok()).unwrap_or(0)
149    }
150
151    pub fn get_u32_or(&self, key: &str, default: u32) -> u32 {
152        self.get(key)
153            .and_then(|v| v.parse().ok())
154            .unwrap_or(default)
155    }
156
157    pub fn get_bool(&self, key: &str) -> bool {
158        self.get(key)
159            .map(|v| v == "true" || v == "1" || v == "yes")
160            .unwrap_or(false)
161    }
162
163    pub fn get_bool_or(&self, key: &str, default: bool) -> bool {
164        self.get(key)
165            .map(|v| v == "true" || v == "1" || v == "yes")
166            .unwrap_or(default)
167    }
168
169    pub fn get_usize(&self, key: &str) -> usize {
170        self.get(key).and_then(|v| v.parse().ok()).unwrap_or(0)
171    }
172
173    pub fn get_usize_or(&self, key: &str, default: usize) -> usize {
174        self.get(key)
175            .and_then(|v| v.parse().ok())
176            .unwrap_or(default)
177    }
178
179    pub fn get_f64(&self, key: &str) -> f64 {
180        self.get(key).and_then(|v| v.parse().ok()).unwrap_or(0.0)
181    }
182
183    pub fn get_f64_or(&self, key: &str, default: f64) -> f64 {
184        self.get(key)
185            .and_then(|v| v.parse().ok())
186            .unwrap_or(default)
187    }
188
189    pub fn get_isize(&self, key: &str) -> isize {
190        self.get(key).and_then(|v| v.parse().ok()).unwrap_or(0)
191    }
192
193    pub fn get_isize_or(&self, key: &str, default: isize) -> isize {
194        self.get(key)
195            .and_then(|v| v.parse().ok())
196            .unwrap_or(default)
197    }
198
199    pub fn get_i64(&self, key: &str) -> i64 {
200        self.get(key).and_then(|v| v.parse().ok()).unwrap_or(0)
201    }
202
203    pub fn get_i64_or(&self, key: &str, default: i64) -> i64 {
204        self.get(key)
205            .and_then(|v| v.parse().ok())
206            .unwrap_or(default)
207    }
208
209    pub fn get_u64(&self, key: &str) -> u64 {
210        self.get(key).and_then(|v| v.parse().ok()).unwrap_or(0)
211    }
212
213    pub fn get_u64_or(&self, key: &str, default: u64) -> u64 {
214        self.get(key)
215            .and_then(|v| v.parse().ok())
216            .unwrap_or(default)
217    }
218
219    pub fn has(&self, key: &str) -> bool {
220        self.values.contains_key(key)
221    }
222}
223
224#[derive(Clone, Debug)]
225pub struct ConfigOptions {
226    pub env_file_path: String,
227    pub include_process_env: bool,
228}
229
230impl Default for ConfigOptions {
231    fn default() -> Self {
232        Self {
233            env_file_path: ".env".to_string(),
234            include_process_env: true,
235        }
236    }
237}
238
239impl ConfigOptions {
240    pub fn new() -> Self {
241        Self::default()
242    }
243
244    pub fn env_file(mut self, path: impl Into<String>) -> Self {
245        self.env_file_path = path.into();
246        self
247    }
248
249    pub fn without_process_env(mut self) -> Self {
250        self.include_process_env = false;
251        self
252    }
253}
254
255pub struct ConfigModule;
256
257impl ConfigModule {
258    pub fn for_root() -> ConfigOptions {
259        ConfigOptions::new()
260    }
261
262    pub fn for_root_with_options(options: ConfigOptions) -> ConfigService {
263        ConfigService::load_with_options(&options).expect("Failed to load configuration")
264    }
265
266    pub fn for_feature() -> ConfigOptions {
267        ConfigOptions::new()
268    }
269}
270
271pub fn load_config() -> ConfigService {
272    ConfigModule::for_root_with_options(ConfigModule::for_root())
273}
274
275pub struct Config<T> {
276    _phantom: std::marker::PhantomData<T>,
277}
278
279impl<T> Config<T> {
280    pub fn new() -> Self {
281        Self {
282            _phantom: std::marker::PhantomData,
283        }
284    }
285}
286
287impl<T> Default for Config<T> {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293pub fn register_config<T: Send + Sync + 'static>(
294    name: &'static str,
295    factory: fn() -> T,
296) -> ConfigRegistration<T> {
297    ConfigRegistration {
298        name,
299        _phantom: std::marker::PhantomData,
300        factory,
301    }
302}
303
304pub struct ConfigRegistration<T: Send + Sync + 'static> {
305    #[allow(dead_code)]
306    name: &'static str,
307    _phantom: std::marker::PhantomData<T>,
308    factory: fn() -> T,
309}
310
311impl<T: Send + Sync + 'static> ConfigRegistration<T> {
312    pub fn load(&self) -> T {
313        (self.factory)()
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_config_service_load() {
323        std::env::set_var("APP_NAME", "TestApp");
324        std::env::set_var("APP_PORT", "8080");
325
326        let config = ConfigService::load().unwrap();
327
328        assert_eq!(config.get("APP_NAME"), Some("TestApp"));
329        assert_eq!(config.get_string("APP_NAME"), "TestApp");
330        assert_eq!(config.get_u16("APP_PORT"), 8080);
331        assert_eq!(config.get_u16_or("MISSING", 3000), 3000);
332        assert!(config.has("APP_NAME"));
333        assert!(!config.has("MISSING"));
334
335        std::env::remove_var("APP_NAME");
336        std::env::remove_var("APP_PORT");
337    }
338
339    #[test]
340    fn test_config_service_defaults() {
341        let config = ConfigService::new();
342
343        assert_eq!(config.get_string("MISSING"), "");
344        assert_eq!(config.get_string_or("MISSING", "default"), "default");
345        assert_eq!(config.get_u16_or("MISSING", 3000), 3000);
346        assert_eq!(config.get_bool_or("MISSING", true), true);
347    }
348
349    #[test]
350    fn test_config_options_builder() {
351        let options = ConfigOptions::new().env_file(".env.test");
352        assert_eq!(options.env_file_path, ".env.test");
353    }
354
355    #[test]
356    fn test_register_config() {
357        let db_config = register_config("database", || DbConfig {
358            host: "localhost".to_string(),
359            port: 5432,
360        });
361
362        let config = db_config.load();
363        assert_eq!(config.host, "localhost");
364        assert_eq!(config.port, 5432);
365    }
366
367    #[derive(Debug, Clone)]
368    struct DbConfig {
369        host: String,
370        port: u16,
371    }
372}