pub mod policy;
use crate::{Method, Request, Response, StatusCode, Uri, dep::http_body::Body, header::LOCATION};
use iri_string::types::{UriAbsoluteString, UriReferenceStr};
use rama_core::{Context, Layer, Service};
use rama_utils::macros::define_inner_service_accessors;
use std::fmt;
use self::policy::{Action, Attempt, Policy, Standard};
#[derive(Clone)]
pub struct FollowRedirectLayer<P = Standard> {
policy: P,
}
impl FollowRedirectLayer {
pub fn new() -> Self {
Self::default()
}
}
impl Default for FollowRedirectLayer {
fn default() -> Self {
FollowRedirectLayer {
policy: Standard::default(),
}
}
}
impl<P: fmt::Debug> fmt::Debug for FollowRedirectLayer<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FollowRedirectLayer")
.field("policy", &self.policy)
.finish()
}
}
impl<P> FollowRedirectLayer<P> {
pub fn with_policy(policy: P) -> Self {
FollowRedirectLayer { policy }
}
}
impl<S, P> Layer<S> for FollowRedirectLayer<P>
where
S: Clone,
P: Clone,
{
type Service = FollowRedirect<S, P>;
fn layer(&self, inner: S) -> Self::Service {
FollowRedirect {
inner,
policy: self.policy.clone(),
}
}
fn into_layer(self, inner: S) -> Self::Service {
FollowRedirect {
inner,
policy: self.policy,
}
}
}
pub struct FollowRedirect<S, P = Standard> {
inner: S,
policy: P,
}
impl<S> FollowRedirect<S> {
pub fn new(inner: S) -> Self {
Self::with_policy(inner, Standard::default())
}
}
impl<S, P> fmt::Debug for FollowRedirect<S, P>
where
S: fmt::Debug,
P: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FollowRedirect")
.field("inner", &self.inner)
.field("policy", &self.policy)
.finish()
}
}
impl<S, P> Clone for FollowRedirect<S, P>
where
S: Clone,
P: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
policy: self.policy.clone(),
}
}
}
impl<S, P> FollowRedirect<S, P> {
pub fn with_policy(inner: S, policy: P) -> Self {
FollowRedirect { inner, policy }
}
define_inner_service_accessors!();
}
impl<State, ReqBody, ResBody, S, P> Service<State, Request<ReqBody>> for FollowRedirect<S, P>
where
State: Clone + Send + Sync + 'static,
S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
ReqBody: Body + Default + Send + 'static,
ResBody: Send + 'static,
P: Policy<State, ReqBody, S::Error> + Clone,
{
type Response = Response<ResBody>;
type Error = S::Error;
fn serve(
&self,
mut ctx: Context<State>,
mut req: Request<ReqBody>,
) -> impl Future<Output = Result<Self::Response, Self::Error>> {
let mut method = req.method().clone();
let mut uri = req.uri().clone();
let version = req.version();
let headers = req.headers().clone();
let mut policy = self.policy.clone();
let mut body = BodyRepr::None;
body.try_clone_from(&ctx, &mut policy, req.body());
policy.on_request(&mut ctx, &mut req);
let service = &self.inner;
async move {
loop {
let mut res = service.serve(ctx.clone(), req).await?;
res.extensions_mut().insert(RequestUri(uri.clone()));
match res.status() {
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
if method == Method::POST {
method = Method::GET;
body = BodyRepr::Empty;
}
}
StatusCode::SEE_OTHER => {
if method != Method::HEAD {
method = Method::GET;
}
body = BodyRepr::Empty;
}
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
_ => return Ok(res),
};
let taken_body = if let Some(body) = body.take() {
body
} else {
return Ok(res);
};
let location = res
.headers()
.get(&LOCATION)
.and_then(|loc| resolve_uri(std::str::from_utf8(loc.as_bytes()).ok()?, &uri));
let location = if let Some(loc) = location {
loc
} else {
return Ok(res);
};
let attempt = Attempt {
status: res.status(),
location: &location,
previous: &uri,
};
match policy.redirect(&ctx, &attempt)? {
Action::Follow => {
uri = location;
body.try_clone_from(&ctx, &mut policy, &taken_body);
req = Request::new(taken_body);
*req.uri_mut() = uri.clone();
*req.method_mut() = method.clone();
*req.version_mut() = version;
*req.headers_mut() = headers.clone();
policy.on_request(&mut ctx, &mut req);
}
Action::Stop => return Ok(res),
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct RequestUri(pub Uri);
#[derive(Debug)]
enum BodyRepr<B> {
Some(B),
Empty,
None,
}
impl<B> BodyRepr<B>
where
B: Body + Default,
{
fn take(&mut self) -> Option<B> {
match std::mem::replace(self, BodyRepr::None) {
BodyRepr::Some(body) => Some(body),
BodyRepr::Empty => {
*self = BodyRepr::Empty;
Some(B::default())
}
BodyRepr::None => None,
}
}
fn try_clone_from<S, P, E>(&mut self, ctx: &Context<S>, policy: &mut P, body: &B)
where
P: Policy<S, B, E>,
{
match self {
BodyRepr::Some(_) | BodyRepr::Empty => {}
BodyRepr::None => {
if let Some(body) = clone_body(ctx, policy, body) {
*self = BodyRepr::Some(body);
}
}
}
}
}
fn clone_body<S, P, B, E>(ctx: &Context<S>, policy: &mut P, body: &B) -> Option<B>
where
P: Policy<S, B, E>,
B: Body + Default,
{
if body.size_hint().exact() == Some(0) {
Some(B::default())
} else {
policy.clone_body(ctx, body)
}
}
fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
let relative = UriReferenceStr::new(relative).ok()?;
let base = UriAbsoluteString::try_from(base.to_string()).ok()?;
let uri = relative.resolve_against(&base).to_string();
Uri::try_from(uri).ok()
}
#[cfg(test)]
mod tests {
use super::{policy::*, *};
use crate::{Body, header::LOCATION};
use rama_core::Layer;
use rama_core::service::service_fn;
use std::convert::Infallible;
#[tokio::test]
async fn follows() {
let svc = FollowRedirectLayer::with_policy(Action::Follow).into_layer(service_fn(handle));
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(*res.body(), 0);
assert_eq!(
res.extensions().get::<RequestUri>().unwrap().0,
"http://example.com/0"
);
}
#[tokio::test]
async fn stops() {
let svc = FollowRedirectLayer::with_policy(Action::Stop).into_layer(service_fn(handle));
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(*res.body(), 42);
assert_eq!(
res.extensions().get::<RequestUri>().unwrap().0,
"http://example.com/42"
);
}
#[tokio::test]
async fn limited() {
let svc = FollowRedirectLayer::with_policy(Limited::new(10)).into_layer(service_fn(handle));
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(*res.body(), 42 - 10);
assert_eq!(
res.extensions().get::<RequestUri>().unwrap().0,
"http://example.com/32"
);
}
async fn handle<S, B>(_ctx: Context<S>, req: Request<B>) -> Result<Response<u64>, Infallible> {
let n: u64 = req.uri().path()[1..].parse().unwrap();
let mut res = Response::builder();
if n > 0 {
res = res
.status(StatusCode::MOVED_PERMANENTLY)
.header(LOCATION, format!("/{}", n - 1));
}
Ok::<_, Infallible>(res.body(n).unwrap())
}
}