use std::sync::Arc;
use futures::future::{self, Either, Ready};
use http::StatusCode;
use policy::Policy;
use tower::{Layer, Service};
pub use authorizer::*;
pub use policy::PolicyBuilder;
pub use reporter::*;
mod authorizer;
pub mod header;
mod policy;
mod reporter;
pub struct SecFetchLayer<A = NoopAuthorizer, R = NoopReporter> {
enforce: bool,
policy: Policy,
authorizer: Arc<A>,
reporter: Arc<R>,
}
impl<A, R> Clone for SecFetchLayer<A, R> {
fn clone(&self) -> Self {
Self {
enforce: self.enforce,
policy: self.policy,
authorizer: self.authorizer.clone(),
reporter: self.reporter.clone(),
}
}
}
impl Default for SecFetchLayer {
fn default() -> Self {
Self {
enforce: true,
policy: Policy::default(),
authorizer: Arc::new(NoopAuthorizer),
reporter: Arc::new(NoopReporter),
}
}
}
impl SecFetchLayer {
pub fn new<F>(make_policy: F) -> Self
where
F: FnOnce(&mut PolicyBuilder),
{
let mut builder = PolicyBuilder::new();
make_policy(&mut builder);
let policy = builder.build();
Self {
policy,
..Default::default()
}
}
}
impl<OldA, OldR> SecFetchLayer<OldA, OldR> {
pub fn allowing(
self,
paths: impl Into<Arc<[&'static str]>>,
) -> SecFetchLayer<PathAuthorizer, OldR> {
self.with_authorizer(PathAuthorizer::new(paths))
}
pub fn no_enforce(mut self) -> Self {
self.enforce = false;
self
}
pub fn with_authorizer<A: SecFetchAuthorizer>(self, authorizer: A) -> SecFetchLayer<A, OldR> {
SecFetchLayer {
enforce: self.enforce,
policy: self.policy,
authorizer: Arc::from(authorizer),
reporter: self.reporter,
}
}
pub fn with_reporter<R: SecFetchReporter>(self, reporter: R) -> SecFetchLayer<OldA, R> {
SecFetchLayer {
enforce: self.enforce,
policy: self.policy,
authorizer: self.authorizer,
reporter: Arc::from(reporter),
}
}
}
impl<A, R, S> Layer<S> for SecFetchLayer<A, R> {
type Service = SecFetch<A, R, S>;
fn layer(&self, inner: S) -> Self::Service {
SecFetch {
enforce: self.enforce,
policy: self.policy,
authorizer: self.authorizer.clone(),
reporter: self.reporter.clone(),
inner,
}
}
}
pub struct SecFetch<A, R, S> {
enforce: bool,
policy: Policy,
authorizer: Arc<A>,
reporter: Arc<R>,
inner: S,
}
impl<A, R, S> Clone for SecFetch<A, R, S>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
enforce: self.enforce,
policy: self.policy,
authorizer: self.authorizer.clone(),
reporter: self.reporter.clone(),
inner: self.inner.clone(),
}
}
}
impl<A, R, ReqB, ResB, S> Service<http::Request<ReqB>> for SecFetch<A, R, S>
where
A: SecFetchAuthorizer,
R: SecFetchReporter,
S: Service<http::Request<ReqB>, Response = http::Response<ResB>>,
ResB: Default,
{
type Response = S::Response;
type Error = S::Error;
type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;
#[inline]
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: http::Request<ReqB>) -> Self::Future {
#[cfg(feature = "tracing")]
tracing::debug!(
method = %request.method(),
path = request.uri().path(),
"processing request",
);
let mut allow = |request: http::Request<ReqB>| {
#[cfg(feature = "tracing")]
tracing::debug!(
method = %request.method(),
path = request.uri().path(),
"request allowed",
);
Either::Left(self.inner.call(request))
};
let deny = || {
#[cfg(feature = "tracing")]
tracing::debug!(
method = %request.method(),
path = request.uri().path(),
"request denied",
);
Either::Right(future::ready(Ok(http::Response::builder()
.status(StatusCode::FORBIDDEN)
.body(ResB::default())
.expect("valid response"))))
};
match self.authorizer.authorize(&request) {
AuthorizationDecision::Allowed => return allow(request),
AuthorizationDecision::Denied => return deny(),
AuthorizationDecision::Continue => {}
}
if self.policy.allow(&request) {
return allow(request);
}
self.reporter.on_request_denied(&request);
if !self.enforce {
return allow(request);
}
deny()
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use assert2::{check, let_assert};
use http::Method;
use tower::ServiceExt;
use tower_test::mock;
use super::*;
macro_rules! request {
(site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
request!(::http::Method::GET, "/", site => $site, mode => $mode, dest => $dest)
};
($path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
request!(::http::Method::GET, $path, site => $site, mode => $mode, dest => $dest)
};
($method:expr, $path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
::http::Request::builder()
.method($method)
.uri(format!("https://example.com{}", $path))
.header(header::SEC_FETCH_SITE, $site)
.header(header::SEC_FETCH_MODE, $mode)
.header(header::SEC_FETCH_DEST, $dest)
.body(())
.unwrap()
};
}
macro_rules! assert_request {
($req:expr, $assert_resp:expr) => {
assert_request!($req, $assert_resp, SecFetchLayer::default())
};
($req:expr, $assert_resp:expr, $layer:expr) => {
let (service, mut handler) =
mock::spawn_layer::<http::Request<()>, http::Response<()>, _>($layer);
tokio::spawn(async move {
let_assert!(Some((_, send)) = handler.next_request().await);
send.send_response(http::Response::new(()));
});
let response = service.into_inner().oneshot($req).await.unwrap();
($assert_resp)(response);
};
}
#[tokio::test]
async fn it_allows_requests_missing_the_fetch_metadata() {
let request = http::Request::new(());
assert_request!(request, |response: http::Response<()>| {
check!(response.status().is_success());
});
}
#[tokio::test]
async fn it_rejects_requests_missing_the_fetch_metadata_if_configured() {
let layer = SecFetchLayer::new(|policy| {
policy.reject_missing_metadata();
});
let request = http::Request::new(());
assert_request!(
request,
|response: http::Response<()>| {
check!(response.status() == StatusCode::FORBIDDEN);
},
layer
);
}
#[tokio::test]
async fn it_allows_same_site_requests() {
let request = request!(site => "same-site", mode => "navigate", dest => "document");
assert_request!(request, |response: http::Response<()>| {
check!(response.status().is_success());
});
}
#[tokio::test]
async fn it_rejects_cross_origin_requests() {
let request = request!(site => "cross-site", mode => "cors", dest => "empty");
assert_request!(request, |response: http::Response<()>| {
check!(response.status() == StatusCode::FORBIDDEN);
});
}
#[tokio::test]
async fn it_allows_cross_origin_requests_safe_methods_if_configured() {
let layer = SecFetchLayer::new(|policy| {
policy.allow_safe_methods();
});
let request =
request!(Method::GET, "/", site => "cross-site", mode => "cors", dest => "empty");
assert_request!(
request,
|response: http::Response<()>| {
check!(response.status().is_success());
},
layer
);
}
#[tokio::test]
async fn it_allows_navigation_requests() {
let request = request!(site => "cross-site", mode => "navigate", dest => "document");
assert_request!(request, |response: http::Response<()>| {
check!(response.status().is_success());
});
}
#[tokio::test]
async fn it_rejects_navigation_requests_resulting_from_embedding() {
let request = request!(site => "cross-site", mode => "navigate", dest => "iframe");
assert_request!(request, |response: http::Response<()>| {
check!(response.status() == StatusCode::FORBIDDEN);
});
}
#[tokio::test]
async fn it_ignores_explicitely_authorized_requests() {
let layer = SecFetchLayer::default().allowing(["/allowed"]);
let request = request!("/allowed", site => "cross-site", mode => "cors", dest => "empty");
assert_request!(
request,
|response: http::Response<()>| {
check!(response.status().is_success());
},
layer
);
}
#[tokio::test]
async fn it_allows_denied_requests_if_enforcement_is_turned_off() {
let layer = SecFetchLayer::default().no_enforce();
let request = request!(site => "cross-site", mode => "cors", dest => "empty");
assert_request!(
request,
|response: http::Response<()>| {
check!(response.status().is_success());
},
layer
);
}
#[derive(Default)]
struct TestReporter {
called: AtomicBool,
}
impl SecFetchReporter for TestReporter {
fn on_request_denied<B>(&self, _: &http::Request<B>) {
self.called.store(true, Ordering::SeqCst);
}
}
#[tokio::test]
async fn it_reports_a_denied_requests() {
let reporter = Arc::new(TestReporter::default());
let layer = SecFetchLayer::default().with_reporter(reporter.clone());
let request = request!(site => "cross-site", mode => "cors", dest => "empty");
assert_request!(
request,
|response: http::Response<()>| {
check!(response.status() == StatusCode::FORBIDDEN);
},
layer
);
let called = reporter.called.load(Ordering::SeqCst);
check!(
called,
"reporter was not called despite the request being rejected"
);
}
}