use axum::{body::Body, extract::ConnectInfo, http::Request, http::Response};
use std::net::SocketAddr;
use tower::{Layer, Service};
use super::types::{RateLimitBackend, RateLimitConfig, RateLimitError, RateLimitResult};
use crate::utils::extract_client_ip;
#[derive(Clone)]
pub enum KeyExtractor {
IpOnly,
IpAndPath,
}
#[derive(Clone)]
pub struct RateLimitLayer {
backend: RateLimitBackend,
config: RateLimitConfig,
key_extractor: KeyExtractor,
trust_proxy: bool,
}
impl RateLimitLayer {
pub fn new<B: Into<RateLimitBackend>>(backend: B, config: RateLimitConfig) -> Self {
Self {
backend: backend.into(),
config,
key_extractor: KeyExtractor::IpAndPath,
trust_proxy: false,
}
}
pub fn with_key_extractor(mut self, extractor: KeyExtractor) -> Self {
self.key_extractor = extractor;
self
}
pub fn with_trust_proxy(mut self, trust_proxy: bool) -> Self {
self.trust_proxy = trust_proxy;
self
}
}
impl<S> Layer<S> for RateLimitLayer {
type Service = RateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService {
inner,
backend: self.backend.clone(),
config: self.config.clone(),
key_extractor: self.key_extractor.clone(),
trust_proxy: self.trust_proxy,
}
}
}
#[derive(Clone)]
pub struct RateLimitService<S> {
inner: S,
backend: RateLimitBackend,
config: RateLimitConfig,
key_extractor: KeyExtractor,
trust_proxy: bool,
}
impl<S> Service<Request<Body>> for RateLimitService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let backend = self.backend.clone();
let config = self.config.clone();
let key_extractor = self.key_extractor.clone();
let trust_proxy = self.trust_proxy;
let mut inner = self.inner.clone();
Box::pin(async move {
let key = build_rate_limit_key(&req, key_extractor, trust_proxy);
let result = backend.check_and_record(&key, &config).await;
if !result.allowed {
let error = RateLimitError {
code: "RATE_LIMITED",
message: "Too many requests. Please try again later.".to_string(),
retry_after: result.reset_secs,
};
let mut response = axum::response::IntoResponse::into_response(error);
add_rate_limit_headers(&mut response, &result);
return Ok(response);
}
let mut response = inner.call(req).await?;
add_rate_limit_headers(&mut response, &result);
Ok(response)
})
}
}
fn add_rate_limit_headers(response: &mut Response<Body>, result: &RateLimitResult) {
use axum::http::HeaderValue;
let headers = response.headers_mut();
if let Ok(limit) = HeaderValue::from_str(&result.limit.to_string()) {
headers.insert("X-RateLimit-Limit", limit);
}
if let Ok(remaining) = HeaderValue::from_str(&result.remaining.to_string()) {
headers.insert("X-RateLimit-Remaining", remaining);
}
if let Ok(reset) = HeaderValue::from_str(&result.reset_secs.to_string()) {
headers.insert("X-RateLimit-Reset", reset);
}
}
fn build_rate_limit_key(
req: &Request<Body>,
key_extractor: KeyExtractor,
trust_proxy: bool,
) -> String {
let ip = resolve_client_ip(req, trust_proxy);
let path = req.uri().path();
match key_extractor {
KeyExtractor::IpOnly => ip.unwrap_or_else(|| {
tracing::warn!(
path = path,
"Rate limit: unable to resolve client IP, using shared fallback key"
);
format!("unknown:{}", path)
}),
KeyExtractor::IpAndPath => match ip {
Some(ip) => format!("{}:{}", ip, path),
None => {
tracing::warn!(
path = path,
"Rate limit: unable to resolve client IP, using shared fallback key"
);
format!("unknown:{}", path)
}
},
}
}
fn resolve_client_ip(req: &Request<Body>, trust_proxy: bool) -> Option<String> {
if trust_proxy {
extract_client_ip(req.headers(), true).or_else(|| {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
})
} else {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::rate_limit::RateLimitStore;
#[test]
fn test_key_extractor_clone() {
let extractor = KeyExtractor::IpOnly;
let cloned = extractor.clone();
assert!(matches!(cloned, KeyExtractor::IpOnly));
let extractor = KeyExtractor::IpAndPath;
let cloned = extractor.clone();
assert!(matches!(cloned, KeyExtractor::IpAndPath));
}
#[test]
fn test_rate_limit_layer_creation() {
let store = RateLimitStore::new();
let config = RateLimitConfig {
limit: 10,
window_secs: 60,
};
let layer = RateLimitLayer::new(store.clone(), config.clone());
let _ = layer.with_key_extractor(KeyExtractor::IpOnly);
}
#[test]
fn test_build_rate_limit_key_connect_info() {
let mut req = Request::builder()
.uri("/login")
.body(Body::empty())
.unwrap();
let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let key = build_rate_limit_key(&req, KeyExtractor::IpOnly, false);
assert_eq!(key, "127.0.0.1");
let key = build_rate_limit_key(&req, KeyExtractor::IpAndPath, false);
assert_eq!(key, "127.0.0.1:/login");
}
#[test]
fn test_build_rate_limit_key_trust_proxy() {
let mut req = Request::builder()
.uri("/login")
.body(Body::empty())
.unwrap();
req.headers_mut().insert(
"x-forwarded-for",
axum::http::HeaderValue::from_static("192.168.1.10"),
);
let key = build_rate_limit_key(&req, KeyExtractor::IpOnly, true);
assert_eq!(key, "192.168.1.10");
}
#[test]
fn test_build_rate_limit_key_trust_proxy_prefers_xff() {
let mut req = Request::builder()
.uri("/login")
.body(Body::empty())
.unwrap();
let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
req.headers_mut().insert(
"x-forwarded-for",
axum::http::HeaderValue::from_static("203.0.113.9"),
);
let key = build_rate_limit_key(&req, KeyExtractor::IpOnly, true);
assert_eq!(key, "203.0.113.9");
}
#[test]
fn test_build_rate_limit_key_missing_ip_fallback() {
let req = Request::builder()
.uri("/login")
.body(Body::empty())
.unwrap();
let key = build_rate_limit_key(&req, KeyExtractor::IpOnly, false);
assert_eq!(key, "unknown:/login");
let key = build_rate_limit_key(&req, KeyExtractor::IpAndPath, false);
assert_eq!(key, "unknown:/login");
}
}