use crate::Response;
use http::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, WWW_AUTHENTICATE};
use http::HeaderMap;
use std::path::PathBuf;
use std::str::FromStr;
#[derive(Clone, Debug)]
pub enum Policy {
Custom(fn(Attempt) -> Action),
Limit(usize),
None,
}
impl PartialEq for Policy {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Policy::Limit(a), Policy::Limit(b)) => *a == *b,
(Policy::None, Policy::None) => true,
(Policy::Custom(_), Policy::Custom(_)) => false,
_ => false,
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Attempt<'a> {
response: &'a Response,
previous: &'a [http::Uri],
}
#[derive(Clone, Debug, PartialEq)]
pub enum Action {
Follow(http::Uri),
Stop(http::Uri),
None,
}
impl Policy {
pub fn limited(max: usize) -> Self {
Policy::Limit(max)
}
pub fn none() -> Self {
Policy::None
}
pub fn custom(policy: fn(Attempt) -> Action) -> Self {
Policy::Custom(policy)
}
pub fn redirect(&self, attempt: Attempt) -> Action {
match self {
Policy::Custom(ref custom) => custom(attempt),
Policy::Limit(max) => match attempt.default_redirect() {
Some(next) => {
if attempt.previous.len() >= *max {
attempt.stop(next)
} else {
attempt.follow(next)
}
}
None => attempt.none(),
},
Policy::None => attempt.none(),
}
}
pub(crate) fn check(&self, response: &Response, previous: &[http::Uri]) -> Action {
self.redirect(Attempt { response, previous })
}
}
impl Default for Policy {
fn default() -> Policy {
Policy::limited(10)
}
}
impl Attempt<'_> {
pub fn response(&self) -> &Response {
self.response
}
pub fn url(&self) -> &http::Uri {
self.response.uri()
}
pub fn previous(&self) -> &[http::Uri] {
self.previous
}
pub fn follow(self, next: http::Uri) -> Action {
Action::Follow(next)
}
pub fn stop(self, next: http::Uri) -> Action {
Action::Stop(next)
}
pub fn none(self) -> Action {
Action::None
}
pub fn default_redirect(&self) -> Option<http::Uri> {
let cur_uri = self.response.uri();
let loc = self
.response
.headers()
.get(http::header::LOCATION)
.and_then(|val| {
let val = val.to_str().ok()?;
if val.starts_with("https://") || val.starts_with("http://") {
http::Uri::from_str(val).ok()
} else {
let path = PathBuf::from(cur_uri.path()).join(val);
http::Uri::builder()
.scheme(cur_uri.scheme_str().unwrap_or_default())
.authority(cur_uri.authority()?.as_str())
.path_and_query(path.to_string_lossy().as_ref())
.build()
.ok()
}
});
loc
}
}
pub fn only_same_host(attempt: Attempt) -> Action {
match attempt.default_redirect() {
Some(next) => {
let p = attempt.url();
if p.host() == next.host() {
if attempt.previous.len() > 10_usize {
attempt.stop(next)
} else {
attempt.follow(next)
}
} else {
attempt.stop(next)
}
}
None => Action::None,
}
}
pub(crate) fn remove_sensitive_headers(
headers: &mut HeaderMap,
next: &http::Uri,
previous: &[http::Uri],
) {
if let Some(previous) = previous.last() {
let cross_host = next.host() != previous.host() || next.port_u16() != previous.port_u16();
if cross_host {
headers.remove(AUTHORIZATION);
headers.remove(COOKIE);
headers.remove("cookie2");
headers.remove(PROXY_AUTHORIZATION);
headers.remove(WWW_AUTHENTICATE);
}
}
}