use super::auth::{ApiKeyInfo, OAuthConfig};
use super::logging::LogLevel;
use super::{
AuthenticationMiddleware, LoggingMiddleware, MiddlewareChain, PerformanceMiddleware,
RateLimitMiddleware, ValidationMiddleware,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub authentication: AuthenticationConfig,
pub rate_limiting: RateLimitingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticationConfig {
pub enabled: bool,
pub require_auth: bool,
pub jwt_secret: String,
pub api_keys: Vec<ApiKeyConfig>,
pub oauth: Option<OAuth2Config>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeyConfig {
pub key: String,
pub key_id: String,
pub permissions: Vec<String>,
pub expires_at: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2Config {
pub client_id: String,
pub client_secret: String,
pub token_endpoint: String,
pub scopes: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitingConfig {
pub enabled: bool,
pub requests_per_minute: u32,
pub burst_limit: u32,
pub custom_limits: Option<HashMap<String, u32>>,
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
authentication: AuthenticationConfig {
enabled: true,
require_auth: false, jwt_secret: "your-secret-key-change-this-in-production".to_string(),
api_keys: vec![],
oauth: None,
},
rate_limiting: RateLimitingConfig {
enabled: true,
requests_per_minute: 60,
burst_limit: 10,
custom_limits: None,
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MiddlewareConfig {
pub logging: LoggingConfig,
pub validation: ValidationConfig,
pub performance: PerformanceConfig,
pub security: SecurityConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub enabled: bool,
pub level: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationConfig {
pub enabled: bool,
pub strict_mode: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceConfig {
pub enabled: bool,
pub slow_request_threshold_ms: u64,
}
impl Default for MiddlewareConfig {
fn default() -> Self {
Self {
logging: LoggingConfig {
enabled: true,
level: "info".to_string(),
},
validation: ValidationConfig {
enabled: true,
strict_mode: false,
},
performance: PerformanceConfig {
enabled: true,
slow_request_threshold_ms: 1000,
},
security: SecurityConfig::default(),
}
}
}
impl MiddlewareConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn build_chain(self) -> MiddlewareChain {
let mut chain = MiddlewareChain::new();
if self.security.authentication.enabled {
let api_keys: HashMap<String, ApiKeyInfo> = self
.security
.authentication
.api_keys
.into_iter()
.map(|config| {
let expires_at = config.expires_at.and_then(|date_str| {
chrono::DateTime::parse_from_rfc3339(&date_str)
.ok()
.map(|dt| dt.with_timezone(&chrono::Utc))
});
let api_key_info = ApiKeyInfo {
key_id: config.key_id,
permissions: config.permissions,
expires_at,
};
(config.key, api_key_info)
})
.collect();
let auth_middleware = if self.security.authentication.require_auth {
if let Some(oauth_config) = self.security.authentication.oauth {
let oauth = OAuthConfig {
client_id: oauth_config.client_id,
client_secret: oauth_config.client_secret,
token_endpoint: oauth_config.token_endpoint,
scope: oauth_config.scopes,
};
AuthenticationMiddleware::with_oauth(
api_keys,
self.security.authentication.jwt_secret,
oauth,
)
} else {
AuthenticationMiddleware::new(api_keys, self.security.authentication.jwt_secret)
}
} else {
AuthenticationMiddleware::permissive()
};
chain = chain.add_middleware(auth_middleware);
}
if self.security.rate_limiting.enabled {
let rate_limit_middleware = RateLimitMiddleware::with_limits(
self.security.rate_limiting.requests_per_minute,
self.security.rate_limiting.burst_limit,
);
chain = chain.add_middleware(rate_limit_middleware);
}
if self.logging.enabled {
let log_level = match self.logging.level.to_lowercase().as_str() {
"debug" => LogLevel::Debug,
"warn" => LogLevel::Warn,
"error" => LogLevel::Error,
_ => LogLevel::Info,
};
chain = chain.add_middleware(LoggingMiddleware::new(log_level));
}
if self.validation.enabled {
chain = chain.add_middleware(ValidationMiddleware::new(self.validation.strict_mode));
}
if self.performance.enabled {
let threshold = Duration::from_millis(self.performance.slow_request_threshold_ms);
chain = chain.add_middleware(PerformanceMiddleware::with_threshold(threshold));
}
chain
}
}