use std::fmt;
use header::{
HeaderMap,
AUTHORIZATION,
COOKIE,
PROXY_AUTHORIZATION,
WWW_AUTHENTICATE,
};
use hyper::StatusCode;
use Url;
#[derive(Debug)]
pub struct RedirectPolicy {
inner: Policy,
}
#[derive(Debug)]
pub struct RedirectAttempt<'a> {
status: StatusCode,
next: &'a Url,
previous: &'a [Url],
}
#[derive(Debug)]
pub struct RedirectAction {
inner: Action,
}
impl RedirectPolicy {
pub fn limited(max: usize) -> RedirectPolicy {
RedirectPolicy {
inner: Policy::Limit(max),
}
}
pub fn none() -> RedirectPolicy {
RedirectPolicy {
inner: Policy::None,
}
}
pub fn custom<T>(policy: T) -> RedirectPolicy
where
T: Fn(RedirectAttempt) -> RedirectAction + Send + Sync + 'static,
{
RedirectPolicy {
inner: Policy::Custom(Box::new(policy)),
}
}
pub fn redirect(&self, attempt: RedirectAttempt) -> RedirectAction {
match self.inner {
Policy::Custom(ref custom) => custom(attempt),
Policy::Limit(max) => {
if attempt.previous.len() == max {
attempt.too_many_redirects()
} else if attempt.previous.contains(attempt.next) {
attempt.loop_detected()
} else {
attempt.follow()
}
}
Policy::None => attempt.stop(),
}
}
pub(crate) fn check(
&self,
status: StatusCode,
next: &Url,
previous: &[Url],
) -> Action {
self
.redirect(RedirectAttempt {
status: status,
next: next,
previous: previous,
})
.inner
}
}
impl Default for RedirectPolicy {
fn default() -> RedirectPolicy {
RedirectPolicy::limited(10)
}
}
impl<'a> RedirectAttempt<'a> {
pub fn status(&self) -> StatusCode {
self.status
}
pub fn url(&self) -> &Url {
self.next
}
pub fn previous(&self) -> &[Url] {
self.previous
}
pub fn follow(self) -> RedirectAction {
RedirectAction {
inner: Action::Follow,
}
}
pub fn stop(self) -> RedirectAction {
RedirectAction {
inner: Action::Stop,
}
}
pub fn loop_detected(self) -> RedirectAction {
RedirectAction {
inner: Action::LoopDetected,
}
}
pub fn too_many_redirects(self) -> RedirectAction {
RedirectAction {
inner: Action::TooManyRedirects,
}
}
}
enum Policy {
Custom(Box<Fn(RedirectAttempt) -> RedirectAction + Send + Sync + 'static>),
Limit(usize),
None,
}
impl fmt::Debug for Policy {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Policy::Custom(..) => f.pad("Custom"),
Policy::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
Policy::None => f.pad("None"),
}
}
}
#[derive(Debug, PartialEq)]
pub(crate) enum Action {
Follow,
Stop,
LoopDetected,
TooManyRedirects,
}
pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &[Url]) {
if let Some(previous) = previous.last() {
let cross_host = next.host_str() != previous.host_str() ||
next.port_or_known_default() != previous.port_or_known_default();
if cross_host {
headers.remove(AUTHORIZATION);
headers.remove(COOKIE);
headers.remove("cookie2");
headers.remove(PROXY_AUTHORIZATION);
headers.remove(WWW_AUTHENTICATE);
}
}
}
#[test]
fn test_redirect_policy_limit() {
let policy = RedirectPolicy::default();
let next = Url::parse("http://x.y/z").unwrap();
let mut previous = (0..9)
.map(|i| Url::parse(&format!("http://a.b/c/{}", i)).unwrap())
.collect::<Vec<_>>();
assert_eq!(
policy.check(StatusCode::FOUND, &next, &previous),
Action::Follow
);
previous.push(Url::parse("http://a.b.d/e/33").unwrap());
assert_eq!(
policy.check(StatusCode::FOUND, &next, &previous),
Action::TooManyRedirects
);
}
#[test]
fn test_redirect_policy_custom() {
let policy = RedirectPolicy::custom(|attempt| {
if attempt.url().host_str() == Some("foo") {
attempt.stop()
} else {
attempt.follow()
}
});
let next = Url::parse("http://bar/baz").unwrap();
assert_eq!(
policy.check(StatusCode::FOUND, &next, &[]),
Action::Follow
);
let next = Url::parse("http://foo/baz").unwrap();
assert_eq!(
policy.check(StatusCode::FOUND, &next, &[]),
Action::Stop
);
}
#[test]
fn test_remove_sensitive_headers() {
use hyper::header::{ACCEPT, AUTHORIZATION, COOKIE, HeaderValue};
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
let next = Url::parse("http://initial-domain.com/path").unwrap();
let mut prev = vec![Url::parse("http://initial-domain.com/new_path").unwrap()];
let mut filtered_headers = headers.clone();
remove_sensitive_headers(&mut headers, &next, &prev);
assert_eq!(headers, filtered_headers);
prev.push(Url::parse("http://new-domain.com/path").unwrap());
filtered_headers.remove(AUTHORIZATION);
filtered_headers.remove(COOKIE);
remove_sensitive_headers(&mut headers, &next, &prev);
assert_eq!(headers, filtered_headers);
}