use std::fmt::{self, Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use http::header::{HOST, LOCATION};
use http::uri::{Authority, Scheme};
use http::{Request, Response, StatusCode, Uri};
use hyper::service::Service as HyperService;
use hyper::Body;
use pin_project::pin_project;
use tower_layer::Layer;
use crate::Either;
#[derive(Clone, Copy, Debug)]
pub struct UpgradeHttpLayer;
impl<Service> Layer<Service> for UpgradeHttpLayer {
type Service = UpgradeHttp<Service>;
fn layer(&self, service: Service) -> Self::Service {
UpgradeHttp::new(service)
}
}
#[derive(Clone, Debug)]
pub struct UpgradeHttp<Service> {
service: Service,
}
impl<Service> UpgradeHttp<Service> {
pub const fn new(service: Service) -> Self {
Self { service }
}
#[allow(clippy::missing_const_for_fn)]
pub fn into_inner(self) -> Service {
self.service
}
pub const fn get_ref(&self) -> &Service {
&self.service
}
pub fn get_mut(&mut self) -> &mut Service {
&mut self.service
}
}
impl<Service, RequestBody, ResponseBody> HyperService<Request<RequestBody>> for UpgradeHttp<Service>
where
Service: HyperService<Request<RequestBody>, Response = Response<ResponseBody>>,
{
type Response = Response<Either<ResponseBody, Body>>;
type Error = Service::Error;
type Future = UpgradeHttpFuture<Service, Request<RequestBody>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
match request.uri().scheme() {
Some(scheme) if scheme == &Scheme::HTTPS => {
UpgradeHttpFuture::new_service(self.service.call(request))
}
_ => {
let response = Response::builder();
let response = if let Some(authority) = extract_authority(&request) {
let mut uri = Uri::builder().scheme(Scheme::HTTPS).authority(authority);
if let Some(path_and_query) = request.uri().path_and_query() {
uri = uri.path_and_query(path_and_query.clone());
}
let uri = uri.build().expect("invalid path and query");
response
.status(StatusCode::MOVED_PERMANENTLY)
.header(LOCATION, uri.to_string())
} else {
response.status(StatusCode::BAD_REQUEST)
}
.body(Body::empty())
.expect("invalid header or body");
UpgradeHttpFuture::new_upgrade(response)
}
}
}
}
#[pin_project]
pub struct UpgradeHttpFuture<Service, Request>(#[pin] FutureServe<Service, Request>)
where
Service: HyperService<Request>;
#[derive(Debug)]
#[pin_project(project = UpgradeHttpFutureProj)]
enum FutureServe<Service, Request>
where
Service: HyperService<Request>,
{
Service(#[pin] Service::Future),
Upgrade(Option<Response<Body>>),
}
impl<Service, Request> Debug for UpgradeHttpFuture<Service, Request>
where
Service: HyperService<Request>,
FutureServe<Service, Request>: Debug,
{
fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter
.debug_tuple("UpgradeHttpFuture")
.field(&self.0)
.finish()
}
}
#[allow(clippy::missing_const_for_fn)]
impl<Service, Request> UpgradeHttpFuture<Service, Request>
where
Service: HyperService<Request>,
{
fn new_service(future: Service::Future) -> Self {
Self(FutureServe::Service(future))
}
fn new_upgrade(response: Response<Body>) -> Self {
Self(FutureServe::Upgrade(Some(response)))
}
}
impl<Service, Request, ResponseBody> Future for UpgradeHttpFuture<Service, Request>
where
Service: HyperService<Request, Response = Response<ResponseBody>>,
{
type Output = Result<Response<Either<ResponseBody, Body>>, Service::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().0.project() {
UpgradeHttpFutureProj::Service(future) => {
future.poll(cx).map_ok(|result| result.map(Either::Left))
}
UpgradeHttpFutureProj::Upgrade(response) => Poll::Ready(Ok(response
.take()
.expect("polled again after `Poll::Ready`")
.map(Either::Right))),
}
}
}
fn extract_authority<Body>(request: &Request<Body>) -> Option<Authority> {
const X_FORWARDED_HOST: &str = "x-forwarded-host";
let headers = request.headers();
headers
.get(X_FORWARDED_HOST)
.or_else(|| headers.get(HOST))
.and_then(|header| header.to_str().ok())
.or_else(|| request.uri().host())
.and_then(|host| Authority::try_from(host).ok())
}