use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct TrafficShieldConfig {
pub max_event_loop_lag: Duration,
pub max_db_latency: Duration,
pub max_active_requests: usize,
pub enable_db_probe: bool,
}
impl Default for TrafficShieldConfig {
fn default() -> Self {
Self {
max_event_loop_lag: Duration::from_millis(100),
max_db_latency: Duration::from_millis(500),
max_active_requests: 1000,
enable_db_probe: true,
}
}
}
impl TrafficShieldConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_event_loop_lag(mut self, lag: Duration) -> Self {
self.max_event_loop_lag = lag;
self
}
pub fn with_max_db_latency(mut self, latency: Duration) -> Self {
self.max_db_latency = latency;
self
}
pub fn with_max_active_requests(mut self, limit: usize) -> Self {
self.max_active_requests = limit;
self
}
pub fn with_db_probe(mut self, enable: bool) -> Self {
self.enable_db_probe = enable;
self
}
}
#[derive(Clone)]
pub struct TrafficShield {
pub(crate) config: TrafficShieldConfig,
event_loop_lag_ms: Arc<AtomicU64>,
db_latency_ms: Arc<AtomicU64>,
active_requests: Arc<AtomicUsize>,
}
impl TrafficShield {
pub fn new(config: TrafficShieldConfig) -> Self {
let shield = Self {
config,
event_loop_lag_ms: Arc::new(AtomicU64::new(0)),
db_latency_ms: Arc::new(AtomicU64::new(0)),
active_requests: Arc::new(AtomicUsize::new(0)),
};
shield.spawn_monitors();
shield
}
fn spawn_monitors(&self) {
let lag_ms = self.event_loop_lag_ms.clone();
tokio::spawn(async move {
let interval = Duration::from_millis(100);
loop {
let start = Instant::now();
tokio::time::sleep(interval).await;
let elapsed = start.elapsed();
let lag = if elapsed > interval {
elapsed - interval
} else {
Duration::from_millis(0)
};
lag_ms.store(lag.as_millis() as u64, Ordering::Relaxed);
}
});
if self.config.enable_db_probe {
let db_lat_ms = self.db_latency_ms.clone();
tokio::spawn(async move {
let interval = Duration::from_millis(1000);
loop {
tokio::time::sleep(interval).await;
if let Some(pool) = crate::db::safe_pool() {
let start = Instant::now();
let res = sqlx::query("SELECT 1").execute(pool).await;
match res {
Ok(_) => {
let latency = start.elapsed();
db_lat_ms.store(latency.as_millis() as u64, Ordering::Relaxed);
}
Err(_) => {
db_lat_ms.store(9999, Ordering::Relaxed);
}
}
} else {
db_lat_ms.store(0, Ordering::Relaxed);
}
}
});
}
}
pub fn event_loop_lag(&self) -> Duration {
Duration::from_millis(self.event_loop_lag_ms.load(Ordering::Relaxed))
}
pub fn db_latency(&self) -> Duration {
Duration::from_millis(self.db_latency_ms.load(Ordering::Relaxed))
}
pub fn active_requests(&self) -> usize {
self.active_requests.load(Ordering::Relaxed)
}
}
pub async fn backpressure_middleware(shield: TrafficShield, req: Request, next: Next) -> Response {
let active = shield.active_requests.fetch_add(1, Ordering::SeqCst);
let max_active = shield.config.max_active_requests;
let lag = shield.event_loop_lag();
let db_lat = shield.db_latency();
let is_critical_cpu = lag >= shield.config.max_event_loop_lag;
let is_critical_db = shield.config.enable_db_probe && db_lat >= shield.config.max_db_latency;
let is_critical_active = active >= max_active;
if is_critical_cpu || is_critical_db || is_critical_active {
shield.active_requests.fetch_sub(1, Ordering::SeqCst);
eprintln!(
"⚠️ [Rullst Backpressure] Load shedding active! CPU lag: {:?}, DB latency: {:?}, Active requests: {}",
lag, db_lat, active
);
match Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.header(axum::http::header::RETRY_AFTER, "5")
.header(
axum::http::header::CONTENT_TYPE,
"text/plain; charset=utf-8",
)
.body(axum::body::Body::from(
"Service Temporarily Saturated. Please try again soon.",
)) {
Ok(res) => return res,
Err(_) => {
let mut res = Response::new(axum::body::Body::empty());
*res.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
return res;
}
}
}
let is_moderate_cpu = lag >= shield.config.max_event_loop_lag / 2;
let is_moderate_db =
shield.config.enable_db_probe && db_lat >= shield.config.max_db_latency / 2;
let is_moderate_active = active >= max_active / 2;
if is_moderate_cpu || is_moderate_db || is_moderate_active {
tokio::time::sleep(Duration::from_millis(25)).await;
}
let response = next.run(req).await;
shield.active_requests.fetch_sub(1, Ordering::SeqCst);
response
}
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct RateLimitConfig {
pub max_tokens: f64,
pub refill_rate: f64,
}
impl RateLimitConfig {
pub fn new(max_tokens: f64, refill_rate: f64) -> Self {
Self {
max_tokens,
refill_rate,
}
}
pub fn per_second(limit: f64) -> Self {
Self::new(limit, limit)
}
pub fn per_minute(limit: f64) -> Self {
Self::new(limit, limit / 60.0)
}
pub fn per_hour(limit: f64) -> Self {
Self::new(limit, limit / 3600.0)
}
}
#[derive(Clone, Debug)]
struct TokenBucket {
tokens: f64,
last_refill: Instant,
}
#[derive(Clone)]
pub struct RateLimiter {
pub(crate) config: RateLimitConfig,
buckets: Arc<DashMap<String, TokenBucket>>,
key_extractor: Arc<dyn Fn(&Request) -> String + Send + Sync>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: Arc::new(DashMap::new()),
key_extractor: Arc::new(default_key_extractor),
}
}
pub fn with_key_extractor<F>(mut self, extractor: F) -> Self
where
F: Fn(&Request) -> String + Send + Sync + 'static,
{
self.key_extractor = Arc::new(extractor);
self
}
pub fn check_and_consume(&self, key: &str) -> bool {
let now = Instant::now();
let mut entry = self
.buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket {
tokens: self.config.max_tokens,
last_refill: now,
});
let elapsed = now.duration_since(entry.last_refill).as_secs_f64();
let new_tokens = entry.tokens + elapsed * self.config.refill_rate;
entry.tokens = new_tokens.min(self.config.max_tokens);
entry.last_refill = now;
if entry.tokens >= 1.0 {
entry.tokens -= 1.0;
true
} else {
false
}
}
}
pub fn default_key_extractor(req: &Request) -> String {
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
if let Ok(s) = forwarded.to_str() {
if let Some(first_ip) = s.split(',').next() {
return first_ip.trim().to_string();
}
}
}
if let Some(real_ip) = req.headers().get("x-real-ip") {
if let Ok(s) = real_ip.to_str() {
return s.trim().to_string();
}
}
if let Some(conn_info) = req
.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
{
return conn_info.0.ip().to_string();
}
"anonymous".to_string()
}
pub async fn rate_limit_middleware(limiter: RateLimiter, req: Request, next: Next) -> Response {
let key = (limiter.key_extractor)(&req);
if limiter.check_and_consume(&key) {
next.run(req).await
} else {
match Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header(
axum::http::header::CONTENT_TYPE,
"text/plain; charset=utf-8",
)
.body(axum::body::Body::from(
"Rate limit exceeded. Please try again later.",
)) {
Ok(res) => res,
Err(_) => {
let mut res = Response::new(axum::body::Body::empty());
*res.status_mut() = StatusCode::TOO_MANY_REQUESTS;
res
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use axum::http::Request;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
#[test]
fn test_default_key_extractor() {
let req1 = Request::builder()
.header("x-forwarded-for", "192.168.1.1, 10.0.0.1")
.body(axum::body::Body::empty())
.unwrap();
assert_eq!(default_key_extractor(&req1), "192.168.1.1");
let req2 = Request::builder()
.header("x-real-ip", "10.0.0.2")
.body(axum::body::Body::empty())
.unwrap();
assert_eq!(default_key_extractor(&req2), "10.0.0.2");
let mut req3 = Request::builder().body(axum::body::Body::empty()).unwrap();
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
req3.extensions_mut()
.insert(axum::extract::ConnectInfo(socket));
assert_eq!(default_key_extractor(&req3), "127.0.0.1");
let req4 = Request::builder().body(axum::body::Body::empty()).unwrap();
assert_eq!(default_key_extractor(&req4), "anonymous");
}
#[tokio::test]
async fn test_traffic_shield_active_requests() {
let config = TrafficShieldConfig::new().with_db_probe(false);
let shield = TrafficShield::new(config);
assert_eq!(shield.active_requests(), 0);
shield
.active_requests
.fetch_add(5, std::sync::atomic::Ordering::SeqCst);
assert_eq!(shield.active_requests(), 5);
}
#[test]
fn test_rate_limit_config_per_minute() {
let config = RateLimitConfig::per_minute(60.0);
assert_eq!(config.max_tokens, 60.0);
assert_eq!(config.refill_rate, 1.0);
}
#[test]
fn test_rate_limit_config_per_second() {
let config = RateLimitConfig::per_second(10.0);
assert_eq!(config.max_tokens, 10.0);
assert_eq!(config.refill_rate, 10.0);
}
#[test]
fn test_rate_limit_config_per_hour() {
let config = RateLimitConfig::per_hour(3600.0);
assert_eq!(config.max_tokens, 3600.0);
assert_eq!(config.refill_rate, 1.0);
}
#[tokio::test]
async fn test_traffic_shield_db_latency() {
let config = TrafficShieldConfig::new().with_db_probe(false);
let shield = TrafficShield::new(config);
assert_eq!(shield.db_latency().as_millis(), 0);
shield
.db_latency_ms
.store(50, std::sync::atomic::Ordering::Relaxed);
assert_eq!(shield.db_latency().as_millis(), 50);
}
#[tokio::test]
async fn test_traffic_shield_event_loop_lag() {
let config = TrafficShieldConfig::new().with_db_probe(false);
let shield = TrafficShield::new(config);
assert_eq!(shield.event_loop_lag().as_millis(), 0);
shield
.event_loop_lag_ms
.store(100, std::sync::atomic::Ordering::Relaxed);
assert_eq!(shield.event_loop_lag().as_millis(), 100);
}
#[test]
fn test_check_and_consume() {
let config = RateLimitConfig::per_second(2.0); let limiter = RateLimiter::new(config);
assert!(limiter.check_and_consume("test_key"));
assert!(limiter.check_and_consume("test_key"));
assert!(!limiter.check_and_consume("test_key"));
}
}