use crate::{HeaderName, Request, utils::HeaderValueGetter};
use rama_core::error::BoxErrorExt as _;
use rama_core::{
Layer, Service,
error::{BoxError, ErrorContext as _, ErrorExt},
extensions::{Extension, ExtensionsRef},
telemetry::tracing,
};
use rama_utils::macros::define_inner_service_accessors;
use std::{fmt, marker::PhantomData};
pub struct HeaderOptionValueService<T, S> {
inner: S,
header_name: HeaderName,
optional: bool,
_marker: PhantomData<fn() -> T>,
}
impl<T, S> HeaderOptionValueService<T, S> {
pub const fn new(inner: S, header_name: HeaderName, optional: bool) -> Self {
Self {
inner,
header_name,
optional,
_marker: PhantomData,
}
}
define_inner_service_accessors!();
pub const fn required(inner: S, header_name: HeaderName) -> Self {
Self::new(inner, header_name, false)
}
pub const fn optional(inner: S, header_name: HeaderName) -> Self {
Self::new(inner, header_name, true)
}
}
impl<T, S: fmt::Debug> fmt::Debug for HeaderOptionValueService<T, S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("HeaderOptionValueService")
.field("inner", &self.inner)
.field("header_name", &self.header_name)
.field("optional", &self.optional)
.field(
"_marker",
&format_args!("{}", std::any::type_name::<fn() -> T>()),
)
.finish()
}
}
impl<T, S> Clone for HeaderOptionValueService<T, S>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
header_name: self.header_name.clone(),
optional: self.optional,
_marker: PhantomData,
}
}
}
impl<T, S, Body, E> Service<Request<Body>> for HeaderOptionValueService<T, S>
where
S: Service<Request<Body>, Error = E>,
T: Default + Extension,
Body: Send + Sync + 'static,
E: Into<BoxError> + Send + Sync + 'static,
{
type Output = S::Output;
type Error = BoxError;
async fn serve(&self, request: Request<Body>) -> Result<Self::Output, Self::Error> {
match request.header_str(&self.header_name) {
Ok(str_value) => {
let str_value = str_value.trim();
if str_value == "1" || str_value.eq_ignore_ascii_case("true") {
request.extensions().insert(T::default());
} else if str_value != "0" && !str_value.eq_ignore_ascii_case("false") {
return Err(BoxError::from_static_str("invalid header option")
.context_field("header_name", self.header_name.clone())
.context_str_field("header_value", str_value));
}
}
Err(err) => {
if self.optional && matches!(err, crate::utils::HeaderValueErr::HeaderMissing(_)) {
tracing::debug!(
http.header.name = %self.header_name,
"failed to determine header option: {err:?}",
);
return self.inner.serve(request).await.into_box_error();
} else {
return Err(err
.context("determine header option")
.context_field("header_name", self.header_name.clone()));
}
}
};
self.inner.serve(request).await.into_box_error()
}
}
pub struct HeaderOptionValueLayer<T> {
header_name: HeaderName,
optional: bool,
_marker: PhantomData<fn() -> T>,
}
impl<T> fmt::Debug for HeaderOptionValueLayer<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("HeaderOptionValueLayer")
.field("header_name", &self.header_name)
.field("optional", &self.optional)
.field(
"_marker",
&format_args!("{}", std::any::type_name::<fn() -> T>()),
)
.finish()
}
}
impl<T> Clone for HeaderOptionValueLayer<T> {
fn clone(&self) -> Self {
Self {
header_name: self.header_name.clone(),
optional: self.optional,
_marker: PhantomData,
}
}
}
impl<T> HeaderOptionValueLayer<T> {
pub fn required(header_name: HeaderName) -> Self {
Self {
header_name,
optional: false,
_marker: PhantomData,
}
}
pub fn optional(header_name: HeaderName) -> Self {
Self {
header_name,
optional: true,
_marker: PhantomData,
}
}
}
impl<T, S> Layer<S> for HeaderOptionValueLayer<T> {
type Service = HeaderOptionValueService<T, S>;
fn layer(&self, inner: S) -> Self::Service {
HeaderOptionValueService::new(inner, self.header_name.clone(), self.optional)
}
fn into_layer(self, inner: S) -> Self::Service {
HeaderOptionValueService::new(inner, self.header_name, self.optional)
}
}
#[cfg(test)]
mod test {
use rama_core::extensions::{Extension, ExtensionsRef};
use super::*;
use crate::Method;
#[derive(Debug, Clone, Default, Extension)]
struct UnitValue;
#[tokio::test]
async fn test_header_option_value_required_happy_path() {
let test_cases = [
("1", true),
("true", true),
("True", true),
("TrUE", true),
("TRUE", true),
("0", false),
("false", false),
("False", false),
("FaLsE", false),
("FALSE", false),
];
for (str_value, expected_output) in test_cases {
let request = Request::builder()
.method(Method::GET)
.uri("https://www.example.com")
.header("x-unit-value", str_value)
.body(())
.unwrap();
let inner_service =
rama_core::service::service_fn(move |req: Request<()>| async move {
assert_eq!(expected_output, req.extensions().contains::<UnitValue>());
Ok::<_, std::convert::Infallible>(())
});
let service = HeaderOptionValueService::<UnitValue, _>::required(
inner_service,
HeaderName::from_static("x-unit-value"),
);
service.serve(request).await.unwrap();
}
}
#[tokio::test]
async fn test_header_option_value_optional_found() {
let test_cases = [
("1", true),
("true", true),
("True", true),
("TrUE", true),
("TRUE", true),
("0", false),
("false", false),
("False", false),
("FaLsE", false),
("FALSE", false),
];
for (str_value, expected_output) in test_cases {
let request = Request::builder()
.method(Method::GET)
.uri("https://www.example.com")
.header("x-unit-value", str_value)
.body(())
.unwrap();
let inner_service =
rama_core::service::service_fn(move |req: Request<()>| async move {
assert_eq!(expected_output, req.extensions().contains::<UnitValue>());
Ok::<_, std::convert::Infallible>(())
});
let service = HeaderOptionValueService::<UnitValue, _>::optional(
inner_service,
HeaderName::from_static("x-unit-value"),
);
service.serve(request).await.unwrap();
}
}
#[tokio::test]
async fn test_header_option_value_optional_missing() {
let request = Request::builder()
.method(Method::GET)
.uri("https://www.example.com")
.body(())
.unwrap();
let inner_service = rama_core::service::service_fn(async |req: Request<()>| {
assert!(!req.extensions().contains::<UnitValue>());
Ok::<_, std::convert::Infallible>(())
});
let service = HeaderOptionValueService::<UnitValue, _>::optional(
inner_service,
HeaderName::from_static("x-unit-value"),
);
service.serve(request).await.unwrap();
}
#[tokio::test]
async fn test_header_option_value_required_missing_header() {
let request = Request::builder()
.method(Method::GET)
.uri("https://www.example.com")
.body(())
.unwrap();
let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
Ok::<_, std::convert::Infallible>(())
});
let service = HeaderOptionValueService::<UnitValue, _>::required(
inner_service,
HeaderName::from_static("x-unit-value"),
);
let result = service.serve(request).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_header_option_value_required_invalid_value() {
let test_cases = ["", "foo", "yes"];
for test_case in test_cases {
let request = Request::builder()
.method(Method::GET)
.uri("https://www.example.com")
.header("x-unit-value", test_case)
.body(())
.unwrap();
let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
Ok::<_, std::convert::Infallible>(())
});
let service = HeaderOptionValueService::<UnitValue, _>::required(
inner_service,
HeaderName::from_static("x-unit-value"),
);
let result = service.serve(request).await;
assert!(result.is_err());
}
}
#[tokio::test]
async fn test_header_option_value_optional_invalid_value() {
let test_cases = ["", "foo", "yes"];
for test_case in test_cases {
let request = Request::builder()
.method(Method::GET)
.uri("https://www.example.com")
.header("x-unit-value", test_case)
.body(())
.unwrap();
let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
Ok::<_, std::convert::Infallible>(())
});
let service = HeaderOptionValueService::<UnitValue, _>::optional(
inner_service,
HeaderName::from_static("x-unit-value"),
);
let result = service.serve(request).await;
assert!(result.is_err());
}
}
}