use super::config::RateLimitConfig;
use axum::{
extract::Request,
http::StatusCode,
response::{IntoResponse, Response},
};
use governor::{
Quota, RateLimiter,
clock::DefaultClock,
middleware::NoOpMiddleware,
state::{InMemoryState, NotKeyed},
};
use std::{
num::NonZeroU32,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
};
use tower::{Layer, Service};
const SHRINK_INTERVAL: u64 = 1000;
const UNKNOWN_CLIENT_BUCKET: &str = "__tideway_unknown_client__";
#[derive(serde::Serialize)]
struct RateLimitError {
error: String,
message: String,
retry_after: u64,
}
impl IntoResponse for RateLimitError {
fn into_response(self) -> Response {
(
StatusCode::TOO_MANY_REQUESTS,
[("Retry-After", self.retry_after.to_string())],
axum::Json(self),
)
.into_response()
}
}
type GlobalLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
type KeyedLimiter = RateLimiter<
String,
governor::state::keyed::DashMapStateStore<String>,
DefaultClock,
NoOpMiddleware,
>;
#[derive(Clone)]
enum LimiterState {
Global(Arc<GlobalLimiter>),
PerIp(Arc<KeyedLimiter>),
}
#[derive(Clone)]
struct RateLimitState {
limiter: LimiterState,
config: RateLimitConfig,
request_count: Arc<AtomicU64>,
}
impl RateLimitState {
fn new(config: RateLimitConfig) -> Self {
let max_requests = NonZeroU32::new(config.max_requests.max(1)).unwrap_or(NonZeroU32::MIN);
let quota =
Quota::with_period(std::time::Duration::from_secs(config.window_seconds.max(1)))
.unwrap_or_else(|| Quota::per_second(max_requests))
.allow_burst(max_requests);
let limiter = if config.strategy == "per_ip" {
LimiterState::PerIp(Arc::new(RateLimiter::keyed(quota)))
} else {
LimiterState::Global(Arc::new(RateLimiter::direct(quota)))
};
Self {
limiter,
config,
request_count: Arc::new(AtomicU64::new(0)),
}
}
fn check_rate_limit(&self, key: Option<&str>) -> Result<(), u64> {
match &self.limiter {
LimiterState::PerIp(limiter) => {
let count = self.request_count.fetch_add(1, Ordering::Relaxed);
if count % SHRINK_INTERVAL == 0 && count > 0 {
limiter.retain_recent();
}
let bucket = key.unwrap_or(UNKNOWN_CLIENT_BUCKET);
match limiter.check_key(&bucket.to_string()) {
Ok(_) => Ok(()),
Err(not_until) => {
let wait = not_until
.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
Err(wait.as_secs().max(1))
}
}
}
LimiterState::Global(limiter) => match limiter.check() {
Ok(_) => Ok(()),
Err(not_until) => {
let wait = not_until
.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
Err(wait.as_secs().max(1))
}
},
}
}
}
#[derive(Clone)]
pub struct RateLimitLayer {
state: RateLimitState,
}
impl RateLimitLayer {
pub fn new(config: RateLimitConfig) -> Self {
Self {
state: RateLimitState::new(config),
}
}
}
impl<S> Layer<S> for RateLimitLayer {
type Service = RateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService {
inner,
state: self.state.clone(),
}
}
}
#[derive(Clone)]
pub struct RateLimitService<S> {
inner: S,
state: RateLimitState,
}
impl<S> Service<Request> for RateLimitService<S>
where
S: Service<Request> + Clone + Send + Sync + 'static,
S::Response: IntoResponse,
S::Future: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
let path = req.uri().path();
if path == "/health" || path.starts_with("/health/") {
let mut svc = self.inner.clone();
return Box::pin(async move {
let response = svc.call(req).await?;
Ok(response.into_response())
});
}
let ip: Option<String> = if self.state.config.trust_proxy {
req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim().to_string())
.or_else(|| {
req.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
})
.or_else(|| {
req.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|addr| addr.ip().to_string())
})
} else {
req.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|addr| addr.ip().to_string())
};
let key = ip.as_deref();
match self.state.check_rate_limit(key) {
Ok(()) => {
let mut svc = self.inner.clone();
Box::pin(async move {
let response = svc.call(req).await?;
Ok(response.into_response())
})
}
Err(retry_after) => {
let error = RateLimitError {
error: "rate_limit_exceeded".to_string(),
message: format!(
"Rate limit exceeded. Please try again in {} seconds",
retry_after
),
retry_after,
};
Box::pin(async move { Ok(error.into_response()) })
}
}
}
}
pub fn build_rate_limit_layer(config: &RateLimitConfig) -> Option<RateLimitLayer> {
if !config.enabled {
return None;
}
Some(RateLimitLayer::new(config.clone()))
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> RateLimitConfig {
RateLimitConfig {
enabled: true,
max_requests: 5,
window_seconds: 60,
strategy: "per_ip".to_string(),
trust_proxy: false,
}
}
#[test]
fn test_rate_limit_allows_requests_under_limit() {
let config = test_config();
let state = RateLimitState::new(config);
for i in 0..5 {
let result = state.check_rate_limit(Some("192.168.1.1"));
assert!(result.is_ok(), "Request {} should be allowed", i + 1);
}
}
#[test]
fn test_rate_limit_blocks_requests_over_limit() {
let config = test_config();
let state = RateLimitState::new(config);
for _ in 0..5 {
state.check_rate_limit(Some("192.168.1.1")).unwrap();
}
let result = state.check_rate_limit(Some("192.168.1.1"));
assert!(result.is_err(), "6th request should be blocked");
}
#[test]
fn test_rate_limit_per_ip_isolation() {
let config = test_config();
let state = RateLimitState::new(config);
for _ in 0..5 {
state.check_rate_limit(Some("192.168.1.1")).unwrap();
}
let result = state.check_rate_limit(Some("192.168.1.2"));
assert!(result.is_ok(), "Different IP should have separate quota");
}
#[test]
fn test_global_rate_limiting() {
let mut config = test_config();
config.strategy = "global".to_string();
let state = RateLimitState::new(config);
for _ in 0..5 {
state.check_rate_limit(Some("192.168.1.1")).unwrap();
}
let result = state.check_rate_limit(Some("192.168.1.2"));
assert!(result.is_err(), "Global limit should block all IPs");
}
#[test]
fn test_rate_limit_returns_retry_after() {
let config = RateLimitConfig {
enabled: true,
max_requests: 1,
window_seconds: 60,
strategy: "per_ip".to_string(),
trust_proxy: false,
};
let state = RateLimitState::new(config);
state.check_rate_limit(Some("192.168.1.1")).unwrap();
let result = state.check_rate_limit(Some("192.168.1.1"));
assert!(result.is_err());
if let Err(retry_after) = result {
assert!(retry_after > 0, "Should return positive retry_after");
assert!(retry_after <= 60, "retry_after should be within window");
}
}
#[test]
fn test_missing_ip_uses_shared_bucket() {
let config = RateLimitConfig {
enabled: true,
max_requests: 1,
window_seconds: 60,
strategy: "per_ip".to_string(),
trust_proxy: false,
};
let state = RateLimitState::new(config);
assert!(state.check_rate_limit(None).is_ok());
assert!(state.check_rate_limit(None).is_err());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let config = RateLimitConfig {
enabled: true,
max_requests: 100,
window_seconds: 60,
strategy: "per_ip".to_string(),
trust_proxy: false,
};
let state = RateLimitState::new(config);
let mut handles = vec![];
for i in 0..10 {
let state = state.clone();
handles.push(thread::spawn(move || {
for j in 0..50 {
let ip = format!("192.168.{}.{}", i, j % 256);
let _ = state.check_rate_limit(Some(&ip));
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let result = state.check_rate_limit(Some("10.0.0.1"));
assert!(result.is_ok(), "Should still work after concurrent access");
}
}