pub mod policy;
use crate::{Method, Request, Response, StatusCode, StreamingBody, header::LOCATION};
use iri_string::types::{UriAbsoluteString, UriReferenceStr};
use rama_core::{
Layer, Service,
extensions::{Extension, Extensions, ExtensionsRef},
};
use rama_http_types::{
HeaderMap,
header::{CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING},
};
use rama_net::uri::Uri;
use rama_utils::macros::{define_inner_service_accessors, generate_set_and_with};
use std::fmt;
use self::policy::{Action, Attempt, Policy, Standard};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum RedirectExtensionsBehaviour {
#[default]
Preserve,
Fork,
Drop,
}
impl RedirectExtensionsBehaviour {
fn redirect_extensions(self, source: &Extensions) -> Extensions {
match self {
Self::Preserve => source.clone(),
Self::Fork => source.fork(),
Self::Drop => Extensions::new(),
}
}
}
#[derive(Clone)]
pub struct FollowRedirectLayer<P = Standard> {
policy: P,
extensions_behaviour: RedirectExtensionsBehaviour,
}
impl FollowRedirectLayer {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl Default for FollowRedirectLayer {
fn default() -> Self {
Self::with_policy(Standard::default())
}
}
impl<P: fmt::Debug> fmt::Debug for FollowRedirectLayer<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FollowRedirectLayer")
.field("policy", &self.policy)
.field("extensions_behaviour", &self.extensions_behaviour)
.finish()
}
}
impl<P> FollowRedirectLayer<P> {
pub fn with_policy(policy: P) -> Self {
Self {
policy,
extensions_behaviour: RedirectExtensionsBehaviour::default(),
}
}
generate_set_and_with! {
pub fn redirect_extensions_behaviour(
mut self,
behaviour: RedirectExtensionsBehaviour,
) -> Self {
self.extensions_behaviour = behaviour;
self
}
}
}
impl<S, P> Layer<S> for FollowRedirectLayer<P>
where
S: Clone,
P: Clone,
{
type Service = FollowRedirect<S, P>;
fn layer(&self, inner: S) -> Self::Service {
FollowRedirect {
inner,
policy: self.policy.clone(),
extensions_behaviour: self.extensions_behaviour,
}
}
fn into_layer(self, inner: S) -> Self::Service {
FollowRedirect {
inner,
policy: self.policy,
extensions_behaviour: self.extensions_behaviour,
}
}
}
#[derive(Debug, Clone)]
pub struct FollowRedirect<S, P = Standard> {
inner: S,
policy: P,
extensions_behaviour: RedirectExtensionsBehaviour,
}
impl<S> FollowRedirect<S> {
pub fn new(inner: S) -> Self {
Self::with_policy(inner, Standard::default())
}
}
impl<S, P> FollowRedirect<S, P> {
pub fn with_policy(inner: S, policy: P) -> Self {
Self {
inner,
policy,
extensions_behaviour: RedirectExtensionsBehaviour::default(),
}
}
generate_set_and_with! {
pub fn redirect_extensions_behaviour(
mut self,
behaviour: RedirectExtensionsBehaviour,
) -> Self {
self.extensions_behaviour = behaviour;
self
}
}
define_inner_service_accessors!();
}
impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P>
where
S: Service<Request<ReqBody>, Output = Response<ResBody>>,
ReqBody: StreamingBody + Default + Send + 'static,
ResBody: Send + 'static,
P: Policy<ReqBody, S::Error> + Clone,
{
type Output = Response<ResBody>;
type Error = S::Error;
fn serve(
&self,
mut req: Request<ReqBody>,
) -> impl Future<Output = Result<Self::Output, Self::Error>> {
let mut method = req.method().clone();
let mut uri = req.uri().clone();
let version = req.version();
let mut headers = req.headers().clone();
let mut policy = self.policy.clone();
let mut body = BodyRepr::None;
body.try_clone_from(&mut policy, req.body());
policy.on_request(&mut req);
let extensions_behaviour = self.extensions_behaviour;
let extensions_source = req.extensions().clone();
let service = &self.inner;
async move {
loop {
let res = service.serve(req).await?;
res.extensions().insert(RequestUri(uri.clone()));
let previous_method = method.clone();
let drop_payload_headers = |headers: &mut HeaderMap| {
for header in &[
CONTENT_TYPE,
CONTENT_LENGTH,
CONTENT_ENCODING,
TRANSFER_ENCODING,
] {
headers.remove(header);
}
};
match res.status() {
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
if method == Method::POST {
method = Method::GET;
body = BodyRepr::Empty;
drop_payload_headers(&mut headers);
}
}
StatusCode::SEE_OTHER => {
if method != Method::HEAD {
method = Method::GET;
}
body = BodyRepr::Empty;
drop_payload_headers(&mut headers);
}
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
_ => return Ok(res),
};
let Some(taken_body) = body.take() else {
return Ok(res);
};
let location = res
.headers()
.get(&LOCATION)
.and_then(|loc| resolve_uri(std::str::from_utf8(loc.as_bytes()).ok()?, &uri));
let Some(location) = location else {
return Ok(res);
};
let attempt = Attempt {
status: res.status(),
method: &method,
location: &location,
previous_method: &previous_method,
previous: &uri,
};
match policy.redirect(&attempt)? {
Action::Follow => {
uri = location;
body.try_clone_from(&mut policy, &taken_body);
req = Request::new(taken_body);
*req.uri_mut() = uri.clone();
*req.method_mut() = method.clone();
*req.version_mut() = version;
*req.headers_mut() = headers.clone();
req.set_extensions(
extensions_behaviour.redirect_extensions(&extensions_source),
);
policy.on_request(&mut req);
headers = req.headers().clone();
}
Action::Stop => return Ok(res),
}
}
}
}
}
#[derive(Debug, Clone, Extension)]
#[extension(tags(http))]
pub struct RequestUri(pub Uri);
#[derive(Debug)]
enum BodyRepr<B> {
Some(B),
Empty,
None,
}
impl<B> BodyRepr<B>
where
B: StreamingBody + Default,
{
fn take(&mut self) -> Option<B> {
match std::mem::replace(self, Self::None) {
Self::Some(body) => Some(body),
Self::Empty => {
*self = Self::Empty;
Some(B::default())
}
Self::None => None,
}
}
fn try_clone_from<P, E>(&mut self, policy: &mut P, body: &B)
where
P: Policy<B, E>,
{
match self {
Self::Some(_) | Self::Empty => {}
Self::None => {
if let Some(body) = clone_body(policy, body) {
*self = Self::Some(body);
}
}
}
}
}
fn clone_body<P, B, E>(policy: &mut P, body: &B) -> Option<B>
where
P: Policy<B, E>,
B: StreamingBody + Default,
{
if body.size_hint().exact() == Some(0) {
Some(B::default())
} else {
policy.clone_body(body)
}
}
fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
let relative = UriReferenceStr::new(relative).ok()?;
let base = UriAbsoluteString::try_from(base.to_string()).ok()?;
let uri = relative.resolve_against(&base).to_string();
Uri::try_from(uri).ok()
}
#[cfg(test)]
mod tests {
use super::{policy::*, *};
use crate::{Body, header::LOCATION};
use rama_core::Layer;
use rama_core::extensions::ExtensionsRef;
use rama_core::service::service_fn;
use std::convert::Infallible;
#[tokio::test]
async fn follows() {
let svc = FollowRedirectLayer::with_policy(Action::Follow).into_layer(service_fn(handle));
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(*res.body(), 0);
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/0"
);
}
#[tokio::test]
async fn stops() {
let svc = FollowRedirectLayer::with_policy(Action::Stop).into_layer(service_fn(handle));
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(*res.body(), 42);
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/42"
);
}
#[tokio::test]
async fn limited() {
let svc = FollowRedirectLayer::with_policy(Limited::new(10)).into_layer(service_fn(handle));
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(*res.body(), 42 - 10);
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/32"
);
}
async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
let n: u64 = req
.uri()
.first_path_segment()
.and_then(|segment| segment.as_encoded_str().parse().ok())
.unwrap();
let mut res = Response::builder();
if n > 0 {
res = res
.status(StatusCode::MOVED_PERMANENTLY)
.header(LOCATION, format!("/{}", n - 1));
}
Ok::<_, Infallible>(res.body(n).unwrap())
}
#[derive(Clone, Debug, PartialEq, rama_core::extensions::Extension)]
struct Marker(u32);
async fn handle_marker<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
let n: u64 = req
.uri()
.first_path_segment()
.and_then(|segment| segment.as_encoded_str().parse().ok())
.unwrap();
let mut res = Response::builder();
if n > 0 {
res = res
.status(StatusCode::MOVED_PERMANENTLY)
.header(LOCATION, format!("/{}", n - 1));
}
let res = res.body(n).unwrap();
if let Some(marker) = req.extensions().get_ref::<Marker>() {
res.extensions().insert(marker.clone());
}
Ok::<_, Infallible>(res)
}
#[tokio::test]
async fn preserves_extensions_by_default() {
let svc = FollowRedirectLayer::new().into_layer(service_fn(handle_marker));
let req = Request::builder()
.uri("http://example.com/3")
.body(Body::empty())
.unwrap();
req.extensions().insert(Marker(7));
let res = svc.serve(req).await.unwrap();
assert_eq!(res.extensions().get_ref::<Marker>(), Some(&Marker(7)));
}
#[tokio::test]
async fn preserve_shares_extensions() {
let svc = FollowRedirectLayer::new()
.with_redirect_extensions_behaviour(RedirectExtensionsBehaviour::Preserve)
.into_layer(service_fn(handle_marker));
let req = Request::builder()
.uri("http://example.com/3")
.body(Body::empty())
.unwrap();
req.extensions().insert(Marker(7));
let res = svc.serve(req).await.unwrap();
assert_eq!(res.extensions().get_ref::<Marker>(), Some(&Marker(7)));
}
async fn handle_cookie_chain<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
let host = req.uri().host().map(|h| h.to_string());
let path = req.uri().path_ref_or_root();
let location = if host.as_deref() == Some("a.example.com") {
Some("http://b.example.com/second")
} else if host.as_deref() == Some("b.example.com") && path == "/second" {
Some("http://b.example.com/final")
} else {
None
};
let mut res = Response::builder();
if let Some(location) = location {
res = res
.status(StatusCode::MOVED_PERMANENTLY)
.header(LOCATION, location);
}
let mut res = res.body(0u64).unwrap();
if req.headers().contains_key(crate::header::COOKIE) {
res.headers_mut()
.insert("x-saw-cookie", crate::HeaderValue::from_static("1"));
}
Ok::<_, Infallible>(res)
}
#[tokio::test]
async fn credentials_do_not_resurrect_after_cross_origin() {
let svc = FollowRedirectLayer::default().into_layer(service_fn(handle_cookie_chain));
let req = Request::builder()
.uri("http://a.example.com/")
.header(crate::header::COOKIE, "session=secret")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert!(
!res.headers().contains_key("x-saw-cookie"),
"Cookie resurrected on a same-origin hop after being dropped cross-origin",
);
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://b.example.com/final"
);
}
#[tokio::test]
async fn drop_extensions_opt_out() {
let svc = FollowRedirectLayer::new()
.with_redirect_extensions_behaviour(RedirectExtensionsBehaviour::Drop)
.into_layer(service_fn(handle_marker));
let req = Request::builder()
.uri("http://example.com/3")
.body(Body::empty())
.unwrap();
req.extensions().insert(Marker(7));
let res = svc.serve(req).await.unwrap();
assert!(res.extensions().get_ref::<Marker>().is_none());
}
#[tokio::test]
async fn test_301_redirects() {
let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
if attempt.previous_method() == Method::POST && attempt.method() == Method::GET {
Ok(Action::Stop)
} else {
Ok(Action::Follow)
}
});
let svc = FollowRedirectLayer::with_policy(policy).into_layer(service_fn(redirections));
{
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com/301")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/301");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/301"
);
}
{
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/301")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/301/final");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/target/301"
);
}
}
#[tokio::test]
async fn test_302_redirects() {
let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
if attempt.previous_method() != attempt.method() {
Ok(Action::Stop)
} else {
Ok(Action::Follow)
}
});
let svc = FollowRedirectLayer::with_policy(policy).into_layer(service_fn(redirections));
{
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com/302")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/302");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/302"
);
}
{
let req = Request::builder()
.method(Method::PUT)
.uri("http://example.com/302")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/302/final");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/target/302"
);
}
{
let req = Request::builder()
.method(Method::HEAD)
.uri("http://example.com/302")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/302/final");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/target/302"
);
}
}
#[tokio::test]
async fn test_303_redirects() {
let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
if attempt.previous_method() != attempt.method() {
Ok(Action::Stop)
} else {
Ok(Action::Follow)
}
});
let svc = FollowRedirectLayer::with_policy(policy).into_layer(service_fn(redirections));
{
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com/303")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/303");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/303"
);
}
{
let req = Request::builder()
.method(Method::PUT)
.uri("http://example.com/303")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/303");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/303"
);
}
{
let req = Request::builder()
.method(Method::HEAD)
.uri("http://example.com/303")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/303/final");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/target/303"
);
}
}
#[tokio::test]
async fn test_307_308_redirects() {
let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
if attempt.previous_method() != Method::POST || attempt.method() != Method::POST {
Ok(Action::Stop)
} else {
Ok(Action::Follow)
}
});
let svc = FollowRedirectLayer::with_policy(policy).into_layer(service_fn(redirections));
{
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com/307")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/307/final");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/target/307"
);
}
{
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com/308")
.body(Body::empty())
.unwrap();
let res = svc.clone().serve(req).await.unwrap();
assert_eq!(*res.body(), "/target/308/final");
assert_eq!(
res.extensions().get_ref::<RequestUri>().unwrap().0,
"http://example.com/target/308"
);
}
}
async fn redirections<B>(req: Request<B>) -> Result<Response<String>, Infallible> {
let path = req.uri().path_ref_or_root();
let mut res = Response::builder();
let body_str;
res = if path == "/301" {
let case = "/target/301";
body_str = case.to_owned();
res.status(StatusCode::MOVED_PERMANENTLY)
.header(LOCATION, case)
} else if path == "/302" {
let case = "/target/302";
body_str = case.to_owned();
res.status(StatusCode::FOUND).header(LOCATION, case)
} else if path == "/303" {
let case = "/target/303";
body_str = case.to_owned();
res.status(StatusCode::SEE_OTHER).header(LOCATION, case)
} else if path == "/307" {
let case = "/target/307";
body_str = case.to_owned();
res.status(StatusCode::TEMPORARY_REDIRECT)
.header(LOCATION, case)
} else if path == "/308" {
let case = "/target/308";
body_str = case.to_owned();
res.status(StatusCode::PERMANENT_REDIRECT)
.header(LOCATION, case)
} else {
body_str = format!("{path}/final");
res.status(StatusCode::OK)
};
Ok::<_, Infallible>(res.body(body_str).unwrap())
}
}