use std::collections::HashSet;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use http::{Method, Uri};
mod future;
mod layer;
mod response;
mod service;
mod url;
pub use self::future::ResponseFuture;
pub use self::layer::CsrfLayer;
pub use self::response::{DefaultResponseForProtectionError, ResponseForProtectionError};
pub use self::service::Csrf;
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum ConfigError {
InvalidOriginUrl {
origin: String,
message: String,
},
InvalidOriginUrlComponents {
origin: String,
},
OpaqueOrigin {
origin: String,
},
NonAsciiHostname {
origin: String,
},
}
impl fmt::Display for ConfigError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
ConfigError::InvalidOriginUrl { origin, message } => {
write!(f, "invalid origin {origin:?}: {message}")
}
ConfigError::InvalidOriginUrlComponents { origin } => write!(
f,
"invalid origin {origin:?}: path, query, and fragment are not allowed"
),
ConfigError::OpaqueOrigin { origin } => write!(
f,
"invalid origin {origin:?}: scheme must be http or https"
),
ConfigError::NonAsciiHostname { origin } => write!(
f,
"invalid origin {origin:?}: non-ASCII hostnames must be supplied in punycode (xn--…)"
),
}
}
}
impl std::error::Error for ConfigError {}
#[derive(Clone, Debug)]
pub struct ProtectionError {
kind: ProtectionErrorKind,
}
impl ProtectionError {
pub(crate) fn new(kind: ProtectionErrorKind) -> Self {
Self { kind }
}
pub fn kind(&self) -> ProtectionErrorKind {
self.kind
}
}
impl fmt::Display for ProtectionError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self.kind {
ProtectionErrorKind::CrossOriginRequest => f.write_str("Cross-Origin request detected"),
ProtectionErrorKind::CrossOriginRequestFromOldBrowser => {
f.write_str("Cross-Origin request from old browser detected")
}
}
}
}
impl std::error::Error for ProtectionError {}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum ProtectionErrorKind {
CrossOriginRequest,
CrossOriginRequestFromOldBrowser,
}
type BypassFn = dyn Fn(&Method, &Uri) -> bool + Send + Sync + 'static;
struct DebugFn;
impl Debug for DebugFn {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("<fn>")
}
}
#[derive(Clone, Default)]
struct Origins(Arc<HashSet<Vec<u8>>>);
impl Origins {
fn contains(&self, origin: &[u8]) -> bool {
self.0.contains(origin)
}
fn insert(&mut self, origin: impl Into<Vec<u8>>) {
Arc::make_mut(&mut self.0).insert(origin.into());
}
}
impl Debug for Origins {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Origins(")?;
f.debug_set()
.entries(self.0.iter().map(|o| String::from_utf8_lossy(o)))
.finish()?;
write!(f, ")")
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use http::{Request, Response, StatusCode};
use tower::{service_fn, ServiceExt};
use tower_layer::Layer;
use super::*;
use crate::test_helpers::{to_bytes, Body};
impl PartialEq for super::ProtectionError {
fn eq(&self, other: &Self) -> bool {
self.kind == other.kind
}
}
fn echo_service() -> impl tower::Service<
Request<Body>,
Response = Response<Body>,
Error = Infallible,
Future = impl std::future::Future<Output = Result<Response<Body>, Infallible>>,
> + Clone {
service_fn(|req: Request<Body>| async move {
let body: Body = match req.uri().path() {
"/foo" => "foo".into(),
"/bar" => "bar".into(),
_ => Body::empty(),
};
Ok::<_, Infallible>(Response::new(body))
})
}
#[tokio::test]
async fn test_service_allows_safe_method() {
let svc = CsrfLayer::new()
.add_trusted_origin("https://example.com")
.unwrap()
.layer(echo_service());
let req = Request::builder()
.method("GET")
.uri("/foo")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = to_bytes(res.into_body()).await.unwrap();
assert_eq!(&body[..], b"foo");
}
#[tokio::test]
async fn test_service_allows_post_from_trusted_origin() {
let svc = CsrfLayer::new()
.add_trusted_origin("https://example.com")
.unwrap()
.layer(echo_service());
let req = Request::builder()
.method("POST")
.uri("/bar")
.header("origin", "https://example.com")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = to_bytes(res.into_body()).await.unwrap();
assert_eq!(&body[..], b"bar");
}
#[tokio::test]
async fn test_service_rejects_post_from_untrusted_origin() {
let svc = CsrfLayer::new()
.add_trusted_origin("https://example.com")
.unwrap()
.layer(echo_service());
let req = Request::builder()
.method("POST")
.uri("/bar")
.header("origin", "https://malicious.example")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::FORBIDDEN);
assert_eq!(
res.extensions().get::<ProtectionError>(),
Some(&ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser
)),
);
}
#[tokio::test]
async fn test_service_uses_custom_rejection_response() {
let svc = CsrfLayer::new()
.with_rejection_response(|_err: ProtectionError| {
let mut res = Response::new(Body::from("denied"));
*res.status_mut() = StatusCode::IM_A_TEAPOT;
res
})
.layer(echo_service());
let req = Request::builder()
.method("POST")
.uri("/bar")
.header("origin", "https://malicious.example")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
assert_ne!(res.status(), StatusCode::OK);
assert_eq!(
res.extensions().get::<ProtectionError>(),
Some(&ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser
)),
);
let body = to_bytes(res.into_body()).await.unwrap();
assert_eq!(&body[..], b"denied");
}
#[tokio::test]
async fn test_service_custom_rejection_response_not_invoked_when_allowed() {
let svc = CsrfLayer::new()
.add_trusted_origin("https://example.com")
.unwrap()
.with_rejection_response(|_err: ProtectionError| {
let mut res = Response::new(Body::from("denied"));
*res.status_mut() = StatusCode::IM_A_TEAPOT;
res
})
.layer(echo_service());
let req = Request::builder()
.method("POST")
.uri("/bar")
.header("origin", "https://example.com")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_ne!(res.status(), StatusCode::IM_A_TEAPOT);
assert!(res.extensions().get::<ProtectionError>().is_none());
let body = to_bytes(res.into_body()).await.unwrap();
assert_eq!(&body[..], b"bar");
}
#[test]
fn test_layer_add_trusted_origin() {
assert!(CsrfLayer::new()
.add_trusted_origin("https://example.com")
.is_ok());
assert!(matches!(
CsrfLayer::new().add_trusted_origin("not a valid url"),
Err(ConfigError::InvalidOriginUrl { .. })
));
}
#[test]
fn test_middleware_bypass() {
let layer = CsrfLayer::new()
.with_insecure_bypass(|_method, uri| -> bool { uri.path() == "/bypass" });
let middleware = layer.layer(());
struct Test {
name: &'static str,
path: &'static str,
sec_fetch_site: Option<&'static str>,
result: Result<(), ProtectionError>,
}
let tests = [
Test {
name: "bypass path without sec-fetch-site",
path: "/bypass",
sec_fetch_site: None,
result: Ok(()),
},
Test {
name: "bypass path with cross-site",
path: "/bypass",
sec_fetch_site: Some("cross-site"),
result: Ok(()),
},
Test {
name: "non-bypass path without sec-fetch-site",
path: "/api",
sec_fetch_site: None,
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "non-bypass path with cross-site",
path: "/api",
sec_fetch_site: Some("cross-site"),
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequest,
)),
},
];
for test in tests {
let mut req = Request::builder()
.method("POST")
.header("host", "example.com")
.header("origin", "https://attacker.example")
.uri(format!("https://example.com{}", test.path));
if let Some(sec_fetch_site) = test.sec_fetch_site {
req = req.header("sec-fetch-site", sec_fetch_site);
}
let req = req.body(()).unwrap();
assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
}
}
#[test]
fn test_middleware_bypass_applies_when_origin_unparseable() {
let middleware = CsrfLayer::new()
.with_insecure_bypass(|_method, uri| uri.path() == "/bypass")
.layer(());
let req = Request::builder()
.method("POST")
.uri("https://example.com/bypass")
.header("host", "example.com")
.header(
"origin",
http::HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
)
.body(())
.unwrap();
assert_eq!(middleware.verify(&req), Ok(()));
}
#[test]
fn test_middleware_debug_trait() {
let layer = CsrfLayer::new();
let middleware = layer
.clone()
.with_insecure_bypass(|method, uri| method == Method::POST && uri.path() == "/bypass")
.layer(());
assert_eq!(
format!("{:?}", middleware),
"Csrf { inner: (), insecure_bypass: Some(<fn>), trusted_origins: Origins({}), rejection_response: <fn> }"
);
let middleware = layer.layer(());
assert_eq!(
format!("{:?}", middleware),
"Csrf { inner: (), insecure_bypass: None, trusted_origins: Origins({}), rejection_response: <fn> }"
);
}
#[test]
fn test_middleware_origin_host_port_match() {
let middleware: Csrf<()> = Default::default();
struct Test {
name: &'static str,
uri: &'static str,
host: Option<&'static str>,
origin: &'static str,
result: Result<(), ProtectionError>,
}
let tests = [
Test {
name: "default port both sides",
uri: "/",
host: Some("example.com"),
origin: "https://example.com",
result: Ok(()),
},
Test {
name: "same non-default port both sides",
uri: "/",
host: Some("example.com:8443"),
origin: "https://example.com:8443",
result: Ok(()),
},
Test {
name: "explicit default port both sides",
uri: "/",
host: Some("example.com:443"),
origin: "https://example.com:443",
result: Ok(()),
},
Test {
name: "mismatched non-default ports",
uri: "/",
host: Some("example.com:8443"),
origin: "https://example.com:8444",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "origin has explicit default, host implicit",
uri: "/",
host: Some("example.com"),
origin: "https://example.com:443",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "host has explicit default, origin implicit",
uri: "/",
host: Some("example.com:443"),
origin: "https://example.com",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "host implicit, origin explicit non-default",
uri: "/",
host: Some("example.com"),
origin: "https://example.com:8443",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "missing host, uri authority implicit, origin explicit non-default",
uri: "https://example.com/path",
host: None,
origin: "https://example.com:8443",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "malformed host header compared verbatim",
uri: "/path",
host: Some("not a valid authority"),
origin: "https://example.com",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "request-target authority wins over host header (match)",
uri: "https://example.com/path",
host: Some("other.example"),
origin: "https://example.com",
result: Ok(()),
},
Test {
name: "origin matching host header but not authority is rejected",
uri: "https://example.com/path",
host: Some("other.example"),
origin: "https://other.example",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "missing host, uri carries authority (match)",
uri: "https://example.com/path",
host: None,
origin: "https://example.com",
result: Ok(()),
},
Test {
name: "missing host, uri authority mismatch",
uri: "https://other.example/path",
host: None,
origin: "https://example.com",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "missing host and no uri authority",
uri: "/path",
host: None,
origin: "https://example.com",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "scheme-less origin does not match host even if bytes agree",
uri: "/",
host: Some("example.com:8443"),
origin: "example.com:8443",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "non-http origin scheme does not enter host fallback",
uri: "/",
host: Some("example.com:8443"),
origin: "ftp://example.com:8443",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
];
for test in tests {
let mut req = Request::builder().method(Method::POST).uri(test.uri);
if let Some(host) = test.host {
req = req.header("host", host);
}
let req = req.header("origin", test.origin).body(()).unwrap();
assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
}
}
#[test]
fn test_middleware_sec_fetch_site() {
let middleware: Csrf<()> = Default::default();
const NON_DECODABLE: &[u8] = &[0xFF, 0xFE];
assert!(
http::HeaderValue::from_bytes(NON_DECODABLE)
.expect("NON_DECODABLE must be a valid HeaderValue")
.to_str()
.is_err(),
"NON_DECODABLE must fail HeaderValue::to_str()"
);
struct Test {
name: &'static str,
method: http::Method,
sec_fetch_site: Option<&'static [u8]>,
origin: Option<&'static [u8]>,
result: Result<(), ProtectionError>,
}
let tests = [
Test {
name: "same-origin allowed",
method: Method::GET,
sec_fetch_site: Some(b"same-origin"),
origin: None,
result: Ok(()),
},
Test {
name: "none allowed",
method: Method::POST,
sec_fetch_site: Some(b"none"),
origin: None,
result: Ok(()),
},
Test {
name: "cross-site blocked",
method: Method::POST,
sec_fetch_site: Some(b"cross-site"),
origin: None,
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequest,
)),
},
Test {
name: "same-site blocked",
method: Method::POST,
sec_fetch_site: Some(b"same-site"),
origin: None,
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequest,
)),
},
Test {
name: "no header with no origin",
method: Method::POST,
sec_fetch_site: None,
origin: None,
result: Ok(()),
},
Test {
name: "no header with matching origin",
method: Method::POST,
sec_fetch_site: None,
origin: Some(b"https://example.com"),
result: Ok(()),
},
Test {
name: "no header with mismatched origin",
method: Method::POST,
sec_fetch_site: None,
origin: Some(b"https://attacker.example"),
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "no header with null origin",
method: Method::POST,
sec_fetch_site: None,
origin: Some(b"null"),
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "GET allowed",
method: Method::GET,
sec_fetch_site: Some(b"cross-site"),
origin: None,
result: Ok(()),
},
Test {
name: "HEAD allowed",
method: Method::HEAD,
sec_fetch_site: Some(b"cross-site"),
origin: None,
result: Ok(()),
},
Test {
name: "OPTIONS allowed",
method: Method::OPTIONS,
sec_fetch_site: Some(b"cross-site"),
origin: None,
result: Ok(()),
},
Test {
name: "PUT blocked",
method: Method::PUT,
sec_fetch_site: Some(b"cross-site"),
origin: None,
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequest,
)),
},
Test {
name: "non-decodable origin without sec-fetch-site rejected",
method: Method::POST,
sec_fetch_site: None,
origin: Some(NON_DECODABLE),
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "non-decodable sec-fetch-site without origin rejected",
method: Method::POST,
sec_fetch_site: Some(NON_DECODABLE),
origin: None,
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequest,
)),
},
Test {
name: "empty sec-fetch-site without origin allowed",
method: Method::POST,
sec_fetch_site: Some(b""),
origin: None,
result: Ok(()),
},
Test {
name: "empty origin without sec-fetch-site allowed",
method: Method::POST,
sec_fetch_site: None,
origin: Some(b""),
result: Ok(()),
},
];
for test in tests {
let mut req = Request::builder()
.method(test.method)
.header("host", "example.com");
if let Some(sec_fetch_site) = test.sec_fetch_site {
req = req.header("sec-fetch-site", sec_fetch_site);
}
if let Some(origin) = test.origin {
req = req.header("origin", origin);
}
let req = req.body(()).unwrap();
assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
}
}
#[test]
fn test_middleware_trusted_origin_bypass() {
let layer = CsrfLayer::new()
.add_trusted_origin("https://trusted.example")
.unwrap();
let middleware = layer.layer(());
struct Test {
name: &'static str,
sec_fetch_site: Option<&'static str>,
origin: Option<&'static str>,
result: Result<(), ProtectionError>,
}
let tests = [
Test {
name: "trusted origin without sec-fetch-site",
origin: Some("https://trusted.example"),
sec_fetch_site: None,
result: Ok(()),
},
Test {
name: "trusted origin with cross-site",
origin: Some("https://trusted.example"),
sec_fetch_site: Some("cross-site"),
result: Ok(()),
},
Test {
name: "untrusted origin without sec-fetch-site",
origin: Some("https://attacker.example"),
sec_fetch_site: None,
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
)),
},
Test {
name: "untrusted origin with cross-site",
origin: Some("https://attacker.example"),
sec_fetch_site: Some("cross-site"),
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequest,
)),
},
];
for test in tests {
let mut req = Request::builder()
.method("POST")
.header("host", "example.com");
if let Some(sec_fetch_site) = test.sec_fetch_site {
req = req.header("sec-fetch-site", sec_fetch_site);
}
if let Some(origin) = test.origin {
req = req.header("origin", origin);
}
let req = req.body(()).unwrap();
assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
}
}
#[test]
fn test_middleware_trusted_origin_strict_byte_match() {
struct Test {
name: &'static str,
trusted: &'static str,
origin: &'static str,
result: Result<(), ProtectionError>,
}
let tests = [
Test {
name: "exact match trusted",
trusted: "https://example.com",
origin: "https://example.com",
result: Ok(()),
},
Test {
name: "exact match with non-default port",
trusted: "https://example.com:8443",
origin: "https://example.com:8443",
result: Ok(()),
},
Test {
name: "host case mismatch not trusted",
trusted: "https://Example.COM",
origin: "https://example.com",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequest,
)),
},
Test {
name: "explicit default port not trusted against bare origin",
trusted: "https://example.com:443",
origin: "https://example.com",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequest,
)),
},
Test {
name: "bare trusted not matched by explicit-default-port origin",
trusted: "https://example.com",
origin: "https://example.com:443",
result: Err(ProtectionError::new(
ProtectionErrorKind::CrossOriginRequest,
)),
},
];
for test in tests {
let middleware = CsrfLayer::new()
.add_trusted_origin(test.trusted)
.unwrap_or_else(|e| panic!("{}: add_trusted_origin failed: {e}", test.name))
.layer(());
let req = Request::builder()
.method("POST")
.header("host", "other.example")
.header("origin", test.origin)
.header("sec-fetch-site", "cross-site")
.body(())
.unwrap();
assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
}
}
}