use std::collections::HashSet;
use std::fmt::{self, Debug, Display};
use std::str;
use cookie::Cookie;
use csrf::{CsrfError, CsrfProtection, CsrfToken};
use data_encoding::{BASE64, BASE64URL};
use iron::error::HttpError;
use iron::headers::{Cookie as IronCookie, Header, HeaderFormat, SetCookie};
use iron::method;
use iron::middleware::{AroundMiddleware, Handler};
use iron::prelude::*;
use iron::status;
use time::Duration;
use urlencoded::{UrlEncodedBody, UrlEncodedQuery};
pub const CSRF_COOKIE_NAME: &'static str = "csrf";
pub const CSRF_FORM_FIELD: &'static str = "csrf-token";
pub const CSRF_HEADER: &'static str = "X-CSRF-Token";
pub const CSRF_QUERY_STRING: &'static str = "csrf-token";
fn iron_error(err: CsrfError) -> IronError {
IronError {
response: Response::with((status::Forbidden, format!("{}", err))),
error: Box::new(err),
}
}
#[derive(Clone, PartialEq)]
struct XCsrfToken(CsrfToken);
impl Header for XCsrfToken {
fn header_name() -> &'static str {
CSRF_HEADER
}
fn parse_header(raw: &[Vec<u8>]) -> Result<Self, HttpError> {
if raw.len() != 1 {
return Err(HttpError::Header);
}
match BASE64.decode(&raw[0]) {
Ok(v) => Ok(Self(CsrfToken::new(v))),
Err(_) => Err(HttpError::Header),
}
}
}
impl HeaderFormat for XCsrfToken {
fn fmt_header(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
Display::fmt(&self, f)
}
}
impl Display for XCsrfToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0.b64_string())
}
}
impl Debug for XCsrfToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Display::fmt(&self, f)
}
}
pub enum CsrfConfigError {
InvalidTtl,
NoProtectedMethods,
}
pub struct CsrfConfig {
ttl_seconds: i64,
protected_methods: HashSet<method::Method>,
secure_cookie: bool,
}
impl CsrfConfig {
pub fn build() -> CsrfConfigBuilder {
CsrfConfigBuilder {
config: CsrfConfig::default(),
}
}
}
impl Default for CsrfConfig {
fn default() -> Self {
let protected_methods: HashSet<method::Method> =
vec![method::Post, method::Put, method::Patch, method::Delete]
.iter()
.cloned()
.collect();
CsrfConfig {
ttl_seconds: 3600,
protected_methods: protected_methods,
secure_cookie: false,
}
}
}
pub struct CsrfConfigBuilder {
config: CsrfConfig,
}
impl CsrfConfigBuilder {
pub fn ttl_seconds(mut self, ttl_seconds: i64) -> Self {
self.config.ttl_seconds = ttl_seconds;
self
}
pub fn protected_methods(mut self, protected_methods: HashSet<method::Method>) -> Self {
self.config.protected_methods = protected_methods;
self
}
pub fn secure_cookie(mut self, secure_cookie: bool) -> Self {
self.config.secure_cookie = secure_cookie;
self
}
pub fn finish(self) -> Result<CsrfConfig, CsrfConfigError> {
let config = self.config;
if config.ttl_seconds < 0 {
return Err(CsrfConfigError::InvalidTtl);
}
if config.protected_methods.is_empty() {
return Err(CsrfConfigError::NoProtectedMethods);
}
Ok(config)
}
}
struct CsrfHandler<P: CsrfProtection, H: Handler> {
protect: P,
config: CsrfConfig,
handler: H,
}
impl<P: CsrfProtection, H: Handler> CsrfHandler<P, H> {
fn new(protect: P, config: CsrfConfig, handler: H) -> Self {
CsrfHandler {
protect: protect,
config: config,
handler: handler,
}
}
fn extract_csrf_cookie(&self, request: &Request) -> Option<Vec<u8>> {
request.headers.get::<IronCookie>().and_then(|raw_cookie| {
raw_cookie
.0
.iter()
.filter_map(|c| {
Cookie::parse_encoded(c.clone())
.ok()
.and_then(|cookie| match cookie.name_value() {
(CSRF_COOKIE_NAME, value) => Some(value.to_string()),
_ => None,
})
.and_then(|c| BASE64.decode(c.as_bytes()).ok())
})
.collect::<Vec<Vec<u8>>>()
.first()
.map(|c| c.clone())
})
}
fn extract_csrf_token(&self, mut request: &mut Request) -> Option<Vec<u8>> {
let f_token = self.extract_csrf_token_from_form_url_encoded(&mut request);
let q_token = self.extract_csrf_token_from_query(&mut request);
let h_token = self.extract_csrf_token_from_headers(&mut request);
debug!(
"CSRF token found in Form: {}, Query: {}, Header: {}",
f_token.is_some(),
q_token.is_some(),
h_token.is_some()
);
f_token.or(q_token).or(h_token)
}
fn extract_csrf_token_from_form_url_encoded(&self, request: &mut Request) -> Option<Vec<u8>> {
request
.get_ref::<UrlEncodedBody>()
.ok()
.and_then(|form| form.get(CSRF_FORM_FIELD))
.and_then(|vs| {
vs.iter()
.filter_map(|v| BASE64URL.decode(v.as_bytes()).ok())
.next()
})
.map(|t| t.clone())
}
fn extract_csrf_token_from_query(&self, request: &mut Request) -> Option<Vec<u8>> {
request
.get_ref::<UrlEncodedQuery>()
.ok()
.and_then(|query| query.get(CSRF_QUERY_STRING))
.and_then(|vs| {
vs.iter()
.filter_map(|v| BASE64URL.decode(v.as_bytes()).ok())
.next()
})
.map(|t| t.clone())
}
fn extract_csrf_token_from_headers(&self, request: &mut Request) -> Option<Vec<u8>> {
let token = request
.headers
.get::<XCsrfToken>()
.map(|t| t.to_string())
.and_then(|s| BASE64.decode(s.as_bytes()).ok());
let _ = request.headers.remove::<XCsrfToken>();
token
}
}
impl<P: CsrfProtection + 'static, H: Handler> Handler for CsrfHandler<P, H> {
fn handle(&self, mut request: &mut Request) -> IronResult<Response> {
let token_opt = self
.extract_csrf_token(&mut request)
.and_then(|t| self.protect.parse_token(&t).ok());
let cookie_opt = self
.extract_csrf_cookie(&request)
.and_then(|c| self.protect.parse_cookie(&c).ok());
if self.config.protected_methods.contains(&request.method) {
debug!(
"CSRF elements present. token: {}, cookie: {}",
token_opt.is_some(),
cookie_opt.is_some()
);
match (token_opt.as_ref(), cookie_opt.as_ref()) {
(Some(token), Some(cookie)) => {
let verified = self.protect.verify_token_pair(&token, &cookie);
if !verified {
return Ok(Response::with((status::Forbidden, "CSRF Error")));
}
}
_ => return Ok(Response::with((status::Forbidden, "CSRF Error"))),
}
}
let (token, csrf_cookie) = self
.protect
.generate_token_pair(
cookie_opt
.and_then(|c| {
let c = c.value();
if c.len() < 64 {
None
} else {
let mut buf = [0; 64];
buf.copy_from_slice(&c);
Some(buf)
}
})
.as_ref(),
self.config.ttl_seconds,
)
.map_err(iron_error)?;
let _ = request.extensions.insert::<CsrfToken>(token);
let mut response = self.handler.handle(&mut request)?;
let cookie = Cookie::build(CSRF_COOKIE_NAME, csrf_cookie.b64_string())
.path("/")
.http_only(true)
.secure(self.config.secure_cookie)
.max_age(Duration::seconds(self.config.ttl_seconds))
.finish();
let mut cookies = vec![format!("{}", cookie.encoded())];
{
if let Some(set_cookie) = response.headers.get::<SetCookie>() {
cookies.extend(set_cookie.0.clone())
}
}
response.headers.set(SetCookie(cookies));
Ok(response)
}
}
pub struct CsrfProtectionMiddleware<P: CsrfProtection> {
protect: P,
config: CsrfConfig,
}
impl<P: CsrfProtection> CsrfProtectionMiddleware<P> {
pub fn new(protect: P, config: CsrfConfig) -> Self {
CsrfProtectionMiddleware {
protect: protect,
config: config,
}
}
}
impl<P: CsrfProtection + 'static> AroundMiddleware for CsrfProtectionMiddleware<P> {
fn around(self, handler: Box<dyn Handler>) -> Box<dyn Handler> {
Box::new(CsrfHandler::new(self.protect, self.config, handler))
}
}
#[cfg(test)]
mod tests {
use super::*;
use csrf::{AesGcmCsrfProtection, ChaCha20Poly1305CsrfProtection, MultiCsrfProtection};
use iron::headers::Headers;
use iron_test::request as mock_request;
use iron_test::response::extract_body_to_string;
lazy_static! {
static ref BODY_METHODS: Vec<method::Method> = vec![
method::Post,
method::Put,
method::Patch,
method::Connect,
method::Extension("WAT".to_string())
];
}
lazy_static! {
static ref ALL_METHODS: Vec<method::Method> = vec![
method::Get,
method::Post,
method::Put,
method::Patch,
method::Delete,
method::Options,
method::Connect,
method::Trace,
method::Extension("WAT".to_string())
];
}
const TEST_QUERY_PARAM: &'static str = "test-param";
const TEST_QUERY_VALUE: &'static str = "test-value";
const TEST_COOKIE_NAME: &'static str = "some-cookie";
const TEST_COOKIE_VALUE: &'static str = "some-value";
const KEY_32: [u8; 32] = *b"01234567012345670123456701234567";
#[test]
fn cookies_and_tokens_can_be_verfied() {
let protect = AesGcmCsrfProtection::from_key(KEY_32);
let (token, cookie) = protect
.generate_token_pair(None, 300)
.expect("couldn't generate token/cookie pair");
let token = BASE64
.decode(token.b64_string().as_bytes())
.expect("token not base64");
let token = protect.parse_token(&token).expect("token not parsed");
let cookie = BASE64
.decode(cookie.b64_string().as_bytes())
.expect("cookie not base64");
let cookie = protect.parse_cookie(&cookie).expect("cookie not parsed");
assert!(
protect.verify_token_pair(&token, &cookie),
"could not verify token/cookie pair"
);
}
#[test]
fn config_properties() {
assert!(CsrfConfig::build().ttl_seconds(0).finish().is_ok());
assert!(CsrfConfig::build().ttl_seconds(-1).finish().is_err());
assert!(CsrfConfig::build()
.protected_methods(HashSet::new())
.finish()
.is_err())
}
fn get_middleware() -> CsrfProtectionMiddleware<AesGcmCsrfProtection> {
let protect = AesGcmCsrfProtection::from_key(KEY_32);
CsrfProtectionMiddleware::new(protect, CsrfConfig::default())
}
#[test]
fn middleware_validates_token() {
let middleware = get_middleware();
let (token, cookie) = middleware.protect.generate_token_pair(None, 300).unwrap();
let token = BASE64
.decode(token.b64_string().as_bytes())
.expect("token not base64");
let token = middleware
.protect
.parse_token(&token)
.expect("token not parsed");
let cookie = BASE64
.decode(cookie.b64_string().as_bytes())
.expect("cookie not base64");
let cookie = middleware
.protect
.parse_cookie(&cookie)
.expect("cookie not parsed");
assert!(middleware.protect.verify_token_pair(&token, &cookie));
}
#[test]
fn middleware_fails_modified_token() {
let middleware = get_middleware();
let (token, _) = middleware.protect.generate_token_pair(None, 300).unwrap();
let mut token = BASE64
.decode(token.b64_string().as_bytes())
.expect("token not base64");
token[0] = token[0] ^ 0x01;
assert!(middleware.protect.parse_token(&token).is_err());
token[0] = token[0] ^ 0x01;
token[16] = token[16] ^ 0x01;
assert!(middleware.protect.parse_token(&token).is_err());
token[16] = token[16] ^ 0x01;
let len = token.len();
token[len - 1] = token[len - 1] ^ 0x01;
assert!(middleware.protect.parse_token(&token).is_err());
}
#[test]
fn middleware_fails_modified_cookie() {
let middleware = get_middleware();
let (_, cookie) = middleware.protect.generate_token_pair(None, 300).unwrap();
let mut cookie = BASE64
.decode(cookie.b64_string().as_bytes())
.expect("cookie not base64");
cookie[0] = cookie[0] ^ 0x01;
assert!(middleware.protect.parse_cookie(&cookie).is_err());
cookie[0] = cookie[0] ^ 0x01;
cookie[16] = cookie[16] ^ 0x01;
assert!(middleware.protect.parse_cookie(&cookie).is_err());
cookie[16] = cookie[16] ^ 0x01;
let len = cookie.len();
cookie[len - 1] = cookie[len - 1] ^ 0x01;
assert!(middleware.protect.parse_cookie(&cookie).is_err());
}
#[test]
fn middleware_fails_expired_tokens() {
let middleware = get_middleware();
let (token, cookie) = middleware.protect.generate_token_pair(None, 0).unwrap();
let token = BASE64
.decode(token.b64_string().as_bytes())
.expect("token not base64");
let token = middleware
.protect
.parse_token(&token)
.expect("token not parsed");
let cookie = BASE64
.decode(cookie.b64_string().as_bytes())
.expect("cookie not base64");
let cookie = middleware
.protect
.parse_cookie(&cookie)
.expect("cookie not parsed");
assert!(!middleware.protect.verify_token_pair(&token, &cookie));
}
#[test]
fn middleware_fails_mismatched_tokens() {
let middleware = get_middleware();
let (token, _) = middleware.protect.generate_token_pair(None, 300).unwrap();
let (_, cookie) = middleware.protect.generate_token_pair(None, 300).unwrap();
let token = BASE64
.decode(token.b64_string().as_bytes())
.expect("token not base64");
let token = middleware
.protect
.parse_token(&token)
.expect("token not parsed");
let cookie = BASE64
.decode(cookie.b64_string().as_bytes())
.expect("cookie not base64");
let cookie = middleware
.protect
.parse_cookie(&cookie)
.expect("cookie not parsed");
assert!(!middleware.protect.verify_token_pair(&token, &cookie));
}
fn mock_header_handler(request: &mut Request) -> IronResult<Response> {
assert_eq!(request.headers.get::<XCsrfToken>(), None);
let token = request
.extensions
.get::<CsrfToken>()
.map(|t| t.b64_string())
.unwrap_or("<no token>".to_string());
Ok(Response::with((status::Ok, token)))
}
fn mock_handler(request: &mut Request) -> IronResult<Response> {
let token = request
.extensions
.get::<CsrfToken>()
.map(|t| t.b64_string())
.unwrap_or("<no token>".to_string());
Ok(Response::with((status::Ok, token)))
}
fn mock_query_handler(request: &mut Request) -> IronResult<Response> {
let token = request
.extensions
.get::<CsrfToken>()
.map(|t| t.b64_string())
.unwrap_or("<no token>".to_string());
if BODY_METHODS.contains(&request.method) {
let form_data = request
.get_ref::<UrlEncodedQuery>()
.expect("no url encoded query");
assert_eq!(
form_data.get(TEST_QUERY_PARAM),
Some(&vec![TEST_QUERY_VALUE.to_string()])
);
}
Ok(Response::with((status::Ok, token)))
}
fn mock_url_form_handler(request: &mut Request) -> IronResult<Response> {
let token = request
.extensions
.get::<CsrfToken>()
.map(|t| t.b64_string())
.unwrap_or("<no token>".to_string());
if BODY_METHODS.contains(&request.method) {
let form_data = request
.get_ref::<UrlEncodedBody>()
.expect("not url form encoded");
assert_eq!(
form_data.get(TEST_QUERY_PARAM),
Some(&vec![TEST_QUERY_VALUE.to_string()])
);
}
Ok(Response::with((status::Ok, token)))
}
fn mock_cookie_handler(_: &mut Request) -> IronResult<Response> {
let cookie = Cookie::new(TEST_COOKIE_NAME, TEST_COOKIE_VALUE);
let mut response = Response::with((status::Ok, ""));
response.headers.set(SetCookie(vec![format!["{}", cookie]]));
Ok(response)
}
fn get_handler_token_cookie<H: Handler>(handler: H) -> (Box<dyn Handler>, CsrfToken, String) {
let middleware = get_middleware();
let handler = middleware.around(Box::new(handler));
let response = mock_request::get("http://localhost/", Headers::new(), &handler).unwrap();
assert_eq!(response.status, Some(status::Ok));
let (csrf_token, csrf_cookie) = extract_token_cookie(response);
(handler, csrf_token, csrf_cookie)
}
fn extract_token_cookie(resp: Response) -> (CsrfToken, String) {
let headers = resp.headers.clone();
let set_cookie = headers.get::<SetCookie>().unwrap();
let cookie = Cookie::parse(set_cookie.0[0].clone()).unwrap();
(
CsrfToken::new(
BASE64
.decode(extract_body_to_string(resp).as_bytes())
.unwrap(),
),
format!("{}", cookie),
)
}
#[test]
fn methods_without_token() {
let (handler, _, csrf_cookie) = get_handler_token_cookie(mock_handler);
let path = "http://localhost/";
let mut headers = Headers::new();
headers.set(IronCookie(vec![csrf_cookie.clone()]));
let body = "";
let response = mock_request::get(path, headers.clone(), &handler).unwrap();
assert_eq!(response.status, Some(status::Ok));
let response = mock_request::head(path, headers.clone(), &handler).unwrap();
assert_eq!(response.status, Some(status::Ok));
let response = mock_request::head(path, headers.clone(), &handler).unwrap();
assert_eq!(response.status, Some(status::Ok));
let response =
mock_request::request(method::Trace, path, body, headers.clone(), &handler).unwrap();
assert_eq!(response.status, Some(status::Ok));
let response =
mock_request::request(method::Connect, path, body, headers.clone(), &handler).unwrap();
assert_eq!(response.status, Some(status::Ok));
let response = mock_request::request(
method::Extension("WAT".to_string()),
path,
body,
headers.clone(),
&handler,
)
.unwrap();
assert_eq!(response.status, Some(status::Ok));
let response = mock_request::post(path, headers.clone(), body, &handler).unwrap();
assert_eq!(response.status, Some(status::Forbidden));
let response = mock_request::put(path, headers.clone(), body, &handler).unwrap();
assert_eq!(response.status, Some(status::Forbidden));
let response = mock_request::put(path, headers.clone(), body, &handler).unwrap();
assert_eq!(response.status, Some(status::Forbidden));
let response = mock_request::patch(path, headers.clone(), body, &handler).unwrap();
assert_eq!(response.status, Some(status::Forbidden));
}
#[test]
fn methods_with_csrf_header() {
let (handler, csrf_token, csrf_cookie) = get_handler_token_cookie(mock_header_handler);
let path = "http://localhost/";
let mut headers = Headers::new();
headers.set(IronCookie(vec![csrf_cookie.clone()]));
headers.set(XCsrfToken(csrf_token));
let body = "";
for verb in ALL_METHODS.iter().cloned() {
let response =
mock_request::request(verb, path, body, headers.clone(), &handler).unwrap();
assert_eq!(response.status, Some(status::Ok));
}
}
#[test]
fn methods_with_csrf_url() {
let (handler, csrf_token, csrf_cookie) = get_handler_token_cookie(mock_query_handler);
let path = format!(
"http://localhost/?{}={}&{}={}",
CSRF_QUERY_STRING,
csrf_token.b64_url_string(),
TEST_QUERY_PARAM,
TEST_QUERY_VALUE
);
let path = path.as_str();
let mut headers = Headers::new();
headers.set(IronCookie(vec![csrf_cookie.clone()]));
let body = "";
for verb in ALL_METHODS.iter().cloned() {
let response =
mock_request::request(verb, path, body, headers.clone(), &handler).unwrap();
assert_eq!(response.status, Some(status::Ok));
}
}
#[test]
fn methods_with_csrf_url_form() {
let (handler, csrf_token, csrf_cookie) = get_handler_token_cookie(mock_url_form_handler);
let path = "http://localhost/";
let mut headers = Headers::new();
headers.set(IronCookie(vec![csrf_cookie.clone()]));
headers.set_raw(
"content-type",
vec![b"application/x-www-form-urlencoded".to_vec()],
);
let body = format!(
"{}={}&{}={}",
CSRF_QUERY_STRING,
csrf_token.b64_url_string(),
TEST_QUERY_PARAM,
TEST_QUERY_VALUE
);
let body = body.as_str();
for verb in BODY_METHODS.iter().cloned() {
let response =
mock_request::request(verb, path, body, headers.clone(), &handler).unwrap();
assert_eq!(response.status, Some(status::Ok));
}
}
#[test]
fn cookies_not_overwritten() {
let middleware = get_middleware();
let handler = middleware.around(Box::new(mock_cookie_handler));
let response = mock_request::get("http://localhost/", Headers::new(), &handler).unwrap();
let set_cookie = response
.headers
.get::<SetCookie>()
.expect("SetCookie header not set");
assert!(set_cookie.0.len() == 2);
assert!(set_cookie
.0
.iter()
.find(|c| { c.contains(TEST_COOKIE_NAME) && c.contains(TEST_COOKIE_VALUE) })
.is_some())
}
#[test]
fn multiprotect_and_rotation() {
let protect_1 = AesGcmCsrfProtection::from_key(KEY_32);
let protect_1_clone = AesGcmCsrfProtection::from_key(KEY_32);
let middle_1 = CsrfProtectionMiddleware::new(protect_1, CsrfConfig::default());
let handler_1 = middle_1.around(Box::new(mock_handler));
let protect_2 = ChaCha20Poly1305CsrfProtection::from_key(KEY_32);
let protect_2_clone = ChaCha20Poly1305CsrfProtection::from_key(KEY_32);
let middle_2 = CsrfProtectionMiddleware::new(protect_2, CsrfConfig::default());
let handler_2 = middle_2.around(Box::new(mock_handler));
let multi_protect =
MultiCsrfProtection::new(Box::new(protect_2_clone), vec![Box::new(protect_1_clone)]);
let multi_middle = CsrfProtectionMiddleware::new(multi_protect, CsrfConfig::default());
let multi_handler = multi_middle.around(Box::new(mock_handler));
let resp = mock_request::get("http://localhost/", Headers::new(), &handler_1).unwrap();
let (token, cookie) = extract_token_cookie(resp);
let path = format!(
"http://localhost/?{}={}&{}={}",
CSRF_QUERY_STRING,
token.b64_url_string(),
TEST_QUERY_PARAM,
TEST_QUERY_VALUE
);
let path = path.as_str();
let mut headers = Headers::new();
headers.set(IronCookie(vec![cookie]));
let body = "";
let resp =
mock_request::request(method::Post, path, body, headers.clone(), &handler_1).unwrap();
assert_eq!(resp.status, Some(status::Ok));
let resp = mock_request::request(method::Post, path, body, headers.clone(), &multi_handler)
.unwrap();
assert_eq!(resp.status, Some(status::Ok));
let (token, cookie) = extract_token_cookie(resp);
let path = format!(
"http://localhost/?{}={}&{}={}",
CSRF_QUERY_STRING,
token.b64_url_string(),
TEST_QUERY_PARAM,
TEST_QUERY_VALUE
);
let path = path.as_str();
let mut headers = Headers::new();
headers.set(IronCookie(vec![cookie]));
let body = "";
let resp =
mock_request::request(method::Post, path, body, headers.clone(), &handler_2).unwrap();
assert_eq!(resp.status, Some(status::Ok));
}
}