#![cfg(feature = "axum")]
use axum::{
body::Body,
http::{Request, StatusCode},
routing::get,
Router,
};
use rok_rate_limit::{Limiter, ThrottleLayer, ThrottleRule};
use tower::ServiceExt;
fn ok_handler() -> Router {
Router::new().route("/", get(|| async { "ok" }))
}
#[tokio::test]
async fn global_layer_allows_within_limit() {
let limiter = Limiter::memory();
let layer = ThrottleLayer::new(limiter, vec![ThrottleRule::global("test:allow", 10, 60)]);
let app = ok_handler().layer(layer);
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn global_layer_returns_429_when_exceeded() {
let limiter = Limiter::memory();
let layer = ThrottleLayer::new(
limiter.clone(),
vec![ThrottleRule::global("test:exceed-gl", 1, 60)],
);
let app = ok_handler().layer(layer.clone());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let app = ok_handler().layer(layer);
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn response_contains_ratelimit_headers_when_allowed() {
let limiter = Limiter::memory();
let layer = ThrottleLayer::new(limiter, vec![ThrottleRule::global("test:headers", 5, 60)]);
let app = ok_handler().layer(layer);
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(
resp.headers().contains_key("x-ratelimit-limit"),
"should have x-ratelimit-limit"
);
assert!(
resp.headers().contains_key("x-ratelimit-remaining"),
"should have x-ratelimit-remaining"
);
assert!(
resp.headers().contains_key("x-ratelimit-reset"),
"should have x-ratelimit-reset"
);
}
#[tokio::test]
async fn response_contains_retry_after_when_exceeded() {
let limiter = Limiter::memory();
for _ in 0..=1 {
limiter
.for_key("test:retry-hdr")
.requests(1)
.per(std::time::Duration::from_secs(60))
.check();
}
let layer = ThrottleLayer::new(limiter, vec![ThrottleRule::global("test:retry-hdr", 1, 60)]);
let app = ok_handler().layer(layer);
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
assert!(
resp.headers().contains_key("retry-after"),
"should have Retry-After header"
);
}
#[tokio::test]
async fn by_ip_layer_uses_forwarded_for_header() {
let limiter = Limiter::memory();
let layer = ThrottleLayer::new(limiter.clone(), vec![ThrottleRule::by_ip("test:ip", 1, 60)]);
let app = ok_handler().layer(layer.clone());
let resp = app
.oneshot(
Request::builder()
.uri("/")
.header("x-forwarded-for", "1.1.1.1")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let app = ok_handler().layer(layer.clone());
let resp = app
.oneshot(
Request::builder()
.uri("/")
.header("x-forwarded-for", "2.2.2.2")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let app = ok_handler().layer(layer);
let resp = app
.oneshot(
Request::builder()
.uri("/")
.header("x-forwarded-for", "1.1.1.1")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
}