#![expect(
clippy::allow_attributes,
reason = "macro-generated `#[allow]` attributes whose underlying lints fire only for some expansions"
)]
use crate::{HeaderValue, Request, Response};
use rama_core::{Layer, Service};
use rama_http_headers::authorization::Credentials;
use rama_utils::macros::define_inner_service_accessors;
#[derive(Debug, Clone)]
pub struct AddAuthorizationLayer {
value: Option<HeaderValue>,
if_not_present: bool,
}
impl AddAuthorizationLayer {
#[must_use]
pub fn none() -> Self {
Self {
value: None,
if_not_present: false,
}
}
#[allow(clippy::needless_pass_by_value)]
pub fn new(credential: impl Credentials) -> Self {
Self {
value: credential.encode(),
if_not_present: false,
}
}
rama_utils::macros::generate_set_and_with! {
pub fn sensitive(mut self, sensitive: bool) -> Self {
if let Some(value) = &mut self.value {
value.set_sensitive(sensitive);
}
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn if_not_present(mut self, value: bool) -> Self {
self.if_not_present = value;
self
}
}
}
impl<S> Layer<S> for AddAuthorizationLayer {
type Service = AddAuthorization<S>;
fn layer(&self, inner: S) -> Self::Service {
AddAuthorization {
inner,
value: self.value.clone(),
if_not_present: self.if_not_present,
}
}
fn into_layer(self, inner: S) -> Self::Service {
AddAuthorization {
inner,
value: self.value,
if_not_present: self.if_not_present,
}
}
}
#[derive(Debug, Clone)]
pub struct AddAuthorization<S> {
inner: S,
value: Option<HeaderValue>,
if_not_present: bool,
}
impl<S> AddAuthorization<S> {
pub fn none(inner: S) -> Self {
AddAuthorizationLayer::none().into_layer(inner)
}
pub fn new(inner: S, credential: impl Credentials) -> Self {
AddAuthorizationLayer::new(credential).into_layer(inner)
}
define_inner_service_accessors!();
rama_utils::macros::generate_set_and_with! {
pub fn sensitive(mut self, sensitive: bool) -> Self {
if let Some(value) = &mut self.value {
value.set_sensitive(sensitive);
}
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn if_not_present(mut self, value: bool) -> Self {
self.if_not_present = value;
self
}
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for AddAuthorization<S>
where
S: Service<Request<ReqBody>, Output = Response<ResBody>>,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type Output = S::Output;
type Error = S::Error;
async fn serve(&self, mut req: Request<ReqBody>) -> Result<Self::Output, Self::Error> {
if let Some(value) = &self.value
&& (!self.if_not_present
|| !req
.headers()
.contains_key(rama_http_types::header::AUTHORIZATION))
{
req.headers_mut()
.insert(rama_http_types::header::AUTHORIZATION, value.clone());
}
self.inner.serve(req).await
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use crate::layer::validate_request::ValidateRequestHeaderLayer;
use crate::{Body, Request, Response, StatusCode};
use rama_core::Service;
use rama_core::error::BoxError;
use rama_core::service::service_fn;
use rama_net::user::credentials::{basic, bearer};
use std::convert::Infallible;
#[tokio::test]
async fn test_basic() {
let svc =
ValidateRequestHeaderLayer::auth(basic!("foo", "bar")).into_layer(service_fn(echo));
let client = AddAuthorization::new(svc, basic!("foo", "bar"));
let res = client.serve(Request::new(Body::empty())).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_token() {
let svc = ValidateRequestHeaderLayer::auth(bearer!("foo")).into_layer(service_fn(echo));
let client = AddAuthorization::new(svc, bearer!("foo"));
let res = client.serve(Request::new(Body::empty())).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn making_header_sensitive() {
let svc = ValidateRequestHeaderLayer::auth(bearer!("foo")).into_layer(service_fn(
async |request: Request<Body>| {
let auth = request
.headers()
.get(rama_http_types::header::AUTHORIZATION)
.unwrap();
assert!(auth.is_sensitive());
Ok::<_, Infallible>(Response::new(Body::empty()))
},
));
let client = AddAuthorization::new(svc, bearer!("foo")).with_sensitive(true);
let res = client.serve(Request::new(Body::empty())).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}