use std::borrow::Cow;
use std::collections::HashSet;
use bytes::Bytes;
use error::{ResponseError, Result};
use http::{header, HeaderMap, HttpTryFrom, Uri};
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::{Middleware, Started};
use server::Request;
#[derive(Debug, Fail)]
pub enum CsrfError {
#[fail(display = "Origin header required")]
MissingOrigin,
#[fail(display = "Could not parse Origin header")]
BadOrigin,
#[fail(display = "Cross-site request denied")]
CsrDenied,
}
impl ResponseError for CsrfError {
fn error_response(&self) -> HttpResponse {
HttpResponse::Forbidden().body(self.to_string())
}
}
fn uri_origin(uri: &Uri) -> Option<String> {
match (uri.scheme_part(), uri.host(), uri.port_part().map(|port| port.as_u16())) {
(Some(scheme), Some(host), Some(port)) => {
Some(format!("{}://{}:{}", scheme, host, port))
}
(Some(scheme), Some(host), None) => Some(format!("{}://{}", scheme, host)),
_ => None,
}
}
fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
headers
.get(header::ORIGIN)
.map(|origin| {
origin
.to_str()
.map_err(|_| CsrfError::BadOrigin)
.map(|o| o.into())
}).or_else(|| {
headers.get(header::REFERER).map(|referer| {
Uri::try_from(Bytes::from(referer.as_bytes()))
.ok()
.as_ref()
.and_then(uri_origin)
.ok_or(CsrfError::BadOrigin)
.map(|o| o.into())
})
})
}
#[derive(Default)]
pub struct CsrfFilter {
origins: HashSet<String>,
allow_xhr: bool,
allow_missing_origin: bool,
allow_upgrade: bool,
}
impl CsrfFilter {
pub fn new() -> CsrfFilter {
CsrfFilter {
origins: HashSet::new(),
allow_xhr: false,
allow_missing_origin: false,
allow_upgrade: false,
}
}
pub fn allowed_origin<T: Into<String>>(mut self, origin: T) -> CsrfFilter {
self.origins.insert(origin.into());
self
}
pub fn allow_xhr(mut self) -> CsrfFilter {
self.allow_xhr = true;
self
}
pub fn allow_missing_origin(mut self) -> CsrfFilter {
self.allow_missing_origin = true;
self
}
pub fn allow_upgrade(mut self) -> CsrfFilter {
self.allow_upgrade = true;
self
}
fn validate(&self, req: &Request) -> Result<(), CsrfError> {
let is_upgrade = req.headers().contains_key(header::UPGRADE);
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with"))
{
Ok(())
} else if let Some(header) = origin(req.headers()) {
match header {
Ok(ref origin) if self.origins.contains(origin.as_ref()) => Ok(()),
Ok(_) => Err(CsrfError::CsrDenied),
Err(err) => Err(err),
}
} else if self.allow_missing_origin {
Ok(())
} else {
Err(CsrfError::MissingOrigin)
}
}
}
impl<S> Middleware<S> for CsrfFilter {
fn start(&self, req: &HttpRequest<S>) -> Result<Started> {
self.validate(req)?;
Ok(Started::Done)
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::Method;
use test::TestRequest;
#[test]
fn test_safe() {
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::HEAD)
.finish();
assert!(csrf.start(&req).is_ok());
}
#[test]
fn test_csrf() {
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::POST)
.finish();
assert!(csrf.start(&req).is_err());
}
#[test]
fn test_referer() {
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let req = TestRequest::with_header(
"Referer",
"https://www.example.com/some/path?query=param",
).method(Method::POST)
.finish();
assert!(csrf.start(&req).is_ok());
}
#[test]
fn test_upgrade() {
let strict_csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let lax_csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com")
.allow_upgrade();
let req = TestRequest::with_header("Origin", "https://cswsh.com")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.method(Method::GET)
.finish();
assert!(strict_csrf.start(&req).is_err());
assert!(lax_csrf.start(&req).is_ok());
}
}