use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Duration, Utc};
use rand::TryRngCore;
use rand::rngs::OsRng;
use sha2::{Digest, Sha256};
use serde::Serialize;
use crate::db::Db;
use crate::error::AuthError;
use crate::event_sink::AuthEvent;
use crate::handle::AllowThem;
use crate::types::{Email, Session, SessionId, SessionToken, TokenHash, UserId};
pub struct SessionConfig {
pub ttl: Duration,
pub cookie_name: &'static str,
pub secure: bool,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
ttl: Duration::hours(24),
cookie_name: "allowthem_session",
secure: true,
}
}
}
#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
pub struct SessionListEntry {
pub id: SessionId,
pub user_id: UserId,
pub user_email: Email,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub expires_at: DateTime<Utc>,
pub created_at: DateTime<Utc>,
}
pub struct ListSessionsParams {
pub user_id: Option<UserId>,
pub limit: u32,
pub offset: u32,
}
pub struct ListSessionsResult {
pub sessions: Vec<SessionListEntry>,
pub total: u32,
}
pub fn generate_token() -> SessionToken {
let mut bytes = [0u8; 32];
OsRng
.try_fill_bytes(&mut bytes)
.expect("OS RNG unavailable");
SessionToken::from_encoded(Base64UrlUnpadded::encode_string(&bytes))
}
pub fn hash_token(token: &SessionToken) -> TokenHash {
let digest = Sha256::digest(token.as_str().as_bytes());
TokenHash::new_unchecked(format!("{digest:x}"))
}
impl Db {
pub async fn create_session(
&self,
user_id: UserId,
token_hash: TokenHash,
ip_address: Option<&str>,
user_agent: Option<&str>,
expires_at: DateTime<Utc>,
) -> Result<Session, AuthError> {
let id = SessionId::new();
let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
sqlx::query_as::<_, Session>(
"INSERT INTO allowthem_sessions (id, token_hash, user_id, ip_address, user_agent, expires_at)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id, token_hash, user_id, ip_address, user_agent, expires_at, created_at",
)
.bind(id)
.bind(token_hash)
.bind(user_id)
.bind(ip_address)
.bind(user_agent)
.bind(expires_at_str)
.fetch_one(self.pool())
.await
.map_err(AuthError::Database)
}
pub async fn lookup_session(&self, token: &SessionToken) -> Result<Option<Session>, AuthError> {
let hash = hash_token(token);
let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
sqlx::query_as::<_, Session>(
"SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at
FROM allowthem_sessions
WHERE token_hash = ? AND expires_at > ?",
)
.bind(hash)
.bind(now)
.fetch_optional(self.pool())
.await
.map_err(AuthError::Database)
}
pub async fn validate_session(
&self,
token: &SessionToken,
ttl: Duration,
) -> Result<Option<Session>, AuthError> {
let session = match self.lookup_session(token).await? {
Some(s) => s,
None => return Ok(None),
};
let now = Utc::now();
let halfway = session.expires_at - ttl / 2;
if now > halfway {
let new_expires_at = now + ttl;
let new_expires_str = new_expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
let hash = hash_token(token);
sqlx::query("UPDATE allowthem_sessions SET expires_at = ? WHERE token_hash = ?")
.bind(&new_expires_str)
.bind(hash)
.execute(self.pool())
.await
.map_err(AuthError::Database)?;
return Ok(Some(Session {
expires_at: new_expires_at,
..session
}));
}
Ok(Some(session))
}
pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
let hash = hash_token(token);
let result = sqlx::query("DELETE FROM allowthem_sessions WHERE token_hash = ?")
.bind(hash)
.execute(self.pool())
.await
.map_err(AuthError::Database)?;
Ok(result.rows_affected() > 0)
}
pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
let result = sqlx::query("DELETE FROM allowthem_sessions WHERE user_id = ?")
.bind(*user_id)
.execute(self.pool())
.await
.map_err(AuthError::Database)?;
Ok(result.rows_affected())
}
pub async fn list_user_sessions(&self, user_id: UserId) -> Result<Vec<Session>, AuthError> {
let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
sqlx::query_as::<_, Session>(
"SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at \
FROM allowthem_sessions \
WHERE user_id = ? AND expires_at > ? \
ORDER BY created_at DESC",
)
.bind(user_id)
.bind(now)
.fetch_all(self.pool())
.await
.map_err(AuthError::Database)
}
pub async fn list_all_sessions(
&self,
params: ListSessionsParams,
) -> Result<ListSessionsResult, AuthError> {
let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
if let Some(user_id) = params.user_id {
let total = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM allowthem_sessions s \
JOIN allowthem_users u ON s.user_id = u.id \
WHERE s.expires_at > ? AND s.user_id = ?",
)
.bind(&now)
.bind(user_id)
.fetch_one(self.pool())
.await
.map_err(AuthError::Database)? as u32;
let sessions = sqlx::query_as::<_, SessionListEntry>(
"SELECT s.id, s.user_id, u.email AS user_email, \
s.ip_address, s.user_agent, s.expires_at, s.created_at \
FROM allowthem_sessions s \
JOIN allowthem_users u ON s.user_id = u.id \
WHERE s.expires_at > ? AND s.user_id = ? \
ORDER BY s.created_at DESC \
LIMIT ? OFFSET ?",
)
.bind(&now)
.bind(user_id)
.bind(params.limit)
.bind(params.offset)
.fetch_all(self.pool())
.await
.map_err(AuthError::Database)?;
Ok(ListSessionsResult { sessions, total })
} else {
let total = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM allowthem_sessions s \
JOIN allowthem_users u ON s.user_id = u.id \
WHERE s.expires_at > ?",
)
.bind(&now)
.fetch_one(self.pool())
.await
.map_err(AuthError::Database)? as u32;
let sessions = sqlx::query_as::<_, SessionListEntry>(
"SELECT s.id, s.user_id, u.email AS user_email, \
s.ip_address, s.user_agent, s.expires_at, s.created_at \
FROM allowthem_sessions s \
JOIN allowthem_users u ON s.user_id = u.id \
WHERE s.expires_at > ? \
ORDER BY s.created_at DESC \
LIMIT ? OFFSET ?",
)
.bind(&now)
.bind(params.limit)
.bind(params.offset)
.fetch_all(self.pool())
.await
.map_err(AuthError::Database)?;
Ok(ListSessionsResult { sessions, total })
}
}
pub async fn delete_session_by_id(&self, id: SessionId) -> Result<bool, AuthError> {
let result = sqlx::query("DELETE FROM allowthem_sessions WHERE id = ?")
.bind(id)
.execute(self.pool())
.await
.map_err(AuthError::Database)?;
Ok(result.rows_affected() > 0)
}
}
impl AllowThem {
pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
let deleted = self.db().delete_session(token).await?;
if deleted {
self.emit_event(AuthEvent::new(
"session.destroyed",
None,
serde_json::json!({ "scope": "single" }),
))
.await;
}
Ok(deleted)
}
pub async fn delete_session_by_id(&self, id: SessionId) -> Result<bool, AuthError> {
let deleted = self.db().delete_session_by_id(id).await?;
if deleted {
self.emit_event(AuthEvent::new(
"session.destroyed",
None,
serde_json::json!({ "scope": "single" }),
))
.await;
}
Ok(deleted)
}
pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
let count = self.db().delete_user_sessions(user_id).await?;
self.emit_event(AuthEvent::new(
"session.destroyed",
Some(*user_id),
serde_json::json!({ "user_id": user_id, "count": count, "scope": "all" }),
))
.await;
Ok(count)
}
}
pub fn session_cookie(token: &SessionToken, config: &SessionConfig, domain: &str) -> String {
let max_age = config.ttl.num_seconds();
let mut cookie = format!(
"{}={}; HttpOnly; SameSite=Lax; Path=/; Max-Age={}",
config.cookie_name,
token.as_str(),
max_age,
);
if !domain.is_empty() {
cookie.push_str("; Domain=");
cookie.push_str(domain);
}
if config.secure {
cookie.push_str("; Secure");
}
cookie
}
pub fn clear_session_cookie(config: &SessionConfig, domain: &str) -> String {
let mut cookie = format!(
"{}=; HttpOnly; SameSite=Lax; Path=/; Max-Age=0",
config.cookie_name,
);
if !domain.is_empty() {
cookie.push_str("; Domain=");
cookie.push_str(domain);
}
if config.secure {
cookie.push_str("; Secure");
}
cookie
}
pub fn parse_session_cookie(cookie_header: &str, cookie_name: &str) -> Option<SessionToken> {
for pair in cookie_header.split("; ") {
if let Some((name, value)) = pair.split_once('=')
&& name.trim() == cookie_name
{
return Some(SessionToken::from_encoded(value.trim().to_string()));
}
}
None
}