use http::StatusCode;
use rapina::middleware::{
BodyLimitMiddleware, CompressionConfig, CorsConfig, RateLimitConfig, RateLimitMiddleware,
TRACE_ID_HEADER, TimeoutMiddleware, TraceIdMiddleware,
};
use rapina::prelude::*;
use rapina::testing::TestClient;
use std::time::Duration;
#[tokio::test]
async fn test_middleware_execution() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "Hello!" }));
let client = TestClient::new(app).await;
let response = client.get("/").send().await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().get(TRACE_ID_HEADER).is_some());
}
#[tokio::test]
async fn test_trace_id_middleware_adds_header() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.router(Router::new().route(http::Method::GET, "/health", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client.get("/health").send().await;
assert_eq!(response.status(), StatusCode::OK);
let trace_id = response.headers().get(TRACE_ID_HEADER);
assert!(trace_id.is_some());
let trace_id_str = trace_id.unwrap().to_str().unwrap();
assert_eq!(trace_id_str.len(), 36);
}
#[tokio::test]
async fn test_trace_id_unique_per_request() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response1 = client.get("/").send().await;
let response2 = client.get("/").send().await;
let trace_id1 = response1
.headers()
.get(TRACE_ID_HEADER)
.unwrap()
.to_str()
.unwrap();
let trace_id2 = response2
.headers()
.get(TRACE_ID_HEADER)
.unwrap()
.to_str()
.unwrap();
assert_ne!(trace_id1, trace_id2);
}
#[tokio::test]
async fn test_timeout_middleware_passes_fast_request() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TimeoutMiddleware::new(Duration::from_secs(5)))
.router(
Router::new().route(http::Method::GET, "/fast", |_, _, _| async {
"fast response"
}),
);
let client = TestClient::new(app).await;
let response = client.get("/fast").send().await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.text(), "fast response");
}
#[tokio::test]
async fn test_body_limit_middleware_allows_small_body() {
let app = Rapina::new()
.with_introspection(false)
.middleware(BodyLimitMiddleware::new(1024 * 1024)) .router(
Router::new().route(http::Method::POST, "/upload", |req, _, _| async move {
use http_body_util::BodyExt;
let body = req.into_body().collect().await.unwrap().to_bytes();
format!("Received {} bytes", body.len())
}),
);
let client = TestClient::new(app).await;
let response = client.post("/upload").body("small payload").send().await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.text().contains("13 bytes")); }
#[tokio::test]
async fn test_multiple_middlewares() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.middleware(TimeoutMiddleware::new(Duration::from_secs(30)))
.middleware(BodyLimitMiddleware::new(1024 * 1024))
.router(Router::new().route(http::Method::GET, "/multi", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client.get("/multi").send().await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.text(), "ok");
assert!(response.headers().get(TRACE_ID_HEADER).is_some());
}
#[tokio::test]
async fn test_middleware_order_trace_id_first() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.middleware(TimeoutMiddleware::default())
.router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client.get("/").send().await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().get(TRACE_ID_HEADER).is_some());
}
#[tokio::test]
async fn test_middleware_with_post_request() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.router(
Router::new().route(http::Method::POST, "/data", |req, _, _| async move {
use http_body_util::BodyExt;
let body = req.into_body().collect().await.unwrap().to_bytes();
String::from_utf8_lossy(&body).to_string()
}),
);
let client = TestClient::new(app).await;
let response = client
.post("/data")
.json(&serde_json::json!({"key": "value"}))
.send()
.await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().get(TRACE_ID_HEADER).is_some());
assert!(response.text().contains("key"));
}
#[tokio::test]
async fn test_default_timeout_middleware() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TimeoutMiddleware::default()) .router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client.get("/").send().await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_default_body_limit_middleware() {
let app = Rapina::new()
.with_introspection(false)
.middleware(BodyLimitMiddleware::default()) .router(Router::new().route(http::Method::POST, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client.post("/").body("test").send().await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_middleware_preserves_response_body() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.router(
Router::new().route(http::Method::GET, "/json", |_, _, _| async {
Json(serde_json::json!({
"status": "success",
"data": [1, 2, 3]
}))
}),
);
let client = TestClient::new(app).await;
let response = client.get("/json").send().await;
assert_eq!(response.status(), StatusCode::OK);
let json: serde_json::Value = response.json();
assert_eq!(json["status"], "success");
assert_eq!(json["data"], serde_json::json!([1, 2, 3]));
}
#[tokio::test]
async fn test_middleware_with_error_response() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.router(
Router::new().route(http::Method::GET, "/error", |_, _, _| async {
Error::not_found("resource not found")
}),
);
let client = TestClient::new(app).await;
let response = client.get("/error").send().await;
assert_eq!(response.status(), StatusCode::NOT_FOUND);
assert!(response.headers().get(TRACE_ID_HEADER).is_some());
}
#[tokio::test]
async fn test_middleware_with_404() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.router(Router::new().route(http::Method::GET, "/exists", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client.get("/does-not-exist").send().await;
assert_eq!(response.status(), StatusCode::NOT_FOUND);
assert!(response.headers().get(TRACE_ID_HEADER).is_some());
}
#[tokio::test]
async fn test_cors_preflight_returns_204() {
let app = Rapina::new()
.with_introspection(false)
.with_cors(CorsConfig::permissive())
.router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client
.request(http::Method::OPTIONS, "/")
.header("Origin", "http://userapina.com")
.send()
.await;
assert_eq!(response.status(), StatusCode::NO_CONTENT);
assert!(
response
.headers()
.get("access-control-allow-origin")
.is_some()
);
assert!(
response
.headers()
.get("access-control-allow-methods")
.is_some()
);
}
#[tokio::test]
async fn test_cors_rejects_disallowed_origin() {
let app = Rapina::new()
.with_introspection(false)
.with_cors(CorsConfig::with_origins(vec![
"http://userapina.com".to_string(),
]))
.router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client
.request(http::Method::GET, "/")
.header("Origin", "http://evil.com")
.send()
.await;
assert_eq!(response.status(), StatusCode::OK);
assert!(
response
.headers()
.get("access-control-allow-origin")
.is_none()
);
}
#[tokio::test]
async fn test_cors_allows_matching_origin() {
let app = Rapina::new()
.with_introspection(false)
.with_cors(CorsConfig::with_origins(vec![
"http://userapina.com".to_string(),
]))
.router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client
.request(http::Method::GET, "/")
.header("Origin", "http://userapina.com")
.send()
.await;
assert_eq!(response.status(), StatusCode::OK);
let origin_header = response.headers().get("access-control-allow-origin");
assert!(origin_header.is_some());
assert_eq!(
origin_header.unwrap().to_str().unwrap(),
"http://userapina.com"
);
}
#[tokio::test]
async fn test_cors_permissive_returns_wildcard() {
let app = Rapina::new()
.with_introspection(false)
.with_cors(CorsConfig::permissive())
.router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let response = client
.request(http::Method::OPTIONS, "/")
.header("Origin", "http://any.com")
.send()
.await;
let origin_header = response.headers().get("access-control-allow-origin");
assert_eq!(origin_header.unwrap().to_str().unwrap(), "*");
}
#[tokio::test]
async fn test_rate_limit_allows_under_limit() {
let app = Rapina::new()
.with_introspection(false)
.with_rate_limit(RateLimitConfig::new(100.0, 10)) .router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
for _ in 0..5 {
let response = client.get("/").send().await;
assert_eq!(response.status(), StatusCode::OK);
}
}
#[tokio::test]
async fn test_rate_limit_returns_429_when_exceeded() {
let app = Rapina::new()
.with_introspection(false)
.middleware(RateLimitMiddleware::new(RateLimitConfig::new(1.0, 2))) .router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
assert_eq!(client.get("/").send().await.status(), StatusCode::OK);
assert_eq!(client.get("/").send().await.status(), StatusCode::OK);
let response = client.get("/").send().await;
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn test_rate_limit_includes_retry_after_header() {
let app = Rapina::new()
.with_introspection(false)
.with_rate_limit(RateLimitConfig::new(1.0, 1)) .router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
assert_eq!(client.get("/").send().await.status(), StatusCode::OK);
let response = client.get("/").send().await;
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
let retry_after = response.headers().get("retry-after");
assert!(retry_after.is_some());
let retry_secs: u64 = retry_after.unwrap().to_str().unwrap().parse().unwrap();
assert!(retry_secs >= 1);
}
#[tokio::test]
async fn test_rate_limit_returns_json_error() {
let app = Rapina::new()
.with_introspection(false)
.enable_rfc7807_errors()
.rfc7807_base_uri("https://userapina.com/errors/")
.with_rate_limit(RateLimitConfig::new(1.0, 1))
.router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
client.get("/").send().await;
let response = client.get("/").send().await;
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
let json: serde_json::Value = response.json();
assert_eq!(json["type"], "https://userapina.com/errors/rate-limited");
assert_eq!(json["title"], "Rate Limited");
assert_eq!(json["detail"], "too many requests");
assert!(json["trace_id"].is_string());
}
#[tokio::test]
async fn test_rate_limit_per_minute_convenience() {
let app = Rapina::new()
.with_introspection(false)
.with_rate_limit(RateLimitConfig::per_minute(60)) .router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
for i in 0..60 {
let response = client.get("/").send().await;
assert_eq!(
response.status(),
StatusCode::OK,
"Request {} should succeed",
i + 1
);
}
let response = client.get("/").send().await;
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn test_compression_gzip() {
let large_body = "hello from rapina ".repeat(100);
let body_clone = large_body.clone();
let app = Rapina::new()
.with_introspection(false)
.with_compression(CompressionConfig::default())
.router(Router::new().route(http::Method::GET, "/", move |_, _, _| {
let body = body_clone.clone();
async move { body }
}));
let client = TestClient::new(app).await;
let response = client
.get("/")
.header("Accept-Encoding", "gzip")
.send()
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers().get("content-encoding").unwrap(), "gzip");
assert_eq!(response.headers().get("vary").unwrap(), "Accept-Encoding");
}
#[tokio::test]
async fn test_compression_skips_small_response() {
let app = Rapina::new()
.with_introspection(false)
.with_compression(CompressionConfig::default())
.router(Router::new().route(http::Method::GET, "/", |_, _, _| async { "small" }));
let client = TestClient::new(app).await;
let response = client
.get("/")
.header("Accept-Encoding", "gzip")
.send()
.await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().get("content-encoding").is_none());
}
#[tokio::test]
async fn test_compression_skips_without_accept_encoding() {
let large_body = "hello from rapina ".repeat(100);
let body_clone = large_body.clone();
let app = Rapina::new()
.with_introspection(false)
.with_compression(CompressionConfig::default())
.router(Router::new().route(http::Method::GET, "/", move |_, _, _| {
let body = body_clone.clone();
async move { body }
}));
let client = TestClient::new(app).await;
let response = client.get("/").send().await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().get("content-encoding").is_none());
}
#[tokio::test]
async fn test_trace_id_middleware_preserves_incoming_trace_id() {
let app = Rapina::new()
.with_introspection(false)
.middleware(TraceIdMiddleware::new())
.router(Router::new().route(http::Method::GET, "/health", |_, _, _| async { "ok" }));
let client = TestClient::new(app).await;
let custom_trace_id = "my-custom-trace-id-123";
let response = client
.get("/health")
.header("x-trace-id", custom_trace_id)
.send()
.await;
assert_eq!(response.status(), StatusCode::OK);
let header_value = response.headers().get(TRACE_ID_HEADER).unwrap();
assert_eq!(header_value.to_str().unwrap(), custom_trace_id);
}