use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::path::PathBuf;
use fortress_core::config::Config as CoreConfig;
pub const DEFAULT_HOST: &str = "0.0.0.0";
pub const DEFAULT_PORT: u16 = 8080;
pub const DEFAULT_MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
pub const DEFAULT_REQUEST_TIMEOUT: u64 = 30;
pub const DEFAULT_CORS_ORIGINS: &[&str] = &["*"];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub network: NetworkConfig,
pub security: SecurityConfig,
pub core: CoreConfig,
pub features: FeatureFlags,
pub logging: LoggingConfig,
pub metrics: MetricsConfig,
pub storage: CoreConfig,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
network: NetworkConfig::default(),
security: SecurityConfig::default(),
core: CoreConfig::default(),
features: FeatureFlags::default(),
logging: LoggingConfig::default(),
metrics: MetricsConfig::default(),
storage: CoreConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkConfig {
pub host: String,
pub port: u16,
pub max_body_size: usize,
pub request_timeout: u64,
pub keep_alive: u64,
pub max_connections: usize,
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
host: DEFAULT_HOST.to_string(),
port: DEFAULT_PORT,
max_body_size: DEFAULT_MAX_BODY_SIZE,
request_timeout: DEFAULT_REQUEST_TIMEOUT,
keep_alive: 75,
max_connections: 10000,
}
}
}
impl NetworkConfig {
pub fn bind_address(&self) -> std::result::Result<SocketAddr, std::net::AddrParseError> {
format!("{}:{}", self.host, self.port).parse()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub jwt_secret: String,
pub token_expiration: u64,
pub cors: CorsConfig,
pub rate_limit: RateLimitConfig,
pub tls: Option<TlsConfig>,
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
jwt_secret: generate_default_jwt_secret(),
token_expiration: 3600, cors: CorsConfig::default(),
rate_limit: RateLimitConfig::default(),
tls: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorsConfig {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<String>,
pub allowed_headers: Vec<String>,
pub allow_credentials: bool,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
allowed_origins: DEFAULT_CORS_ORIGINS.iter().map(|s| s.to_string()).collect(),
allowed_methods: vec![
"GET".to_string(),
"POST".to_string(),
"PUT".to_string(),
"DELETE".to_string(),
"PATCH".to_string(),
"OPTIONS".to_string(),
],
allowed_headers: vec![
"Content-Type".to_string(),
"Authorization".to_string(),
"X-Requested-With".to_string(),
],
allow_credentials: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub enabled: bool,
pub requests_per_minute: u32,
pub requests_per_hour: u32,
pub burst_size: u32,
#[serde(default = "default_rate_limit_algorithm")]
pub algorithm: RateLimitAlgorithm,
#[serde(default)]
pub ddos_protection: DdosProtectionConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RateLimitAlgorithm {
TokenBucket,
SlidingWindow,
FixedWindow,
LeakyBucket,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DdosProtectionConfig {
pub enabled: bool,
pub global_rps_threshold: Option<u32>,
pub ip_rps_threshold: Option<u32>,
pub auto_block_threshold: Option<u32>,
pub block_duration_seconds: u64,
pub reputation_decay_rate: u8,
}
fn default_rate_limit_algorithm() -> RateLimitAlgorithm {
RateLimitAlgorithm::TokenBucket
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enabled: true,
requests_per_minute: 60,
requests_per_hour: 1000,
burst_size: 10,
algorithm: RateLimitAlgorithm::TokenBucket,
ddos_protection: DdosProtectionConfig::default(),
}
}
}
impl Default for DdosProtectionConfig {
fn default() -> Self {
Self {
enabled: true,
global_rps_threshold: None,
ip_rps_threshold: None,
auto_block_threshold: None,
block_duration_seconds: 300,
reputation_decay_rate: 10,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
pub cert_path: PathBuf,
pub key_path: PathBuf,
pub ca_path: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcConfig {
pub issuer_url: String,
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
pub enable_pkce: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureFlags {
pub auth_enabled: bool,
pub audit_enabled: bool,
pub metrics_enabled: bool,
pub health_enabled: bool,
pub multi_tenant: bool,
pub field_encryption: bool,
pub oidc_enabled: bool,
pub oidc_config: Option<OidcConfig>,
}
impl Default for FeatureFlags {
fn default() -> Self {
Self {
auth_enabled: true,
audit_enabled: true,
metrics_enabled: true,
health_enabled: true,
multi_tenant: true,
field_encryption: true,
oidc_enabled: false,
oidc_config: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub level: String,
pub json_format: bool,
pub file_path: Option<PathBuf>,
pub log_requests: bool,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: "info".to_string(),
json_format: false,
file_path: None,
log_requests: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig {
pub prometheus_enabled: bool,
pub metrics_path: String,
pub collection_interval: u64,
}
impl Default for MetricsConfig {
fn default() -> Self {
Self {
prometheus_enabled: true,
metrics_path: "/metrics".to_string(),
collection_interval: 60,
}
}
}
fn generate_default_jwt_secret() -> String {
use rand::Rng;
let mut secret = String::with_capacity(64);
let mut rng = rand::thread_rng();
for _ in 0..64 {
let chars = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
secret.push(chars[rng.gen_range(0..chars.len())] as char);
}
secret
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ServerConfig::default();
assert_eq!(config.network.host, DEFAULT_HOST);
assert_eq!(config.network.port, DEFAULT_PORT);
assert_eq!(config.security.token_expiration, 3600);
assert!(config.features.auth_enabled);
assert!(config.features.audit_enabled);
}
#[test]
fn test_bind_address() {
let network = NetworkConfig::default();
let addr = network.bind_address().expect("Default bind address should be valid");
assert_eq!(addr.port(), DEFAULT_PORT);
}
#[test]
fn test_jwt_secret_generation() {
let secret1 = generate_default_jwt_secret();
let secret2 = generate_default_jwt_secret();
assert_eq!(secret1.len(), 64);
assert_eq!(secret2.len(), 64);
assert_ne!(secret1, secret2);
}
}