use crate::filter::{Filter, Next};
use crate::{Request, Response, Result};
use crate::state::State;
use async_trait::async_trait;
use cookie::Cookie;
use headers::{Header, SetCookie};
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::Mutex;
use tokio::sync::Mutex as AsyncMutex;
use tracing::debug;
use uuid::Uuid;
#[async_trait]
pub trait SessionStore {
async fn get(&self, id: &str) -> Result<Option<String>>;
async fn set(&mut self, id: String, value: String) -> Result<()>;
async fn clear(&mut self, id: &str) -> Result<()>;
}
#[derive(Default)]
pub struct MemorySessionStore {
data: HashMap<String, String>,
}
impl MemorySessionStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl SessionStore for MemorySessionStore {
async fn get(&self, id: &str) -> Result<Option<String>> {
debug!(id, "memory store get");
Ok(self.data.get(id).cloned())
}
async fn set(&mut self, id: String, value: String) -> Result<()> {
debug!(%id, %value, "memory store set");
self.data.insert(id, value);
Ok(())
}
async fn clear(&mut self, id: &str) -> Result<()> {
debug!(id, "memory store clear");
self.data.remove(id);
Ok(())
}
}
pub const DEFAULT_COOKIE_NAME: &str = "sid";
type DynCookieCallback = dyn Fn(&mut Cookie) + Send + Sync + 'static;
pub struct SessionFilter {
cookie_name: Cow<'static, str>,
expiry: time::Duration,
cookie_callback: Option<Box<DynCookieCallback>>,
store: AsyncMutex<Box<dyn SessionStore + Send + Sync + 'static>>,
}
impl SessionFilter {
pub fn new(store: impl SessionStore + Send + Sync + 'static) -> SessionFilter {
SessionFilter {
cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME),
expiry: time::Duration::hours(1),
cookie_callback: None,
store: AsyncMutex::new(Box::new(store)),
}
}
pub fn with_cookie_name(mut self, name: impl Into<Cow<'static, str>>) -> Self {
self.cookie_name = name.into();
self
}
pub fn with_expiry(mut self, expiry: time::Duration) -> Self {
self.expiry = expiry;
self
}
pub fn with_callback<F>(mut self, callback: F) -> Self
where
F: Fn(&mut Cookie) + Send + Sync + 'static,
{
self.cookie_callback = Some(Box::new(callback));
self
}
}
#[derive(Default)]
struct SessionInner {
modified: AtomicBool,
data: Mutex<HashMap<String, String>>,
}
#[derive(Default)]
pub struct Session {
inner: Arc<SessionInner>,
}
impl SessionInner {
fn get(&self, key: &str) -> Option<String> {
debug!(key, "session get");
let data = self.data.lock().unwrap();
data.get(key).cloned()
}
fn set(&self, key: String, value: String) {
debug!(%key, %value, "session set");
self.data.lock().unwrap().insert(key, value);
self.modified.store(true, Ordering::Relaxed);
}
fn is_modified(&self) -> bool {
self.modified.load(Ordering::Relaxed)
}
fn load(&self, data: HashMap<String, String>) {
*self.data.lock().unwrap() = data;
self.modified.store(false, Ordering::Relaxed);
}
}
impl Session {
pub fn get(&self, key: &str) -> Option<String> {
self.inner.get(key)
}
pub fn set(&self, key: String, value: String) {
self.inner.set(key, value)
}
pub fn is_modified(&self) -> bool {
self.inner.is_modified()
}
}
pub trait HasSession {
fn session(&mut self) -> &mut Session;
}
impl<S> HasSession for Request<S>
where
S: State,
S::Context: HasSession,
{
fn session(&mut self) -> &mut Session {
self.context_mut().session()
}
}
#[async_trait]
impl<S> Filter<S> for SessionFilter
where
S: State,
S::Context: HasSession,
{
async fn apply(&self, mut req: Request<S>, next: Next<'_, S>) -> Result<Response> {
let session = Arc::clone(&req.session().inner);
let maybe_sid = req
.cookies()?
.get(self.cookie_name.as_ref())
.map(|c| c.value().to_owned());
let sid = if let Some(sid) = maybe_sid {
debug!(%sid, "request has session cookie");
let store = self.store.lock().await;
let raw_data = store.get(&sid).await?.unwrap_or_default();
let data = serde_urlencoded::from_str(&raw_data)?;
session.load(data);
sid
} else {
debug!("request has no session cookie");
Uuid::new_v4().to_string()
};
let mut resp = next.next(req).await?;
if session.is_modified() {
debug!("session was modified");
let mut store = self.store.lock().await;
let raw_data = {
let data = session.data.lock().unwrap();
serde_urlencoded::to_string(&*data)?
};
let mut cookie = Cookie::new(self.cookie_name.as_ref(), &sid);
cookie.set_http_only(true);
cookie.set_secure(true);
cookie.set_same_site(cookie::SameSite::Strict);
let expiry = time::OffsetDateTime::now_utc() + self.expiry;
cookie.set_expires(expiry);
if let Some(ref callback) = self.cookie_callback {
callback(&mut cookie);
}
resp.set_raw_header(SetCookie::name(), cookie.to_string())?;
store.set(sid, raw_data).await?;
}
Ok(resp)
}
}