use std::{
ops::Deref,
task::{Context, Poll},
};
use ecow::EcoString;
use futures::future::{self, Either, Ready};
use headers::{HeaderMapExt, Host};
use http::{uri::Scheme, Request, Response, StatusCode};
use http_api_problem::ApiError;
use hyper::Body;
use iri_string::types::{UriAbsoluteStr, UriAbsoluteString, UriStr};
use tower::Service;
use tracing::{debug, error, info};
use crate::{
header::{
forwarded::Forwarded, x_forwarded_host::XForwardedHost, x_forwarded_proto::XForwardedProto,
},
uri::invariant::AbsoluteHttpUri,
};
#[derive(Debug, Clone)]
pub struct ReconstructTargetUri<S>
where
S: Clone,
{
inner: S,
default_scheme: Scheme,
}
impl<S> ReconstructTargetUri<S>
where
S: Clone,
{
#[inline]
pub fn new(default_scheme: Scheme, inner: S) -> Self {
Self {
inner,
default_scheme,
}
}
}
impl<S> Service<Request<Body>> for ReconstructTargetUri<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
if req.extensions().get::<AbsoluteHttpUri>().is_some() {
return Either::Left(self.inner.call(req));
}
let (mut scheme, mut authority) = (EcoString::from(self.default_scheme.as_str()), None);
if let Some(h_host) = req.headers().typed_get::<Host>() {
debug!("Host header present.");
authority = Some(h_host);
}
if let Some(h_forwarded) = req.headers().typed_get::<Forwarded>() {
debug!("Forwarded header present. {:?}", h_forwarded);
if let Some(forwarded_host) = h_forwarded.elements[0].host_decoded() {
authority = Some(forwarded_host);
}
if let Some(forwarded_proto) = h_forwarded.elements[0].proto() {
scheme = forwarded_proto.deref().into()
}
}
else {
if let Some(x_forwarded_host) = req.headers().typed_get::<XForwardedHost>() {
authority = Some(x_forwarded_host.into())
}
if let Some(x_forwarded_proto) = req.headers().typed_get::<XForwardedProto>() {
scheme = x_forwarded_proto.deref().deref().into();
}
}
let mut builder = iri_string::build::Builder::new();
builder.scheme(scheme.as_str());
if let Some(authority) = authority.as_ref() {
builder.host(authority.hostname());
if let Some(port) = authority.port() {
builder.port(port);
}
}
builder.path(req.uri().path());
if let Some(query) = req.uri().query() {
builder.query(query);
}
debug!("Target uri builder: {:?}", builder);
if let Some(target_uri) = builder.build::<UriAbsoluteStr>().ok().and_then(|built| {
AbsoluteHttpUri::try_new_from(AsRef::<UriStr>::as_ref(&UriAbsoluteString::from(built)))
.ok()
}) {
info!("Reconstructed target uri: {:?}", target_uri);
req.extensions_mut().insert(target_uri);
Either::Left(self.inner.call(req))
}
else {
error!("Error in reconstructing target uri.");
Either::Right(future::ready(Ok(ApiError::builder(
StatusCode::BAD_REQUEST,
)
.message("Invalid request target")
.finish()
.into_hyper_response())))
}
}
}