use axum::{
extract::{ConnectInfo, Request},
http::{HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use dashmap::DashMap;
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use thiserror::Error;
use tower::{Layer, Service};
#[derive(Debug, Error)]
pub enum RateLimitError {
#[error("Rate limit exceeded. Retry after {0} seconds")]
TooManyRequests(u64),
#[error("Unable to determine client IP address")]
IpNotAvailable,
}
impl IntoResponse for RateLimitError {
fn into_response(self) -> Response {
let (status, message) = match &self {
RateLimitError::TooManyRequests(secs) => {
let mut response = (
StatusCode::TOO_MANY_REQUESTS,
axum::Json(serde_json::json!({
"error": self.to_string(),
"code": "RATE_LIMIT_EXCEEDED",
"retry_after": secs
})),
)
.into_response();
response.headers_mut().insert(
"Retry-After",
HeaderValue::from(*secs),
);
response
.headers_mut()
.insert("X-RateLimit-Remaining", HeaderValue::from_static("0"));
return response;
}
RateLimitError::IpNotAvailable => (StatusCode::BAD_REQUEST, "IP address not available"),
};
(status, message).into_response()
}
}
#[derive(Debug, Clone)]
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64,
last_refill: Instant,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
let tokens_to_add = elapsed * self.refill_rate;
self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
self.last_refill = now;
}
fn consume(&mut self, amount: f64) -> Result<(), u64> {
self.refill();
if self.tokens >= amount {
self.tokens -= amount;
Ok(())
} else {
let tokens_needed = amount - self.tokens;
let retry_after = (tokens_needed / self.refill_rate).ceil() as u64;
Err(retry_after)
}
}
fn remaining(&mut self) -> u64 {
self.refill();
self.tokens.floor() as u64
}
}
#[derive(Clone)]
pub struct RateLimiter {
buckets: Arc<DashMap<IpAddr, TokenBucket>>,
requests_per_minute: u32,
burst_capacity: u32,
secure_ip: bool,
}
impl RateLimiter {
pub fn new(requests_per_minute: u32) -> Self {
Self {
buckets: Arc::new(DashMap::new()),
requests_per_minute,
burst_capacity: requests_per_minute, secure_ip: false,
}
}
pub fn with_burst_capacity(mut self, capacity: u32) -> Self {
self.burst_capacity = capacity;
self
}
pub fn with_secure_ip(mut self, secure: bool) -> Self {
self.secure_ip = secure;
self
}
pub fn check(&self, ip: IpAddr) -> Result<u64, RateLimitError> {
let refill_rate = self.requests_per_minute as f64 / 60.0;
let mut entry = self
.buckets
.entry(ip)
.or_insert_with(|| TokenBucket::new(self.burst_capacity as f64, refill_rate));
match entry.consume(1.0) {
Ok(()) => {
let remaining = entry.remaining();
Ok(remaining)
}
Err(retry_after) => Err(RateLimitError::TooManyRequests(retry_after)),
}
}
pub fn bucket_info(&self, ip: IpAddr) -> Option<(u64, u64)> {
self.buckets.get_mut(&ip).map(|mut bucket| {
bucket.refill();
(bucket.remaining(), self.burst_capacity as u64)
})
}
pub fn cleanup(&self, max_age: Duration) {
self.buckets
.retain(|_, bucket| bucket.last_refill.elapsed() < max_age);
}
pub fn into_layer(self) -> RateLimiterLayer {
RateLimiterLayer { limiter: self }
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new(100) }
}
#[derive(Clone)]
pub struct RateLimiterLayer {
limiter: RateLimiter,
}
impl<S> Layer<S> for RateLimiterLayer {
type Service = RateLimiterService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimiterService {
inner,
limiter: self.limiter.clone(),
}
}
}
#[derive(Clone)]
pub struct RateLimiterService<S> {
inner: S,
limiter: RateLimiter,
}
impl<S> Service<Request> for RateLimiterService<S>
where
S: Service<Request, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
let limiter = self.limiter.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let ip = if limiter.secure_ip {
extract_proxy_ip(&req)
.or_else(|| extract_connect_ip(&req))
} else {
extract_connect_ip(&req)
.or_else(|| extract_proxy_ip(&req))
};
let ip = match ip {
Some(ip) => ip,
None => {
IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
}
};
match limiter.check(ip) {
Ok(remaining) => {
let mut response = inner.call(req).await?;
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Limit",
HeaderValue::from(limiter.requests_per_minute),
);
headers.insert(
"X-RateLimit-Remaining",
HeaderValue::from(remaining),
);
Ok(response)
}
Err(err) => {
Ok(err.into_response())
}
}
})
}
}
fn extract_connect_ip(req: &Request) -> Option<IpAddr> {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip())
}
fn extract_proxy_ip(req: &Request) -> Option<IpAddr> {
if let Some(xff) = req.headers().get("x-forwarded-for") {
if let Ok(value) = xff.to_str() {
if let Some(first) = value.split(',').next() {
if let Ok(ip) = first.trim().parse::<IpAddr>() {
return Some(ip);
}
}
}
}
if let Some(xri) = req.headers().get("x-real-ip") {
if let Ok(value) = xri.to_str() {
if let Ok(ip) = value.trim().parse::<IpAddr>() {
return Some(ip);
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_token_bucket_new() {
let bucket = TokenBucket::new(100.0, 10.0);
assert_eq!(bucket.tokens, 100.0);
assert_eq!(bucket.capacity, 100.0);
}
#[test]
fn test_token_bucket_consume() {
let mut bucket = TokenBucket::new(10.0, 1.0);
assert!(bucket.consume(5.0).is_ok());
assert!(bucket.tokens >= 4.9 && bucket.tokens <= 5.1); assert!(bucket.consume(5.0).is_ok());
assert!(bucket.tokens < 0.1); assert!(bucket.consume(1.0).is_err());
}
#[test]
fn test_rate_limiter_check() {
let limiter = RateLimiter::new(60); let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
assert!(limiter.check(ip).is_ok());
for _ in 0..59 {
let _ = limiter.check(ip);
}
assert!(matches!(
limiter.check(ip),
Err(RateLimitError::TooManyRequests(_))
));
}
#[test]
fn test_rate_limiter_multiple_ips() {
let limiter = RateLimiter::new(10);
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
for _ in 0..10 {
assert!(limiter.check(ip1).is_ok());
}
assert!(limiter.check(ip1).is_err());
assert!(limiter.check(ip2).is_ok());
}
#[test]
fn test_bucket_info() {
let limiter = RateLimiter::new(100).with_burst_capacity(100);
let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
limiter.check(ip).unwrap();
let (remaining, capacity) = limiter.bucket_info(ip).unwrap();
assert_eq!(remaining, 99);
assert_eq!(capacity, 100);
limiter.check(ip).unwrap();
let (remaining, _) = limiter.bucket_info(ip).unwrap();
assert_eq!(remaining, 98);
}
#[test]
fn test_cleanup() {
let limiter = RateLimiter::new(100);
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
limiter.check(ip1).unwrap();
limiter.check(ip2).unwrap();
assert_eq!(limiter.buckets.len(), 2);
limiter.cleanup(Duration::from_secs(0));
assert_eq!(limiter.buckets.len(), 0);
}
#[tokio::test]
async fn test_token_bucket_refill() {
let mut bucket = TokenBucket::new(10.0, 10.0);
bucket.consume(10.0).unwrap();
assert_eq!(bucket.tokens, 0.0);
tokio::time::sleep(Duration::from_secs(1)).await;
bucket.refill();
assert!(bucket.tokens >= 9.0); }
}