use std::error::Error;
use std::{fmt, mem, str};
use crypto::aead::{AeadEncryptor, AeadDecryptor};
use crypto::aes::KeySize;
use crypto::aes_gcm::AesGcm;
use crypto::scrypt::{scrypt, ScryptParams};
use ring::rand::SystemRandom;
use rustc_serialize::base64::{self, ToBase64};
use time;
#[cfg(feature = "iron")]
use typemap;
pub const CSRF_COOKIE_NAME: &'static str = "csrf";
pub const CSRF_FORM_FIELD: &'static str = "csrf-token";
pub const CSRF_HEADER: &'static str = "X-CSRF-Token";
pub const CSRF_QUERY_STRING: &'static str = "csrf-token";
#[derive(Debug)]
pub enum CsrfError {
InternalError,
ValidationFailure,
}
impl Error for CsrfError {
fn description(&self) -> &str {
match *self {
CsrfError::InternalError => "CSRF library error",
CsrfError::ValidationFailure => "CSRF validation failed",
}
}
}
impl fmt::Display for CsrfError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
pub enum CsrfConfigError {
InvalidTtl,
NoProtectedMethods,
Unspecified,
}
#[derive(Eq, PartialEq, Debug)]
pub struct CsrfToken {
bytes: Vec<u8>,
}
impl CsrfToken {
pub fn new(bytes: Vec<u8>) -> Self {
CsrfToken { bytes: bytes }
}
pub fn b64_string(&self) -> String {
self.bytes.to_base64(base64::STANDARD)
}
pub fn b64_url_string(&self) -> String {
self.bytes.to_base64(base64::URL_SAFE)
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct CsrfCookie {
bytes: Vec<u8>,
}
impl CsrfCookie {
pub fn new(bytes: Vec<u8>) -> Self {
CsrfCookie { bytes: bytes }
}
pub fn b64_string(&self) -> String {
self.bytes.to_base64(base64::STANDARD)
}
}
#[derive(Clone, Debug)]
pub struct UnencryptedCsrfToken {
token: Vec<u8>,
}
impl UnencryptedCsrfToken {
pub fn new(token: Vec<u8>) -> Self {
UnencryptedCsrfToken { token: token }
}
pub fn token(&self) -> &[u8] {
self.token.as_slice()
}
}
#[derive(Clone, Debug)]
pub struct UnencryptedCsrfCookie {
expires: i64,
token: Vec<u8>,
}
impl UnencryptedCsrfCookie {
pub fn new(expires: i64, token: Vec<u8>) -> Self {
UnencryptedCsrfCookie {
expires: expires,
token: token,
}
}
}
pub trait CsrfProtection: Send + Sync {
fn from_password(password: &[u8]) -> Self;
fn verify_token_pair(&self,
token: &UnencryptedCsrfToken,
cookie: &UnencryptedCsrfCookie)
-> bool;
fn generate_cookie(&self, token: &[u8], ttl_seconds: i64) -> Result<CsrfCookie, CsrfError>;
fn generate_token(&self, token: &[u8]) -> Result<CsrfToken, CsrfError>;
fn parse_cookie(&self, cookie: &[u8]) -> Result<UnencryptedCsrfCookie, CsrfError>;
fn parse_token(&self, token: &[u8]) -> Result<UnencryptedCsrfToken, CsrfError>;
fn rng(&self) -> &SystemRandom;
fn random_bytes(&self, buf: &mut [u8]) -> Result<(), CsrfError> {
self.rng()
.fill(buf)
.map_err(|_| {
warn!("Failed to get random bytes");
CsrfError::InternalError
})
}
fn generate_token_pair(&self,
previous_token: Option<Vec<u8>>,
ttl_seconds: i64)
-> Result<(CsrfToken, CsrfCookie), CsrfError> {
let mut token = vec![0; 64];
match previous_token {
Some(ref previous) if previous.len() == 64 => {
for i in 0..64 {
token[i] = previous[i];
}
}
_ => self.random_bytes(&mut token)?,
}
match (self.generate_token(&token), self.generate_cookie(&token, ttl_seconds)) {
(Ok(t), Ok(c)) => Ok((t, c)),
_ => Err(CsrfError::ValidationFailure),
}
}
}
pub struct AesGcmCsrfProtection {
rng: SystemRandom,
aes_key: [u8; 32],
}
impl AesGcmCsrfProtection {
pub fn from_key(aes_key: [u8; 32]) -> Self {
AesGcmCsrfProtection {
rng: SystemRandom::new(),
aes_key: aes_key,
}
}
fn aead<'a>(&self, nonce: &[u8; 12]) -> AesGcm<'a> {
AesGcm::new(KeySize::KeySize256, &self.aes_key, nonce, &[])
}
}
impl CsrfProtection for AesGcmCsrfProtection {
fn from_password(password: &[u8]) -> Self {
let params = if cfg!(test) {
ScryptParams::new(1, 8, 1)
} else {
ScryptParams::new(12, 8, 1)
};
let salt = b"rust-csrf-scrypt-salt";
let mut aes_key = [0; 32];
info!("Generating key material. This may take some time.");
scrypt(password, salt, ¶ms, &mut aes_key);
info!("Key material generated.");
AesGcmCsrfProtection::from_key(aes_key)
}
fn rng(&self) -> &SystemRandom {
&self.rng
}
fn verify_token_pair(&self,
token: &UnencryptedCsrfToken,
cookie: &UnencryptedCsrfCookie)
-> bool {
let tokens_match = token.token == cookie.token;
let not_expired = cookie.expires > time::precise_time_s() as i64;
tokens_match && not_expired
}
fn generate_cookie(&self, token: &[u8], ttl_seconds: i64) -> Result<CsrfCookie, CsrfError> {
if cfg!(test) {
assert!(token.len() == 64);
}
let expires = time::precise_time_s() as i64 + ttl_seconds;
let expires_bytes = unsafe { mem::transmute::<i64, [u8; 8]>(expires) };
let mut nonce = [0; 12];
self.random_bytes(&mut nonce)?;
let mut padding = [0; 16];
self.random_bytes(&mut padding)?;
let mut plaintext = [0; 88];
for i in 0..16 {
plaintext[i] = padding[i];
}
for i in 0..8 {
plaintext[i + 16] = expires_bytes[i];
}
for i in 0..64 {
plaintext[i + 24] = token[i];
}
let mut ciphertext = [0; 88];
let mut tag = [0; 16];
let mut aead = self.aead(&nonce);
aead.encrypt(&plaintext, &mut ciphertext, &mut tag);
let mut transport = [0; 116];
for i in 0..88 {
transport[i] = ciphertext[i];
}
for i in 0..12 {
transport[i + 88] = nonce[i];
}
for i in 0..16 {
transport[i + 100] = tag[i];
}
Ok(CsrfCookie::new(transport.to_vec()))
}
fn generate_token(&self, token: &[u8]) -> Result<CsrfToken, CsrfError> {
if cfg!(test) {
assert!(token.len() == 64);
}
let mut nonce = [0; 12];
self.random_bytes(&mut nonce)?;
let mut padding = [0; 16];
self.random_bytes(&mut padding)?;
let mut plaintext = [0; 80];
for i in 0..16 {
plaintext[i] = padding[i];
}
for i in 0..64 {
plaintext[i + 16] = token[i];
}
let mut ciphertext = [0; 80];
let mut tag = vec![0; 16];
let mut aead = self.aead(&nonce);
aead.encrypt(&plaintext, &mut ciphertext, &mut tag);
let mut transport = [0; 108];
for i in 0..80 {
transport[i] = ciphertext[i];
}
for i in 0..12 {
transport[i + 80] = nonce[i];
}
for i in 0..16 {
transport[i + 92] = tag[i];
}
Ok(CsrfToken::new(transport.to_vec()))
}
fn parse_cookie(&self, cookie: &[u8]) -> Result<UnencryptedCsrfCookie, CsrfError> {
if cookie.len() != 116 {
return Err(CsrfError::ValidationFailure);
}
let mut ciphertext = [0; 88];
let mut plaintext = [0; 88];
let mut nonce = [0; 12];
let mut tag = [0; 16];
for i in 0..88 {
ciphertext[i] = cookie[i];
}
for i in 0..12 {
nonce[i] = cookie[i + 88];
}
for i in 0..16 {
tag[i] = cookie[i + 100];
}
let mut aead = self.aead(&nonce);
if !aead.decrypt(&ciphertext, &mut plaintext, &tag) {
info!("Failed to decrypt CSRF cookie");
return Err(CsrfError::ValidationFailure);
}
let mut expires_bytes = [0; 8];
let mut token = [0; 64];
for i in 0..8 {
expires_bytes[i] = plaintext[i + 16];
}
for i in 0..64 {
token[i] = plaintext[i + 24];
}
let expires = unsafe { mem::transmute::<[u8; 8], i64>(expires_bytes) };
Ok(UnencryptedCsrfCookie::new(expires, token.to_vec()))
}
fn parse_token(&self, token: &[u8]) -> Result<UnencryptedCsrfToken, CsrfError> {
if token.len() != 108 {
return Err(CsrfError::ValidationFailure);
}
let mut ciphertext = [0; 80];
let mut plaintext = [0; 80];
let mut nonce = [0; 12];
let mut tag = [0; 16];
for i in 0..80 {
ciphertext[i] = token[i];
}
for i in 0..12 {
nonce[i] = token[i + 80];
}
for i in 0..16 {
tag[i] = token[i + 92];
}
let mut aead = self.aead(&nonce);
if !aead.decrypt(&ciphertext, &mut plaintext, &tag) {
info!("Failed to decrypt CSRF token");
return Err(CsrfError::ValidationFailure);
}
let mut token = [0; 64];
for i in 0..64 {
token[i] = plaintext[i + 16];
}
Ok(UnencryptedCsrfToken::new(token.to_vec()))
}
fn generate_token_pair(&self,
previous_token: Option<Vec<u8>>,
ttl_seconds: i64)
-> Result<(CsrfToken, CsrfCookie), CsrfError> {
let mut token = vec![0; 64];
match previous_token {
Some(ref previous) if previous.len() == 64 => {
for i in 0..64 {
token[i] = previous[i];
}
}
_ => self.random_bytes(&mut token)?,
}
match (self.generate_token(&token), self.generate_cookie(&token, ttl_seconds)) {
(Ok(t), Ok(c)) => Ok((t, c)),
_ => Err(CsrfError::ValidationFailure),
}
}
}
#[cfg(feature = "iron")]
impl typemap::Key for CsrfToken {
type Value = CsrfToken;
}
#[cfg(test)]
mod tests {
use super::*;
use rustc_serialize::base64::FromBase64;
fn verification_succeeds<P: CsrfProtection>(protect: P) {
let (token, cookie) = protect.generate_token_pair(None, 300)
.expect("couldn't generate token/cookie pair");
let token = token.b64_string().from_base64().expect("token not base64");
let token = protect.parse_token(&token).expect("token not parsed");
let cookie = cookie.b64_string().from_base64().expect("cookie not base64");
let cookie = protect.parse_cookie(&cookie).expect("cookie not parsed");
assert!(protect.verify_token_pair(&token, &cookie),
"could not verify token/cookie pair");
}
fn modified_cookie_sig_fails<P: CsrfProtection>(protect: P) {
let (_, mut cookie) = protect.generate_token_pair(None, 300)
.expect("couldn't generate token/cookie pair");
let cookie_len = cookie.bytes.len();
cookie.bytes[cookie_len - 1] ^= 0x01;
let cookie = cookie.b64_string().from_base64().expect("cookie not base64");
assert!(protect.parse_cookie(&cookie).is_err());
}
fn modified_cookie_value_fails<P: CsrfProtection>(protect: P) {
let (_, mut cookie) = protect.generate_token_pair(None, 300)
.expect("couldn't generate token/cookie pair");
cookie.bytes[0] ^= 0x01;
let cookie = cookie.b64_string().from_base64().expect("cookie not base64");
assert!(protect.parse_cookie(&cookie).is_err());
}
fn modified_token_sig_fails<P: CsrfProtection>(protect: P) {
let (mut token, _) = protect.generate_token_pair(None, 300)
.expect("couldn't generate token/token pair");
let token_len = token.bytes.len();
token.bytes[token_len - 1] ^= 0x01;
let token = token.b64_string().from_base64().expect("token not base64");
assert!(protect.parse_token(&token).is_err());
}
fn modified_token_value_fails<P: CsrfProtection>(protect: P) {
let (mut token, _) = protect.generate_token_pair(None, 300)
.expect("couldn't generate token/token pair");
token.bytes[0] ^= 0x01;
let token = token.b64_string().from_base64().expect("token not base64");
assert!(protect.parse_token(&token).is_err());
}
#[test]
fn aes_gcm_verification_succeeds() {
let password = b"hunter2";
let protect = AesGcmCsrfProtection::from_password(password);
verification_succeeds(protect);
}
#[test]
fn aes_gcm_modified_cookie_sig_fails() {
let password = b"hunter2";
let protect = AesGcmCsrfProtection::from_password(password);
modified_cookie_sig_fails(protect);
}
#[test]
fn aes_gcm_modified_cookie_value_fails() {
let password = b"hunter2";
let protect = AesGcmCsrfProtection::from_password(password);
modified_cookie_value_fails(protect);
}
#[test]
fn aes_gcm_modified_token_sig_fails() {
let password = b"hunter2";
let protect = AesGcmCsrfProtection::from_password(password);
modified_token_sig_fails(protect);
}
#[test]
fn aes_gcm_modified_token_value_fails() {
let password = b"hunter2";
let protect = AesGcmCsrfProtection::from_password(password);
modified_token_value_fails(protect);
}
}