use async_trait::async_trait;
use axum::{extract::FromRequestParts, http::request::Parts};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tower_sessions::Session;
use crate::error::Error;
pub struct TypedSession<T> {
session: Session,
data: T,
}
impl<T> TypedSession<T>
where
T: Default + DeserializeOwned + Serialize + Send + Sync,
{
const DATA_KEY: &'static str = "_typed_session_data";
#[must_use]
pub fn data(&self) -> &T {
&self.data
}
pub fn data_mut(&mut self) -> &mut T {
&mut self.data
}
#[must_use]
pub fn into_data(self) -> T {
self.data
}
#[must_use]
pub fn session(&self) -> &Session {
&self.session
}
pub async fn save(&self) -> Result<(), Error> {
self.session
.insert(Self::DATA_KEY, &self.data)
.await
.map_err(|e| Error::Session(format!("Failed to save session data: {e}")))
}
pub async fn update<F>(&mut self, f: F) -> Result<(), Error>
where
F: FnOnce(&mut T),
{
f(&mut self.data);
self.save().await
}
pub async fn clear(&mut self) -> Result<(), Error> {
self.data = T::default();
self.session
.remove::<T>(Self::DATA_KEY)
.await
.map_err(|e| Error::Session(format!("Failed to clear session data: {e}")))?;
Ok(())
}
pub async fn destroy(&self) -> Result<(), Error> {
self.session
.flush()
.await
.map_err(|e| Error::Session(format!("Failed to destroy session: {e}")))
}
pub async fn regenerate(&self) -> Result<(), Error> {
self.session
.cycle_id()
.await
.map_err(|e| Error::Session(format!("Failed to regenerate session ID: {e}")))
}
}
impl<S, T> FromRequestParts<S> for TypedSession<T>
where
S: Send + Sync,
T: Default + DeserializeOwned + Serialize + Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let session = parts.extensions.get::<Session>().cloned().ok_or_else(|| {
Error::Session(
"Session not found in request extensions. Is SessionManagerLayer configured?"
.to_string(),
)
})?;
let data: T = session
.get(Self::DATA_KEY)
.await
.map_err(|e| Error::Session(format!("Failed to read session data: {e}")))?
.unwrap_or_default();
Ok(Self { session, data })
}
}
#[async_trait]
pub trait SessionData {
async fn get_value<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error>;
async fn set_value<T: Serialize + Send + Sync>(
&self,
key: &str,
value: &T,
) -> Result<(), Error>;
async fn remove_value<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error>;
async fn has_key(&self, key: &str) -> Result<bool, Error>;
}
#[async_trait]
impl SessionData for Session {
async fn get_value<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
self.get(key)
.await
.map_err(|e| Error::Session(format!("Session get error: {e}")))
}
async fn set_value<T: Serialize + Send + Sync>(
&self,
key: &str,
value: &T,
) -> Result<(), Error> {
self.insert(key, value)
.await
.map_err(|e| Error::Session(format!("Session set error: {e}")))
}
async fn remove_value<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
self.remove(key)
.await
.map_err(|e| Error::Session(format!("Session remove error: {e}")))
}
async fn has_key(&self, key: &str) -> Result<bool, Error> {
let value: Option<serde_json::Value> = self
.get(key)
.await
.map_err(|e| Error::Session(format!("Session check error: {e}")))?;
Ok(value.is_some())
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AuthSession {
pub user_id: Option<String>,
pub roles: Vec<String>,
pub authenticated_at: Option<i64>,
#[serde(default)]
pub extra: std::collections::HashMap<String, String>,
}
impl AuthSession {
#[must_use]
pub fn is_authenticated(&self) -> bool {
self.user_id.is_some()
}
#[must_use]
pub fn user_id(&self) -> Option<&str> {
self.user_id.as_deref()
}
pub fn login(&mut self, user_id: String, roles: Vec<String>) {
self.user_id = Some(user_id);
self.roles = roles;
self.authenticated_at = Some(chrono::Utc::now().timestamp());
}
pub fn login_with_extra(
&mut self,
user_id: String,
roles: Vec<String>,
extra: std::collections::HashMap<String, String>,
) {
self.login(user_id, roles);
self.extra = extra;
}
pub fn logout(&mut self) {
self.user_id = None;
self.roles.clear();
self.authenticated_at = None;
self.extra.clear();
}
#[must_use]
pub fn has_role(&self, role: &str) -> bool {
self.roles.iter().any(|r| r == role)
}
#[must_use]
pub fn has_any_role(&self, roles: &[&str]) -> bool {
roles.iter().any(|r| self.has_role(r))
}
#[must_use]
pub fn has_all_roles(&self, roles: &[&str]) -> bool {
roles.iter().all(|r| self.has_role(r))
}
#[must_use]
pub fn get_extra(&self, key: &str) -> Option<&str> {
self.extra.get(key).map(String::as_str)
}
pub fn set_extra(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.extra.insert(key.into(), value.into());
}
#[must_use]
pub fn session_age(&self) -> Option<chrono::Duration> {
self.authenticated_at.map(|ts| {
let now = chrono::Utc::now().timestamp();
chrono::Duration::seconds(now - ts)
})
}
}
pub type SessionAuth = TypedSession<AuthSession>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_session_login_logout() {
let mut auth = AuthSession::default();
assert!(!auth.is_authenticated());
auth.login("user-123".to_string(), vec!["admin".to_string()]);
assert!(auth.is_authenticated());
assert_eq!(auth.user_id(), Some("user-123"));
assert!(auth.has_role("admin"));
assert!(!auth.has_role("superuser"));
auth.logout();
assert!(!auth.is_authenticated());
assert!(auth.roles.is_empty());
}
#[test]
fn test_auth_session_roles() {
let mut auth = AuthSession::default();
auth.login(
"user".to_string(),
vec!["admin".to_string(), "editor".to_string()],
);
assert!(auth.has_any_role(&["admin", "viewer"]));
assert!(auth.has_all_roles(&["admin", "editor"]));
assert!(!auth.has_all_roles(&["admin", "superuser"]));
}
#[test]
fn test_auth_session_extra_data() {
let mut auth = AuthSession::default();
auth.set_extra("email", "user@example.com");
assert_eq!(auth.get_extra("email"), Some("user@example.com"));
assert_eq!(auth.get_extra("missing"), None);
}
}