use super::{AcceptHeader, BoxValidateRequestFn, ValidateRequest};
use crate::{Request, Response};
use rama_core::error::BoxError;
use rama_core::{Layer, Service};
use rama_http_types::mime::Mime;
use rama_utils::macros::define_inner_service_accessors;
#[derive(Debug, Clone)]
pub struct ValidateRequestHeaderLayer<T> {
pub(crate) validate: T,
}
impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
pub fn try_accept_for_str(value: &str) -> Result<Self, BoxError> {
Ok(Self::custom(AcceptHeader::try_new(value)?))
}
#[must_use]
pub fn accept(mime: Mime) -> Self {
Self::custom(AcceptHeader::new(mime))
}
}
impl<T> ValidateRequestHeaderLayer<T> {
pub fn custom(validate: T) -> Self {
Self { validate }
}
}
impl<F, A> ValidateRequestHeaderLayer<BoxValidateRequestFn<F, A>> {
pub fn custom_fn(validate: F) -> Self {
Self {
validate: BoxValidateRequestFn::new(validate),
}
}
}
impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T>
where
T: Clone,
{
type Service = ValidateRequestHeader<S, T>;
fn layer(&self, inner: S) -> Self::Service {
ValidateRequestHeader::new(inner, self.validate.clone())
}
fn into_layer(self, inner: S) -> Self::Service {
ValidateRequestHeader::new(inner, self.validate)
}
}
#[derive(Debug, Clone)]
pub struct ValidateRequestHeader<S, T> {
inner: S,
pub(crate) validate: T,
}
impl<S, T> ValidateRequestHeader<S, T> {
fn new(inner: S, validate: T) -> Self {
Self::custom(inner, validate)
}
define_inner_service_accessors!();
}
impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> {
pub fn try_accept_for_str(inner: S, value: &str) -> Result<Self, BoxError> {
Ok(Self::custom(inner, AcceptHeader::try_new(value)?))
}
#[must_use]
pub fn accept(inner: S, mime: Mime) -> Self {
Self::custom(inner, AcceptHeader::new(mime))
}
}
impl<S, T> ValidateRequestHeader<S, T> {
pub fn custom(inner: S, validate: T) -> Self {
Self { inner, validate }
}
}
impl<S, F, A> ValidateRequestHeader<S, BoxValidateRequestFn<F, A>> {
pub fn custom_fn(inner: S, validate: F) -> Self {
Self {
inner,
validate: BoxValidateRequestFn::new(validate),
}
}
}
impl<ReqBody, ServiceResBody, ValidateResBody, S, V> Service<Request<ReqBody>>
for ValidateRequestHeader<S, V>
where
ReqBody: Send + 'static,
ServiceResBody: Send + 'static,
ValidateResBody: From<ServiceResBody> + Send + 'static,
V: ValidateRequest<ReqBody, ResponseBody = ValidateResBody>,
S: Service<Request<ReqBody>, Output = Response<ServiceResBody>>,
{
type Output = Response<ValidateResBody>;
type Error = S::Error;
async fn serve(&self, req: Request<ReqBody>) -> Result<Self::Output, Self::Error> {
match self.validate.validate(req).await {
Ok(req) => Ok(self.inner.serve(req).await?.map(ValidateResBody::from)),
Err(res) => Ok(res),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Body, StatusCode, header};
use rama_core::{Layer, error::BoxError, service::service_fn};
use rama_http_types::mime::APPLICATION_JSON;
#[tokio::test]
async fn valid_accept_header() {
let service = ValidateRequestHeaderLayer::try_accept_for_str("application/json")
.unwrap()
.into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, "application/json")
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn valid_accept_header_with_mime() {
let service =
ValidateRequestHeaderLayer::accept(APPLICATION_JSON).into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, "application/json")
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn valid_accept_header_accept_all_json() {
let service = ValidateRequestHeaderLayer::try_accept_for_str("application/json")
.unwrap()
.into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, "application/*")
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn valid_accept_header_accept_all() {
let service = ValidateRequestHeaderLayer::try_accept_for_str("application/json")
.unwrap()
.into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, "*/*")
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn invalid_accept_header() {
let service =
ValidateRequestHeaderLayer::accept(APPLICATION_JSON).into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, "invalid")
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
}
#[tokio::test]
async fn not_accepted_accept_header_subtype() {
let service =
ValidateRequestHeaderLayer::accept(APPLICATION_JSON).into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, "application/strings")
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
}
#[tokio::test]
async fn not_accepted_accept_header() {
let service =
ValidateRequestHeaderLayer::accept(APPLICATION_JSON).into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, "text/strings")
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
}
#[tokio::test]
async fn accepted_multiple_header_value() {
let service =
ValidateRequestHeaderLayer::accept(APPLICATION_JSON).into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, "text/strings")
.header(header::ACCEPT, "invalid, application/json")
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn accepted_inner_header_value() {
let service =
ValidateRequestHeaderLayer::accept(APPLICATION_JSON).into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, "text/strings, invalid, application/json")
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn accepted_header_with_quotes_valid() {
let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*";
let service = ValidateRequestHeaderLayer::try_accept_for_str("application/xml")
.unwrap()
.into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, value)
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn accepted_header_with_quotes_invalid() {
let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\"";
let service = ValidateRequestHeaderLayer::try_accept_for_str("text/html")
.unwrap()
.into_layer(service_fn(echo));
let request = Request::get("/")
.header(header::ACCEPT, value)
.body(Body::empty())
.unwrap();
let res = service.serve(request).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
}
async fn echo<B>(req: Request<B>) -> Result<Response<B>, BoxError> {
Ok(Response::new(req.into_body()))
}
}