1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::env;
4use std::fs;
5use std::path::Path;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum ConfigError {
10 #[error("I/O error: {0}")]
11 Io(#[from] std::io::Error),
12 #[error("TOML parse error: {0}")]
13 TomlDe(#[from] toml::de::Error),
14 #[error("YAML parse error: {0}")]
15 YamlDe(#[from] serde_yaml::Error),
16 #[error("invalid value for environment variable `{name}`: `{value}`")]
17 InvalidEnvValue { name: String, value: String },
18 #[error("missing configuration key: {0}")]
19 MissingKey(String),
20 #[error("invalid type for configuration key: {0}")]
21 InvalidType(String),
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum Environment {
26 Development,
27 Testing,
28 Production,
29}
30
31impl Environment {
32 pub fn from_str(s: &str) -> Self {
33 match s.to_lowercase().as_str() {
34 "production" | "prod" => Self::Production,
35 "testing" | "test" => Self::Testing,
36 _ => Self::Development,
37 }
38 }
39
40 pub fn as_str(&self) -> &str {
41 match self {
42 Self::Development => "development",
43 Self::Testing => "testing",
44 Self::Production => "production",
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Config {
51 #[serde(default)]
52 pub app: AppConfig,
53 #[serde(default)]
54 pub server: ServerConfig,
55 #[serde(default)]
56 pub database: DatabaseConfig,
57 #[serde(default)]
58 pub cache: CacheConfig,
59 #[serde(default)]
60 pub queue: QueueConfig,
61 #[serde(default)]
62 pub security: SecurityConfig,
63 #[serde(default)]
64 pub custom: HashMap<String, toml::Value>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct AppConfig {
69 #[serde(default = "default_app_name")]
70 pub name: String,
71 #[serde(default)]
72 pub version: String,
73 #[serde(default)]
74 pub environment: String,
75 #[serde(default)]
76 pub debug: bool,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ServerConfig {
81 #[serde(default = "default_host")]
82 pub host: String,
83 #[serde(default = "default_port")]
84 pub port: u16,
85 #[serde(default)]
86 pub workers: usize,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct DatabaseConfig {
91 #[serde(default)]
92 pub url: String,
93 #[serde(default = "default_pool_size")]
94 pub pool_size: u32,
95 #[serde(default)]
96 pub ssl: bool,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct CacheConfig {
101 #[serde(default)]
102 pub driver: String,
103 #[serde(default)]
104 pub redis_url: String,
105 #[serde(default = "default_ttl")]
106 pub default_ttl: u64,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct QueueConfig {
111 #[serde(default)]
112 pub driver: String,
113 #[serde(default)]
114 pub redis_url: String,
115 #[serde(default = "default_workers")]
116 pub workers: usize,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct SecurityConfig {
121 #[serde(default)]
122 pub jwt_secret: String,
123 #[serde(default = "default_jwt_expiry")]
124 pub jwt_expiry: u64,
125 #[serde(default)]
126 pub cors_origins: Vec<String>,
127 #[serde(default)]
128 pub rate_limit: u32,
129}
130
131fn default_app_name() -> String {
133 "oxidite-app".to_string()
134}
135
136fn default_host() -> String {
137 "127.0.0.1".to_string()
138}
139
140fn default_port() -> u16 {
141 3000
142}
143
144fn default_pool_size() -> u32 {
145 10
146}
147
148fn default_ttl() -> u64 {
149 3600
150}
151
152fn default_workers() -> usize {
153 4
154}
155
156fn default_jwt_expiry() -> u64 {
157 900 }
159
160impl Default for AppConfig {
161 fn default() -> Self {
162 Self {
163 name: default_app_name(),
164 version: env!("CARGO_PKG_VERSION").to_string(),
165 environment: "development".to_string(),
166 debug: true,
167 }
168 }
169}
170
171impl Default for ServerConfig {
172 fn default() -> Self {
173 Self {
174 host: default_host(),
175 port: default_port(),
176 workers: num_cpus::get(),
177 }
178 }
179}
180
181impl Default for DatabaseConfig {
182 fn default() -> Self {
183 Self {
184 url: String::new(),
185 pool_size: default_pool_size(),
186 ssl: false,
187 }
188 }
189}
190
191impl Default for CacheConfig {
192 fn default() -> Self {
193 Self {
194 driver: "memory".to_string(),
195 redis_url: String::new(),
196 default_ttl: default_ttl(),
197 }
198 }
199}
200
201impl Default for QueueConfig {
202 fn default() -> Self {
203 Self {
204 driver: "memory".to_string(),
205 redis_url: String::new(),
206 workers: default_workers(),
207 }
208 }
209}
210
211impl Default for SecurityConfig {
212 fn default() -> Self {
213 Self {
214 jwt_secret: String::new(),
215 jwt_expiry: default_jwt_expiry(),
216 cors_origins: vec![],
217 rate_limit: 0,
218 }
219 }
220}
221
222impl Default for Config {
223 fn default() -> Self {
224 Self {
225 app: AppConfig::default(),
226 server: ServerConfig::default(),
227 database: DatabaseConfig::default(),
228 cache: CacheConfig::default(),
229 queue: QueueConfig::default(),
230 security: SecurityConfig::default(),
231 custom: HashMap::new(),
232 }
233 }
234}
235
236impl Config {
237 fn apply_env_overrides(&mut self) -> Result<(), ConfigError> {
238 if let Ok(val) = env::var("APP_NAME") {
239 self.app.name = val;
240 }
241 if let Ok(val) = env::var("SERVER_HOST") {
242 self.server.host = val;
243 }
244 if let Ok(val) = env::var("SERVER_PORT") {
245 self.server.port = val
246 .parse()
247 .map_err(|_| ConfigError::InvalidEnvValue {
248 name: "SERVER_PORT".to_string(),
249 value: val,
250 })?;
251 }
252 if let Ok(val) = env::var("DATABASE_URL") {
253 self.database.url = val;
254 }
255 if let Ok(val) = env::var("REDIS_URL") {
256 self.cache.redis_url = val.clone();
257 self.queue.redis_url = val;
258 }
259 if let Ok(val) = env::var("JWT_SECRET") {
260 self.security.jwt_secret = val;
261 }
262 Ok(())
263 }
264
265 fn has_key(&self, key: &str) -> bool {
266 if self.custom.contains_key(key) {
267 return true;
268 }
269 let Some(root) = toml::Value::try_from(self).ok() else {
270 return false;
271 };
272 let mut cursor = &root;
273 for part in key.split('.') {
274 let Some(next) = cursor.get(part) else {
275 return false;
276 };
277 cursor = next;
278 }
279 true
280 }
281
282 pub fn load() -> Result<Self, ConfigError> {
284 let _ = dotenv::dotenv();
286
287 let env = env::var("OXIDITE_ENV")
288 .or_else(|_| env::var("ENVIRONMENT"))
289 .unwrap_or_else(|_| "development".to_string());
290
291 let mut config = if Path::new("oxidite.toml").exists() {
293 let content = fs::read_to_string("oxidite.toml")?;
294 toml::from_str(&content)?
295 } else {
296 Config::default()
297 };
298
299 config.apply_env_overrides()?;
301
302 config.app.environment = env;
303
304 Ok(config)
305 }
306
307 pub fn load_from(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
309 let _ = dotenv::dotenv();
310 let env_name = env::var("OXIDITE_ENV")
311 .or_else(|_| env::var("ENVIRONMENT"))
312 .unwrap_or_else(|_| "development".to_string());
313
314 let mut config = if path.as_ref().exists() {
315 let content = fs::read_to_string(path)?;
316 toml::from_str(&content)?
317 } else {
318 Config::default()
319 };
320
321 config.app.environment = env_name;
322 config.apply_env_overrides()?;
323 Ok(config)
324 }
325
326 pub fn load_yaml_from(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
328 let _ = dotenv::dotenv();
329 let env_name = env::var("OXIDITE_ENV")
330 .or_else(|_| env::var("ENVIRONMENT"))
331 .unwrap_or_else(|_| "development".to_string());
332
333 let mut config = if path.as_ref().exists() {
334 let content = fs::read_to_string(path)?;
335 serde_yaml::from_str(&content)?
336 } else {
337 Config::default()
338 };
339
340 config.app.environment = env_name;
341 config.apply_env_overrides()?;
342 Ok(config)
343 }
344
345 pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
347 if let Some(value) = self.custom.get(key) {
349 if let Ok(parsed) = T::deserialize(value.clone()) {
350 return Some(parsed);
351 }
352 }
353
354 let root = toml::Value::try_from(self).ok()?;
356 let mut cursor = &root;
357 for part in key.split('.') {
358 cursor = cursor.get(part)?;
359 }
360
361 T::deserialize(cursor.clone()).ok()
362 }
363
364 pub fn get_required<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Result<T, ConfigError> {
366 self.get(key).ok_or_else(|| {
367 if self.has_key(key) {
368 ConfigError::InvalidType(key.to_string())
369 } else {
370 ConfigError::MissingKey(key.to_string())
371 }
372 })
373 }
374
375 pub fn get_string(&self, key: &str) -> Result<String, ConfigError> {
377 self.get_required(key)
378 }
379
380 pub fn get_bool(&self, key: &str) -> Result<bool, ConfigError> {
382 self.get_required(key)
383 }
384
385 pub fn get_u16(&self, key: &str) -> Result<u16, ConfigError> {
387 self.get_required(key)
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_default_config() {
397 let config = Config::default();
398 assert_eq!(config.server.host, "127.0.0.1");
399 assert_eq!(config.server.port, 3000);
400 }
401
402 #[test]
403 fn test_environment_parsing() {
404 assert_eq!(Environment::from_str("production"), Environment::Production);
405 assert_eq!(Environment::from_str("PROD"), Environment::Production);
406 assert_eq!(Environment::from_str("development"), Environment::Development);
407 }
408
409 #[test]
410 fn test_get_required_typed_values() {
411 let config = Config::default();
412 assert_eq!(config.get_u16("server.port").unwrap(), 3000);
413 assert_eq!(config.get_bool("app.debug").unwrap(), true);
414 }
415
416 #[test]
417 fn test_invalid_server_port_env_returns_error() {
418 let prev = env::var("SERVER_PORT").ok();
419 env::set_var("SERVER_PORT", "not-a-port");
420
421 let result = Config::load();
422 assert!(matches!(
423 result,
424 Err(ConfigError::InvalidEnvValue { name, .. }) if name == "SERVER_PORT"
425 ));
426
427 if let Some(value) = prev {
428 env::set_var("SERVER_PORT", value);
429 } else {
430 env::remove_var("SERVER_PORT");
431 }
432 }
433
434 #[test]
435 fn test_load_from_applies_env_overrides() {
436 let prev_host = env::var("SERVER_HOST").ok();
437 env::set_var("SERVER_HOST", "0.0.0.0");
438
439 let cfg = Config::load_from("this-file-does-not-exist.toml").expect("load");
440 assert_eq!(cfg.server.host, "0.0.0.0");
441
442 if let Some(v) = prev_host {
443 env::set_var("SERVER_HOST", v);
444 } else {
445 env::remove_var("SERVER_HOST");
446 }
447 }
448}