use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::future::BoxFuture;
use tower::Layer;
use tower_service::Service;
use crate::rate_limit::{RateLimitInfo, RateLimitKey, RateLimiter};
pub trait KeyExtractor: Send + Sync + 'static {
fn extract_key<B>(&self, req: &http::Request<B>) -> RateLimitKey;
fn extract_endpoint<B>(&self, req: &http::Request<B>) -> String;
}
#[derive(Debug, Clone, Default)]
pub struct IpKeyExtractor;
impl KeyExtractor for IpKeyExtractor {
fn extract_key<B>(&self, req: &http::Request<B>) -> RateLimitKey {
let ip = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.map(str::trim)
.or_else(|| {
req.headers().get("x-real-ip").and_then(|v| v.to_str().ok())
})
.unwrap_or("unknown");
RateLimitKey::ip(ip)
}
fn extract_endpoint<B>(&self, req: &http::Request<B>) -> String {
let path = req.uri().path();
path.rsplit('/')
.next()
.filter(|s| !s.is_empty())
.unwrap_or("unknown")
.to_string()
}
}
#[derive(Debug, Clone)]
pub struct RateLimitLayer<K = IpKeyExtractor> {
limiter: RateLimiter,
key_extractor: Arc<K>,
}
impl RateLimitLayer<IpKeyExtractor> {
pub fn new(limiter: RateLimiter) -> Self {
Self {
limiter,
key_extractor: Arc::new(IpKeyExtractor),
}
}
}
impl<K: KeyExtractor> RateLimitLayer<K> {
pub fn with_key_extractor(limiter: RateLimiter, key_extractor: K) -> Self {
Self {
limiter,
key_extractor: Arc::new(key_extractor),
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitService<S, K = IpKeyExtractor> {
inner: S,
limiter: RateLimiter,
key_extractor: Arc<K>,
}
impl<S, K> RateLimitService<S, K> {
pub fn new(inner: S, limiter: RateLimiter, key_extractor: Arc<K>) -> Self {
Self {
inner,
limiter,
key_extractor,
}
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
}
impl<S, K: KeyExtractor + Clone> Layer<S> for RateLimitLayer<K> {
type Service = RateLimitService<S, K>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService::new(inner, self.limiter.clone(), Arc::clone(&self.key_extractor))
}
}
impl<S, K, B> Service<http::Request<B>> for RateLimitService<S, K>
where
S: Service<http::Request<B>> + Clone + Send + 'static,
S::Response: Send,
S::Error: Send,
S::Future: Send,
K: KeyExtractor,
B: Send + 'static,
{
type Response = Result<S::Response, RateLimitRejection<S::Error>>;
type Error = std::convert::Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(_)) => Poll::Ready(Ok(())), Poll::Pending => Poll::Pending,
}
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
let key = self.key_extractor.extract_key(&req);
let endpoint = self.key_extractor.extract_endpoint(&req);
let limiter = self.limiter.clone();
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
Box::pin(async move {
if let Err(info) = limiter.check(&key, &endpoint).await {
return Ok(Err(RateLimitRejection::RateLimited(info)));
}
match inner.call(req).await {
Ok(resp) => Ok(Ok(resp)),
Err(e) => Ok(Err(RateLimitRejection::Inner(e))),
}
})
}
}
#[derive(Debug)]
pub enum RateLimitRejection<E> {
RateLimited(RateLimitInfo),
Inner(E),
}
impl<E: std::fmt::Display> std::fmt::Display for RateLimitRejection<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::RateLimited(info) => write!(f, "Rate limited: {info}"),
Self::Inner(e) => write!(f, "{e}"),
}
}
}
impl<E: std::error::Error + 'static> std::error::Error for RateLimitRejection<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::RateLimited(info) => Some(info),
Self::Inner(e) => Some(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rate_limit::{EndpointLimit, RateLimitConfig};
use std::time::Duration;
#[test]
fn test_ip_key_extractor_x_forwarded_for() {
let extractor = IpKeyExtractor;
let req = http::Request::builder()
.header("x-forwarded-for", "192.168.1.100, 10.0.0.1")
.body(())
.unwrap();
let key = extractor.extract_key(&req);
assert_eq!(key.key_type, "ip");
assert_eq!(key.value, "192.168.1.100");
}
#[test]
fn test_ip_key_extractor_x_real_ip() {
let extractor = IpKeyExtractor;
let req = http::Request::builder()
.header("x-real-ip", "10.20.30.40")
.body(())
.unwrap();
let key = extractor.extract_key(&req);
assert_eq!(key.key_type, "ip");
assert_eq!(key.value, "10.20.30.40");
}
#[test]
fn test_ip_key_extractor_no_headers() {
let extractor = IpKeyExtractor;
let req = http::Request::builder().body(()).unwrap();
let key = extractor.extract_key(&req);
assert_eq!(key.key_type, "ip");
assert_eq!(key.value, "unknown");
}
#[test]
fn test_ip_key_extractor_endpoint_from_path() {
let extractor = IpKeyExtractor;
let req = http::Request::builder()
.uri("/oauth/token")
.body(())
.unwrap();
let endpoint = extractor.extract_endpoint(&req);
assert_eq!(endpoint, "token");
}
#[test]
fn test_ip_key_extractor_endpoint_root_path() {
let extractor = IpKeyExtractor;
let req = http::Request::builder().uri("/").body(()).unwrap();
let endpoint = extractor.extract_endpoint(&req);
assert_eq!(endpoint, "unknown");
}
#[test]
fn test_layer_creation() {
let limiter = RateLimiter::for_auth();
let _layer = RateLimitLayer::new(limiter);
}
#[test]
fn test_layer_with_custom_extractor() {
struct TestExtractor;
impl KeyExtractor for TestExtractor {
fn extract_key<B>(&self, _req: &http::Request<B>) -> RateLimitKey {
RateLimitKey::ip("test")
}
fn extract_endpoint<B>(&self, _req: &http::Request<B>) -> String {
"test".to_string()
}
}
let limiter = RateLimiter::for_auth();
let _layer = RateLimitLayer::with_key_extractor(limiter, TestExtractor);
}
#[tokio::test]
async fn test_service_allows_under_limit() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.default_limit(100, Duration::from_secs(60))
.build(),
);
let inner_service = tower::service_fn(|_req: http::Request<()>| async move {
Ok::<_, std::convert::Infallible>(http::Response::new(()))
});
let mut service = RateLimitService::new(inner_service, limiter, Arc::new(IpKeyExtractor));
let req = http::Request::builder()
.header("x-forwarded-for", "192.168.1.1")
.body(())
.unwrap();
let result = service.call(req).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
}
#[tokio::test]
async fn test_service_blocks_over_limit() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 1,
window: Duration::from_secs(60),
burst: 0,
},
)
.build(),
);
let inner_service = tower::service_fn(|_req: http::Request<()>| async move {
Ok::<_, std::convert::Infallible>(http::Response::new(()))
});
let mut service = RateLimitService::new(inner_service, limiter, Arc::new(IpKeyExtractor));
let req1 = http::Request::builder()
.uri("/test")
.header("x-forwarded-for", "192.168.1.1")
.body(())
.unwrap();
let result1 = service.call(req1).await.unwrap();
assert!(result1.is_ok());
let req2 = http::Request::builder()
.uri("/test")
.header("x-forwarded-for", "192.168.1.1")
.body(())
.unwrap();
let result2 = service.call(req2).await.unwrap();
assert!(result2.is_err());
match result2 {
Err(RateLimitRejection::RateLimited(info)) => {
assert_eq!(info.limit, 1);
assert!(info.retry_after.as_secs() > 0);
}
_ => panic!("Expected RateLimited error"),
}
}
#[tokio::test]
async fn test_service_different_ips_different_limits() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 1,
window: Duration::from_secs(60),
burst: 0,
},
)
.build(),
);
let inner_service = tower::service_fn(|_req: http::Request<()>| async move {
Ok::<_, std::convert::Infallible>(http::Response::new(()))
});
let mut service = RateLimitService::new(inner_service, limiter, Arc::new(IpKeyExtractor));
let req1 = http::Request::builder()
.uri("/test")
.header("x-forwarded-for", "192.168.1.1")
.body(())
.unwrap();
let result1 = service.call(req1).await.unwrap();
assert!(result1.is_ok());
let req2 = http::Request::builder()
.uri("/test")
.header("x-forwarded-for", "192.168.1.2")
.body(())
.unwrap();
let result2 = service.call(req2).await.unwrap();
assert!(result2.is_ok());
}
#[test]
fn test_rate_limit_rejection_display() {
use crate::rate_limit::RateLimitInfo;
let info = RateLimitInfo {
retry_after: Duration::from_secs(30),
current_count: 5,
limit: 3,
window: Duration::from_secs(60),
};
let rejection = RateLimitRejection::<std::io::Error>::RateLimited(info);
let display = format!("{rejection}");
assert!(display.contains("Rate limited"));
}
#[test]
fn test_rate_limit_rejection_inner_display() {
use std::io;
let err = io::Error::other("test error");
let rejection = RateLimitRejection::Inner(err);
let display = format!("{rejection}");
assert!(display.contains("test error"));
}
}