#![doc = include_str!("../README.md")]
#![forbid(unsafe_code, future_incompatible)]
#![deny(
missing_debug_implementations,
nonstandard_style,
missing_docs,
unreachable_pub,
missing_copy_implementations,
unused_qualifications,
clippy::unwrap_in_result,
clippy::unwrap_used
)]
use std::collections::HashSet;
use std::time::Duration;
use csrf::{
AesGcmCsrfProtection, CsrfCookie, CsrfProtection, CsrfToken, UnencryptedCsrfCookie,
UnencryptedCsrfToken,
};
use data_encoding::{BASE64, BASE64URL};
use tide::{
http::{cookies::SameSite, mime},
http::{headers::HeaderName, Cookie, Method},
Body, Middleware, Next, Request, Response, StatusCode,
};
struct CsrfRequestExtData {
csrf_token: String,
csrf_header_name: HeaderName,
csrf_query_param: String,
csrf_field_name: String,
}
pub trait CsrfRequestExt {
fn csrf_token(&self) -> &str;
fn csrf_header_name(&self) -> &str;
fn csrf_query_param(&self) -> &str;
fn csrf_field_name(&self) -> &str;
}
impl<State> CsrfRequestExt for Request<State>
where
State: Send + Sync + 'static,
{
fn csrf_token(&self) -> &str {
let ext_data: &CsrfRequestExtData = self
.ext()
.expect("You must install CsrfMiddleware to access the CSRF token.");
&ext_data.csrf_token
}
fn csrf_header_name(&self) -> &str {
let ext_data: &CsrfRequestExtData = self
.ext()
.expect("You must install CsrfMiddleware to access the CSRF token.");
ext_data.csrf_header_name.as_str()
}
fn csrf_query_param(&self) -> &str {
let ext_data: &CsrfRequestExtData = self
.ext()
.expect("You must install CsrfMiddleware to access the CSRF token.");
ext_data.csrf_query_param.as_str()
}
fn csrf_field_name(&self) -> &str {
let ext_data: &CsrfRequestExtData = self
.ext()
.expect("You must install CsrfMiddleware to access the CSRF token.");
ext_data.csrf_field_name.as_str()
}
}
pub struct CsrfMiddleware {
cookie_path: String,
cookie_name: String,
cookie_domain: Option<String>,
ttl: Duration,
header_name: HeaderName,
query_param: String,
form_field: String,
protected_methods: HashSet<Method>,
protect: AesGcmCsrfProtection,
}
impl std::fmt::Debug for CsrfMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CsrfMiddleware")
.field("cookie_path", &self.cookie_path)
.field("cookie_name", &self.cookie_name)
.field("cookie_domain", &self.cookie_domain)
.field("ttl", &self.ttl)
.field("header_name", &self.header_name)
.field("query_param", &self.query_param)
.field("form_field", &self.form_field)
.field("protected_methods", &self.protected_methods)
.finish()
}
}
impl CsrfMiddleware {
pub fn new(secret: &[u8]) -> Self {
let mut key = [0u8; 32];
derive_key(secret, &mut key);
Self {
cookie_path: "/".into(),
cookie_name: "tide.csrf".into(),
cookie_domain: None,
ttl: Duration::from_secs(24 * 60 * 60),
header_name: "X-CSRF-Token".into(),
query_param: "csrf-token".into(),
form_field: "csrf-token".into(),
protected_methods: vec![Method::Post, Method::Put, Method::Patch, Method::Delete]
.iter()
.cloned()
.collect(),
protect: AesGcmCsrfProtection::from_key(key),
}
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn with_header_name(mut self, header_name: impl AsRef<str>) -> Self {
self.header_name = header_name.as_ref().into();
self
}
pub fn with_query_param(mut self, query_param: impl AsRef<str>) -> Self {
self.query_param = query_param.as_ref().into();
self
}
pub fn with_form_field(mut self, form_field: impl AsRef<str>) -> Self {
self.form_field = form_field.as_ref().into();
self
}
pub fn with_protected_methods(mut self, methods: &[Method]) -> Self {
self.protected_methods = methods.iter().cloned().collect();
self
}
fn build_cookie(&self, secure: bool, cookie_value: String) -> Cookie<'static> {
let mut cookie = Cookie::build(self.cookie_name.clone(), cookie_value)
.http_only(true)
.same_site(SameSite::Strict)
.path(self.cookie_path.clone())
.secure(secure)
.expires((std::time::SystemTime::now() + self.ttl).into())
.finish();
if let Some(cookie_domain) = self.cookie_domain.clone() {
cookie.set_domain(cookie_domain);
}
cookie
}
fn generate_token(
&self,
existing_cookie: Option<&UnencryptedCsrfCookie>,
) -> (CsrfToken, CsrfCookie) {
let existing_cookie_bytes = existing_cookie.and_then(|c| {
let c = c.value();
if c.len() < 64 {
None
} else {
let mut buf = [0; 64];
buf.copy_from_slice(c);
Some(buf)
}
});
self.protect
.generate_token_pair(existing_cookie_bytes.as_ref(), self.ttl.as_secs() as i64)
.expect("couldn't generate token/cookie pair")
}
fn find_csrf_cookie<State>(&self, req: &Request<State>) -> Option<UnencryptedCsrfCookie>
where
State: Clone + Send + Sync + 'static,
{
req.cookie(&self.cookie_name)
.and_then(|c| BASE64.decode(c.value().as_bytes()).ok())
.and_then(|b| self.protect.parse_cookie(&b).ok())
}
async fn find_csrf_token<State>(
&self,
req: &mut Request<State>,
) -> Result<Option<UnencryptedCsrfToken>, tide::Error>
where
State: Clone + Send + Sync + 'static,
{
let csrf_token = if let Some(csrf_token) = self.find_csrf_token_in_header(req) {
csrf_token
} else if let Some(csrf_token) = self.find_csrf_token_in_query(req) {
csrf_token
} else if let Some(csrf_token) = self.find_csrf_token_in_form(req).await? {
csrf_token
} else {
return Ok(None);
};
Ok(Some(self.protect.parse_token(&csrf_token).map_err(
|err| tide::Error::new(StatusCode::Forbidden, err),
)?))
}
fn find_csrf_token_in_header<State>(&self, req: &Request<State>) -> Option<Vec<u8>>
where
State: Clone + Send + Sync + 'static,
{
req.header(&self.header_name).and_then(|vs| {
vs.iter()
.find_map(|v| BASE64URL.decode(v.as_str().as_bytes()).ok())
})
}
fn find_csrf_token_in_query<State>(&self, req: &Request<State>) -> Option<Vec<u8>>
where
State: Clone + Send + Sync + 'static,
{
req.url().query_pairs().find_map(|(key, value)| {
if key == self.query_param {
BASE64URL.decode(value.as_bytes()).ok()
} else {
None
}
})
}
async fn find_csrf_token_in_form<State>(
&self,
req: &mut Request<State>,
) -> Result<Option<Vec<u8>>, tide::Error>
where
State: Clone + Send + Sync + 'static,
{
if req.content_type() != Some(mime::FORM) {
return Ok(None);
}
let body = req.take_body().into_bytes().await?;
let csrf_token = serde_urlencoded::from_bytes::<Vec<(String, String)>>(&body)
.unwrap_or_default()
.into_iter()
.find_map(|(key, value)| {
if key == self.form_field {
BASE64URL.decode(value.as_bytes()).ok()
} else {
None
}
});
req.set_body(Body::from_bytes(body));
Ok(csrf_token)
}
}
#[tide::utils::async_trait]
impl<State> Middleware<State> for CsrfMiddleware
where
State: Clone + Send + Sync + 'static,
{
async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> tide::Result {
let existing_cookie = self.find_csrf_cookie(&req);
if self.protected_methods.contains(&req.method()) {
if let Some(cookie) = &existing_cookie {
if let Some(token) = self.find_csrf_token(&mut req).await? {
if self.protect.verify_token_pair(&token, cookie) {
tide::log::debug!("Verified CSRF token.");
} else {
tide::log::debug!(
"Rejecting request due to invalid or expired CSRF token."
);
return Ok(Response::new(StatusCode::Forbidden));
}
} else {
tide::log::debug!("Rejecting request due to missing CSRF token.",);
return Ok(Response::new(StatusCode::Forbidden));
}
} else {
tide::log::debug!("Rejecting request due to missing CSRF cookie.",);
return Ok(Response::new(StatusCode::Forbidden));
}
}
let (token, cookie) = self.generate_token(existing_cookie.as_ref());
let secure_cookie = req.url().scheme() == "https";
req.set_ext(CsrfRequestExtData {
csrf_token: token.b64_url_string(),
csrf_header_name: self.header_name.clone(),
csrf_query_param: self.query_param.clone(),
csrf_field_name: self.form_field.clone(),
});
let mut res = next.run(req).await;
let cookie = self.build_cookie(secure_cookie, cookie.b64_string());
res.insert_cookie(cookie);
Ok(res)
}
}
fn derive_key(secret: &[u8], key: &mut [u8; 32]) {
let hk = hkdf::Hkdf::<sha2::Sha256>::new(None, secret);
hk.expand(&[0u8; 0], key)
.expect("Sha256 should be able to produce a 32 byte key.");
}
#[cfg(test)]
mod tests {
use super::*;
use tide::{
http::headers::{COOKIE, SET_COOKIE},
Request,
};
use tide_testing::{surf::Response, TideTestingExt};
const SECRET: [u8; 32] = *b"secrets must be >= 32 bytes long";
#[async_std::test]
async fn middleware_exposes_csrf_request_extensions() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/").get(|req: Request<()>| async move {
assert_ne!(req.csrf_token(), "");
assert_eq!(req.csrf_header_name(), "x-csrf-token");
Ok("")
});
let res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
Ok(())
}
#[async_std::test]
async fn middleware_adds_csrf_cookie_sets_request_token() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) });
let mut res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let csrf_token = res.body_string().await?;
assert_ne!(csrf_token, "");
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
Ok(())
}
#[async_std::test]
async fn middleware_validates_token_in_header() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
.post(|_| async { Ok("POST") });
let mut res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let csrf_token = res.body_string().await?;
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
let res = app.post("/").await?;
assert_eq!(res.status(), StatusCode::Forbidden);
let mut res = app
.post("/")
.header(COOKIE, cookie.to_string())
.header("X-CSRF-Token", csrf_token)
.await?;
assert_eq!(res.status(), StatusCode::Ok);
assert_eq!(res.body_string().await?, "POST");
Ok(())
}
#[async_std::test]
async fn middleware_validates_token_in_alternate_header() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET).with_header_name("X-MyCSRF-Header"));
app.at("/")
.get(|req: Request<()>| async move {
assert_eq!(req.csrf_header_name(), "x-mycsrf-header");
Ok(req.csrf_token().to_string())
})
.post(|_| async { Ok("POST") });
let mut res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let csrf_token = res.body_string().await?;
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
let mut res = app
.post("/")
.header(COOKIE, cookie.to_string())
.header("X-MyCSRF-Header", csrf_token)
.await?;
assert_eq!(res.status(), StatusCode::Ok);
assert_eq!(res.body_string().await?, "POST");
Ok(())
}
#[async_std::test]
async fn middleware_validates_token_in_alternate_query() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET).with_query_param("my-csrf-token"));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
.post(|_| async { Ok("POST") });
let mut res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let csrf_token = res.body_string().await?;
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
let res = app.post("/").await?;
assert_eq!(res.status(), StatusCode::Forbidden);
let mut res = app
.post(format!("/?a=1&my-csrf-token={}&b=2", csrf_token))
.header(COOKIE, cookie.to_string())
.await?;
assert_eq!(res.status(), StatusCode::Ok);
assert_eq!(res.body_string().await?, "POST");
Ok(())
}
#[async_std::test]
async fn middleware_validates_token_in_query() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
.post(|_| async { Ok("POST") });
let mut res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let csrf_token = res.body_string().await?;
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
let res = app.post("/").await?;
assert_eq!(res.status(), StatusCode::Forbidden);
let mut res = app
.post(format!("/?a=1&csrf-token={}&b=2", csrf_token))
.header(COOKIE, cookie.to_string())
.await?;
assert_eq!(res.status(), StatusCode::Ok);
assert_eq!(res.body_string().await?, "POST");
Ok(())
}
#[async_std::test]
async fn middleware_validates_token_in_form() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
.post(|mut req: Request<()>| async move {
#[derive(serde::Deserialize)]
struct Form {
a: String,
b: i32,
}
let form: Form = req.body_form().await?;
assert_eq!(form.a, "1");
assert_eq!(form.b, 2);
Ok("POST")
});
let mut res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let csrf_token = res.body_string().await?;
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
let res = app.post("/").await?;
assert_eq!(res.status(), StatusCode::Forbidden);
let mut res = app
.post("/")
.header(COOKIE, cookie.to_string())
.content_type("application/x-www-form-urlencoded")
.body(format!("a=1&csrf-token={}&b=2", csrf_token))
.await?;
assert_eq!(res.status(), StatusCode::Ok);
assert_eq!(res.body_string().await?, "POST");
Ok(())
}
#[async_std::test]
async fn middleware_ignores_non_form_bodies() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
.post(|_| async { Ok("POST") });
let mut res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let csrf_token = res.body_string().await?;
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
let res = app.post("/").await?;
assert_eq!(res.status(), StatusCode::Forbidden);
let res = app
.post("/")
.header(COOKIE, cookie.to_string())
.content_type("text/html")
.body(format!("a=1&csrf-token={}&b=2", csrf_token))
.await?;
assert_eq!(res.status(), StatusCode::Forbidden);
Ok(())
}
#[async_std::test]
async fn middleware_allows_different_generation_cookies_and_tokens() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
.post(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) });
let mut res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let csrf_token = res.body_string().await?;
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
let res = app.post("/").await?;
assert_eq!(res.status(), StatusCode::Forbidden);
let mut res = app
.post("/")
.header(COOKIE, cookie.to_string())
.header("X-CSRF-Token", &csrf_token)
.await?;
assert_eq!(res.status(), StatusCode::Ok);
let new_csrf_token = res.body_string().await?;
assert_ne!(new_csrf_token, csrf_token);
let new_cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(new_cookie.name(), "tide.csrf");
assert_ne!(new_cookie.to_string(), cookie.to_string());
let res = app
.post("/")
.header(COOKIE, new_cookie.to_string())
.header("X-CSRF-Token", csrf_token)
.await?;
assert_eq!(res.status(), StatusCode::Ok);
let res = app
.post("/")
.header(COOKIE, cookie.to_string())
.header("X-CSRF-Token", new_csrf_token)
.await?;
assert_eq!(res.status(), StatusCode::Ok);
Ok(())
}
#[async_std::test]
async fn middleware_rejects_short_token() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
.post(|_| async { Ok("POST") });
let res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
let res = app.post("/").await?;
assert_eq!(res.status(), StatusCode::Forbidden);
let res = app
.post("/")
.header(COOKIE, cookie.to_string())
.header("X-CSRF-Token", "aGVsbG8=")
.await?;
assert_eq!(res.status(), StatusCode::Forbidden);
Ok(())
}
#[async_std::test]
async fn middleware_rejects_invalid_base64_token() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
.post(|_| async { Ok("POST") });
let res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
let res = app.post("/").await?;
assert_eq!(res.status(), StatusCode::Forbidden);
let res = app
.post("/")
.header(COOKIE, cookie.to_string())
.header("X-CSRF-Token", "aGVsbG8")
.await?;
assert_eq!(res.status(), StatusCode::Forbidden);
Ok(())
}
#[async_std::test]
async fn middleware_rejects_mismatched_token() -> tide::Result<()> {
let mut app = tide::new();
app.with(CsrfMiddleware::new(&SECRET));
app.at("/")
.get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
.post(|_| async { Ok("POST") });
let mut res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let csrf_token = res.body_string().await?;
let res = app.get("/").await?;
assert_eq!(res.status(), StatusCode::Ok);
let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
assert_eq!(cookie.name(), "tide.csrf");
let res = app.post("/").await?;
assert_eq!(res.status(), StatusCode::Forbidden);
let res = app
.post("/")
.header(COOKIE, cookie.to_string())
.header("X-CSRF-Token", csrf_token)
.await?;
assert_eq!(res.status(), StatusCode::Forbidden);
Ok(())
}
fn get_csrf_cookie(res: &Response) -> Option<Cookie> {
if let Some(values) = res.header(SET_COOKIE) {
if let Some(value) = values.get(0) {
Cookie::parse(value.to_string()).ok()
} else {
None
}
} else {
None
}
}
}