pub mod app_config;
pub mod config_manager;
pub use config_manager::{
AuthFrameworkSettings, ConfigBuilder, ConfigIntegration, ConfigManager, SessionCookieSettings,
SessionSettings,
};
use crate::errors::{AuthError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
pub token_lifetime: Duration,
pub refresh_token_lifetime: Duration,
pub enable_multi_factor: bool,
pub issuer: String,
pub audience: String,
pub secret: Option<String>,
pub storage: StorageConfig,
pub rate_limiting: RateLimitConfig,
pub security: SecurityConfig,
pub audit: AuditConfig,
pub method_configs: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StorageConfig {
Memory,
#[cfg(feature = "redis-storage")]
Redis { url: String, key_prefix: String },
#[cfg(feature = "postgres-storage")]
Postgres {
connection_string: String,
table_prefix: String,
},
#[cfg(feature = "mysql-storage")]
MySQL {
connection_string: String,
table_prefix: String,
},
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub enabled: bool,
pub max_requests: u32,
pub window: Duration,
pub burst: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub min_password_length: usize,
pub require_password_complexity: bool,
pub password_hash_algorithm: PasswordHashAlgorithm,
pub jwt_algorithm: JwtAlgorithm,
pub secret_key: Option<String>,
pub secure_cookies: bool,
pub cookie_same_site: CookieSameSite,
pub csrf_protection: bool,
pub session_timeout: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PasswordHashAlgorithm {
Argon2,
Bcrypt,
Scrypt,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum JwtAlgorithm {
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
ES256,
ES384,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CookieSameSite {
Strict,
Lax,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditConfig {
pub enabled: bool,
pub log_success: bool,
pub log_failures: bool,
pub log_permissions: bool,
pub log_tokens: bool,
pub storage: AuditStorage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AuditStorage {
Tracing,
File { path: String },
Database { connection_string: String },
External { endpoint: String, api_key: String },
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
token_lifetime: Duration::from_secs(3600), refresh_token_lifetime: Duration::from_secs(86400 * 7), enable_multi_factor: false,
issuer: "auth-framework".to_string(),
audience: "api".to_string(),
secret: None,
storage: StorageConfig::Memory,
rate_limiting: RateLimitConfig::default(),
security: SecurityConfig::default(),
audit: AuditConfig::default(),
method_configs: HashMap::new(),
}
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enabled: true,
max_requests: 100,
window: Duration::from_secs(60), burst: 10,
}
}
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
min_password_length: 8,
require_password_complexity: true,
password_hash_algorithm: PasswordHashAlgorithm::Argon2,
jwt_algorithm: JwtAlgorithm::HS256,
secret_key: None,
secure_cookies: true,
cookie_same_site: CookieSameSite::Lax,
csrf_protection: true,
session_timeout: Duration::from_secs(3600 * 24), }
}
}
impl Default for AuditConfig {
fn default() -> Self {
Self {
enabled: true,
log_success: true,
log_failures: true,
log_permissions: true,
log_tokens: false, storage: AuditStorage::Tracing,
}
}
}
impl AuthConfig {
pub fn new() -> Self {
Self::default()
}
pub fn token_lifetime(mut self, lifetime: Duration) -> Self {
self.token_lifetime = lifetime;
self
}
pub fn refresh_token_lifetime(mut self, lifetime: Duration) -> Self {
self.refresh_token_lifetime = lifetime;
self
}
pub fn enable_multi_factor(mut self, enabled: bool) -> Self {
self.enable_multi_factor = enabled;
self
}
pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = issuer.into();
self
}
pub fn audience(mut self, audience: impl Into<String>) -> Self {
self.audience = audience.into();
self
}
pub fn secret(mut self, secret: impl Into<String>) -> Self {
self.secret = Some(secret.into());
self
}
pub fn require_mfa(mut self, required: bool) -> Self {
self.enable_multi_factor = required;
self
}
pub fn enable_caching(self, _enabled: bool) -> Self {
self
}
pub fn max_failed_attempts(self, _max: u32) -> Self {
self
}
pub fn enable_rbac(self, _enabled: bool) -> Self {
self
}
pub fn enable_security_audit(self, _enabled: bool) -> Self {
self
}
pub fn enable_middleware(self, _enabled: bool) -> Self {
self
}
pub fn storage(mut self, storage: StorageConfig) -> Self {
self.storage = storage;
self
}
#[cfg(feature = "redis-storage")]
pub fn redis_storage(mut self, url: impl Into<String>) -> Self {
self.storage = StorageConfig::Redis {
url: url.into(),
key_prefix: "auth:".to_string(),
};
self
}
pub fn rate_limiting(mut self, config: RateLimitConfig) -> Self {
self.rate_limiting = config;
self
}
pub fn security(mut self, config: SecurityConfig) -> Self {
self.security = config;
self
}
pub fn audit(mut self, config: AuditConfig) -> Self {
self.audit = config;
self
}
pub fn method_config(
mut self,
method_name: impl Into<String>,
config: impl Serialize,
) -> Result<Self> {
let value = serde_json::to_value(config)
.map_err(|e| AuthError::config(format!("Failed to serialize method config: {e}")))?;
self.method_configs.insert(method_name.into(), value);
Ok(self)
}
pub fn get_method_config<T>(&self, method_name: &str) -> Result<Option<T>>
where
T: for<'de> Deserialize<'de>,
{
if let Some(value) = self.method_configs.get(method_name) {
let config = serde_json::from_value(value.clone()).map_err(|e| {
AuthError::config(format!("Failed to deserialize method config: {e}"))
})?;
Ok(Some(config))
} else {
Ok(None)
}
}
pub fn validate(&self) -> Result<()> {
if self.token_lifetime.as_secs() == 0 {
return Err(AuthError::config("Token lifetime must be greater than 0"));
}
if self.refresh_token_lifetime.as_secs() == 0 {
return Err(AuthError::config(
"Refresh token lifetime must be greater than 0",
));
}
if self.refresh_token_lifetime <= self.token_lifetime {
return Err(AuthError::config(
"Refresh token lifetime must be greater than token lifetime",
));
}
self.validate_jwt_secret()?;
if self.security.min_password_length < 4 {
return Err(AuthError::config(
"Minimum password length must be at least 4 characters",
));
}
if self.is_production_environment() && !self.is_test_environment() {
self.validate_production_security()?;
}
if self.rate_limiting.enabled && self.rate_limiting.max_requests == 0 {
return Err(AuthError::config(
"Rate limit max requests must be greater than 0 when enabled",
));
}
self.validate_storage_config()?;
Ok(())
}
fn validate_jwt_secret(&self) -> Result<()> {
let env_secret = std::env::var("JWT_SECRET").ok();
let jwt_secret = self
.security
.secret_key
.as_ref()
.or(self.secret.as_ref())
.or(env_secret.as_ref());
if let Some(secret) = jwt_secret {
if secret.len() < 32 {
return Err(AuthError::config(
"JWT secret must be at least 32 characters for security. \
Generate with: openssl rand -base64 32",
));
}
if !self.is_test_environment()
&& (secret.contains("secret")
|| secret.contains("password")
|| secret.contains("123"))
{
return Err(AuthError::config(
"JWT secret appears to contain common words or patterns. \
Use a cryptographically secure random string.",
));
}
if secret.len() < 44
&& secret
.chars()
.all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '=')
{
tracing::warn!(
"JWT secret may be too short for optimal security. \
Consider using at least 44 characters (32 bytes base64-encoded)."
);
}
} else if self.is_production_environment() {
return Err(AuthError::config(
"JWT secret is required for production environments. \
Set JWT_SECRET environment variable or configure security.secret_key",
));
}
Ok(())
}
fn validate_production_security(&self) -> Result<()> {
if self.security.min_password_length < 8 {
return Err(AuthError::config(
"Production environments require minimum password length of 8 characters",
));
}
if !self.security.require_password_complexity {
tracing::warn!("Production deployment should enable password complexity requirements");
}
if !self.security.secure_cookies {
return Err(AuthError::config(
"Production environments must use secure cookies (HTTPS required)",
));
}
if !self.rate_limiting.enabled {
tracing::warn!("Production deployment should enable rate limiting for security");
}
if !self.audit.enabled {
return Err(AuthError::config(
"Production environments require audit logging for compliance",
));
}
Ok(())
}
fn validate_storage_config(&self) -> Result<()> {
match &self.storage {
StorageConfig::Memory => {
if self.is_production_environment() && !self.is_test_environment() {
return Err(AuthError::config(
"Memory storage is not suitable for production environments. \
Use PostgreSQL, Redis, or MySQL storage.",
));
}
}
#[cfg(feature = "mysql-storage")]
StorageConfig::MySQL { .. } => {
tracing::warn!(
"MySQL storage has known RSA vulnerability (RUSTSEC-2023-0071). \
Consider using PostgreSQL for enhanced security."
);
}
_ => {} }
Ok(())
}
fn is_production_environment(&self) -> bool {
if let Ok(env) = std::env::var("ENVIRONMENT")
&& (env.to_lowercase() == "production" || env.to_lowercase() == "prod")
{
return true;
}
if let Ok(env) = std::env::var("ENV")
&& (env.to_lowercase() == "production" || env.to_lowercase() == "prod")
{
return true;
}
if let Ok(env) = std::env::var("NODE_ENV")
&& env.to_lowercase() == "production"
{
return true;
}
if let Ok(env) = std::env::var("RUST_ENV")
&& env.to_lowercase() == "production"
{
return true;
}
if std::env::var("KUBERNETES_SERVICE_HOST").is_ok() {
return true;
}
if std::env::var("DOCKER_CONTAINER").is_ok() {
return true;
}
false
}
fn is_test_environment(&self) -> bool {
cfg!(test)
|| std::thread::current()
.name()
.is_some_and(|name| name.contains("test"))
|| std::env::var("RUST_TEST").is_ok()
|| std::env::var("ENVIRONMENT").as_deref() == Ok("test")
|| std::env::var("ENV").as_deref() == Ok("test")
|| std::env::args().any(|arg| arg.contains("test"))
}
}
impl RateLimitConfig {
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
enabled: true,
max_requests,
window,
burst: max_requests / 10, }
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
}
impl SecurityConfig {
pub fn secure() -> Self {
Self {
min_password_length: 12,
require_password_complexity: true,
password_hash_algorithm: PasswordHashAlgorithm::Argon2,
jwt_algorithm: JwtAlgorithm::RS256,
secret_key: None,
secure_cookies: true,
cookie_same_site: CookieSameSite::Strict,
csrf_protection: true,
session_timeout: Duration::from_secs(3600 * 8), }
}
pub fn development() -> Self {
Self {
min_password_length: 6,
require_password_complexity: false,
password_hash_algorithm: PasswordHashAlgorithm::Bcrypt,
jwt_algorithm: JwtAlgorithm::HS256,
secret_key: None, secure_cookies: false,
cookie_same_site: CookieSameSite::Lax,
csrf_protection: false,
session_timeout: Duration::from_secs(3600 * 24), }
}
}