use base64::{engine::general_purpose, Engine as _};
use bcrypt::{hash, verify, BcryptError};
use rand::{distributions::Standard, Rng};
use rocket::{
async_trait, error,
fairing::{self, Fairing as RocketFairing, Info, Kind},
http::{
Cookie,
Status,
},
info,
request::{FromRequest, Outcome},
response::{Responder, Response},
time::{Duration, OffsetDateTime},
Data, Request, Rocket, State,
};
use std::{
borrow::Cow,
fmt,
};
const BCRYPT_COST: u32 = 8;
const HEADER_NAME: &str = "X-CSRF-Token";
const _PARAM_NAME: &str = "authenticity_token";
const _PARAM_META_NAME: &str = "csrf-param";
const _TOKEN_META_NAME: &str = "csrf-token";
#[derive(Debug, Clone)]
pub struct CsrfConfig {
lifespan: Option<Duration>,
cookie_name: Cow<'static, str>,
cookie_len: usize,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self {
lifespan: Some(Duration::days(1)),
cookie_name: "csrf_token".into(),
cookie_len: 32,
}
}
}
impl CsrfConfig {
pub fn with_lifetime(mut self, time: Option<Duration>) -> Self {
self.lifespan = time;
self
}
pub fn with_cookie_name(mut self, name: impl Into<Cow<'static, str>>) -> Self {
self.cookie_name = name.into();
self
}
pub fn with_cookie_len(mut self, length: usize) -> Self {
self.cookie_len = length;
self
}
}
pub struct Fairing {
config: CsrfConfig,
}
impl Default for Fairing {
fn default() -> Self {
Self::new(CsrfConfig::default())
}
}
impl Fairing {
pub fn new(config: CsrfConfig) -> Self {
Self { config }
}
}
#[derive(Clone)]
pub struct CsrfToken(String);
impl CsrfToken {
pub fn authenticity_token(&self) -> Result<String, BcryptError> {
match hash(&self.0, BCRYPT_COST) {
Ok(token) => Ok(token),
Err(err) => Err(err),
}
}
pub fn verify(&self, form_authenticity_token: &String) -> Result<(), VerificationFailure> {
if verify(&self.0, form_authenticity_token).unwrap_or(false) {
info!("CSRF token verification succeeded.");
Ok(())
} else {
Err(VerificationFailure {})
}
}
}
#[async_trait]
impl RocketFairing for Fairing {
fn info(&self) -> Info {
Info {
name: "CSRF",
kind: Kind::Ignite | Kind::Request,
}
}
async fn on_ignite(&self, rocket: Rocket<rocket::Build>) -> fairing::Result {
Ok(rocket.manage(self.config.clone()))
}
async fn on_request(&self, request: &mut Request<'_>, data: &mut Data<'_>) {
let config = match request.guard::<&State<CsrfConfig>>().await {
Outcome::Success(cfg) => cfg,
Outcome::Error(e) => {
error!("CSRF config is missing: {:?}", e);
return;
}
Outcome::Forward(_) => {
error!("Request should be forwarded");
return;
}
};
if let Some(_) = request.valid_csrf_token_from_session(&config) {
return;
}
let values: Vec<u8> = rand::thread_rng()
.sample_iter(Standard)
.take(config.cookie_len)
.collect();
let encoded = general_purpose::STANDARD.encode(&values[..]);
let expires = match config.lifespan {
Some(duration) => Some(OffsetDateTime::now_utc() + duration),
None => None, };
let cookie_builder = Cookie::build((config.cookie_name.clone(), encoded)).path("/");
let cookie_builder = match expires {
Some(expiration) => cookie_builder.expires(expiration),
None => cookie_builder.expires(None), };
let cookie = cookie_builder.build();
if request.cookies().add_private(cookie) == () {
info!("CSRF cookie added successfully.");
} else {
error!("Failed to add CSRF cookie");
}
let _ = CsrfToken("".to_string()).on_request(request, data).await;
}
}
#[async_trait]
impl<'r> FromRequest<'r> for CsrfToken {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = request.guard::<&State<CsrfConfig>>().await.unwrap();
match request.valid_csrf_token_from_session(&config) {
Some(token) => {
let encoded = general_purpose::STANDARD.encode(token);
Outcome::Success(Self(encoded))
}
None => Outcome::Error((Status::Forbidden, ())),
}
}
}
impl fmt::Display for CsrfToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
fn _ajax_csrf_meta_tags(request: &Request) -> String {
let csrf_token = request.local_cache(|| CsrfToken("".to_string()));
format!(
r#"<meta name="csrf-token" content="{}">
<meta name="csrf-param" content="{}">"#,
csrf_token, _PARAM_NAME
)
}
struct _AjaxCsrfMetaTagsResponder<'o>(Response<'o>);
#[async_trait]
impl RocketFairing for CsrfToken {
fn info(&self) -> Info {
Info {
name: "VerifyAllRequests",
kind: Kind::Request,
}
}
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
let csrf_token = request.headers().get_one(HEADER_NAME).map(String::from);
let csrf_config = request.guard::<&State<CsrfConfig>>().await;
match csrf_config {
Outcome::Success(_config) => {
if csrf_token.is_some() {
match self.verify(&csrf_token.clone().unwrap()) {
Ok(_) => {
info!("CsrfToken is successfully created");
request.local_cache(|| CsrfToken(csrf_token.unwrap()));
}
Err(err) => {
error!("{:?}", err);
}
}
} else {
error!("Request lacks X-CSRF-Token");
}
}
Outcome::Error(e) => {
error!("CSRF config is missing: {:?}", e);
}
Outcome::Forward(_) => {
error!("Request should be forwarded");
}
}
}
async fn on_response<'r>(&self, _req: &'r Request<'_>, res: &mut Response<'r>) {
if let Some(content_type) = res.content_type() {
if content_type.is_html() {
}
}
}
}
pub struct VerificationFailure;
impl fmt::Debug for VerificationFailure {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "CSRF token verification failed!")
}
}
impl<'r> Responder<'r, 'static> for VerificationFailure {
fn respond_to(self, _request: &Request) -> rocket::response::Result<'static> {
let response = Response::build().status(Status::Forbidden).finalize();
Ok(response)
}
}
trait RequestCsrf {
fn valid_csrf_token_from_session(&self, config: &CsrfConfig) -> Option<Vec<u8>> {
match self.csrf_token_from_session(config) {
Some(raw) if raw.len() >= config.cookie_len => Some(raw),
_ => None,
}
}
fn csrf_token_from_session(&self, config: &CsrfConfig) -> Option<Vec<u8>>;
}
impl RequestCsrf for Request<'_> {
fn csrf_token_from_session(&self, config: &CsrfConfig) -> Option<Vec<u8>> {
if let Some(cookie) = self.cookies().get_private(&config.cookie_name) {
if let Ok(decoded) = general_purpose::STANDARD.decode(cookie.value()) {
return Some(decoded);
}
}
None
}
}