use crate::{Fang, FangProc, IntoResponse, Request, Response};
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct Csrf {
trusted_origins: Arc<Vec<String>>,
}
impl Default for Csrf {
fn default() -> Self {
Self::new()
}
}
impl Csrf {
pub fn new() -> Self {
Csrf {
trusted_origins: Arc::new(vec![]),
}
}
pub fn with_trusted_origins(
trusted_origins: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
let trusted_origins = trusted_origins
.into_iter()
.map(Into::<String>::into)
.collect::<Vec<_>>();
for origin in &trusted_origins {
super::validate_origin(origin)
.unwrap_or_else(|err| panic!("[Csrf::with_trusted_origins] {err}"))
}
Csrf {
trusted_origins: Arc::new(trusted_origins),
}
}
}
pub enum CsrfError {
InvalidSecFetchSite,
OriginNotMatchHost,
NoHostHeader,
}
impl IntoResponse for CsrfError {
fn into_response(self) -> Response {
match self {
CsrfError::InvalidSecFetchSite => Response::Forbidden()
.with_text("cross-origin request detected from Sec-Fetch-Site header"),
CsrfError::OriginNotMatchHost => Response::Forbidden()
.with_text("cross-origin request detected, and/or browser is out of date: Sec-Fetch-Site is missing, and Origin does not match Host"),
CsrfError::NoHostHeader => Response::BadRequest(),
}
}
}
impl Csrf {
pub fn verify(&self, req: &Request) -> Result<(), CsrfError> {
let is_trusted = || {
req.headers
.origin()
.is_some_and(|it| self.trusted_origins.iter().any(|x| x == it))
};
if req.method.is_safe() {
Ok(())
} else if let Some(sec_fetch_site) = req.headers.sec_fetch_site() {
match sec_fetch_site {
"same-origin" | "none" => Ok(()),
_ => is_trusted()
.then_some(())
.ok_or(CsrfError::InvalidSecFetchSite),
}
} else {
match (req.headers.origin(), req.headers.host()) {
(None, _) => Ok(()), (_, None) => Err(CsrfError::NoHostHeader),
(Some(origin), Some(host))
if matches!(origin.strip_suffix(host), Some("http://" | "https://")) =>
{
Ok(())
}
_ => is_trusted()
.then_some(())
.ok_or(CsrfError::OriginNotMatchHost),
}
}
}
}
const _: () = {
pub struct CsrfProc<I: FangProc> {
csrf: Csrf,
inner: I,
}
impl<I: FangProc> Fang<I> for Csrf {
type Proc = CsrfProc<I>;
fn chain(&self, inner: I) -> Self::Proc {
CsrfProc {
csrf: self.clone(),
inner,
}
}
}
impl<I: FangProc> FangProc for CsrfProc<I> {
async fn bite<'b>(&'b self, req: &'b mut Request) -> Response {
match self.csrf.verify(req) {
Ok(()) => self.inner.bite(req).await,
Err(e) => e.into_response(),
}
}
}
};
#[cfg(test)]
#[cfg(feature = "__rt_native__")]
mod tests {
use super::*;
use crate::testing::*;
use crate::{Ohkami, Route};
#[test]
fn test_csrf_with_trusted_origins_with_str_or_string() {
let _: Csrf = Csrf::with_trusted_origins(["https://example.com"]);
let _: Csrf = Csrf::with_trusted_origins([format!("https://example.com")]);
}
macro_rules! x {
($method:ident) => {
TestRequest::$method("/").header("host", "example.com")
};
}
#[test]
fn test_sec_fetch_site() {
let t = Ohkami::new((
Csrf::new(),
"/".GET(async || ()).PUT(async || ()).POST(async || ()),
))
.test();
crate::__rt__::testing::block_on(async {
for (req, expected) in [
(x!(POST).header("sec-fetch-site", "same-origin"), Status::OK),
(x!(POST).header("sec-fetch-site", "none"), Status::OK),
(
x!(POST).header("sec-fetch-site", "cross-site"),
Status::Forbidden,
),
(
x!(POST).header("sec-fetch-site", "same-site"),
Status::Forbidden,
),
(x!(POST), Status::OK),
(x!(POST).header("origin", "https://example.com"), Status::OK),
(
x!(POST).header("origin", "https://attacker.example"),
Status::Forbidden,
),
(x!(POST).header("origin", "null"), Status::Forbidden),
(x!(GET).header("sec-fetch-site", "cross-site"), Status::OK),
(x!(HEAD).header("sec-fetch-site", "cross-site"), Status::OK),
(
x!(OPTIONS).header("sec-fetch-site", "cross-site"),
Status::NotFound,
), (
x!(PUT).header("sec-fetch-site", "cross-site"),
Status::Forbidden,
),
] {
let res = t.oneshot(req).await;
assert_eq!(res.status(), expected);
}
});
}
#[test]
fn test_trusted_origins() {
let t = Ohkami::new((
Csrf::with_trusted_origins(["https://trusted.example"]),
"/".POST(async || ()),
))
.test();
crate::__rt__::testing::block_on(async {
for (req, expected) in [
(
x!(POST).header("origin", "https://trusted.example"),
Status::OK,
),
(
x!(POST)
.header("origin", "https://trusted.example")
.header("sec-fetch-site", "cross-site"),
Status::OK,
),
(
x!(POST).header("origin", "https://attacker.example"),
Status::Forbidden,
),
(
x!(POST)
.header("origin", "https://attacker.example")
.header("sec-fetch-site", "cross-site"),
Status::Forbidden,
),
] {
let res = t.oneshot(req).await;
assert_eq!(res.status(), expected);
}
});
}
#[test]
fn test_invalid_trusted_origins() {
for (trusted_origin, should_judged_as_invalid) in [
("https://example.com", false),
("https://example.com:8080", false),
("http://example.com", false),
("example.com", true), ("https://", true), ("https://example.com/", true), ("https://example.com/path", true), ("https://example.com?query=1", true), ("https://example.com#fragment", true), ("https://ex ample.com", true), ("", true), ("null", true), ("https://example.com:port", true), ] {
let is_judged_as_invalid = std::panic::catch_unwind(|| {
let _ = Csrf::with_trusted_origins([trusted_origin]);
})
.is_err();
assert_eq!(
is_judged_as_invalid, should_judged_as_invalid,
"unexpected result for trusted origin `{trusted_origin}`"
);
}
}
}