use std::collections::VecDeque;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use bytes::Bytes;
use http::header::{
self, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
ORIGIN,
};
use http_body_util::Full;
use hyper::{Method, Request, Response, StatusCode};
use serde_json::json;
use tower::{Layer, Service, ServiceBuilder, ServiceExt};
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::request_id::{
MakeRequestUuid, PropagateRequestIdLayer, RequestId, SetRequestIdLayer,
};
use hyperlite::{failure, success, ApiError, BoxBody, Router};
mod test_helpers;
use test_helpers::*;
#[derive(Clone)]
struct AddHeaderLayer {
name: HeaderName,
value: HeaderValue,
}
impl AddHeaderLayer {
fn new(name: &'static str, value: &'static str) -> Self {
Self {
name: HeaderName::from_static(name),
value: HeaderValue::from_static(value),
}
}
}
#[derive(Clone)]
struct AddHeaderService<S> {
inner: S,
name: HeaderName,
value: HeaderValue,
}
impl<S> Layer<S> for AddHeaderLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible> + Clone,
{
type Service = AddHeaderService<S>;
fn layer(&self, inner: S) -> Self::Service {
AddHeaderService {
inner,
name: self.name.clone(),
value: self.value.clone(),
}
}
}
impl<S> Service<Request<BoxBody>> for AddHeaderService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
let name = self.name.clone();
let value = self.value.clone();
let future = self.inner.call(req);
Box::pin(async move {
let mut response = future.await?;
response.headers_mut().insert(name, value);
Ok(response)
})
}
}
#[derive(Clone)]
struct OrderLayer {
label: &'static str,
order: Arc<Mutex<VecDeque<&'static str>>>,
}
#[derive(Clone)]
struct OrderService<S> {
inner: S,
label: &'static str,
order: Arc<Mutex<VecDeque<&'static str>>>,
}
impl OrderLayer {
fn new(label: &'static str, order: Arc<Mutex<VecDeque<&'static str>>>) -> Self {
Self { label, order }
}
}
impl<S> Layer<S> for OrderLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible> + Clone,
{
type Service = OrderService<S>;
fn layer(&self, inner: S) -> Self::Service {
OrderService {
inner,
label: self.label,
order: self.order.clone(),
}
}
}
impl<S> Service<Request<BoxBody>> for OrderService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
self.order
.lock()
.expect("order vec poisoned")
.push_back(self.label);
let future = self.inner.call(req);
Box::pin(future)
}
}
#[derive(Clone)]
struct RequestHeaderLayer {
name: HeaderName,
value: HeaderValue,
}
#[derive(Clone)]
struct RequestHeaderService<S> {
inner: S,
name: HeaderName,
value: HeaderValue,
}
impl RequestHeaderLayer {
fn new(name: &'static str, value: &'static str) -> Self {
Self {
name: HeaderName::from_static(name),
value: HeaderValue::from_static(value),
}
}
}
impl<S> Layer<S> for RequestHeaderLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible> + Clone,
{
type Service = RequestHeaderService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestHeaderService {
inner,
name: self.name.clone(),
value: self.value.clone(),
}
}
}
impl<S> Service<Request<BoxBody>> for RequestHeaderService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<BoxBody>) -> Self::Future {
req.headers_mut()
.insert(self.name.clone(), self.value.clone());
let future = self.inner.call(req);
Box::pin(future)
}
}
#[derive(Clone)]
struct ResponseHeaderLayer {
name: HeaderName,
value: HeaderValue,
}
#[derive(Clone)]
struct ResponseHeaderService<S> {
inner: S,
name: HeaderName,
value: HeaderValue,
}
impl ResponseHeaderLayer {
fn new(name: &'static str, value: &'static str) -> Self {
Self {
name: HeaderName::from_static(name),
value: HeaderValue::from_static(value),
}
}
}
impl<S> Layer<S> for ResponseHeaderLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible> + Clone,
{
type Service = ResponseHeaderService<S>;
fn layer(&self, inner: S) -> Self::Service {
ResponseHeaderService {
inner,
name: self.name.clone(),
value: self.value.clone(),
}
}
}
impl<S> Service<Request<BoxBody>> for ResponseHeaderService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
let name = self.name.clone();
let value = self.value.clone();
let future = self.inner.call(req);
Box::pin(async move {
let mut response = future.await?;
response.headers_mut().insert(name, value);
Ok(response)
})
}
}
#[derive(Clone)]
struct ExtensionLayer<T>
where
T: Clone + Send + Sync + 'static,
{
value: T,
}
#[derive(Clone)]
struct ExtensionService<S, T>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>,
T: Clone + Send + Sync + 'static,
{
inner: S,
value: T,
}
impl<T> ExtensionLayer<T>
where
T: Clone + Send + Sync + 'static,
{
fn new(value: T) -> Self {
Self { value }
}
}
impl<S, T> Layer<S> for ExtensionLayer<T>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible> + Clone,
T: Clone + Send + Sync + 'static,
{
type Service = ExtensionService<S, T>;
fn layer(&self, inner: S) -> Self::Service {
ExtensionService {
inner,
value: self.value.clone(),
}
}
}
impl<S, T> Service<Request<BoxBody>> for ExtensionService<S, T>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
T: Clone + Send + Sync + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<BoxBody>) -> Self::Future {
req.extensions_mut().insert(self.value.clone());
let future = self.inner.call(req);
Box::pin(future)
}
}
#[derive(Clone)]
struct StateIncrementLayer {
state: TestState,
}
#[derive(Clone)]
struct StateIncrementService<S> {
inner: S,
state: TestState,
}
impl StateIncrementLayer {
fn new(state: TestState) -> Self {
Self { state }
}
}
impl<S> Layer<S> for StateIncrementLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible> + Clone,
{
type Service = StateIncrementService<S>;
fn layer(&self, inner: S) -> Self::Service {
StateIncrementService {
inner,
state: self.state.clone(),
}
}
}
impl<S> Service<Request<BoxBody>> for StateIncrementService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
self.state.increment();
let future = self.inner.call(req);
Box::pin(future)
}
}
#[derive(Clone)]
struct AuthLayer {
expected: &'static str,
}
#[derive(Clone)]
struct AuthService<S> {
inner: S,
expected: &'static str,
}
impl AuthLayer {
fn new(expected: &'static str) -> Self {
Self { expected }
}
}
impl<S> Layer<S> for AuthLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible> + Clone,
{
type Service = AuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
AuthService {
inner,
expected: self.expected,
}
}
}
impl<S> Service<Request<BoxBody>> for AuthService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<BoxBody>) -> Self::Future {
let authorised = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
== Some(self.expected);
if authorised {
req.extensions_mut().insert(String::from("user-123"));
let future = self.inner.call(req);
Box::pin(future)
} else {
Box::pin(async move {
Ok(failure(
StatusCode::UNAUTHORIZED,
vec![ApiError::new("UNAUTHORIZED", "missing token")],
))
})
}
}
}
#[derive(Clone)]
struct LoggingLayer {
logs: Arc<Mutex<Vec<String>>>,
}
#[derive(Clone)]
struct LoggingService<S> {
inner: S,
logs: Arc<Mutex<Vec<String>>>,
}
impl LoggingLayer {
fn new(logs: Arc<Mutex<Vec<String>>>) -> Self {
Self { logs }
}
}
impl<S> Layer<S> for LoggingLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible> + Clone,
{
type Service = LoggingService<S>;
fn layer(&self, inner: S) -> Self::Service {
LoggingService {
inner,
logs: self.logs.clone(),
}
}
}
impl<S> Service<Request<BoxBody>> for LoggingService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
let method = req.method().clone();
let path = req.uri().path().to_owned();
let logs = self.logs.clone();
let future = self.inner.call(req);
Box::pin(async move {
logs.lock()
.expect("logs poisoned")
.push(format!("{} {}", method, path));
future.await
})
}
}
#[derive(Clone)]
struct RateLimitLayer {
limit: usize,
counter: Arc<Mutex<usize>>,
}
#[derive(Clone)]
struct RateLimitService<S> {
inner: S,
limit: usize,
counter: Arc<Mutex<usize>>,
}
impl RateLimitLayer {
fn new(limit: usize) -> Self {
Self {
limit,
counter: Arc::new(Mutex::new(0)),
}
}
}
impl<S> Layer<S> for RateLimitLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible> + Clone,
{
type Service = RateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService {
inner,
limit: self.limit,
counter: self.counter.clone(),
}
}
}
impl<S> Service<Request<BoxBody>> for RateLimitService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
let mut counter = self.counter.lock().expect("rate limit poisoned");
if *counter >= self.limit {
return Box::pin(async move {
Ok(failure(
StatusCode::TOO_MANY_REQUESTS,
vec![ApiError::new("RATE_LIMIT", "Too many requests")],
))
});
}
*counter += 1;
drop(counter);
let future = self.inner.call(req);
Box::pin(future)
}
}
fn basic_router() -> Router<()> {
Router::new(()).route(
"/test",
Method::GET,
Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "ok")) })),
)
}
#[tokio::test]
async fn test_router_with_single_middleware() {
let service = AddHeaderLayer::new("x-test", "one").layer(basic_router());
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(response.headers().get("x-test").unwrap(), "one");
}
#[tokio::test]
async fn test_router_with_multiple_middleware() {
let service = ServiceBuilder::new()
.layer(AddHeaderLayer::new("x-first", "a"))
.layer(AddHeaderLayer::new("x-second", "b"))
.service(basic_router());
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(response.headers().get("x-first").unwrap(), "a");
assert_eq!(response.headers().get("x-second").unwrap(), "b");
}
#[tokio::test]
async fn test_middleware_order() {
let order = Arc::new(Mutex::new(VecDeque::new()));
let service = ServiceBuilder::new()
.layer(OrderLayer::new("inner", order.clone()))
.layer(OrderLayer::new("outer", order.clone()))
.service(basic_router());
service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
let recorded: Vec<_> = order.lock().unwrap().iter().copied().collect();
assert_eq!(recorded.len(), 2);
assert!(recorded.contains(&"inner"));
assert!(recorded.contains(&"outer"));
let inner_idx = recorded
.iter()
.position(|label| *label == "inner")
.expect("missing inner middleware label");
let outer_idx = recorded
.iter()
.position(|label| *label == "outer")
.expect("missing outer middleware label");
assert!(
inner_idx < outer_idx,
"expected inner middleware to run before outer"
);
}
#[tokio::test]
async fn test_cors_middleware() {
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::exact(HeaderValue::from_static(
"http://localhost:3000",
)))
.allow_methods([Method::GET]);
let service = ServiceBuilder::new().layer(cors).service(basic_router());
let mut request = build_request(Method::GET, "/test", empty_body());
request
.headers_mut()
.insert(ORIGIN, HeaderValue::from_static("http://localhost:3000"));
let response = service.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key(ACCESS_CONTROL_ALLOW_ORIGIN));
}
#[tokio::test]
async fn test_cors_preflight() {
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::exact(HeaderValue::from_static(
"http://localhost:3000",
)))
.allow_methods([Method::POST]);
let service = ServiceBuilder::new().layer(cors).service(basic_router());
let mut request = build_request(Method::OPTIONS, "/test", empty_body());
{
let headers = request.headers_mut();
headers.insert(ORIGIN, HeaderValue::from_static("http://localhost:3000"));
headers.insert(
header::ACCESS_CONTROL_REQUEST_METHOD,
HeaderValue::from_static("POST"),
);
}
let response = service.oneshot(request).await.unwrap();
assert!(response.status() == StatusCode::NO_CONTENT || response.status() == StatusCode::OK);
assert!(response
.headers()
.contains_key(ACCESS_CONTROL_ALLOW_METHODS));
}
#[tokio::test]
async fn test_cors_headers() {
let cors = CorsLayer::permissive();
let service = ServiceBuilder::new().layer(cors).service(basic_router());
let mut request = build_request(Method::GET, "/test", empty_body());
request
.headers_mut()
.insert(ORIGIN, HeaderValue::from_static("http://example.com"));
let response = service.oneshot(request).await.unwrap();
assert!(response.headers().contains_key(ACCESS_CONTROL_ALLOW_ORIGIN));
}
#[tokio::test]
async fn test_request_id_middleware() {
let service = ServiceBuilder::new()
.layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
.layer(PropagateRequestIdLayer::x_request_id())
.service(basic_router());
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert!(response.headers().contains_key("x-request-id"));
}
#[tokio::test]
async fn test_request_id_propagation() {
let service = ServiceBuilder::new()
.layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
.layer(PropagateRequestIdLayer::x_request_id())
.service(Router::new(()).route(
"/test",
Method::GET,
Arc::new(|req, _| {
Box::pin(async move {
let request_id = req
.extensions()
.get::<RequestId>()
.and_then(|id| id.header_value().to_str().ok())
.unwrap_or("")
.to_string();
Ok(success(StatusCode::OK, json!({ "id": request_id })))
})
}),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
let body = read_body_json(response).await;
assert!(!body["data"]["id"].as_str().unwrap().is_empty());
}
#[tokio::test]
async fn test_request_id_in_extensions() {
let service = ServiceBuilder::new()
.layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
.service(Router::new(()).route(
"/test",
Method::GET,
Arc::new(|req, _| {
Box::pin(async move {
assert!(req.extensions().get::<RequestId>().is_some());
Ok(success(StatusCode::OK, "ok"))
})
}),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_custom_middleware_layer() {
let service = AddHeaderLayer::new("x-custom", "value").layer(basic_router());
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(response.headers().get("x-custom").unwrap(), "value");
}
#[tokio::test]
async fn test_custom_middleware_state_access() {
let state = TestState::new();
let service = StateIncrementLayer::new(state.clone()).layer(Router::new(state.clone()).route(
"/test",
Method::GET,
Arc::new(|_, state| Box::pin(async move { Ok(success(StatusCode::OK, state.get())) })),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
let json = read_body_json(response).await;
assert_eq!(json["data"], 1);
assert_eq!(state.get(), 1);
}
#[tokio::test]
async fn test_custom_middleware_request_modification() {
let service = RequestHeaderLayer::new("x-added", "yes").layer(Router::new(()).route(
"/test",
Method::GET,
Arc::new(|req, _| {
Box::pin(async move {
let present = req.headers().contains_key("x-added");
Ok(success(StatusCode::OK, present))
})
}),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
let json = read_body_json(response).await;
assert_eq!(json["data"], true);
}
#[tokio::test]
async fn test_custom_middleware_response_modification() {
let service = ServiceBuilder::new()
.layer(ResponseHeaderLayer::new("x-after", "2"))
.layer(AddHeaderLayer::new("x-before", "1"))
.service(basic_router());
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(response.headers().get("x-before").unwrap(), "1");
assert_eq!(response.headers().get("x-after").unwrap(), "2");
}
#[tokio::test]
async fn test_middleware_preserves_extensions() {
let state = TestState::new();
let service = AddHeaderLayer::new("x-test", "1").layer(Router::new(state.clone()).route(
"/test",
Method::GET,
Arc::new(|req, state| {
Box::pin(async move {
let ext_state = req
.extensions()
.get::<Arc<TestState>>()
.cloned()
.expect("state missing in extensions");
ext_state.increment();
Ok(success(StatusCode::OK, state.get()))
})
}),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
let json = read_body_json(response).await;
assert_eq!(json["data"], 1);
assert_eq!(state.get(), 1);
}
#[tokio::test]
async fn test_middleware_adds_extensions() {
let service = ExtensionLayer::new(String::from("ext")).layer(Router::new(()).route(
"/test",
Method::GET,
Arc::new(|req, _| {
Box::pin(async move {
let value = req.extensions().get::<String>().cloned().unwrap();
Ok(success(StatusCode::OK, value))
})
}),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
let json = read_body_json(response).await;
assert_eq!(json["data"], "ext");
}
#[tokio::test]
async fn test_router_extensions_accessible() {
let service = ExtensionLayer::new(42u32).layer(Router::new(()).route(
"/test",
Method::GET,
Arc::new(|req, _| {
Box::pin(async move {
let value = req.extensions().get::<u32>().copied().unwrap();
Ok(success(StatusCode::OK, value))
})
}),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
let json = read_body_json(response).await;
assert_eq!(json["data"], 42);
}
#[tokio::test]
async fn test_service_builder_composition() {
let service = ServiceBuilder::new()
.layer(AddHeaderLayer::new("x-one", "1"))
.layer(AddHeaderLayer::new("x-two", "2"))
.service(basic_router());
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(response.headers().get("x-one").unwrap(), "1");
assert_eq!(response.headers().get("x-two").unwrap(), "2");
}
#[tokio::test]
async fn test_service_builder_with_router() {
let service = ServiceBuilder::new()
.layer(AddHeaderLayer::new("x-layer", "value"))
.service(Router::new(()).route(
"/test",
Method::GET,
Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "ok")) })),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(response.headers().get("x-layer").unwrap(), "value");
}
#[tokio::test]
async fn test_service_builder_order() {
let order = Arc::new(Mutex::new(VecDeque::new()));
let service = ServiceBuilder::new()
.layer(OrderLayer::new("first", order.clone()))
.layer(OrderLayer::new("second", order.clone()))
.service(basic_router());
service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
let recorded: Vec<_> = order.lock().unwrap().iter().copied().collect();
assert_eq!(recorded, vec!["first", "second"]);
}
#[tokio::test]
async fn test_middleware_error_handling() {
let service = ServiceBuilder::new()
.layer(AddHeaderLayer::new("x-layer", "present"))
.service(Router::new(()).route(
"/test",
Method::GET,
Arc::new(|_, _| {
Box::pin(async move {
Ok(failure(
StatusCode::INTERNAL_SERVER_ERROR,
vec![ApiError::new("ERROR", "failure")],
))
})
}),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.headers().get("x-layer").unwrap(), "present");
}
#[tokio::test]
async fn test_middleware_error_propagation() {
let service = ServiceBuilder::new()
.layer(AddHeaderLayer::new("x-layer", "present"))
.service(Router::new(()).route(
"/test",
Method::GET,
Arc::new(|_, _| {
Box::pin(async move {
Ok(failure(
StatusCode::BAD_REQUEST,
vec![ApiError::new("BAD", "oops")],
))
})
}),
));
let response = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_auth_middleware_pattern() {
let service = AuthLayer::new("Bearer token").layer(Router::new(()).route(
"/protected",
Method::GET,
Arc::new(|req, _| {
Box::pin(async move {
let user = req.extensions().get::<String>().cloned().unwrap();
Ok(success(StatusCode::OK, user))
})
}),
));
let mut request = build_request(Method::GET, "/protected", empty_body());
request.headers_mut().insert(
header::AUTHORIZATION,
HeaderValue::from_static("Bearer token"),
);
let response = service.clone().oneshot(request).await.unwrap();
let json = read_body_json(response).await;
assert_eq!(json["data"], "user-123");
let unauthorized = service
.oneshot(build_request(Method::GET, "/protected", empty_body()))
.await
.unwrap();
assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_logging_middleware_pattern() {
let logs = Arc::new(Mutex::new(Vec::new()));
let service = LoggingLayer::new(logs.clone()).layer(basic_router());
service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
let entries = logs.lock().unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0], "GET /test");
}
#[tokio::test]
async fn test_rate_limiting_pattern() {
let service = RateLimitLayer::new(2).layer(basic_router());
let first = service
.clone()
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(first.status(), StatusCode::OK);
let second = service
.clone()
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(second.status(), StatusCode::OK);
let third = service
.oneshot(build_request(Method::GET, "/test", empty_body()))
.await
.unwrap();
assert_eq!(third.status(), StatusCode::TOO_MANY_REQUESTS);
}