Skip to main content

oxidite_config/
lib.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::env;
4use std::fs;
5use std::path::Path;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum ConfigError {
10    #[error("I/O error: {0}")]
11    Io(#[from] std::io::Error),
12    #[error("TOML parse error: {0}")]
13    TomlDe(#[from] toml::de::Error),
14    #[error("YAML parse error: {0}")]
15    YamlDe(#[from] serde_yaml::Error),
16    #[error("invalid value for environment variable `{name}`: `{value}`")]
17    InvalidEnvValue { name: String, value: String },
18    #[error("missing configuration key: {0}")]
19    MissingKey(String),
20    #[error("invalid type for configuration key: {0}")]
21    InvalidType(String),
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum Environment {
26    Development,
27    Testing,
28    Production,
29}
30
31impl Environment {
32    pub fn from_str(s: &str) -> Self {
33        match s.to_lowercase().as_str() {
34            "production" | "prod" => Self::Production,
35            "testing" | "test" => Self::Testing,
36            _ => Self::Development,
37        }
38    }
39
40    pub fn as_str(&self) -> &str {
41        match self {
42            Self::Development => "development",
43            Self::Testing => "testing",
44            Self::Production => "production",
45        }
46    }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Config {
51    #[serde(default)]
52    pub app: AppConfig,
53    #[serde(default)]
54    pub server: ServerConfig,
55    #[serde(default)]
56    pub database: DatabaseConfig,
57    #[serde(default)]
58    pub cache: CacheConfig,
59    #[serde(default)]
60    pub queue: QueueConfig,
61    #[serde(default)]
62    pub security: SecurityConfig,
63    #[serde(default)]
64    pub custom: HashMap<String, toml::Value>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct AppConfig {
69    #[serde(default = "default_app_name")]
70    pub name: String,
71    #[serde(default)]
72    pub version: String,
73    #[serde(default)]
74    pub environment: String,
75    #[serde(default)]
76    pub debug: bool,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ServerConfig {
81    #[serde(default = "default_host")]
82    pub host: String,
83    #[serde(default = "default_port")]
84    pub port: u16,
85    #[serde(default)]
86    pub workers: usize,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct DatabaseConfig {
91    #[serde(default)]
92    pub url: String,
93    #[serde(default = "default_pool_size")]
94    pub pool_size: u32,
95    #[serde(default)]
96    pub ssl: bool,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct CacheConfig {
101    #[serde(default)]
102    pub driver: String,
103    #[serde(default)]
104    pub redis_url: String,
105    #[serde(default = "default_ttl")]
106    pub default_ttl: u64,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct QueueConfig {
111    #[serde(default)]
112    pub driver: String,
113    #[serde(default)]
114    pub redis_url: String,
115    #[serde(default = "default_workers")]
116    pub workers: usize,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct SecurityConfig {
121    #[serde(default)]
122    pub jwt_secret: String,
123    #[serde(default = "default_jwt_expiry")]
124    pub jwt_expiry: u64,
125    #[serde(default)]
126    pub cors_origins: Vec<String>,
127    #[serde(default)]
128    pub rate_limit: u32,
129}
130
131// Default functions
132fn default_app_name() -> String {
133    "oxidite-app".to_string()
134}
135
136fn default_host() -> String {
137    "127.0.0.1".to_string()
138}
139
140fn default_port() -> u16 {
141    3000
142}
143
144fn default_pool_size() -> u32 {
145    10
146}
147
148fn default_ttl() -> u64 {
149    3600
150}
151
152fn default_workers() -> usize {
153    4
154}
155
156fn default_jwt_expiry() -> u64 {
157    900 // 15 minutes
158}
159
160impl Default for AppConfig {
161    fn default() -> Self {
162        Self {
163            name: default_app_name(),
164            version: env!("CARGO_PKG_VERSION").to_string(),
165            environment: "development".to_string(),
166            debug: true,
167        }
168    }
169}
170
171impl Default for ServerConfig {
172    fn default() -> Self {
173        Self {
174            host: default_host(),
175            port: default_port(),
176            workers: num_cpus::get(),
177        }
178    }
179}
180
181impl Default for DatabaseConfig {
182    fn default() -> Self {
183        Self {
184            url: String::new(),
185            pool_size: default_pool_size(),
186            ssl: false,
187        }
188    }
189}
190
191impl Default for CacheConfig {
192    fn default() -> Self {
193        Self {
194            driver: "memory".to_string(),
195            redis_url: String::new(),
196            default_ttl: default_ttl(),
197        }
198    }
199}
200
201impl Default for QueueConfig {
202    fn default() -> Self {
203        Self {
204            driver: "memory".to_string(),
205            redis_url: String::new(),
206            workers: default_workers(),
207        }
208    }
209}
210
211impl Default for SecurityConfig {
212    fn default() -> Self {
213        Self {
214            jwt_secret: String::new(),
215            jwt_expiry: default_jwt_expiry(),
216            cors_origins: vec![],
217            rate_limit: 0,
218        }
219    }
220}
221
222impl Default for Config {
223    fn default() -> Self {
224        Self {
225            app: AppConfig::default(),
226            server: ServerConfig::default(),
227            database: DatabaseConfig::default(),
228            cache: CacheConfig::default(),
229            queue: QueueConfig::default(),
230            security: SecurityConfig::default(),
231            custom: HashMap::new(),
232        }
233    }
234}
235
236impl Config {
237    fn apply_env_overrides(&mut self) -> Result<(), ConfigError> {
238        if let Ok(val) = env::var("APP_NAME") {
239            self.app.name = val;
240        }
241        if let Ok(val) = env::var("SERVER_HOST") {
242            self.server.host = val;
243        }
244        if let Ok(val) = env::var("SERVER_PORT") {
245            self.server.port = val
246                .parse()
247                .map_err(|_| ConfigError::InvalidEnvValue {
248                    name: "SERVER_PORT".to_string(),
249                    value: val,
250                })?;
251        }
252        if let Ok(val) = env::var("DATABASE_URL") {
253            self.database.url = val;
254        }
255        if let Ok(val) = env::var("REDIS_URL") {
256            self.cache.redis_url = val.clone();
257            self.queue.redis_url = val;
258        }
259        if let Ok(val) = env::var("JWT_SECRET") {
260            self.security.jwt_secret = val;
261        }
262        Ok(())
263    }
264
265    fn has_key(&self, key: &str) -> bool {
266        if self.custom.contains_key(key) {
267            return true;
268        }
269        let Some(root) = toml::Value::try_from(self).ok() else {
270            return false;
271        };
272        let mut cursor = &root;
273        for part in key.split('.') {
274            let Some(next) = cursor.get(part) else {
275                return false;
276            };
277            cursor = next;
278        }
279        true
280    }
281
282    /// Load configuration from environment variables and config files
283    pub fn load() -> Result<Self, ConfigError> {
284        // Load .env file if it exists
285        let _ = dotenv::dotenv();
286
287        let env = env::var("OXIDITE_ENV")
288            .or_else(|_| env::var("ENVIRONMENT"))
289            .unwrap_or_else(|_| "development".to_string());
290
291        // Try to load oxidite.toml
292        let mut config = if Path::new("oxidite.toml").exists() {
293            let content = fs::read_to_string("oxidite.toml")?;
294            toml::from_str(&content)?
295        } else {
296            Config::default()
297        };
298
299        // Override with environment variables
300        config.apply_env_overrides()?;
301
302        config.app.environment = env;
303
304        Ok(config)
305    }
306
307    /// Load configuration from a specific TOML file path and env overrides.
308    pub fn load_from(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
309        let _ = dotenv::dotenv();
310        let env_name = env::var("OXIDITE_ENV")
311            .or_else(|_| env::var("ENVIRONMENT"))
312            .unwrap_or_else(|_| "development".to_string());
313
314        let mut config = if path.as_ref().exists() {
315            let content = fs::read_to_string(path)?;
316            toml::from_str(&content)?
317        } else {
318            Config::default()
319        };
320
321        config.app.environment = env_name;
322        config.apply_env_overrides()?;
323        Ok(config)
324    }
325
326    /// Load configuration from a YAML file path and env overrides.
327    pub fn load_yaml_from(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
328        let _ = dotenv::dotenv();
329        let env_name = env::var("OXIDITE_ENV")
330            .or_else(|_| env::var("ENVIRONMENT"))
331            .unwrap_or_else(|_| "development".to_string());
332
333        let mut config = if path.as_ref().exists() {
334            let content = fs::read_to_string(path)?;
335            serde_yaml::from_str(&content)?
336        } else {
337            Config::default()
338        };
339
340        config.app.environment = env_name;
341        config.apply_env_overrides()?;
342        Ok(config)
343    }
344
345    /// Get value from custom configuration
346    pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
347        // Prefer explicitly registered custom keys first.
348        if let Some(value) = self.custom.get(key) {
349            if let Ok(parsed) = T::deserialize(value.clone()) {
350                return Some(parsed);
351            }
352        }
353
354        // Support dotted lookup across the full config tree, e.g. "database.url".
355        let root = toml::Value::try_from(self).ok()?;
356        let mut cursor = &root;
357        for part in key.split('.') {
358            cursor = cursor.get(part)?;
359        }
360
361        T::deserialize(cursor.clone()).ok()
362    }
363
364    /// Get required typed configuration value.
365    pub fn get_required<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Result<T, ConfigError> {
366        self.get(key).ok_or_else(|| {
367            if self.has_key(key) {
368                ConfigError::InvalidType(key.to_string())
369            } else {
370                ConfigError::MissingKey(key.to_string())
371            }
372        })
373    }
374
375    /// Get a required string configuration value.
376    pub fn get_string(&self, key: &str) -> Result<String, ConfigError> {
377        self.get_required(key)
378    }
379
380    /// Get a required boolean configuration value.
381    pub fn get_bool(&self, key: &str) -> Result<bool, ConfigError> {
382        self.get_required(key)
383    }
384
385    /// Get a required unsigned 16-bit integer configuration value.
386    pub fn get_u16(&self, key: &str) -> Result<u16, ConfigError> {
387        self.get_required(key)
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_default_config() {
397        let config = Config::default();
398        assert_eq!(config.server.host, "127.0.0.1");
399        assert_eq!(config.server.port, 3000);
400    }
401
402    #[test]
403    fn test_environment_parsing() {
404        assert_eq!(Environment::from_str("production"), Environment::Production);
405        assert_eq!(Environment::from_str("PROD"), Environment::Production);
406        assert_eq!(Environment::from_str("development"), Environment::Development);
407    }
408
409    #[test]
410    fn test_get_required_typed_values() {
411        let config = Config::default();
412        assert_eq!(config.get_u16("server.port").unwrap(), 3000);
413        assert_eq!(config.get_bool("app.debug").unwrap(), true);
414    }
415
416    #[test]
417    fn test_invalid_server_port_env_returns_error() {
418        let prev = env::var("SERVER_PORT").ok();
419        env::set_var("SERVER_PORT", "not-a-port");
420
421        let result = Config::load();
422        assert!(matches!(
423            result,
424            Err(ConfigError::InvalidEnvValue { name, .. }) if name == "SERVER_PORT"
425        ));
426
427        if let Some(value) = prev {
428            env::set_var("SERVER_PORT", value);
429        } else {
430            env::remove_var("SERVER_PORT");
431        }
432    }
433
434    #[test]
435    fn test_load_from_applies_env_overrides() {
436        let prev_host = env::var("SERVER_HOST").ok();
437        env::set_var("SERVER_HOST", "0.0.0.0");
438
439        let cfg = Config::load_from("this-file-does-not-exist.toml").expect("load");
440        assert_eq!(cfg.server.host, "0.0.0.0");
441
442        if let Some(v) = prev_host {
443            env::set_var("SERVER_HOST", v);
444        } else {
445            env::remove_var("SERVER_HOST");
446        }
447    }
448}