#![deny(missing_docs)]
use mini_moka::sync::Cache;
use rand::RngCore;
use rocket::{
form::{Form, FromForm},
get,
http::{Cookie, CookieJar, SameSite, Status},
outcome::IntoOutcome,
post,
request::{FromRequest, Outcome, Request},
response::Redirect,
routes, uri, Build, Rocket, State,
};
use rocket_dyn_templates::{context, Template};
use sha3::{Digest, Sha3_256};
extern crate self as rocket_csrf_guard;
use super::{
with_csrf_token, CheckCsrfProtectionHeader, CsrfCheckProof, CsrfProtectedForm,
DoubleSubmitCookieCsrfProtectedForm, SetDoubleSubmitCookieCsrfToken,
VerifierWithKnownExpectedToken,
};
const SESSION_COOKIE_NAME: &str = "__Host-session";
const SESSION_HEADER_NAME: &str = "Authorization";
fn random_id(len: usize) -> String {
let mut buf = vec![0; len];
rand::thread_rng()
.try_fill_bytes(&mut buf)
.expect("Couldn't generate random number");
base64::encode_config(buf, base64::URL_SAFE)
}
fn hash(id: &str) -> String {
let digest = Sha3_256::digest(id.as_bytes());
hex::encode(digest)
}
#[derive(Clone, Debug)]
pub struct Session {
session_id_hash: String,
username: String,
csrf_token: String,
}
#[async_trait::async_trait]
impl<'r> FromRequest<'r> for Session {
type Error = std::convert::Infallible;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let session_result = request
.local_cache_async(async {
let session_info: Option<String> = {
request
.headers()
.get_one(SESSION_HEADER_NAME)
.and_then(|value| value.strip_prefix("Bearer ").map(hash))
.or_else(|| {
request
.cookies()
.get_private(SESSION_COOKIE_NAME)
.map(|c| hash(c.value()))
})
};
let Some(session_id_hash) = session_info else {
return None;
};
let manager = request
.guard::<&State<SessionManager>>()
.await
.succeeded()?;
manager.fetch_session(&session_id_hash)
})
.await;
session_result
.clone()
.or_forward(Status::InternalServerError)
}
}
impl VerifierWithKnownExpectedToken for Session {
type Proof = CsrfCheckProof;
fn expected_token(&self) -> &str {
&self.csrf_token
}
}
#[derive(Debug)]
pub struct SessionManager {
sessions: Cache<String, Session>,
}
impl SessionManager {
fn new() -> Self {
let sessions = Cache::builder()
.max_capacity(16)
.initial_capacity(16)
.build();
Self { sessions }
}
fn create_session(&self, username: String, cookies: &CookieJar<'_>) -> Session {
let session_id = random_id(16);
let csrf_token = random_id(16);
let session_id_hash = hash(&session_id);
let session = Session {
session_id_hash: session_id_hash.clone(),
username,
csrf_token,
};
let session_cookie = Cookie::build((SESSION_COOKIE_NAME, session_id.clone()))
.max_age(rocket::time::Duration::days(1))
.same_site(SameSite::Strict)
.secure(true);
cookies.add_private(session_cookie);
self.sessions
.insert(session.session_id_hash.clone(), session.clone());
session
}
fn fetch_session(&self, session_id_hash: &String) -> Option<Session> {
self.sessions.get(session_id_hash)
}
fn logout(&self, _proof: CsrfCheckProof, session_id_hash: &String, cookies: &CookieJar<'_>) {
if let Some(cookie) = cookies.get(SESSION_COOKIE_NAME).cloned() {
cookies.remove_private(cookie)
}
self.sessions.invalidate(session_id_hash)
}
}
type VerifyCsrfTokenViaHeaders = CheckCsrfProtectionHeader<Session>;
type SessionCsrfProtectedForm<F> = CsrfProtectedForm<Session, F>;
#[get("/header")]
fn check_csrf_header(_csrf_check: VerifyCsrfTokenViaHeaders) -> String {
"You successfully passed the right CSRF token, congrats!".to_string()
}
#[with_csrf_token]
#[derive(Debug, FromForm)]
struct LoginForm<'r> {
name: String,
}
#[with_csrf_token]
#[derive(Debug, FromForm)]
struct LogoutForm<'r> {}
#[get("/", rank = 2)]
fn show_login_page(csrf_token: SetDoubleSubmitCookieCsrfToken) -> Template {
Template::render(
"login",
context! {
csrf_token
},
)
}
#[post("/", data = "<form>")]
fn do_login(
form: DoubleSubmitCookieCsrfProtectedForm<Form<LoginForm>>,
manager: &State<SessionManager>,
cookies: &CookieJar<'_>,
) -> Redirect {
let _ = manager.create_session(form.name.clone(), cookies);
Redirect::to(uri!(show_loggedin_page))
}
#[get("/", rank = 1)]
fn show_loggedin_page(cookies: &CookieJar<'_>, session: Session) -> Template {
let session_id = cookies
.get_private(SESSION_COOKIE_NAME)
.map(|c| c.value().to_string())
.expect("have session id");
Template::render(
"loggedin",
context! {
csrf_token: session.csrf_token,
name: session.username,
session_id
},
)
}
#[post("/logout", data = "<form>")]
fn do_logout(
session: Session,
form: SessionCsrfProtectedForm<Form<LogoutForm>>,
manager: &State<SessionManager>,
cookies: &CookieJar<'_>,
) -> Redirect {
let (proof, _) = form.into_parts();
manager.logout(proof, &session.session_id_hash, cookies);
Redirect::to(uri!(show_login_page))
}
pub fn build_rocket() -> Rocket<Build> {
rocket::build()
.mount(
"/",
routes![
check_csrf_header,
show_login_page,
show_loggedin_page,
do_login,
do_logout
],
)
.manage(SessionManager::new())
.attach(Template::fairing())
}