use crate::{
auth::Authentication,
body::AsyncBody,
config::{request::RequestConfig, RedirectPolicy},
error::{Error, ErrorKind},
handler::RequestBody,
interceptor::{Context, Interceptor, InterceptorFuture},
request::RequestExt,
};
use http::{header::ToStrError, uri::Scheme, HeaderMap, HeaderValue, Request, Response, Uri};
use std::{borrow::Cow, convert::TryFrom, fmt::Write, str};
use url::Url;
const DEFAULT_REDIRECT_LIMIT: u32 = 1024;
pub(crate) struct EffectiveUri(pub(crate) Uri);
pub(crate) struct RedirectInterceptor;
impl Interceptor for RedirectInterceptor {
type Err = Error;
fn intercept<'a>(
&'a self,
mut request: Request<AsyncBody>,
ctx: Context<'a>,
) -> InterceptorFuture<'a, Self::Err> {
Box::pin(async move {
let mut effective_uri = request.uri().clone();
let policy = request
.extensions()
.get::<RequestConfig>()
.and_then(|config| config.redirect_policy.as_ref())
.cloned()
.unwrap_or_default();
if policy == RedirectPolicy::None {
let mut response = ctx.send(request).await?;
response
.extensions_mut()
.insert(EffectiveUri(effective_uri));
return Ok(response);
}
let auto_referer = request
.extensions()
.get::<RequestConfig>()
.and_then(|config| config.auto_referer)
.unwrap_or(false);
let limit = match policy {
RedirectPolicy::Limit(limit) => limit,
_ => DEFAULT_REDIRECT_LIMIT,
};
let mut redirect_count: u32 = 0;
loop {
let mut request_builder = request.to_builder();
let mut response = ctx.send(request).await?;
if let Some(redirect_location) = get_redirect_location(&effective_uri, &response) {
if redirect_count >= limit {
return Err(Error::with_response(ErrorKind::TooManyRedirects, &response));
}
if auto_referer {
if let Some(referer) = create_referer(&effective_uri, &redirect_location) {
if let Some(headers) = request_builder.headers_mut() {
headers.insert(http::header::REFERER, referer);
}
}
}
if response.status() == 301
|| response.status() == 302
|| response.status() == 303
{
request_builder = request_builder.method(http::Method::GET);
}
if !is_same_authority(&effective_uri, &redirect_location) {
if let Some(headers) = request_builder.headers_mut() {
scrub_sensitive_headers(headers);
}
if let Some(extensions) = request_builder.extensions_mut() {
extensions.remove::<Authentication>();
}
}
let mut request_body = response
.extensions_mut()
.remove::<RequestBody>()
.map(|v| v.0)
.unwrap_or_default();
if !request_body.reset() {
return Err(Error::with_response(
ErrorKind::RequestBodyNotRewindable,
&response,
));
}
effective_uri = redirect_location.clone();
request = request_builder
.uri(redirect_location)
.body(request_body)
.map_err(|e| Error::new(ErrorKind::InvalidRequest, e))?;
redirect_count += 1;
}
else {
response
.extensions_mut()
.insert(EffectiveUri(effective_uri));
return Ok(response);
}
}
})
}
}
fn get_redirect_location<T>(request_uri: &Uri, response: &Response<T>) -> Option<Uri> {
if response.status().is_redirection() {
let location = response.headers().get(http::header::LOCATION)?;
match parse_location(location) {
Ok(location) => match resolve(request_uri, location.as_ref()) {
Ok(uri) => return Some(uri),
Err(e) => {
tracing::debug!("invalid redirect location: {}", e);
}
},
Err(e) => {
tracing::debug!("invalid redirect location: {}", e);
}
}
}
None
}
fn parse_location(location: &HeaderValue) -> Result<Cow<'_, str>, ToStrError> {
match location.to_str() {
Ok(s) => Ok(Cow::Borrowed(s)),
Err(e) => {
if str::from_utf8(location.as_bytes()).is_ok() {
let mut s = String::with_capacity(location.len());
for &byte in location.as_bytes() {
if byte.is_ascii() {
s.push(byte as char);
} else {
write!(&mut s, "%{:02x}", byte).unwrap();
}
}
Ok(Cow::Owned(s))
} else {
Err(e)
}
}
}
}
fn resolve(base: &Uri, target: &str) -> Result<Uri, Box<dyn std::error::Error>> {
match Url::parse(target) {
Ok(url) => Ok(Uri::try_from(url.as_str())?),
Err(url::ParseError::RelativeUrlWithoutBase) => {
let base = Url::parse(base.to_string().as_str())?;
Ok(Uri::try_from(base.join(target)?.as_str())?)
}
Err(e) => Err(Box::new(e)),
}
}
fn create_referer(uri: &Uri, target_uri: &Uri) -> Option<HeaderValue> {
if uri.scheme() == Some(&Scheme::HTTPS) && target_uri.scheme() != Some(&Scheme::HTTPS) {
return None;
}
let mut referer = String::new();
if let Some(scheme) = uri.scheme() {
referer.push_str(scheme.as_str());
referer.push_str("://");
}
if let Some(authority) = uri.authority() {
referer.push_str(authority.host());
if let Some(port) = authority.port() {
referer.push(':');
referer.push_str(port.as_str());
}
}
referer.push_str(uri.path());
if let Some(query) = uri.query() {
referer.push('?');
referer.push_str(query);
}
HeaderValue::try_from(referer).ok()
}
fn is_same_authority(a: &Uri, b: &Uri) -> bool {
a.scheme() == b.scheme() && a.host() == b.host() && a.port() == b.port()
}
fn scrub_sensitive_headers(headers: &mut HeaderMap) {
headers.remove(http::header::AUTHORIZATION);
headers.remove(http::header::COOKIE);
headers.remove("cookie2");
headers.remove(http::header::PROXY_AUTHORIZATION);
headers.remove(http::header::WWW_AUTHENTICATE);
}
#[cfg(test)]
mod tests {
use http::Response;
use test_case::test_case;
#[test_case("http://foo.com", "http://foo.com", "http://foo.com/")]
#[test_case("http://foo.com", "/two", "http://foo.com/two")]
#[test_case("http://foo.com", "http://foo.com#foo", "http://foo.com/")]
fn resolve_redirect_location(request_uri: &str, location: &str, resolved: &str) {
let response = Response::builder()
.status(301)
.header("Location", location)
.body(())
.unwrap();
assert_eq!(
super::get_redirect_location(&request_uri.parse().unwrap(), &response)
.unwrap()
.to_string(),
resolved
);
}
#[test_case(
"http://example.org/Overview.html",
"http://example.org/Overview.html",
Some("http://example.org/Overview.html")
)]
#[test_case(
"http://example.org/#heading",
"http://example.org/#heading",
Some("http://example.org/")
)]
#[test_case(
"http://user:pass@example.org",
"http://user:pass@example.org",
Some("http://example.org/")
)]
#[test_case("https://example.com", "http://example.org", None)]
fn create_referer_from_uri(uri: &str, target_uri: &str, referer: Option<&str>) {
assert_eq!(
super::create_referer(&uri.parse().unwrap(), &target_uri.parse().unwrap())
.as_ref()
.and_then(|value| value.to_str().ok()),
referer
);
}
#[test_case("http://example.com", "http://example.com", true)]
#[test_case("http://example.com", "http://example.com/foo", true)]
#[test_case("http://example.com", "http://user:pass@example.com", true)]
#[test_case("http://example.com", "http://example.com:9000", false)]
#[test_case("http://example.com:9000", "http://example.com:9000", true)]
#[test_case("http://example.com", "http://example.org", false)]
#[test_case("http://example.com", "https://example.com", false)]
#[test_case("http://example.com", "http://www.example.com", false)]
fn is_same_authority(a: &str, b: &str, expected: bool) {
assert_eq!(
super::is_same_authority(&a.parse().unwrap(), &b.parse().unwrap()),
expected
);
}
}