use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::marker::PhantomData;
use std::ops::Add;
use std::time::{Duration, Instant};
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard};
use rand::{rngs::OsRng, Rng, TryRngCore};
use rocket::{
fairing::{self, Fairing, Info},
http::{Cookie, Status},
outcome::Outcome,
request::FromRequest,
Build, Request, Response, Rocket, State,
};
#[derive(Debug)]
pub struct SessionStore<D>
where
D: 'static + Sync + Send + Default,
{
inner: RwLock<StoreInner<D>>,
config: SessionConfig,
}
#[derive(Debug, Clone)]
struct SessionConfig {
lifespan: Duration,
cookie_name: Cow<'static, str>,
cookie_path: Cow<'static, str>,
cookie_len: usize,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
lifespan: Duration::from_secs(3600),
cookie_name: "rocket_session".into(),
cookie_path: "/".into(),
cookie_len: 16,
}
}
}
#[derive(Debug)]
struct StoreInner<D>
where
D: 'static + Sync + Send + Default,
{
sessions: HashMap<String, Mutex<SessionInstance<D>>>,
last_expiry_sweep: Instant,
}
impl<D> Default for StoreInner<D>
where
D: 'static + Sync + Send + Default,
{
fn default() -> Self {
Self {
sessions: Default::default(),
last_expiry_sweep: Instant::now(),
}
}
}
#[derive(Debug)]
struct SessionInstance<D>
where
D: 'static + Sync + Send + Default,
{
data: D,
expires: Instant,
}
#[derive(Clone, Debug)]
struct SessionID(String);
impl SessionID {
fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl Display for SessionID {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug)]
pub struct Session<'a, D>
where
D: 'static + Sync + Send + Default,
{
store: &'a State<SessionStore<D>>,
id: &'a SessionID,
}
#[rocket::async_trait]
impl<'r, D> FromRequest<'r> for Session<'r, D>
where
D: 'static + Sync + Send + Default,
{
type Error = ();
async fn from_request(
request: &'r Request<'_>,
) -> Outcome<Self, (Status, Self::Error), Status> {
let store = request.guard::<&State<SessionStore<D>>>().await.unwrap();
Outcome::Success(Session {
id: request.local_cache(|| {
let store_ug = store.inner.upgradable_read();
let id = request
.cookies()
.get(&store.config.cookie_name)
.map(|cookie| SessionID(cookie.value().to_string()));
let expires = Instant::now().add(store.config.lifespan);
if let Some(m) = id
.as_ref()
.and_then(|token| store_ug.sessions.get(token.as_str()))
{
let mut inner = m.lock();
if inner.expires <= Instant::now() {
inner.data = D::default();
}
inner.expires = expires;
id.unwrap()
} else {
let mut store_wg = RwLockUpgradableReadGuard::upgrade(store_ug);
if store_wg.last_expiry_sweep.elapsed() > store.config.lifespan {
let now = Instant::now();
store_wg.sessions.retain(|_k, v| v.lock().expires > now);
store_wg.last_expiry_sweep = now;
}
let new_id = SessionID(loop {
let token: String = OsRng
.unwrap_err()
.sample_iter(&rand::distr::Alphanumeric)
.take(store.config.cookie_len)
.map(char::from)
.collect();
if !store_wg.sessions.contains_key(&token) {
break token;
}
});
store_wg.sessions.insert(
new_id.to_string(),
Mutex::new(SessionInstance {
data: Default::default(),
expires,
}),
);
new_id
}
}),
store,
})
}
}
impl<'a, D> Session<'a, D>
where
D: 'static + Sync + Send + Default,
{
pub fn fairing() -> SessionFairing<D> {
SessionFairing::<D>::new()
}
pub fn clear(&self) {
self.tap(|m| {
*m = D::default();
})
}
pub fn tap<T>(&self, func: impl FnOnce(&mut D) -> T) -> T {
let store_rg = self.store.inner.read();
let mut instance = store_rg
.sessions
.get(self.id.as_str())
.expect("Session data unexpectedly missing")
.lock();
func(&mut instance.data)
}
}
#[derive(Default)]
pub struct SessionFairing<D>
where
D: 'static + Sync + Send + Default,
{
config: SessionConfig,
phantom: PhantomData<D>,
}
impl<D> SessionFairing<D>
where
D: 'static + Sync + Send + Default,
{
fn new() -> Self {
Self::default()
}
pub fn with_lifetime(mut self, time: Duration) -> Self {
self.config.lifespan = time;
self
}
pub fn with_cookie_name(mut self, name: impl Into<Cow<'static, str>>) -> Self {
self.config.cookie_name = name.into();
self
}
pub fn with_cookie_len(mut self, length: usize) -> Self {
self.config.cookie_len = length;
self
}
pub fn with_cookie_path(mut self, path: impl Into<Cow<'static, str>>) -> Self {
self.config.cookie_path = path.into();
self
}
}
#[rocket::async_trait]
impl<D> Fairing for SessionFairing<D>
where
D: 'static + Sync + Send + Default,
{
fn info(&self) -> Info {
Info {
name: "Session",
kind: fairing::Kind::Ignite | fairing::Kind::Response,
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> Result<Rocket<Build>, Rocket<Build>> {
Ok(rocket.manage(SessionStore::<D> {
inner: Default::default(),
config: self.config.clone(),
}))
}
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response) {
let session = request.local_cache(|| SessionID("".to_string()));
if !session.0.is_empty() {
response.adjoin_header(
Cookie::build((self.config.cookie_name.clone(), session.to_string()))
.path("/")
.build(),
);
}
}
}