use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpConfig {
pub bind_address: String,
pub port: u16,
pub smith_service_url: String,
pub jwt_secret: String,
pub cors_enabled: bool,
pub rate_limit: RateLimitConfig,
pub websocket: WebSocketConfig,
pub security: SecurityConfig,
pub performance: PerformanceConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub burst_size: u32,
pub websocket_messages_per_minute: u32,
pub max_connections_per_ip: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSocketConfig {
pub max_message_size: usize,
#[serde(with = "duration_serde")]
pub ping_interval: Duration,
#[serde(with = "duration_serde")]
pub connection_timeout: Duration,
pub max_connections: usize,
pub event_buffer_size: usize,
#[serde(with = "duration_serde")]
pub heartbeat_interval: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
#[serde(with = "duration_serde")]
pub jwt_expiration: Duration,
pub require_auth_websocket: bool,
pub require_auth_api: bool,
pub https_only: bool,
pub trusted_proxies: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceConfig {
pub event_batch_size: usize,
#[serde(with = "duration_serde")]
pub event_batch_timeout: Duration,
pub smith_connection_pool_size: usize,
#[serde(with = "duration_serde")]
pub request_timeout: Duration,
pub enable_compression: bool,
pub max_request_size: usize,
}
impl Default for HttpConfig {
fn default() -> Self {
Self {
bind_address: "127.0.0.1".to_string(),
port: 3000,
smith_service_url: "tcp://127.0.0.1:7878".to_string(),
jwt_secret: "dev-secret-change-in-production-secure-key".to_string(),
cors_enabled: false,
rate_limit: RateLimitConfig::default(),
websocket: WebSocketConfig::default(),
security: SecurityConfig::default(),
performance: PerformanceConfig::default(),
}
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: 1000,
burst_size: 100,
websocket_messages_per_minute: 6000, max_connections_per_ip: 10,
}
}
}
impl Default for WebSocketConfig {
fn default() -> Self {
Self {
max_message_size: 64 * 1024, ping_interval: Duration::from_secs(30),
connection_timeout: Duration::from_secs(300), max_connections: 1000,
event_buffer_size: 1000,
heartbeat_interval: Duration::from_secs(10),
}
}
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
jwt_expiration: Duration::from_secs(24 * 60 * 60), require_auth_websocket: false, require_auth_api: false, https_only: false, trusted_proxies: vec![],
}
}
}
impl Default for PerformanceConfig {
fn default() -> Self {
Self {
event_batch_size: 50,
event_batch_timeout: Duration::from_millis(10), smith_connection_pool_size: 10,
request_timeout: Duration::from_secs(30),
enable_compression: true,
max_request_size: 16 * 1024 * 1024, }
}
}
impl HttpConfig {
pub fn validate(&self) -> Result<()> {
if self.bind_address.is_empty() {
return Err(anyhow::anyhow!("Bind address cannot be empty"));
}
if self.port < 1024 {
return Err(anyhow::anyhow!(
"Port must be between 1024 and 65535, got: {}",
self.port
));
}
if self.smith_service_url.is_empty() {
return Err(anyhow::anyhow!("Smith service URL cannot be empty"));
}
if self.jwt_secret.contains("dev-secret-change-in-production") {
tracing::warn!("⚠️ Using default JWT secret - change this in production!");
}
if self.jwt_secret.len() < 32 {
return Err(anyhow::anyhow!("JWT secret must be at least 32 characters"));
}
self.rate_limit.validate()?;
self.websocket.validate()?;
self.security.validate()?;
self.performance.validate()?;
Ok(())
}
pub fn development() -> Self {
Self {
bind_address: "127.0.0.1".to_string(),
port: 3000,
cors_enabled: true, security: SecurityConfig {
require_auth_websocket: false,
require_auth_api: false,
https_only: false,
..Default::default()
},
performance: PerformanceConfig {
event_batch_timeout: Duration::from_millis(50), ..Default::default()
},
..Default::default()
}
}
pub fn production() -> Self {
Self {
bind_address: "0.0.0.0".to_string(),
port: 3000,
cors_enabled: false, security: SecurityConfig {
require_auth_websocket: true,
require_auth_api: true,
https_only: true,
jwt_expiration: Duration::from_secs(8 * 60 * 60), trusted_proxies: vec!["127.0.0.1".to_string(), "::1".to_string()],
},
rate_limit: RateLimitConfig {
requests_per_minute: 2000,
burst_size: 200,
websocket_messages_per_minute: 12000, max_connections_per_ip: 20,
},
performance: PerformanceConfig {
event_batch_size: 100, event_batch_timeout: Duration::from_millis(5), smith_connection_pool_size: 20,
request_timeout: Duration::from_secs(15), enable_compression: true,
max_request_size: 8 * 1024 * 1024, },
..Default::default()
}
}
pub fn testing() -> Self {
Self {
bind_address: "127.0.0.1".to_string(),
port: 0, cors_enabled: true,
rate_limit: RateLimitConfig {
requests_per_minute: 100, burst_size: 50,
websocket_messages_per_minute: 600,
max_connections_per_ip: 5,
},
websocket: WebSocketConfig {
max_connections: 10, event_buffer_size: 100,
connection_timeout: Duration::from_secs(10), ..Default::default()
},
performance: PerformanceConfig {
request_timeout: Duration::from_secs(5), max_request_size: 1024 * 1024, ..Default::default()
},
..Default::default()
}
}
}
impl RateLimitConfig {
pub fn validate(&self) -> Result<()> {
if self.requests_per_minute == 0 {
return Err(anyhow::anyhow!(
"Rate limit requests_per_minute must be > 0"
));
}
if self.burst_size == 0 {
return Err(anyhow::anyhow!("Rate limit burst_size must be > 0"));
}
if self.websocket_messages_per_minute == 0 {
return Err(anyhow::anyhow!("WebSocket rate limit must be > 0"));
}
if self.max_connections_per_ip == 0 {
return Err(anyhow::anyhow!("Max connections per IP must be > 0"));
}
Ok(())
}
}
impl WebSocketConfig {
pub fn validate(&self) -> Result<()> {
if self.max_message_size < 1024 {
return Err(anyhow::anyhow!("WebSocket max message size must be >= 1KB"));
}
if self.max_message_size > 100 * 1024 * 1024 {
return Err(anyhow::anyhow!(
"WebSocket max message size must be <= 100MB"
));
}
if self.ping_interval.as_secs() == 0 {
return Err(anyhow::anyhow!("WebSocket ping interval must be > 0"));
}
if self.connection_timeout.as_secs() == 0 {
return Err(anyhow::anyhow!("WebSocket connection timeout must be > 0"));
}
if self.max_connections == 0 {
return Err(anyhow::anyhow!("WebSocket max_connections must be > 0"));
}
if self.event_buffer_size == 0 {
return Err(anyhow::anyhow!("WebSocket event buffer size must be > 0"));
}
if self.heartbeat_interval.as_secs() == 0 {
return Err(anyhow::anyhow!("WebSocket heartbeat interval must be > 0"));
}
Ok(())
}
}
impl SecurityConfig {
pub fn validate(&self) -> Result<()> {
if self.jwt_expiration.as_secs() == 0 {
return Err(anyhow::anyhow!("JWT expiration must be > 0"));
}
if self.jwt_expiration.as_secs() > 7 * 24 * 60 * 60 {
tracing::warn!("JWT expiration > 7 days may be a security risk");
}
for proxy in &self.trusted_proxies {
if proxy.parse::<std::net::IpAddr>().is_err() && proxy != "localhost" {
return Err(anyhow::anyhow!("Invalid trusted proxy address: {}", proxy));
}
}
Ok(())
}
}
impl PerformanceConfig {
pub fn validate(&self) -> Result<()> {
if self.event_batch_size == 0 {
return Err(anyhow::anyhow!("Event batch size must be > 0"));
}
if self.event_batch_timeout.as_millis() == 0 {
return Err(anyhow::anyhow!("Event batch timeout must be > 0"));
}
if self.event_batch_timeout.as_millis() > 100 {
tracing::warn!(
"⚠️ Event batch timeout > 100ms may not meet sub-100ms latency requirement"
);
}
if self.smith_connection_pool_size == 0 {
return Err(anyhow::anyhow!("Smith connection pool size must be > 0"));
}
if self.request_timeout.as_secs() == 0 {
return Err(anyhow::anyhow!("Request timeout must be > 0"));
}
if self.max_request_size < 1024 {
return Err(anyhow::anyhow!("Max request size must be >= 1KB"));
}
Ok(())
}
}
mod duration_serde {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u64(duration.as_millis() as u64)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(Duration::from_millis(millis))
}
}