use crate::http::cookie::{Cookie, SameSite};
use crate::http::Response;
use crate::middleware::{Middleware, Next};
use crate::Request;
use async_trait::async_trait;
use rand::Rng;
use std::sync::Arc;
use tokio::sync::RwLock;
use super::config::SessionConfig;
use super::driver::DatabaseSessionDriver;
use super::store::{SessionData, SessionStore};
tokio::task_local! {
pub(crate) static SESSION_CONTEXT: Arc<RwLock<Option<SessionData>>>;
}
pub fn session() -> Option<SessionData> {
SESSION_CONTEXT
.try_with(|ctx| {
ctx.try_read().ok().and_then(|guard| guard.clone())
})
.ok()
.flatten()
}
pub fn session_mut<F, R>(f: F) -> Option<R>
where
F: FnOnce(&mut SessionData) -> R,
{
SESSION_CONTEXT
.try_with(|ctx| {
ctx.try_write()
.ok()
.and_then(|mut guard| guard.as_mut().map(f))
})
.ok()
.flatten()
}
fn take_session_internal(ctx: &Arc<RwLock<Option<SessionData>>>) -> Option<SessionData> {
ctx.try_write().ok().and_then(|mut guard| guard.take())
}
pub fn generate_session_id() -> String {
let mut rng = rand::thread_rng();
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
(0..40)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
pub fn generate_csrf_token() -> String {
generate_session_id()
}
pub struct SessionMiddleware {
config: SessionConfig,
store: Arc<dyn SessionStore>,
}
impl SessionMiddleware {
pub fn new(config: SessionConfig) -> Self {
let store = Arc::new(DatabaseSessionDriver::new(
config.idle_lifetime,
config.absolute_lifetime,
));
Self { config, store }
}
pub fn with_store(config: SessionConfig, store: Arc<dyn SessionStore>) -> Self {
Self { config, store }
}
fn create_session_cookie(&self, session_id: &str) -> Cookie {
let mut cookie = Cookie::new(&self.config.cookie_name, session_id)
.http_only(self.config.cookie_http_only)
.secure(self.config.cookie_secure)
.path(&self.config.cookie_path)
.max_age(std::cmp::max(
self.config.idle_lifetime,
self.config.absolute_lifetime,
));
cookie = match self.config.cookie_same_site.to_lowercase().as_str() {
"strict" => cookie.same_site(SameSite::Strict),
"none" => cookie.same_site(SameSite::None),
_ => cookie.same_site(SameSite::Lax),
};
cookie
}
}
#[async_trait]
impl Middleware for SessionMiddleware {
async fn handle(&self, request: Request, next: Next) -> Response {
let session_id = request
.cookie(&self.config.cookie_name)
.unwrap_or_else(generate_session_id);
let mut session = match self.store.read(&session_id).await {
Ok(Some(s)) => s,
Ok(None) => {
SessionData::new(session_id.clone(), generate_csrf_token())
}
Err(e) => {
eprintln!("Session read error: {e}");
SessionData::new(session_id.clone(), generate_csrf_token())
}
};
session.age_flash_data();
let ctx = Arc::new(RwLock::new(Some(session)));
let response = SESSION_CONTEXT
.scope(ctx.clone(), async { next(request).await })
.await;
let session = take_session_internal(&ctx);
if let Some(session) = session {
if let Err(e) = self.store.write(&session).await {
eprintln!("Session write error: {e}");
}
let cookie = self.create_session_cookie(&session.id);
match response {
Ok(res) => Ok(res.cookie(cookie)),
Err(res) => Err(res.cookie(cookie)),
}
} else {
response
}
}
}
pub fn regenerate_session_id() {
session_mut(|session| {
session.id = generate_session_id();
session.dirty = true;
});
}
pub fn invalidate_session() {
session_mut(|session| {
session.flush();
session.csrf_token = generate_csrf_token();
});
}
pub fn get_csrf_token() -> Option<String> {
session().map(|s| s.csrf_token)
}
pub fn is_authenticated() -> bool {
session().map(|s| s.user_id.is_some()).unwrap_or(false)
}
pub fn auth_user_id() -> Option<i64> {
session().and_then(|s| s.user_id)
}
pub fn set_auth_user(user_id: i64) {
session_mut(|session| {
session.user_id = Some(user_id);
session.dirty = true;
});
}
pub fn clear_auth_user() {
session_mut(|session| {
session.user_id = None;
session.dirty = true;
});
}
pub async fn invalidate_all_for_user(
store: &dyn super::store::SessionStore,
user_id: i64,
except_session_id: Option<&str>,
) -> Result<u64, crate::error::FrameworkError> {
store.destroy_for_user(user_id, except_session_id).await
}