1use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::{Path, PathBuf};
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11pub enum ConfigError {
12 #[error("Configuration file not found: {0}")]
13 FileNotFound(String),
14 #[error("Invalid configuration format: {0}")]
15 InvalidFormat(String),
16 #[error("Configuration validation error: {0}")]
17 Validation(String),
18 #[error("Environment variable error: {0}")]
19 Environment(String),
20 #[error("IO error: {0}")]
21 Io(#[from] std::io::Error),
22 #[error("Serialization error: {0}")]
23 Serialization(#[from] serde_json::Error),
24 #[error("TOML parsing error: {0}")]
25 Toml(#[from] toml::de::Error),
26 #[error("YAML parsing error: {0}")]
27 Yaml(#[from] serde_yaml::Error),
28}
29
30#[derive(Debug, Clone)]
32pub enum ConfigFormat {
33 Json,
34 Toml,
35 Yaml,
36 Environment,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct EnvironmentConfig {
42 pub name: String,
43 pub variables: HashMap<String, String>,
44 pub overrides: HashMap<String, serde_json::Value>,
45 pub secrets: Vec<String>,
46 pub required_vars: Vec<String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct DatabaseConfig {
52 pub host: String,
53 pub port: u16,
54 pub database: String,
55 pub username: String,
56 pub password: String,
57 pub ssl_mode: String,
58 pub pool_size: u32,
59 pub timeout: u64,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ServerConfig {
65 pub host: String,
66 pub port: u16,
67 pub workers: u32,
68 pub max_connections: u32,
69 pub timeout: u64,
70 pub tls_cert: Option<String>,
71 pub tls_key: Option<String>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct LoggingConfig {
77 pub level: String,
78 pub format: String,
79 pub output: Vec<String>,
80 pub rotation: Option<LogRotationConfig>,
81 pub structured: bool,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct LogRotationConfig {
86 pub size: String,
87 pub keep: u32,
88 pub compress: bool,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct SecurityConfig {
94 pub jwt_secret: String,
95 pub session_timeout: u64,
96 pub bcrypt_cost: u32,
97 pub rate_limiting: RateLimitConfig,
98 pub cors: CorsConfig,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct RateLimitConfig {
103 pub enabled: bool,
104 pub requests_per_minute: u32,
105 pub burst_size: u32,
106 pub whitelist: Vec<String>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct CorsConfig {
111 pub enabled: bool,
112 pub allowed_origins: Vec<String>,
113 pub allowed_methods: Vec<String>,
114 pub allowed_headers: Vec<String>,
115 pub max_age: u32,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct AppConfig {
121 pub environment: String,
122 pub debug: bool,
123 pub server: ServerConfig,
124 pub database: DatabaseConfig,
125 pub logging: LoggingConfig,
126 pub security: SecurityConfig,
127 pub features: HashMap<String, bool>,
128 pub custom: HashMap<String, serde_json::Value>,
129}
130
131pub struct ConfigManager {
133 config: AppConfig,
134 config_path: PathBuf,
135 format: ConfigFormat,
136 environments: HashMap<String, EnvironmentConfig>,
137 watchers: Vec<Box<dyn ConfigWatcher>>,
138}
139
140pub trait ConfigWatcher: Send + Sync {
142 fn on_config_changed(&self, config: &AppConfig) -> Result<(), ConfigError>;
143}
144
145impl Default for ConfigManager {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151impl ConfigManager {
152 pub fn new() -> Self {
154 Self {
155 config: AppConfig::default(),
156 config_path: PathBuf::from("config.toml"),
157 format: ConfigFormat::Toml,
158 environments: HashMap::new(),
159 watchers: Vec::new(),
160 }
161 }
162
163 pub fn load_from_file<P: AsRef<Path>>(&mut self, path: P) -> Result<(), ConfigError> {
165 let path = path.as_ref();
166 self.config_path = path.to_path_buf();
167
168 self.format = match path.extension().and_then(|ext| ext.to_str()) {
170 Some("json") => ConfigFormat::Json,
171 Some("toml") => ConfigFormat::Toml,
172 Some("yaml") | Some("yml") => ConfigFormat::Yaml,
173 _ => ConfigFormat::Toml,
174 };
175
176 let content = fs::read_to_string(path)
177 .map_err(|_| ConfigError::FileNotFound(path.display().to_string()))?;
178
179 self.config = self.parse_config(&content)?;
180 self.validate_config()?;
181
182 Ok(())
183 }
184
185 pub fn load_from_env(&mut self) -> Result<(), ConfigError> {
187 self.format = ConfigFormat::Environment;
188
189 let mut config = AppConfig::default();
190
191 if let Ok(host) = std::env::var("SERVER_HOST") {
193 config.server.host = host;
194 }
195 if let Ok(port) = std::env::var("SERVER_PORT") {
196 config.server.port = port
197 .parse()
198 .map_err(|_| ConfigError::Environment("Invalid SERVER_PORT".to_string()))?;
199 }
200
201 if let Ok(host) = std::env::var("DATABASE_HOST") {
203 config.database.host = host;
204 }
205 if let Ok(port) = std::env::var("DATABASE_PORT") {
206 config.database.port = port
207 .parse()
208 .map_err(|_| ConfigError::Environment("Invalid DATABASE_PORT".to_string()))?;
209 }
210 if let Ok(database) = std::env::var("DATABASE_NAME") {
211 config.database.database = database;
212 }
213 if let Ok(username) = std::env::var("DATABASE_USER") {
214 config.database.username = username;
215 }
216 if let Ok(password) = std::env::var("DATABASE_PASSWORD") {
217 config.database.password = password;
218 }
219
220 if let Ok(jwt_secret) = std::env::var("JWT_SECRET") {
222 config.security.jwt_secret = jwt_secret;
223 }
224
225 if let Ok(env) = std::env::var("ENVIRONMENT") {
227 config.environment = env;
228 }
229
230 self.config = config;
231 self.validate_config()?;
232
233 Ok(())
234 }
235
236 fn parse_config(&self, content: &str) -> Result<AppConfig, ConfigError> {
238 match self.format {
239 ConfigFormat::Json => {
240 serde_json::from_str(content).map_err(|e| ConfigError::InvalidFormat(e.to_string()))
241 }
242 ConfigFormat::Toml => {
243 toml::from_str(content).map_err(|e| ConfigError::InvalidFormat(e.to_string()))
244 }
245 ConfigFormat::Yaml => {
246 serde_yaml::from_str(content).map_err(|e| ConfigError::InvalidFormat(e.to_string()))
247 }
248 ConfigFormat::Environment => Err(ConfigError::InvalidFormat(
249 "Environment loading not supported here".to_string(),
250 )),
251 }
252 }
253
254 fn validate_config(&self) -> Result<(), ConfigError> {
256 if self.config.server.host.is_empty() {
258 return Err(ConfigError::Validation(
259 "Server host cannot be empty".to_string(),
260 ));
261 }
262 if self.config.server.port == 0 {
263 return Err(ConfigError::Validation(
264 "Server port must be greater than 0".to_string(),
265 ));
266 }
267 if self.config.server.workers == 0 {
268 return Err(ConfigError::Validation(
269 "Server workers must be greater than 0".to_string(),
270 ));
271 }
272
273 if self.config.database.host.is_empty() {
275 return Err(ConfigError::Validation(
276 "Database host cannot be empty".to_string(),
277 ));
278 }
279 if self.config.database.port == 0 {
280 return Err(ConfigError::Validation(
281 "Database port must be greater than 0".to_string(),
282 ));
283 }
284 if self.config.database.database.is_empty() {
285 return Err(ConfigError::Validation(
286 "Database name cannot be empty".to_string(),
287 ));
288 }
289
290 if self.config.security.jwt_secret.is_empty() {
292 return Err(ConfigError::Validation(
293 "JWT secret cannot be empty".to_string(),
294 ));
295 }
296 if self.config.security.jwt_secret.len() < 32 {
297 return Err(ConfigError::Validation(
298 "JWT secret must be at least 32 characters".to_string(),
299 ));
300 }
301 if self.config.security.bcrypt_cost < 4 || self.config.security.bcrypt_cost > 31 {
302 return Err(ConfigError::Validation(
303 "Bcrypt cost must be between 4 and 31".to_string(),
304 ));
305 }
306
307 let valid_levels = ["trace", "debug", "info", "warn", "error"];
309 if !valid_levels.contains(&self.config.logging.level.as_str()) {
310 return Err(ConfigError::Validation("Invalid logging level".to_string()));
311 }
312
313 Ok(())
314 }
315
316 pub fn get_config(&self) -> &AppConfig {
318 &self.config
319 }
320
321 pub fn set_value(&mut self, key: &str, value: serde_json::Value) -> Result<(), ConfigError> {
323 let parts: Vec<&str> = key.split('.').collect();
325
326 match parts.as_slice() {
327 ["server", "host"] => {
328 if let Some(host) = value.as_str() {
329 self.config.server.host = host.to_string();
330 } else {
331 return Err(ConfigError::Validation(
332 "Server host must be a string".to_string(),
333 ));
334 }
335 }
336 ["server", "port"] => {
337 if let Some(port) = value.as_u64() {
338 self.config.server.port = port as u16;
339 } else {
340 return Err(ConfigError::Validation(
341 "Server port must be a number".to_string(),
342 ));
343 }
344 }
345 ["database", "host"] => {
346 if let Some(host) = value.as_str() {
347 self.config.database.host = host.to_string();
348 } else {
349 return Err(ConfigError::Validation(
350 "Database host must be a string".to_string(),
351 ));
352 }
353 }
354 ["database", "port"] => {
355 if let Some(port) = value.as_u64() {
356 self.config.database.port = port as u16;
357 } else {
358 return Err(ConfigError::Validation(
359 "Database port must be a number".to_string(),
360 ));
361 }
362 }
363 ["features", feature] => {
364 if let Some(feature_value) = value.as_bool() {
365 self.config
366 .features
367 .insert(feature.to_string(), feature_value);
368 } else {
369 return Err(ConfigError::Validation(
370 "Feature value must be boolean".to_string(),
371 ));
372 }
373 }
374 ["custom", custom_key] => {
375 self.config.custom.insert(custom_key.to_string(), value);
376 }
377 _ => {
378 return Err(ConfigError::Validation(format!(
379 "Unknown configuration key: {}",
380 key
381 )));
382 }
383 }
384
385 self.validate_config()?;
386 self.notify_watchers()?;
387
388 Ok(())
389 }
390
391 pub fn add_environment(&mut self, name: String, env_config: EnvironmentConfig) {
393 self.environments.insert(name, env_config);
394 }
395
396 pub fn switch_environment(&mut self, env_name: &str) -> Result<(), ConfigError> {
398 let overrides = if let Some(env_config) = self.environments.get(env_name) {
399 env_config.overrides.clone()
400 } else {
401 return Err(ConfigError::Validation(format!(
402 "Environment not found: {}",
403 env_name
404 )));
405 };
406
407 for (key, value) in &overrides {
409 self.set_value(key, value.clone())?;
410 }
411
412 self.config.environment = env_name.to_string();
413 self.notify_watchers()?;
414
415 Ok(())
416 }
417
418 pub fn add_watcher(&mut self, watcher: Box<dyn ConfigWatcher>) {
420 self.watchers.push(watcher);
421 }
422
423 fn notify_watchers(&self) -> Result<(), ConfigError> {
425 for watcher in &self.watchers {
426 watcher.on_config_changed(&self.config)?;
427 }
428 Ok(())
429 }
430
431 pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), ConfigError> {
433 let content = match self.format {
434 ConfigFormat::Json => serde_json::to_string_pretty(&self.config)?,
435 ConfigFormat::Toml => toml::to_string(&self.config)
436 .map_err(|e| ConfigError::InvalidFormat(e.to_string()))?,
437 ConfigFormat::Yaml => serde_yaml::to_string(&self.config)
438 .map_err(|e| ConfigError::InvalidFormat(e.to_string()))?,
439 ConfigFormat::Environment => {
440 return Err(ConfigError::InvalidFormat(
441 "Cannot save environment config to file".to_string(),
442 ));
443 }
444 };
445
446 fs::write(path, content)?;
447 Ok(())
448 }
449
450 pub fn reload(&mut self) -> Result<(), ConfigError> {
452 let config_path = self.config_path.clone();
453 if config_path.exists() {
454 self.load_from_file(&config_path)?;
455 self.notify_watchers()?;
456 }
457 Ok(())
458 }
459}
460
461impl Default for AppConfig {
462 fn default() -> Self {
463 Self {
464 environment: "development".to_string(),
465 debug: true,
466 server: ServerConfig {
467 host: "127.0.0.1".to_string(),
468 port: 8080,
469 workers: 4,
470 max_connections: 1000,
471 timeout: 30,
472 tls_cert: None,
473 tls_key: None,
474 },
475 database: DatabaseConfig {
476 host: "localhost".to_string(),
477 port: 5432,
478 database: "authframework".to_string(),
479 username: "postgres".to_string(),
480 password: "password".to_string(),
481 ssl_mode: "prefer".to_string(),
482 pool_size: 10,
483 timeout: 30,
484 },
485 logging: LoggingConfig {
486 level: "info".to_string(),
487 format: "json".to_string(),
488 output: vec!["stdout".to_string()],
489 rotation: Some(LogRotationConfig {
490 size: "10MB".to_string(),
491 keep: 7,
492 compress: true,
493 }),
494 structured: true,
495 },
496 security: SecurityConfig {
497 jwt_secret: "your-super-secret-jwt-key-change-this-in-production".to_string(),
498 session_timeout: 3600,
499 bcrypt_cost: 12,
500 rate_limiting: RateLimitConfig {
501 enabled: true,
502 requests_per_minute: 100,
503 burst_size: 20,
504 whitelist: vec!["127.0.0.1".to_string()],
505 },
506 cors: CorsConfig {
507 enabled: true,
508 allowed_origins: vec!["*".to_string()],
509 allowed_methods: vec![
510 "GET".to_string(),
511 "POST".to_string(),
512 "PUT".to_string(),
513 "DELETE".to_string(),
514 ],
515 allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
516 max_age: 3600,
517 },
518 },
519 features: HashMap::new(),
520 custom: HashMap::new(),
521 }
522 }
523}
524
525pub struct SimpleConfigWatcher {
527 name: String,
528}
529
530impl SimpleConfigWatcher {
531 pub fn new(name: String) -> Self {
532 Self { name }
533 }
534}
535
536impl ConfigWatcher for SimpleConfigWatcher {
537 fn on_config_changed(&self, _config: &AppConfig) -> Result<(), ConfigError> {
538 println!("Configuration changed for watcher: {}", self.name);
539 Ok(())
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use std::fs::write;
547 use tempfile::tempdir;
548
549 #[test]
550 fn test_config_manager_creation() {
551 let manager = ConfigManager::new();
552 assert_eq!(manager.config.environment, "development");
553 }
554
555 #[test]
556 fn test_load_from_env() {
557 unsafe {
558 std::env::set_var("SERVER_HOST", "0.0.0.0");
559 std::env::set_var("SERVER_PORT", "9090");
560 }
561
562 let mut manager = ConfigManager::new();
563 let result = manager.load_from_env();
564
565 assert!(result.is_ok());
566 assert_eq!(manager.config.server.host, "0.0.0.0");
567 assert_eq!(manager.config.server.port, 9090);
568
569 unsafe {
570 std::env::remove_var("SERVER_HOST");
571 std::env::remove_var("SERVER_PORT");
572 }
573 }
574
575 #[test]
576 fn test_load_from_toml_file() {
577 let dir = tempdir().unwrap();
578 let file_path = dir.path().join("config.toml");
579
580 let toml_content = r#"
581environment = "test"
582debug = false
583
584[server]
585host = "0.0.0.0"
586port = 9000
587workers = 8
588max_connections = 2000
589timeout = 60
590
591[database]
592host = "db.example.com"
593port = 5432
594database = "test_db"
595username = "test_user"
596password = "test_pass"
597ssl_mode = "require"
598pool_size = 20
599timeout = 60
600
601[logging]
602level = "debug"
603format = "text"
604output = ["stdout", "file"]
605structured = false
606
607[security]
608jwt_secret = "test-secret-key-that-is-long-enough-for-validation"
609session_timeout = 7200
610bcrypt_cost = 10
611
612[security.rate_limiting]
613enabled = true
614requests_per_minute = 200
615burst_size = 40
616whitelist = ["192.168.1.1"]
617
618[security.cors]
619enabled = true
620allowed_origins = ["https://example.com"]
621allowed_methods = ["GET", "POST"]
622allowed_headers = ["Content-Type"]
623max_age = 1800
624
625[features]
626# Add some example features
627mfa = true
628oauth = false
629
630[custom]
631# Custom configuration values
632app_version = "1.0.0"
633 "#;
634
635 write(&file_path, toml_content).unwrap();
636
637 let mut manager = ConfigManager::new();
638 let result = manager.load_from_file(&file_path);
639
640 if let Err(ref e) = result {
641 eprintln!("Config load error: {:?}", e);
642 }
643 assert!(
644 result.is_ok(),
645 "Failed to load config: {:?}",
646 result.unwrap_err()
647 );
648 assert_eq!(manager.config.environment, "test");
649 assert_eq!(manager.config.server.host, "0.0.0.0");
650 assert_eq!(manager.config.server.port, 9000);
651 assert_eq!(manager.config.database.host, "db.example.com");
652 assert_eq!(manager.config.security.bcrypt_cost, 10);
653 }
654
655 #[test]
656 fn test_config_validation() {
657 let mut config = AppConfig::default();
658 config.security.jwt_secret = "short".to_string(); let manager = ConfigManager {
661 config,
662 config_path: PathBuf::new(),
663 format: ConfigFormat::Toml,
664 environments: HashMap::new(),
665 watchers: Vec::new(),
666 };
667
668 let result = manager.validate_config();
669 assert!(result.is_err());
670 assert!(matches!(result.unwrap_err(), ConfigError::Validation(_)));
671 }
672
673 #[test]
674 fn test_set_value() {
675 let mut manager = ConfigManager::new();
676
677 let result = manager.set_value(
678 "server.port",
679 serde_json::Value::Number(serde_json::Number::from(9999)),
680 );
681 assert!(result.is_ok());
682 assert_eq!(manager.config.server.port, 9999);
683 }
684
685 #[test]
686 fn test_config_watcher() {
687 let mut manager = ConfigManager::new();
688 let watcher = Box::new(SimpleConfigWatcher::new("test".to_string()));
689 manager.add_watcher(watcher);
690
691 let result = manager.set_value(
692 "server.port",
693 serde_json::Value::Number(serde_json::Number::from(8888)),
694 );
695 assert!(result.is_ok());
696 }
697}
698
699