use std::{
future::Future,
pin::Pin,
str,
task::{Context, Poll, ready},
};
use futures_util::future::Either;
use http::{
HeaderMap, Method, Request, Response, StatusCode, Uri,
header::{CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, LOCATION, TRANSFER_ENCODING},
request::Parts,
};
use http_body::Body;
use pin_project_lite::pin_project;
use tower::{Service, util::Oneshot};
use url::Url;
use super::{Action, Attempt, BodyRepr, Policy};
use crate::{Error, error::BoxError, ext::RequestUri, into_uri::IntoUriSealed};
pub struct Pending<ReqBody, Response> {
future: Pin<Box<dyn Future<Output = Action> + Send>>,
location: Uri,
body: ReqBody,
res: Response,
}
pin_project! {
#[project = ResponseFutureProj]
pub enum ResponseFuture<S, B, P>
where
S: Service<Request<B>>,
{
Redirect {
#[pin]
future: Either<S::Future, Oneshot<S, Request<B>>>,
pending_future: Option<Pending<B, S::Response>>,
service: S,
policy: P,
parts: Parts,
body_repr: BodyRepr<B>,
},
Direct {
#[pin]
future: S::Future,
},
}
}
impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
S::Error: From<BoxError>,
P: Policy<ReqBody, S::Error>,
ReqBody: Body + Default,
{
type Output = Result<Response<ResBody>, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
ResponseFutureProj::Direct { mut future } => future.as_mut().poll(cx),
ResponseFutureProj::Redirect {
mut future,
pending_future,
service,
policy,
parts,
body_repr,
} => {
if let Some(mut state) = pending_future.take() {
let action = match state.future.as_mut().poll(cx) {
Poll::Ready(action) => action,
Poll::Pending => {
*pending_future = Some(state);
return Poll::Pending;
}
};
return handle_action(
cx,
RedirectAction {
action,
future: &mut future,
service,
policy,
parts,
body: state.body,
body_repr,
res: state.res,
location: state.location,
},
);
}
let mut res = {
let mut res = ready!(future.as_mut().poll(cx)?);
res.extensions_mut().insert(RequestUri(parts.uri.clone()));
res
};
match res.status() {
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
if parts.method == Method::POST {
parts.method = Method::GET;
*body_repr = BodyRepr::Empty;
drop_payload_headers(&mut parts.headers);
}
}
StatusCode::SEE_OTHER => {
if parts.method != Method::HEAD {
parts.method = Method::GET;
}
*body_repr = BodyRepr::Empty;
drop_payload_headers(&mut parts.headers);
}
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
_ => {
policy.on_response(&mut res);
return Poll::Ready(Ok(res));
}
};
let Some(body) = body_repr.take() else {
return Poll::Ready(Ok(res));
};
let Some(location) = res
.headers()
.get(LOCATION)
.and_then(|loc| loc.to_str().ok())
.and_then(|loc| resolve_uri(loc, &parts.uri))
else {
return Poll::Ready(Ok(res));
};
let attempt = Attempt {
status: res.status(),
headers: res.headers(),
location: &location,
previous: &parts.uri,
};
let action = match policy.redirect(attempt)? {
Action::Pending(future) => {
*pending_future = Some(Pending {
future,
location,
body,
res,
});
cx.waker().wake_by_ref();
return Poll::Pending;
}
action => action,
};
handle_action(
cx,
RedirectAction {
action,
future: &mut future,
service,
policy,
parts,
body,
body_repr,
res,
location,
},
)
}
}
}
}
fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
Url::parse(&base.to_string())
.ok()?
.join(relative)
.map(String::from)
.ok()?
.into_uri()
.ok()
}
fn drop_payload_headers(headers: &mut HeaderMap) {
for header in &[
CONTENT_TYPE,
CONTENT_LENGTH,
CONTENT_ENCODING,
TRANSFER_ENCODING,
] {
headers.remove(header);
}
}
type RedirectFuturePin<'a, S, ReqBody> =
Pin<&'a mut Either<<S as Service<Request<ReqBody>>>::Future, Oneshot<S, Request<ReqBody>>>>;
struct RedirectAction<'a, S, ReqBody, ResBody, P>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
P: Policy<ReqBody, S::Error>,
{
action: Action,
future: &'a mut RedirectFuturePin<'a, S, ReqBody>,
service: &'a S,
policy: &'a mut P,
parts: &'a mut Parts,
body: ReqBody,
body_repr: &'a mut BodyRepr<ReqBody>,
res: Response<ResBody>,
location: Uri,
}
fn handle_action<S, ReqBody, ResBody, P>(
cx: &mut Context<'_>,
redirect: RedirectAction<'_, S, ReqBody, ResBody, P>,
) -> Poll<Result<Response<ResBody>, S::Error>>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
S::Error: From<BoxError>,
P: Policy<ReqBody, S::Error>,
ReqBody: Body + Default,
{
match redirect.action {
Action::Follow => {
redirect.parts.uri = redirect.location;
redirect
.body_repr
.try_clone_from(&redirect.body, redirect.policy);
let mut req = Request::from_parts(redirect.parts.clone(), redirect.body);
redirect.policy.on_request(&mut req);
redirect
.future
.set(Either::Right(Oneshot::new(redirect.service.clone(), req)));
cx.waker().wake_by_ref();
Poll::Pending
}
Action::Stop => Poll::Ready(Ok(redirect.res)),
Action::Pending(_) => Poll::Ready(Err(S::Error::from(
Error::redirect(
"Nested pending Action is not supported in redirect policy",
redirect.parts.uri.clone(),
)
.into(),
))),
Action::Error(err) => Poll::Ready(Err(err.into())),
}
}