tower-canonical-redirect 0.1.0

A Tower middleware to enforce canonical hosts in HTTP requests
Documentation
use crate::http::HeaderError;
use crate::http::HttpExt;
use crate::http::uri::Origin;
use futures::future;
use http::HeaderName;
use http::Request;
use http::Response;
use http::StatusCode;
use http::Uri;
use http::header::LOCATION;
use http::uri::Authority;
use http::uri::Scheme;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use tower::Service;

#[derive(Clone, Debug)]
pub struct CanonicalRedirect<S>
{
	pub(crate) inner: S,
	pub(crate) canonical_origin: Origin,
	pub(crate) temporary_origins: Arc<[Origin]>,
	pub(crate) proto_headers: Arc<[HeaderName]>,
	pub(crate) host_headers: Arc<[HeaderName]>,
}

impl<S, I, O> Service<Request<I>> for CanonicalRedirect<S>
where
	S: Service<Request<I>, Response = Response<O>>,
	O: Default,
{
	type Response = S::Response;
	type Error = S::Error;
	type Future = future::Either<S::Future, future::Ready<Result<Self::Response, Self::Error>>>;

	fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>
	{
		self.inner.poll_ready(cx)
	}

	fn call(&mut self, req: Request<I>) -> Self::Future
	{
		let Some(result) = self.parse(&req).transpose()
		else
		{
			return future::Either::Left(self.inner.call(req));
		};

		let res = if let Ok(res) = result
		{
			res.map(|_| O::default())
		}
		else
		{
			Response::builder()
				.status(StatusCode::BAD_REQUEST)
				.body(O::default())
				.expect("infallible")
		};

		future::Either::Right(future::ready(Ok(res)))
	}
}

impl<S> CanonicalRedirect<S>
{
	fn parse<B>(&self, req: &Request<B>) -> Result<Option<Response<()>>, HeaderError>
	{
		let origin = Origin {
			scheme: Scheme::from_request(req, self.proto_headers.iter().cloned())?,
			authority: Authority::from_request(req, self.host_headers.iter().cloned())?,
		};
		if origin == self.canonical_origin
		{
			return Ok(None);
		}

		let mut redirect_uri_builder = Uri::builder()
			.scheme(self.canonical_origin.scheme.clone())
			.authority(self.canonical_origin.authority.clone());
		if let Some(path_and_query) = req.uri().path_and_query()
		{
			redirect_uri_builder = redirect_uri_builder.path_and_query(path_and_query.clone());
		}

		let redirect_uri = redirect_uri_builder.build().expect("infallible");
		let redirect_status = if self.temporary_origins.contains(&origin)
		{
			StatusCode::TEMPORARY_REDIRECT
		}
		else
		{
			StatusCode::PERMANENT_REDIRECT
		};
		let redirect_res = Response::builder()
			.status(redirect_status)
			.header(LOCATION, redirect_uri.to_string())
			.body(())
			.expect("infallible");

		Ok(Some(redirect_res))
	}
}

#[cfg(test)]
mod test
{
	use super::CanonicalRedirect;
	use crate::http::header::X_FORWARDED_HOST;
	use crate::http::header::X_FORWARDED_PROTO;
	use crate::http::uri::Origin;
	use http::HeaderName;
	use http::Request;
	use http::StatusCode;
	use http::header::FORWARDED;
	use http::header::HOST;
	use http::header::LOCATION;
	use test_case::test_case;
	use test_case::test_matrix;

	pub const MOCK_PROTO: HeaderName = HeaderName::from_static("mock-proto");
	pub const MOCK_HOST: HeaderName = HeaderName::from_static("mock-host");

	fn mock_service(
		uri: &'static str,
		temporary_origins: impl IntoIterator<Item = &'static str>,
	) -> CanonicalRedirect<()>
	{
		CanonicalRedirect::<()> {
			inner: (),
			canonical_origin: uri.parse().unwrap(),
			temporary_origins: temporary_origins
				.into_iter()
				.map(str::parse)
				.collect::<Result<Vec<_>, _>>()
				.unwrap()
				.into(),
			proto_headers: [MOCK_PROTO].into(),
			host_headers: [MOCK_HOST].into(),
		}
	}

	#[test]
	fn request_without_headers_errors()
	{
		let req = Request::builder().body(()).unwrap();
		let res = mock_service("https://example.com", []).parse(&req);

		assert!(res.is_err());
	}

	#[test_case(MOCK_PROTO)]
	#[test_case(X_FORWARDED_PROTO)]
	fn request_with_proto_header_only_errors(proto: HeaderName)
	{
		let req = Request::builder().header(proto, "https").body(()).unwrap();
		let res = mock_service("https://example.com", []).parse(&req);

		assert!(res.is_err());
	}

	#[test_case(MOCK_HOST)]
	#[test_case(X_FORWARDED_HOST)]
	#[test_case(HOST)]
	fn request_with_host_header_only_errors(host: HeaderName)
	{
		let req = Request::builder()
			.header(host, "example.com")
			.body(())
			.unwrap();
		let res = mock_service("https://example.com", []).parse(&req);

		assert!(res.is_err());
	}

	#[test_case("proto=https")]
	#[test_case("host=example.com")]
	fn forwarded_request_with_partial_header_only_errors(header: &'static str)
	{
		let req = Request::builder()
			.header(FORWARDED, header)
			.body(())
			.unwrap();
		let res = mock_service("https://example.com", []).parse(&req);

		assert!(res.is_err());
	}

	#[test_matrix(
		[MOCK_PROTO, X_FORWARDED_PROTO],
		[MOCK_HOST, X_FORWARDED_HOST, HOST]
	)]
	fn match_does_not_redirect(proto: HeaderName, host: HeaderName)
	{
		let req = Request::builder()
			.header(proto, "https")
			.header(host, "example.com")
			.body(())
			.unwrap();
		let res = mock_service("https://example.com", []).parse(&req);

		assert!(res.is_ok());

		let res = res.unwrap();

		assert!(res.is_none());
	}

	#[test]
	fn forwarded_match_does_not_redirect()
	{
		let req = Request::builder()
			.header(FORWARDED, "proto=https;host=example.com")
			.body(())
			.unwrap();
		let res = mock_service("https://example.com", []).parse(&req);

		assert!(res.is_ok());

		let res = res.unwrap();

		assert!(res.is_none());
	}

	#[test]
	fn uri_match_does_not_redirect()
	{
		let req = Request::builder()
			.uri("https://example.com")
			.body(())
			.unwrap();
		let res = mock_service("https://example.com", []).parse(&req);

		assert!(res.is_ok());

		let res = res.unwrap();

		assert!(res.is_none());
	}

	#[test_matrix(
		[MOCK_PROTO, X_FORWARDED_PROTO],
		[MOCK_HOST, X_FORWARDED_HOST, HOST],
		[
			("http://example.com", "https://example.com"),
			("https://example.com", "http://example.com"),
			("http://www.example.com", "http://example.com"),
			("http://example.com", "http://www.example.com"),
			("http://example.com:8000", "http://example.com"),
			("http://example.com", "http://example.com:8000"),
			("http://other.com", "https://example.com"),
			("http://example.com", "https://other.com"),
			("http://username:password@example.com", "http://example.com"),
			("http://example.com", "http://username:password@example.com"),
		]
	)]
	fn mismatch_redirects(
		proto: HeaderName,
		host: HeaderName,
		(origin, redirect): (&'static str, &'static str),
	)
	{
		let Origin { scheme, authority } = origin.parse().unwrap();
		let req = Request::builder()
			.header(proto, scheme.as_str())
			.header(host, authority.as_str())
			.body(())
			.unwrap();
		let res = mock_service(redirect, []).parse(&req);

		assert!(res.is_ok());

		let res = res.unwrap();

		assert!(res.is_some());

		let res = res.unwrap();
		let loc = res.headers().get(LOCATION).unwrap().to_str().unwrap();

		assert_eq!(loc, format!("{redirect}/"));
	}

	#[test_case("http://example.com", "https://example.com")]
	#[test_case("https://example.com", "http://example.com")]
	#[test_case("http://www.example.com", "http://example.com")]
	#[test_case("http://example.com", "http://www.example.com")]
	#[test_case("http://example.com:8000", "http://example.com")]
	#[test_case("http://example.com", "http://example.com:8000")]
	#[test_case("http://other.com", "https://example.com")]
	#[test_case("http://example.com", "https://other.com")]
	#[test_case("http://username:password@example.com", "http://example.com")]
	#[test_case("http://example.com", "http://username:password@example.com")]
	fn forwarded_mismatch_redirects(origin: &'static str, redirect: &'static str)
	{
		let Origin { scheme, authority } = origin.parse().unwrap();
		let req = Request::builder()
			.header(FORWARDED, format!("proto={scheme};host={authority}"))
			.body(())
			.unwrap();
		let res = mock_service(redirect, []).parse(&req);

		assert!(res.is_ok());

		let res = res.unwrap();

		assert!(res.is_some());

		let res = res.unwrap();
		let loc = res.headers().get(LOCATION).unwrap().to_str().unwrap();

		assert_eq!(loc, format!("{redirect}/"));
	}

	#[test_case("http://example.com", "https://example.com")]
	#[test_case("https://example.com", "http://example.com")]
	#[test_case("http://www.example.com", "http://example.com")]
	#[test_case("http://example.com", "http://www.example.com")]
	#[test_case("http://example.com:8000", "http://example.com")]
	#[test_case("http://example.com", "http://example.com:8000")]
	#[test_case("http://other.com", "https://example.com")]
	#[test_case("http://example.com", "https://other.com")]
	#[test_case("http://username:password@example.com", "http://example.com")]
	#[test_case("http://example.com", "http://username:password@example.com")]
	fn uri_mismatch_redirects(origin: &'static str, redirect: &'static str)
	{
		let req = Request::builder().uri(origin).body(()).unwrap();
		let res = mock_service(redirect, []).parse(&req);

		assert!(res.is_ok());

		let res = res.unwrap();

		assert!(res.is_some());

		let res = res.unwrap();
		let loc = res.headers().get(LOCATION).unwrap().to_str().unwrap();

		assert_eq!(loc, format!("{redirect}/"));
	}

	#[test]
	fn temporary_redirect()
	{
		let req = Request::builder()
			.uri("http://temporary.example.com")
			.body(())
			.unwrap();
		let res = mock_service("http://example.com", ["http://temporary.example.com"])
			.parse(&req)
			.unwrap()
			.unwrap();

		assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
	}

	#[test]
	fn permanent_redirect()
	{
		let req = Request::builder()
			.uri("http://permanent.example.com")
			.body(())
			.unwrap();
		let res = mock_service("http://example.com", [])
			.parse(&req)
			.unwrap()
			.unwrap();

		assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
	}

	#[test]
	fn redirect_preserves_path_and_query()
	{
		let req = Request::builder()
			.uri("http://www.example.com/path?query=1")
			.body(())
			.unwrap();
		let res = mock_service("https://example.com", [])
			.parse(&req)
			.unwrap()
			.unwrap();
		let loc = res.headers().get(LOCATION).unwrap().to_str().unwrap();

		assert_eq!(loc, "https://example.com/path?query=1");
	}
}