Skip to main content

auth_framework/config/
app_config.rs

1/// Configuration management with environment variable support.
2///
3/// This module provides easy configuration loading from environment
4/// variables, config files, and other sources.
5use super::SecurityConfig;
6use serde::{Deserialize, Serialize};
7use std::{env, time::Duration};
8
9impl Default for ConfigBuilder {
10    fn default() -> Self {
11        Self::new()
12    }
13}
14
15/// Complete application configuration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AppConfig {
18    /// Database configuration
19    pub database: DatabaseConfig,
20    /// Redis configuration
21    pub redis: Option<RedisConfig>,
22    /// JWT configuration
23    pub jwt: JwtConfig,
24    /// OAuth providers
25    pub oauth: OAuthConfig,
26    /// Security settings
27    pub security: SecuritySettings,
28    /// Logging configuration
29    pub logging: LoggingConfig,
30}
31
32/// Database connection settings.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct DatabaseConfig {
35    /// Database connection URL (PostgreSQL, MySQL, SQLite, etc.)
36    pub url: String,
37    /// Maximum number of concurrent database connections
38    pub max_connections: u32,
39    /// Minimum number of idle connections to maintain
40    pub min_connections: u32,
41    /// Connection timeout in seconds
42    pub connect_timeout_seconds: u64,
43}
44
45/// Redis cache and session storage settings.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct RedisConfig {
48    /// Redis connection URL
49    pub url: String,
50    /// Number of connections in the Redis connection pool
51    pub pool_size: u32,
52}
53
54/// JWT authentication settings.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct JwtConfig {
57    /// Cryptographic secret key used to sign and verify JWTs
58    pub secret_key: String,
59    /// The 'iss' (issuer) claim to embed in generated tokens
60    pub issuer: String,
61    /// The 'aud' (audience) claim to embed in generated tokens
62    pub audience: String,
63    /// Primary access token lifetime in seconds
64    pub access_token_ttl_seconds: u64,
65    /// Refresh token lifetime in seconds
66    pub refresh_token_ttl_seconds: u64,
67}
68
69/// Supported OAuth provider identities.
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, std::hash::Hash)]
71#[serde(rename_all = "lowercase")]
72pub enum OAuthProvider {
73    /// Google OAuth identity provider
74    Google,
75    /// GitHub OAuth identity provider
76    GitHub,
77    /// Microsoft OAuth identity provider
78    Microsoft,
79    /// Other custom OAuth provider
80    Custom(String),
81}
82
83/// OAuth 2.0 configuration settings.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct OAuthConfig {
86    /// OAuth provider configurations, keyed by provider
87    pub providers: std::collections::HashMap<String, OAuthProviderConfig>,
88}
89
90/// Individual OAuth provider settings.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct OAuthProviderConfig {
93    /// OAuth client ID provided by the Identity Provider
94    pub client_id: String,
95    /// OAuth client secret provided by the Identity Provider
96    pub client_secret: String,
97    /// The redirect URI where the IDP will send the user post-authentication
98    pub redirect_uri: String,
99    /// List of scopes to request during authentication
100    pub scopes: Vec<String>,
101}
102
103/// Application-wide security policies.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SecuritySettings {
106    /// Minimum character length for new passwords
107    pub password_min_length: usize,
108    /// Whether to require at least one special character in new passwords
109    pub password_require_special_chars: bool,
110    /// Maximum number of requests allowed per minute per IP
111    pub rate_limit_requests_per_minute: u32,
112    /// Maximum hours a session token remains valid without activity
113    pub session_timeout_hours: u64,
114    /// Maximum number of simultaneous active sessions a user can have
115    pub max_concurrent_sessions: u32,
116    /// Whether Multi-Factor Authentication is globally required
117    pub require_mfa: bool,
118}
119
120/// System logging and auditing settings.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct LoggingConfig {
123    /// Log level filter (e.g., "info", "debug", "warn")
124    pub level: String,
125    /// True if audit logging for security events is enabled
126    pub audit_enabled: bool,
127    /// Location to store audit events ("database", "file", "syslog")
128    pub audit_storage: String,
129}
130
131impl AppConfig {
132    /// Load configuration from environment variables
133    pub fn from_env() -> Result<Self, ConfigError> {
134        Ok(Self {
135            database: DatabaseConfig {
136                url: env::var("DATABASE_URL")
137                    .map_err(|_| ConfigError::MissingEnvVar("DATABASE_URL"))?,
138                max_connections: env::var("DB_MAX_CONNECTIONS")
139                    .unwrap_or_else(|_| "10".to_string())
140                    .parse()
141                    .map_err(|_| ConfigError::InvalidValue("DB_MAX_CONNECTIONS"))?,
142                min_connections: 1,
143                connect_timeout_seconds: 30,
144            },
145            redis: if let Ok(redis_url) = env::var("REDIS_URL") {
146                Some(RedisConfig {
147                    url: redis_url,
148                    pool_size: 10,
149                })
150            } else {
151                None
152            },
153            jwt: JwtConfig {
154                secret_key: env::var("JWT_SECRET")
155                    .map_err(|_| ConfigError::MissingEnvVar("JWT_SECRET"))?,
156                issuer: env::var("JWT_ISSUER").unwrap_or_else(|_| "auth-framework".to_string()),
157                audience: env::var("JWT_AUDIENCE").unwrap_or_else(|_| "api".to_string()),
158                access_token_ttl_seconds: 3600,
159                refresh_token_ttl_seconds: 86400 * 7,
160            },
161            oauth: OAuthConfig {
162                providers: {
163                    let mut map = std::collections::HashMap::new();
164                    if let Some(cfg) = Self::load_oauth_provider("GOOGLE") {
165                        map.insert("google".to_string(), cfg);
166                    }
167                    if let Some(cfg) = Self::load_oauth_provider("GITHUB") {
168                        map.insert("github".to_string(), cfg);
169                    }
170                    if let Some(cfg) = Self::load_oauth_provider("MICROSOFT") {
171                        map.insert("microsoft".to_string(), cfg);
172                    }
173                    map
174                },
175            },
176            security: SecuritySettings {
177                password_min_length: 8,
178                password_require_special_chars: true,
179                rate_limit_requests_per_minute: 60,
180                session_timeout_hours: 24,
181                max_concurrent_sessions: 5,
182                require_mfa: env::var("REQUIRE_MFA").unwrap_or_default() == "true",
183            },
184            logging: LoggingConfig {
185                level: env::var("LOG_LEVEL").unwrap_or_else(|_| "info".to_string()),
186                audit_enabled: true,
187                audit_storage: env::var("AUDIT_STORAGE").unwrap_or_else(|_| "database".to_string()),
188            },
189        })
190    }
191
192    fn load_oauth_provider(provider: &str) -> Option<OAuthProviderConfig> {
193        let client_id = env::var(format!("{}_CLIENT_ID", provider)).ok()?;
194        let client_secret = env::var(format!("{}_CLIENT_SECRET", provider)).ok()?;
195
196        Some(OAuthProviderConfig {
197            client_id,
198            client_secret,
199            redirect_uri: env::var(format!("{}_REDIRECT_URI", provider))
200                .unwrap_or_else(|_| format!("/auth/{}/callback", provider.to_lowercase())),
201            scopes: env::var(format!("{}_SCOPES", provider))
202                .unwrap_or_default()
203                .split(',')
204                .map(|s| s.trim().to_string())
205                .filter(|s| !s.is_empty())
206                .collect(),
207        })
208    }
209
210    /// Convert to AuthConfig
211    pub fn to_auth_config(&self) -> super::AuthConfig {
212        let mut config = super::AuthConfig::new()
213            .token_lifetime(Duration::from_secs(self.jwt.access_token_ttl_seconds))
214            .refresh_token_lifetime(Duration::from_secs(self.jwt.refresh_token_ttl_seconds))
215            .issuer(&self.jwt.issuer)
216            .audience(&self.jwt.audience)
217            .secret(&self.jwt.secret_key)
218            .security(self.to_security_config());
219
220        config.storage = self.primary_storage_config();
221        config.enable_multi_factor = self.security.require_mfa;
222        config.rate_limiting = super::RateLimitConfig {
223            enabled: self.security.rate_limit_requests_per_minute > 0,
224            max_requests: self.security.rate_limit_requests_per_minute,
225            window: Duration::from_secs(60),
226            burst: (self.security.rate_limit_requests_per_minute / 10).max(1),
227        };
228        config.audit.enabled = self.logging.audit_enabled;
229        config
230    }
231
232    /// Convert to SecurityConfig
233    pub fn to_security_config(&self) -> SecurityConfig {
234        let mut config = SecurityConfig::default();
235        config.min_password_length = self.security.password_min_length;
236        config.require_password_complexity = self.security.password_require_special_chars;
237        config.secret_key = Some(self.jwt.secret_key.clone());
238        config.session_timeout = Duration::from_secs(self.security.session_timeout_hours * 3600);
239        config
240    }
241
242    /// Build an initialized AuthFramework using the configured storage backend.
243    pub async fn build_auth_framework(&self) -> crate::errors::Result<crate::AuthFramework> {
244        let auth_config = self.to_auth_config();
245        let pool_size = self.primary_storage_pool_size();
246
247        let mut framework = crate::AuthFramework::new(auth_config.clone());
248        let storage =
249            crate::storage::factory::build_storage_backend(&auth_config.storage, pool_size).await?;
250        framework.replace_storage(storage);
251        framework.initialize().await?;
252        Ok(framework)
253    }
254
255    pub(crate) fn primary_storage_config(&self) -> super::StorageConfig {
256        let database_url = self.database.url.trim();
257
258        if database_url.starts_with("postgres://") || database_url.starts_with("postgresql://") {
259            #[cfg(feature = "postgres-storage")]
260            {
261                return super::StorageConfig::Postgres {
262                    connection_string: database_url.to_string(),
263                    table_prefix: "auth_".to_string(),
264                };
265            }
266
267            #[cfg(not(feature = "postgres-storage"))]
268            {
269                return super::StorageConfig::Custom(
270                    "postgres-storage feature is required for PostgreSQL DATABASE_URL".to_string(),
271                );
272            }
273        }
274
275        if database_url.starts_with("mysql://") {
276            #[cfg(feature = "mysql-storage")]
277            {
278                return super::StorageConfig::MySQL {
279                    connection_string: database_url.to_string(),
280                    table_prefix: "auth_".to_string(),
281                };
282            }
283
284            #[cfg(not(feature = "mysql-storage"))]
285            {
286                return super::StorageConfig::Custom(
287                    "mysql-storage feature is required for MySQL DATABASE_URL".to_string(),
288                );
289            }
290        }
291
292        if database_url.starts_with("sqlite:") {
293            #[cfg(feature = "sqlite-storage")]
294            {
295                return super::StorageConfig::Sqlite {
296                    connection_string: database_url.to_string(),
297                };
298            }
299
300            #[cfg(not(feature = "sqlite-storage"))]
301            {
302                return super::StorageConfig::Custom(
303                    "sqlite-storage feature is required for SQLite DATABASE_URL".to_string(),
304                );
305            }
306        }
307
308        super::StorageConfig::Memory
309    }
310
311    fn primary_storage_pool_size(&self) -> Option<u32> {
312        let database_url = self.database.url.trim();
313        if database_url.starts_with("postgres://")
314            || database_url.starts_with("postgresql://")
315            || database_url.starts_with("mysql://")
316            || database_url.starts_with("sqlite:")
317        {
318            return Some(self.database.max_connections);
319        }
320
321        None
322    }
323}
324
325#[derive(Debug, thiserror::Error)]
326pub enum ConfigError {
327    #[error("Missing environment variable: {0}")]
328    MissingEnvVar(&'static str),
329    #[error("Invalid value for: {0}")]
330    InvalidValue(&'static str),
331    #[error("Configuration validation error: {0}")]
332    Validation(String),
333}
334
335/// Configuration builder for easy setup
336pub struct ConfigBuilder {
337    config: AppConfig,
338}
339
340impl ConfigBuilder {
341    pub fn new() -> Self {
342        Self {
343            config: AppConfig::from_env().unwrap_or_else(|_| AppConfig::default()),
344        }
345    }
346
347    pub fn with_database_url(mut self, url: impl Into<String>) -> Self {
348        self.config.database.url = url.into();
349        self
350    }
351
352    pub fn with_database_max_connections(mut self, max_connections: u32) -> Self {
353        self.config.database.max_connections = max_connections;
354        self
355    }
356
357    pub fn with_database_min_connections(mut self, min_connections: u32) -> Self {
358        self.config.database.min_connections = min_connections;
359        self
360    }
361
362    pub fn with_database_connect_timeout(mut self, seconds: u64) -> Self {
363        self.config.database.connect_timeout_seconds = seconds;
364        self
365    }
366
367    pub fn with_jwt_secret(mut self, secret: impl Into<String>) -> Self {
368        self.config.jwt.secret_key = secret.into();
369        self
370    }
371
372    pub fn with_jwt_issuer(mut self, issuer: impl Into<String>) -> Self {
373        self.config.jwt.issuer = issuer.into();
374        self
375    }
376
377    pub fn with_jwt_audience(mut self, audience: impl Into<String>) -> Self {
378        self.config.jwt.audience = audience.into();
379        self
380    }
381
382    pub fn with_access_token_ttl_seconds(mut self, ttl_seconds: u64) -> Self {
383        self.config.jwt.access_token_ttl_seconds = ttl_seconds;
384        self
385    }
386
387    pub fn with_refresh_token_ttl_seconds(mut self, ttl_seconds: u64) -> Self {
388        self.config.jwt.refresh_token_ttl_seconds = ttl_seconds;
389        self
390    }
391
392    pub fn with_redis_url(mut self, url: impl Into<String>) -> Self {
393        self.config.redis = Some(RedisConfig {
394            url: url.into(),
395            pool_size: 10,
396        });
397        self
398    }
399
400    pub fn with_redis_pool_size(mut self, pool_size: u32) -> Self {
401        let redis = self.config.redis.get_or_insert(RedisConfig {
402            url: "redis://127.0.0.1:6379".to_string(),
403            pool_size: 10,
404        });
405        redis.pool_size = pool_size;
406        self
407    }
408
409    /// Set password policy constraints.
410    ///
411    /// # Arguments
412    /// * `min_length` - minimum password length (typically 8+)
413    /// * `require_special` - require at least one special character (!@#$%^&* etc.)
414    pub fn with_password_policy(mut self, min_length: usize, require_special: bool) -> Self {
415        self.config.security.password_min_length = min_length;
416        self.config.security.password_require_special_chars = require_special;
417        self
418    }
419
420    pub fn with_rate_limit_requests_per_minute(mut self, requests: u32) -> Self {
421        self.config.security.rate_limit_requests_per_minute = requests;
422        self
423    }
424
425    pub fn with_session_timeout_hours(mut self, hours: u64) -> Self {
426        self.config.security.session_timeout_hours = hours;
427        self
428    }
429
430    pub fn with_require_mfa(mut self, require_mfa: bool) -> Self {
431        self.config.security.require_mfa = require_mfa;
432        self
433    }
434
435    pub fn with_log_level(mut self, level: impl Into<String>) -> Self {
436        self.config.logging.level = level.into();
437        self
438    }
439
440    pub fn build(self) -> AppConfig {
441        self.config
442    }
443}
444
445impl Default for AppConfig {
446    fn default() -> Self {
447        Self {
448            database: DatabaseConfig {
449                url: "postgresql://localhost/auth_framework".to_string(),
450                max_connections: 10,
451                min_connections: 1,
452                connect_timeout_seconds: 30,
453            },
454            redis: None,
455            jwt: JwtConfig {
456                secret_key: "development-only-secret-change-in-production".to_string(),
457                issuer: "auth-framework".to_string(),
458                audience: "api".to_string(),
459                access_token_ttl_seconds: 3600,
460                refresh_token_ttl_seconds: 86400 * 7,
461            },
462            oauth: OAuthConfig {
463                providers: std::collections::HashMap::new(),
464            },
465            security: SecuritySettings {
466                password_min_length: 8,
467                password_require_special_chars: true,
468                rate_limit_requests_per_minute: 60,
469                session_timeout_hours: 24,
470                max_concurrent_sessions: 5,
471                require_mfa: false,
472            },
473            logging: LoggingConfig {
474                level: "info".to_string(),
475                audit_enabled: true,
476                audit_storage: "database".to_string(),
477            },
478        }
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_config_builder() {
488        let config = ConfigBuilder::new()
489            .with_database_url("postgresql://test")
490            .with_database_max_connections(25)
491            .with_jwt_secret("test-secret")
492            .with_jwt_issuer("issuer")
493            .with_jwt_audience("audience")
494            .with_rate_limit_requests_per_minute(120)
495            .build();
496
497        assert_eq!(config.database.url, "postgresql://test");
498        assert_eq!(config.database.max_connections, 25);
499        assert_eq!(config.jwt.secret_key, "test-secret");
500        assert_eq!(config.jwt.issuer, "issuer");
501        assert_eq!(config.jwt.audience, "audience");
502        assert_eq!(config.security.rate_limit_requests_per_minute, 120);
503    }
504}