use std::{borrow::Cow, error::Error as StdError, fmt, sync::Arc};
use bytes::Bytes;
use futures_util::FutureExt;
use http::{HeaderMap, HeaderName, HeaderValue, StatusCode, Uri};
use crate::{
client::{Body, layer::redirect},
config::RequestConfig,
error::{BoxError, Error},
ext::UriExt,
header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, REFERER, WWW_AUTHENTICATE},
};
#[derive(Debug, Clone)]
pub struct Policy {
inner: PolicyKind,
}
#[derive(Debug)]
#[non_exhaustive]
pub struct Attempt<'a, const PENDING: bool = true> {
pub status: StatusCode,
pub headers: Cow<'a, HeaderMap>,
pub uri: Cow<'a, Uri>,
pub previous: Cow<'a, [Uri]>,
}
#[derive(Debug)]
pub struct Action {
inner: redirect::Action,
}
#[derive(Debug, Clone)]
pub struct History(Vec<HistoryEntry>);
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct HistoryEntry {
pub status: StatusCode,
pub uri: Uri,
pub previous: Uri,
pub headers: HeaderMap,
}
#[derive(Clone)]
enum PolicyKind {
Custom(Arc<dyn Fn(Attempt) -> Action + Send + Sync + 'static>),
Limit(usize),
None,
}
#[derive(Debug)]
struct TooManyRedirects;
#[derive(Clone)]
pub(crate) struct FollowRedirectPolicy {
policy: RequestConfig<Policy>,
referer: bool,
uris: Vec<Uri>,
https_only: bool,
history: Option<Vec<HistoryEntry>>,
}
impl Policy {
#[inline]
pub fn limited(max: usize) -> Self {
Self {
inner: PolicyKind::Limit(max),
}
}
#[inline]
pub fn none() -> Self {
Self {
inner: PolicyKind::None,
}
}
#[inline]
pub fn custom<T>(policy: T) -> Self
where
T: Fn(Attempt) -> Action + Send + Sync + 'static,
{
Self {
inner: PolicyKind::Custom(Arc::new(policy)),
}
}
pub fn redirect(&self, attempt: Attempt) -> Action {
match self.inner {
PolicyKind::Custom(ref custom) => custom(attempt),
PolicyKind::Limit(max) => {
if attempt.previous.len() > max {
attempt.error(TooManyRedirects)
} else {
attempt.follow()
}
}
PolicyKind::None => attempt.stop(),
}
}
#[inline]
fn check(
&self,
status: StatusCode,
headers: &HeaderMap,
next: &Uri,
previous: &[Uri],
) -> redirect::Action {
self.redirect(Attempt {
status,
headers: Cow::Borrowed(headers),
uri: Cow::Borrowed(next),
previous: Cow::Borrowed(previous),
})
.inner
}
}
impl Default for Policy {
#[inline]
fn default() -> Policy {
Policy::limited(10)
}
}
impl_request_config_value!(Policy);
impl<const PENDING: bool> Attempt<'_, PENDING> {
#[inline]
pub fn follow(self) -> Action {
Action {
inner: redirect::Action::Follow,
}
}
#[inline]
pub fn stop(self) -> Action {
Action {
inner: redirect::Action::Stop,
}
}
#[inline]
pub fn error<E: Into<BoxError>>(self, error: E) -> Action {
Action {
inner: redirect::Action::Error(error.into()),
}
}
}
impl Attempt<'_, true> {
pub fn pending<F, Fut>(self, func: F) -> Action
where
F: FnOnce(Attempt<'static, false>) -> Fut + Send + 'static,
Fut: Future<Output = Action> + Send + 'static,
{
let attempt = Attempt {
status: self.status,
headers: Cow::Owned(self.headers.into_owned()),
uri: Cow::Owned(self.uri.into_owned()),
previous: Cow::Owned(self.previous.into_owned()),
};
let pending = Box::pin(func(attempt).map(|action| action.inner));
Action {
inner: redirect::Action::Pending(pending),
}
}
}
impl IntoIterator for History {
type Item = HistoryEntry;
type IntoIter = std::vec::IntoIter<HistoryEntry>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a> IntoIterator for &'a History {
type Item = &'a HistoryEntry;
type IntoIter = std::slice::Iter<'a, HistoryEntry>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
impl fmt::Debug for PolicyKind {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
PolicyKind::Custom(..) => f.pad("Custom"),
PolicyKind::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
PolicyKind::None => f.pad("None"),
}
}
}
impl fmt::Display for TooManyRedirects {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("too many redirects")
}
}
impl StdError for TooManyRedirects {}
impl FollowRedirectPolicy {
pub fn new(policy: Policy) -> Self {
Self {
policy: RequestConfig::new(Some(policy)),
referer: false,
uris: Vec::new(),
https_only: false,
history: None,
}
}
#[inline]
pub fn with_referer(mut self, referer: bool) -> Self {
self.referer = referer;
self
}
#[inline]
pub fn with_https_only(mut self, https_only: bool) -> Self {
self.https_only = https_only;
self
}
}
impl redirect::Policy<Body, BoxError> for FollowRedirectPolicy {
fn redirect(&mut self, attempt: redirect::Attempt<'_>) -> Result<redirect::Action, BoxError> {
let previous_uri = attempt.previous;
let next_uri = attempt.location;
self.uris.push(previous_uri.clone());
let policy = self
.policy
.as_ref()
.expect("[BUG] FollowRedirectPolicy should always have a policy set");
match policy.check(attempt.status, attempt.headers, next_uri, &self.uris) {
redirect::Action::Follow => {
if !(next_uri.is_http() || next_uri.is_https()) {
return Err(Error::uri_bad_scheme(next_uri.clone()).into());
}
if self.https_only && !next_uri.is_https() {
return Err(Error::redirect(
Error::uri_bad_scheme(next_uri.clone()),
next_uri.clone(),
)
.into());
}
if !matches!(policy.inner, PolicyKind::None) {
self.history.get_or_insert_default().push(HistoryEntry {
status: attempt.status,
uri: attempt.location.clone(),
previous: attempt.previous.clone(),
headers: attempt.headers.clone(),
});
}
Ok(redirect::Action::Follow)
}
redirect::Action::Stop => Ok(redirect::Action::Stop),
redirect::Action::Pending(task) => Ok(redirect::Action::Pending(task)),
redirect::Action::Error(err) => Err(Error::redirect(err, previous_uri.clone()).into()),
}
}
fn follow_redirects(&mut self, request: &mut http::Request<Body>) -> bool {
self.policy
.load(request.extensions_mut())
.is_some_and(|policy| !matches!(policy.inner, PolicyKind::None))
}
fn on_request(&mut self, req: &mut http::Request<Body>) {
let next_url = req.uri().clone();
remove_sensitive_headers(req.headers_mut(), &next_url, &self.uris);
if self.referer
&& let Some(previous_url) = self.uris.last()
&& let Some(v) = make_referer(next_url, previous_url)
{
req.headers_mut().insert(REFERER, v);
}
}
fn on_response<Body>(&mut self, response: &mut http::Response<Body>) {
if let Some(history) = self.history.take() {
response.extensions_mut().insert(History(history));
}
}
#[inline]
fn clone_body(&self, body: &Body) -> Option<Body> {
body.try_clone()
}
}
fn make_referer(next: Uri, previous: &Uri) -> Option<HeaderValue> {
if next.is_http() && previous.is_https() {
return None;
}
let mut referer = previous.clone();
referer.set_userinfo("", None);
HeaderValue::from_maybe_shared(Bytes::from(referer.to_string())).ok()
}
fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Uri, previous: &[Uri]) {
if let Some(previous) = previous.last() {
let cross_host = next.host() != previous.host()
|| next.port() != previous.port()
|| next.scheme() != previous.scheme();
if cross_host {
const COOKIE2: HeaderName = HeaderName::from_static("cookie2");
headers.remove(AUTHORIZATION);
headers.remove(COOKIE);
headers.remove(COOKIE2);
headers.remove(PROXY_AUTHORIZATION);
headers.remove(WWW_AUTHENTICATE);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_redirect_policy_limit() {
let policy = Policy::default();
let next = Uri::try_from("http://x.y/z").unwrap();
let mut previous = (0..=9)
.map(|i| Uri::try_from(&format!("http://a.b/c/{i}")).unwrap())
.collect::<Vec<_>>();
match policy.check(StatusCode::FOUND, &HeaderMap::new(), &next, &previous) {
redirect::Action::Follow => (),
other => panic!("unexpected {other:?}"),
}
previous.push(Uri::try_from("http://a.b.d/e/33").unwrap());
match policy.check(StatusCode::FOUND, &HeaderMap::new(), &next, &previous) {
redirect::Action::Error(err) if err.is::<TooManyRedirects>() => (),
other => panic!("unexpected {other:?}"),
}
}
#[test]
fn test_redirect_policy_limit_to_0() {
let policy = Policy::limited(0);
let next = Uri::try_from("http://x.y/z").unwrap();
let previous = vec![Uri::try_from("http://a.b/c").unwrap()];
match policy.check(StatusCode::FOUND, &HeaderMap::new(), &next, &previous) {
redirect::Action::Error(err) if err.is::<TooManyRedirects>() => (),
other => panic!("unexpected {other:?}"),
}
}
#[test]
fn test_redirect_policy_custom() {
let policy = Policy::custom(|attempt| {
if attempt.uri.host() == Some("foo") {
attempt.stop()
} else {
attempt.follow()
}
});
let next = Uri::try_from("http://bar/baz").unwrap();
match policy.check(StatusCode::FOUND, &HeaderMap::new(), &next, &[]) {
redirect::Action::Follow => (),
other => panic!("unexpected {other:?}"),
}
let next = Uri::try_from("http://foo/baz").unwrap();
match policy.check(StatusCode::FOUND, &HeaderMap::new(), &next, &[]) {
redirect::Action::Stop => (),
other => panic!("unexpected {other:?}"),
}
}
#[test]
fn test_remove_sensitive_headers() {
use http::header::{ACCEPT, AUTHORIZATION, COOKIE, HeaderValue};
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
let next = Uri::try_from("http://initial-domain.com/path").unwrap();
let mut prev = vec![Uri::try_from("http://initial-domain.com/new_path").unwrap()];
let mut filtered_headers = headers.clone();
remove_sensitive_headers(&mut headers, &next, &prev);
assert_eq!(headers, filtered_headers);
prev.push(Uri::try_from("http://new-domain.com/path").unwrap());
filtered_headers.remove(AUTHORIZATION);
filtered_headers.remove(COOKIE);
remove_sensitive_headers(&mut headers, &next, &prev);
assert_eq!(headers, filtered_headers);
}
}