pub mod policy;
use self::policy::{Action, Attempt, Policy, Standard};
use http::{
header::LOCATION, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri, Version,
};
use http_body::Body;
use iri_string::types::{UriAbsoluteString, UriReferenceStr};
use std::{convert::TryFrom, mem, str};
use tower_async_layer::Layer;
use tower_async_service::Service;
#[derive(Clone, Copy, Debug, Default)]
pub struct FollowRedirectLayer<P = Standard> {
policy: P,
}
impl FollowRedirectLayer {
pub fn new() -> Self {
Self::default()
}
}
impl<P> FollowRedirectLayer<P> {
pub fn with_policy(policy: P) -> Self {
FollowRedirectLayer { policy }
}
}
impl<S, P> Layer<S> for FollowRedirectLayer<P>
where
S: Clone,
P: Clone,
{
type Service = FollowRedirect<S, P>;
fn layer(&self, inner: S) -> Self::Service {
FollowRedirect::with_policy(inner, self.policy.clone())
}
}
#[derive(Clone, Copy, Debug)]
pub struct FollowRedirect<S, P = Standard> {
inner: S,
policy: P,
}
impl<S> FollowRedirect<S> {
pub fn new(inner: S) -> Self {
Self::with_policy(inner, Standard::default())
}
pub fn layer() -> FollowRedirectLayer {
FollowRedirectLayer::new()
}
}
impl<S, P> FollowRedirect<S, P>
where
P: Clone,
{
pub fn with_policy(inner: S, policy: P) -> Self {
FollowRedirect { inner, policy }
}
pub fn layer_with_policy(policy: P) -> FollowRedirectLayer<P> {
FollowRedirectLayer::with_policy(policy)
}
define_inner_service_accessors!();
}
struct RedirectServiceState<B> {
method: Method,
uri: Uri,
version: Version,
headers: HeaderMap<HeaderValue>,
body: BodyRepr<B>,
}
impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ReqBody: Body + Default,
P: Policy<ReqBody, S::Error>,
{
type Response = Response<ResBody>;
type Error = S::Error;
async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
let mut this = RedirectServiceState {
method: req.method().clone(),
uri: req.uri().clone(),
version: req.version(),
headers: req.headers().clone(),
body: BodyRepr::None,
};
this.body.try_clone_from(req.body(), &self.policy);
self.policy.on_request(&mut req);
loop {
let mut res = self.inner.call(req).await?;
res.extensions_mut().insert(RequestUri(this.uri.clone()));
match res.status() {
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
if this.method == Method::POST {
this.method = Method::GET;
this.body = BodyRepr::Empty;
}
}
StatusCode::SEE_OTHER => {
if this.method != Method::HEAD {
this.method = Method::GET;
}
this.body = BodyRepr::Empty;
}
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
_ => return Ok(res),
};
let body = if let Some(body) = this.body.take() {
body
} else {
return Ok(res);
};
let location = res
.headers()
.get(&LOCATION)
.and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, &this.uri));
let location = if let Some(loc) = location {
loc
} else {
return Ok(res);
};
let attempt = Attempt {
status: res.status(),
location: &location,
previous: &this.uri,
};
match self.policy.redirect(&attempt)? {
Action::Follow => {
this.uri = location;
this.body.try_clone_from(&body, &self.policy);
req = Request::new(body);
*req.uri_mut() = this.uri.clone();
*req.method_mut() = this.method.clone();
*req.version_mut() = this.version;
*req.headers_mut() = this.headers.clone();
self.policy.on_request(&mut req);
}
Action::Stop => return Ok(res),
}
}
}
}
#[derive(Clone, Debug)]
pub struct RequestUri(pub Uri);
#[derive(Debug)]
enum BodyRepr<B> {
Some(B),
Empty,
None,
}
impl<B> BodyRepr<B>
where
B: Body + Default,
{
fn take(&mut self) -> Option<B> {
match mem::replace(self, BodyRepr::None) {
BodyRepr::Some(body) => Some(body),
BodyRepr::Empty => {
*self = BodyRepr::Empty;
Some(B::default())
}
BodyRepr::None => None,
}
}
fn try_clone_from<P, E>(&mut self, body: &B, policy: &P)
where
P: Policy<B, E>,
{
match self {
BodyRepr::Some(_) | BodyRepr::Empty => {}
BodyRepr::None => {
if let Some(body) = clone_body(policy, body) {
*self = BodyRepr::Some(body);
}
}
}
}
}
fn clone_body<P, B, E>(policy: &P, body: &B) -> Option<B>
where
P: Policy<B, E>,
B: Body + Default,
{
if body.size_hint().exact() == Some(0) {
Some(B::default())
} else {
policy.clone_body(body)
}
}
fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
let relative = UriReferenceStr::new(relative).ok()?;
let base = UriAbsoluteString::try_from(base.to_string()).ok()?;
let uri = relative.resolve_against(&base).to_string();
Uri::try_from(uri).ok()
}
#[cfg(test)]
mod tests {
use super::{policy::*, *};
use crate::test_helpers::Body;
use http::header::LOCATION;
use std::convert::Infallible;
use tower_async::{ServiceBuilder, ServiceExt};
#[tokio::test]
async fn follows() {
let svc = ServiceBuilder::new()
.layer(FollowRedirectLayer::with_policy(Action::Follow))
.service_fn(handle);
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(*res.body(), 0);
assert_eq!(
res.extensions().get::<RequestUri>().unwrap().0,
"http://example.com/0"
);
}
#[tokio::test]
async fn stops() {
let svc = ServiceBuilder::new()
.layer(FollowRedirectLayer::with_policy(Action::Stop))
.service_fn(handle);
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(*res.body(), 42);
assert_eq!(
res.extensions().get::<RequestUri>().unwrap().0,
"http://example.com/42"
);
}
#[tokio::test]
async fn limited() {
let svc = ServiceBuilder::new()
.layer(FollowRedirectLayer::with_policy(Limited::new(10)))
.service_fn(handle);
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(*res.body(), 42 - 10);
assert_eq!(
res.extensions().get::<RequestUri>().unwrap().0,
"http://example.com/32"
);
}
async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
let n: u64 = req.uri().path()[1..].parse().unwrap();
let mut res = Response::builder();
if n > 0 {
res = res
.status(StatusCode::MOVED_PERMANENTLY)
.header(LOCATION, format!("/{}", n - 1));
}
Ok::<_, Infallible>(res.body(n).unwrap())
}
}