use crate::middleware::Middleware;
use crate::{Request, Response};
use http::header::{HeaderName, HeaderValue};
use http::StatusCode;
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tracing::{debug, warn};
use super::Next;
const DEFAULT_RATE_LIMIT: u32 = 1000;
#[derive(Debug, Clone)]
pub struct CspConfig {
pub default_src: Vec<String>,
pub script_src: Vec<String>,
pub style_src: Vec<String>,
pub img_src: Vec<String>,
pub connect_src: Vec<String>,
pub font_src: Vec<String>,
pub object_src: Vec<String>,
pub report_uri: Option<String>,
}
impl Default for CspConfig {
fn default() -> Self {
Self {
default_src: vec!["'self'".to_string()],
script_src: vec!["'self'".to_string()],
style_src: vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
img_src: vec!["'self'".to_string(), "data:".to_string()],
connect_src: vec!["'self'".to_string()],
font_src: vec!["'self'".to_string()],
object_src: vec!["'none'".to_string()],
report_uri: None,
}
}
}
#[derive(Debug, Clone)]
struct RateLimitBucket {
tokens: u32,
last_refill: Instant,
capacity: u32,
}
impl RateLimitBucket {
fn new(capacity: u32) -> Self {
Self {
tokens: capacity,
last_refill: Instant::now(),
capacity,
}
}
fn try_consume(&mut self, refill_rate: u32, window: Duration) -> bool {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
if elapsed >= window {
self.tokens = self.capacity;
self.last_refill = now;
} else {
let tokens_to_add =
(elapsed.as_secs_f64() / window.as_secs_f64() * refill_rate as f64) as u32;
self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
if tokens_to_add > 0 {
self.last_refill = now;
}
}
if self.tokens > 0 {
self.tokens -= 1;
true
} else {
false
}
}
}
#[derive(Debug)]
pub struct SecurityMiddleware {
enable_hsts: bool,
hsts_max_age: u32,
hsts_include_subdomains: bool,
hsts_preload: bool,
enable_csp: bool,
csp_config: CspConfig,
enable_rate_limiting: bool,
rate_limit_buckets: Arc<Mutex<HashMap<IpAddr, RateLimitBucket>>>,
rate_limit_max: u32,
rate_limit_window: Duration,
enable_security_headers: bool,
remove_server_header: bool,
}
impl Default for SecurityMiddleware {
fn default() -> Self {
Self {
enable_hsts: true,
hsts_max_age: 31536000, hsts_include_subdomains: true,
hsts_preload: false,
enable_csp: true,
csp_config: CspConfig::default(),
enable_rate_limiting: true,
rate_limit_buckets: Arc::new(Mutex::new(HashMap::new())),
rate_limit_max: DEFAULT_RATE_LIMIT,
rate_limit_window: Duration::from_secs(60),
enable_security_headers: true,
remove_server_header: true,
}
}
}
impl SecurityMiddleware {
pub fn new() -> Self {
Self::default()
}
pub fn with_hsts(mut self, enabled: bool) -> Self {
self.enable_hsts = enabled;
self
}
pub fn with_hsts_config(
mut self,
max_age: u32,
include_subdomains: bool,
preload: bool,
) -> Self {
self.hsts_max_age = max_age;
self.hsts_include_subdomains = include_subdomains;
self.hsts_preload = preload;
self
}
pub fn with_csp(mut self, config: CspConfig) -> Self {
self.enable_csp = true;
self.csp_config = config;
self
}
pub fn with_rate_limit(mut self, max_requests: u32, window: Duration) -> Self {
self.enable_rate_limiting = true;
self.rate_limit_max = max_requests;
self.rate_limit_window = window;
self
}
fn get_client_ip(&self, req: &Request) -> Option<IpAddr> {
if let Some(forwarded) = req.headers.get("x-forwarded-for") {
if let Ok(forwarded_str) = forwarded.to_str() {
if let Some(first_ip) = forwarded_str.split(',').next() {
if let Ok(ip) = first_ip.trim().parse() {
return Some(ip);
}
}
}
}
if let Some(real_ip) = req.headers.get("x-real-ip") {
if let Ok(ip_str) = real_ip.to_str() {
if let Ok(ip) = ip_str.parse() {
return Some(ip);
}
}
}
Some("127.0.0.1".parse().unwrap())
}
fn check_rate_limit(&self, ip: IpAddr) -> bool {
if !self.enable_rate_limiting {
return true;
}
let mut buckets = self.rate_limit_buckets.lock().unwrap();
let bucket = buckets
.entry(ip)
.or_insert_with(|| RateLimitBucket::new(self.rate_limit_max));
bucket.try_consume(self.rate_limit_max, self.rate_limit_window)
}
fn build_csp_header(&self) -> String {
let mut directives = Vec::new();
if !self.csp_config.default_src.is_empty() {
directives.push(format!(
"default-src {}",
self.csp_config.default_src.join(" ")
));
}
if !self.csp_config.script_src.is_empty() {
directives.push(format!(
"script-src {}",
self.csp_config.script_src.join(" ")
));
}
if !self.csp_config.style_src.is_empty() {
directives.push(format!("style-src {}", self.csp_config.style_src.join(" ")));
}
if !self.csp_config.img_src.is_empty() {
directives.push(format!("img-src {}", self.csp_config.img_src.join(" ")));
}
if !self.csp_config.connect_src.is_empty() {
directives.push(format!(
"connect-src {}",
self.csp_config.connect_src.join(" ")
));
}
if !self.csp_config.font_src.is_empty() {
directives.push(format!("font-src {}", self.csp_config.font_src.join(" ")));
}
if !self.csp_config.object_src.is_empty() {
directives.push(format!(
"object-src {}",
self.csp_config.object_src.join(" ")
));
}
if let Some(report_uri) = &self.csp_config.report_uri {
directives.push(format!("report-uri {}", report_uri));
}
directives.join("; ")
}
}
#[async_trait::async_trait]
impl Middleware for SecurityMiddleware {
async fn handle(&self, req: Request, next: Next) -> Response {
if self.enable_rate_limiting {
if let Some(client_ip) = self.get_client_ip(&req) {
if !self.check_rate_limit(client_ip) {
warn!(ip = %client_ip, "Rate limit exceeded");
return Response::text("Too Many Requests")
.with_status(StatusCode::TOO_MANY_REQUESTS);
}
}
}
debug!("Security validations passed");
let mut response = next.run(req).await;
if self.enable_hsts {
let mut hsts_value = format!("max-age={}", self.hsts_max_age);
if self.hsts_include_subdomains {
hsts_value.push_str("; includeSubDomains");
}
if self.hsts_preload {
hsts_value.push_str("; preload");
}
if let Ok(value) = HeaderValue::from_str(&hsts_value) {
response
.headers
.insert(HeaderName::from_static("strict-transport-security"), value);
}
}
if self.enable_csp {
let csp_value = self.build_csp_header();
if let Ok(value) = HeaderValue::from_str(&csp_value) {
response
.headers
.insert(HeaderName::from_static("content-security-policy"), value);
}
}
if self.enable_security_headers {
response.headers.insert(
HeaderName::from_static("x-frame-options"),
HeaderValue::from_static("DENY"),
);
response.headers.insert(
HeaderName::from_static("x-content-type-options"),
HeaderValue::from_static("nosniff"),
);
response.headers.insert(
HeaderName::from_static("referrer-policy"),
HeaderValue::from_static("strict-origin-when-cross-origin"),
);
response.headers.insert(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("geolocation=(), microphone=(), camera=()"),
);
response.headers.insert(
HeaderName::from_static("x-xss-protection"),
HeaderValue::from_static("0"), );
}
if self.remove_server_header {
response.headers.remove("server");
response.headers.remove("x-powered-by");
}
debug!("Security headers added to response");
response
}
}
impl SecurityMiddleware {
pub fn for_api() -> Self {
Self::new().with_rate_limit(2000, Duration::from_secs(60)) }
pub fn for_web() -> Self {
Self::new()
.with_rate_limit(500, Duration::from_secs(60))
.with_csp(CspConfig {
default_src: vec!["'self'".to_string()],
script_src: vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
style_src: vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
img_src: vec![
"'self'".to_string(),
"data:".to_string(),
"https:".to_string(),
],
font_src: vec!["'self'".to_string(), "https:".to_string()],
..Default::default()
})
}
pub fn high_security() -> Self {
Self::new()
.with_rate_limit(100, Duration::from_secs(60)) .with_hsts_config(63072000, true, true) .with_csp(CspConfig {
default_src: vec!["'none'".to_string()],
script_src: vec!["'self'".to_string()],
style_src: vec!["'self'".to_string()],
img_src: vec!["'self'".to_string()],
connect_src: vec!["'self'".to_string()],
font_src: vec!["'self'".to_string()],
object_src: vec!["'none'".to_string()],
..Default::default()
})
}
pub fn for_development() -> Self {
Self::new()
.with_rate_limit(10000, Duration::from_secs(60)) .with_hsts(false) }
}