use crate::Request;
use crate::headers::{HeaderMapExt, ProxyAuthorization};
use rama_core::extensions::ExtensionsRef;
use rama_core::telemetry::tracing;
use rama_core::{Layer, Service};
use rama_http_types::proxy::is_req_http_proxy_connect;
use rama_net::user::credentials::DpiProxyCredential;
use rama_net::user::{Basic, Bearer, ProxyCredential};
use rama_utils::macros::define_inner_service_accessors;
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct DpiProxyCredentialExtractorLayer;
impl DpiProxyCredentialExtractorLayer {
#[inline(always)]
pub const fn new() -> Self {
Self
}
}
impl<S> Layer<S> for DpiProxyCredentialExtractorLayer {
type Service = DpiProxyCredentialExtractor<S>;
#[inline(always)]
fn layer(&self, inner: S) -> Self::Service {
DpiProxyCredentialExtractor::new(inner)
}
}
#[derive(Debug, Clone)]
pub struct DpiProxyCredentialExtractor<S> {
inner: S,
}
impl<S> DpiProxyCredentialExtractor<S> {
#[inline(always)]
pub const fn new(inner: S) -> Self {
Self { inner }
}
define_inner_service_accessors!();
}
impl<S, ReqBody> Service<Request<ReqBody>> for DpiProxyCredentialExtractor<S>
where
S: Service<Request<ReqBody>>,
ReqBody: Send + 'static,
{
type Output = S::Output;
type Error = S::Error;
async fn serve(&self, req: Request<ReqBody>) -> Result<Self::Output, Self::Error> {
if is_req_http_proxy_connect(&req) {
tracing::trace!("DpiProxyCredentialExtractor: try to extract proxy authorization data");
if let Some(ProxyAuthorization::<Basic>(credentials)) = req.headers().typed_get() {
tracing::debug!(
"DpiProxyCredentialExtractor: extracted Basic proxy auth: inserted in req extensions"
);
req.extensions()
.insert(DpiProxyCredential(ProxyCredential::Basic(credentials)));
} else if let Some(ProxyAuthorization::<Bearer>(token)) = req.headers().typed_get() {
tracing::debug!(
"DpiProxyCredentialExtractor: extracted Bearer proxy auth: inserted in req extensions"
);
req.extensions()
.insert(DpiProxyCredential(ProxyCredential::Bearer(token)));
}
}
self.inner.serve(req).await
}
}