use crate::{
CookieExt, Handler, MiddleWareHandler, Next, Request, Response, SilentError, StatusCode,
};
use async_lock::RwLock;
use async_session::{MemoryStore, Session, SessionStore};
use async_trait::async_trait;
use cookie::{Cookie, CookieJar};
use std::sync::Arc;
pub struct SessionMiddleware<T>
where
T: SessionStore,
{
pub session_store: Arc<RwLock<T>>,
}
impl Default for SessionMiddleware<MemoryStore> {
fn default() -> SessionMiddleware<MemoryStore> {
let session = MemoryStore::new();
Self::new(session)
}
}
impl<T> SessionMiddleware<T>
where
T: SessionStore,
{
pub fn new(session: T) -> Self {
let session_store = Arc::new(RwLock::new(session));
SessionMiddleware { session_store }
}
}
#[async_trait]
impl<T> MiddleWareHandler for SessionMiddleware<T>
where
T: SessionStore,
{
async fn handle(&self, mut req: Request, next: &Next) -> crate::Result<Response> {
let mut cookies = req.cookies().clone();
let cookie = cookies.get("silent-web-session");
let session_store = self.session_store.read().await;
let mut session_key_exists = false;
let mut cookie_value = if let Some(cookie) = cookie {
session_key_exists = true;
cookie.value().to_string()
} else {
session_store
.store_session(Session::new())
.await?
.ok_or_else(|| {
SilentError::business_error(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to create session",
)
})?
};
let session =
if let Ok(Some(session)) = session_store.load_session(cookie_value.clone()).await {
session
} else {
session_key_exists = false;
cookie_value = session_store
.store_session(Session::new())
.await?
.ok_or_else(|| {
SilentError::business_error(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to create session",
)
})?;
session_store
.load_session(cookie_value.clone())
.await?
.ok_or_else(|| {
SilentError::business_error(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to load session",
)
})?
};
req.extensions_mut().insert(session.clone());
let session_copied = session.clone();
if !session_key_exists {
cookies.add(
Cookie::build(("silent-web-session", cookie_value))
.max_age(cookie::time::Duration::hours(2))
.secure(true),
);
}
let mut res = next.call(req).await?;
if res.extensions().get::<Session>().is_none() {
res.extensions_mut().insert(session_copied);
}
if res.extensions().get::<CookieJar>().is_none() {
res.extensions_mut().insert(cookies);
} else {
if let Some(cookie_jar) = res.extensions().get::<CookieJar>() {
for cookie in cookie_jar.iter() {
cookies.add(cookie.clone());
}
}
res.extensions_mut().insert(cookies.clone());
}
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handler::HandlerWrapper;
use crate::session::session_ext::SessionExt;
use async_session::MemoryStore;
use cookie::{Cookie, CookieJar};
use std::sync::Arc;
fn test_cookie(name: &str, value: impl AsRef<str>) -> Cookie<'static> {
Cookie::build((name.to_owned(), value.as_ref().to_owned()))
.secure(true)
.build()
}
async fn test_handler(_req: Request) -> crate::Result<Response> {
Ok(Response::empty())
}
#[test]
fn test_session_middleware_default() {
let middleware = SessionMiddleware::<MemoryStore>::default();
let _count = Arc::strong_count(&middleware.session_store);
}
#[test]
fn test_session_middleware_new() {
let store = MemoryStore::new();
let middleware = SessionMiddleware::new(store);
let _count = Arc::strong_count(&middleware.session_store);
}
#[tokio::test]
async fn test_middleware_with_no_session_cookie() {
let middleware = SessionMiddleware::default();
let mut req = Request::empty();
req.extensions_mut().insert(CookieJar::new());
let handler = HandlerWrapper::new(test_handler).arc();
let middlewares: Vec<Arc<dyn MiddleWareHandler>> = vec![];
let next = Next::build(handler, &middlewares);
let result = middleware.handle(req, &next).await;
assert!(result.is_ok());
let res = result.unwrap();
assert!(res.extensions().get::<Session>().is_some());
}
#[tokio::test]
async fn test_middleware_with_valid_session_cookie() {
let middleware = SessionMiddleware::default();
let store = middleware.session_store.read().await;
let session = Session::new();
let cookie_value = store.store_session(session).await.unwrap().unwrap();
drop(store);
let mut jar = CookieJar::new();
jar.add(test_cookie("silent-web-session", cookie_value));
let mut req = Request::empty();
req.extensions_mut().insert(jar);
let handler = HandlerWrapper::new(test_handler).arc();
let middlewares: Vec<Arc<dyn MiddleWareHandler>> = vec![];
let next = Next::build(handler, &middlewares);
let result = middleware.handle(req, &next).await;
assert!(result.is_ok());
let res = result.unwrap();
assert!(res.extensions().get::<Session>().is_some());
}
#[tokio::test]
async fn test_middleware_creates_new_session_if_cookie_invalid() {
let middleware = SessionMiddleware::default();
let mut jar = CookieJar::new();
jar.add(test_cookie("silent-web-session", "invalid_cookie_value"));
let mut req = Request::empty();
req.extensions_mut().insert(jar);
let handler = HandlerWrapper::new(test_handler).arc();
let middlewares: Vec<Arc<dyn MiddleWareHandler>> = vec![];
let next = Next::build(handler, &middlewares);
let result = middleware.handle(req, &next).await;
assert!(result.is_ok());
let res = result.unwrap();
assert!(res.extensions().get::<Session>().is_some());
}
#[tokio::test]
async fn test_middleware_session_inserted_to_request() {
let middleware = SessionMiddleware::default();
let mut req = Request::empty();
req.extensions_mut().insert(CookieJar::new());
let handler = HandlerWrapper::new(test_handler).arc();
let middlewares: Vec<Arc<dyn MiddleWareHandler>> = vec![];
let next = Next::build(handler, &middlewares);
let result = middleware.handle(req, &next).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_middleware_preserves_response_session() {
let middleware = SessionMiddleware::default();
let mut req = Request::empty();
req.extensions_mut().insert(CookieJar::new());
async fn session_handler(mut req: Request) -> crate::Result<Response> {
let session = req.sessions_mut();
session.insert("test_key", "test_value").unwrap();
Ok(Response::empty())
}
let handler = HandlerWrapper::new(session_handler).arc();
let middlewares: Vec<Arc<dyn MiddleWareHandler>> = vec![];
let next = Next::build(handler, &middlewares);
let result = middleware.handle(req, &next).await;
assert!(result.is_ok());
let res = result.unwrap();
assert!(res.extensions().get::<Session>().is_some());
}
#[tokio::test]
async fn test_middleware_with_existing_cookie_jar() {
let middleware = SessionMiddleware::default();
let mut req = Request::empty();
req.extensions_mut().insert(CookieJar::new());
async fn cookie_handler(_req: Request) -> crate::Result<Response> {
let mut res = Response::empty();
let mut jar = CookieJar::new();
jar.add(test_cookie("test_cookie", "test_value"));
res.extensions_mut().insert(jar);
Ok(res)
}
let handler = HandlerWrapper::new(cookie_handler).arc();
let middlewares: Vec<Arc<dyn MiddleWareHandler>> = vec![];
let next = Next::build(handler, &middlewares);
let result = middleware.handle(req, &next).await;
assert!(result.is_ok());
let res = result.unwrap();
assert!(res.extensions().get::<CookieJar>().is_some());
}
#[tokio::test]
async fn test_middleware_adds_cookie_when_session_key_not_exists() {
let middleware = SessionMiddleware::default();
let jar = CookieJar::new();
let mut req = Request::empty();
req.extensions_mut().insert(jar);
let handler = HandlerWrapper::new(test_handler).arc();
let middlewares: Vec<Arc<dyn MiddleWareHandler>> = vec![];
let next = Next::build(handler, &middlewares);
let result = middleware.handle(req, &next).await;
assert!(result.is_ok());
let res = result.unwrap();
if let Some(cookie_jar) = res.extensions().get::<CookieJar>() {
assert!(cookie_jar.get("silent-web-session").is_some());
}
}
}