use super::{Session, SessionStore};
use crate::http::{
cookies::{Cookie, Key, SameSite},
format_err,
};
use crate::{utils::async_trait, Middleware, Next, Request};
use std::time::Duration;
use async_session::{
base64,
hmac::{Hmac, Mac, NewMac},
sha2::Sha256,
};
const BASE64_DIGEST_LEN: usize = 44;
pub struct SessionMiddleware<Store> {
store: Store,
cookie_path: String,
cookie_name: String,
cookie_domain: Option<String>,
session_ttl: Option<Duration>,
save_unchanged: bool,
same_site_policy: SameSite,
key: Key,
}
impl<Store: SessionStore> std::fmt::Debug for SessionMiddleware<Store> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionMiddleware")
.field("store", &self.store)
.field("cookie_path", &self.cookie_path)
.field("cookie_name", &self.cookie_name)
.field("cookie_domain", &self.cookie_domain)
.field("session_ttl", &self.session_ttl)
.field("same_site_policy", &self.same_site_policy)
.field("key", &"..")
.field("save_unchanged", &self.save_unchanged)
.finish()
}
}
#[async_trait]
impl<Store, State> Middleware<State> for SessionMiddleware<Store>
where
Store: SessionStore,
State: Clone + Send + Sync + 'static,
{
async fn handle(&self, mut request: Request<State>, next: Next<'_, State>) -> crate::Result {
let cookie = request.cookie(&self.cookie_name);
let cookie_value = cookie
.clone()
.and_then(|cookie| self.verify_signature(cookie.value()).ok());
let mut session = self.load_or_create(cookie_value).await;
if let Some(ttl) = self.session_ttl {
session.expire_in(ttl);
}
let secure_cookie = request.url().scheme() == "https";
request.set_ext(session.clone());
let mut response = next.run(request).await;
if session.is_destroyed() {
if let Err(e) = self.store.destroy_session(session).await {
crate::log::error!("unable to destroy session", { error: e.to_string() });
}
if let Some(mut cookie) = cookie {
cookie.set_path("/");
response.remove_cookie(cookie);
}
} else if self.save_unchanged || session.data_changed() {
if let Some(cookie_value) = self
.store
.store_session(session)
.await
.map_err(|e| format_err!("{}", e.to_string()))?
{
let cookie = self.build_cookie(secure_cookie, cookie_value);
response.insert_cookie(cookie);
}
}
Ok(response)
}
}
impl<Store: SessionStore> SessionMiddleware<Store> {
pub fn new(store: Store, secret: &[u8]) -> Self {
Self {
store,
save_unchanged: true,
cookie_path: "/".into(),
cookie_name: "tide.sid".into(),
cookie_domain: None,
same_site_policy: SameSite::Strict,
session_ttl: Some(Duration::from_secs(24 * 60 * 60)),
key: Key::derive_from(secret),
}
}
pub fn with_cookie_path(mut self, cookie_path: impl AsRef<str>) -> Self {
self.cookie_path = cookie_path.as_ref().to_owned();
self
}
pub fn with_session_ttl(mut self, session_ttl: Option<Duration>) -> Self {
self.session_ttl = session_ttl;
self
}
pub fn with_cookie_name(mut self, cookie_name: impl AsRef<str>) -> Self {
self.cookie_name = cookie_name.as_ref().to_owned();
self
}
pub fn without_save_unchanged(mut self) -> Self {
self.save_unchanged = false;
self
}
pub fn with_same_site_policy(mut self, policy: SameSite) -> Self {
self.same_site_policy = policy;
self
}
pub fn with_cookie_domain(mut self, cookie_domain: impl AsRef<str>) -> Self {
self.cookie_domain = Some(cookie_domain.as_ref().to_owned());
self
}
async fn load_or_create(&self, cookie_value: Option<String>) -> Session {
let session = match cookie_value {
Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(),
None => None,
};
session
.and_then(|session| session.validate())
.unwrap_or_default()
}
fn build_cookie(&self, secure: bool, cookie_value: String) -> Cookie<'static> {
let mut cookie = Cookie::build(self.cookie_name.clone(), cookie_value)
.http_only(true)
.same_site(self.same_site_policy)
.secure(secure)
.path(self.cookie_path.clone())
.finish();
if let Some(ttl) = self.session_ttl {
cookie.set_expires(Some((std::time::SystemTime::now() + ttl).into()));
}
if let Some(cookie_domain) = self.cookie_domain.clone() {
cookie.set_domain(cookie_domain)
}
self.sign_cookie(&mut cookie);
cookie
}
fn sign_cookie(&self, cookie: &mut Cookie<'_>) {
let mut mac = Hmac::<Sha256>::new_varkey(&self.key.signing()).expect("good key");
mac.update(cookie.value().as_bytes());
let mut new_value = base64::encode(&mac.finalize().into_bytes());
new_value.push_str(cookie.value());
cookie.set_value(new_value);
}
fn verify_signature(&self, cookie_value: &str) -> Result<String, &'static str> {
if cookie_value.len() < BASE64_DIGEST_LEN {
return Err("length of value is <= BASE64_DIGEST_LEN");
}
let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN);
let digest = base64::decode(digest_str).map_err(|_| "bad base64 digest")?;
let mut mac = Hmac::<Sha256>::new_varkey(&self.key.signing()).expect("good key");
mac.update(value.as_bytes());
mac.verify(&digest)
.map(|_| value.to_string())
.map_err(|_| "value did not verify")
}
}