use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Duration, Utc};
use rand::TryRngCore;
use rand::rngs::OsRng;
use sha2::{Digest, Sha256};
use crate::db::Db;
use crate::error::AuthError;
use crate::types::{Session, SessionId, SessionToken, TokenHash, UserId};
pub struct SessionConfig {
pub ttl: Duration,
pub cookie_name: &'static str,
pub secure: bool,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
ttl: Duration::hours(24),
cookie_name: "allowthem_session",
secure: true,
}
}
}
pub fn generate_token() -> SessionToken {
let mut bytes = [0u8; 32];
OsRng
.try_fill_bytes(&mut bytes)
.expect("OS RNG unavailable");
SessionToken::from_encoded(Base64UrlUnpadded::encode_string(&bytes))
}
pub fn hash_token(token: &SessionToken) -> TokenHash {
let digest = Sha256::digest(token.as_str().as_bytes());
TokenHash::new_unchecked(format!("{digest:x}"))
}
impl Db {
pub async fn create_session(
&self,
user_id: UserId,
token_hash: TokenHash,
ip_address: Option<&str>,
user_agent: Option<&str>,
expires_at: DateTime<Utc>,
) -> Result<Session, AuthError> {
let id = SessionId::new();
let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
sqlx::query_as::<_, Session>(
"INSERT INTO allowthem_sessions (id, token_hash, user_id, ip_address, user_agent, expires_at)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id, token_hash, user_id, ip_address, user_agent, expires_at, created_at",
)
.bind(id)
.bind(token_hash)
.bind(user_id)
.bind(ip_address)
.bind(user_agent)
.bind(expires_at_str)
.fetch_one(self.pool())
.await
.map_err(AuthError::Database)
}
pub async fn lookup_session(&self, token: &SessionToken) -> Result<Option<Session>, AuthError> {
let hash = hash_token(token);
let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
sqlx::query_as::<_, Session>(
"SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at
FROM allowthem_sessions
WHERE token_hash = ? AND expires_at > ?",
)
.bind(hash)
.bind(now)
.fetch_optional(self.pool())
.await
.map_err(AuthError::Database)
}
pub async fn validate_session(
&self,
token: &SessionToken,
ttl: Duration,
) -> Result<Option<Session>, AuthError> {
let session = match self.lookup_session(token).await? {
Some(s) => s,
None => return Ok(None),
};
let now = Utc::now();
let halfway = session.expires_at - ttl / 2;
if now > halfway {
let new_expires_at = now + ttl;
let new_expires_str = new_expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
let hash = hash_token(token);
sqlx::query("UPDATE allowthem_sessions SET expires_at = ? WHERE token_hash = ?")
.bind(&new_expires_str)
.bind(hash)
.execute(self.pool())
.await
.map_err(AuthError::Database)?;
return Ok(Some(Session {
expires_at: new_expires_at,
..session
}));
}
Ok(Some(session))
}
pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
let hash = hash_token(token);
let result = sqlx::query("DELETE FROM allowthem_sessions WHERE token_hash = ?")
.bind(hash)
.execute(self.pool())
.await
.map_err(AuthError::Database)?;
Ok(result.rows_affected() > 0)
}
pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
let result = sqlx::query("DELETE FROM allowthem_sessions WHERE user_id = ?")
.bind(*user_id)
.execute(self.pool())
.await
.map_err(AuthError::Database)?;
Ok(result.rows_affected())
}
}
pub fn session_cookie(token: &SessionToken, config: &SessionConfig, domain: &str) -> String {
let max_age = config.ttl.num_seconds();
let mut cookie = format!(
"{}={}; HttpOnly; SameSite=Lax; Path=/; Max-Age={}",
config.cookie_name,
token.as_str(),
max_age,
);
if !domain.is_empty() {
cookie.push_str("; Domain=");
cookie.push_str(domain);
}
if config.secure {
cookie.push_str("; Secure");
}
cookie
}
pub fn parse_session_cookie(cookie_header: &str, cookie_name: &str) -> Option<SessionToken> {
for pair in cookie_header.split("; ") {
if let Some((name, value)) = pair.split_once('=')
&& name.trim() == cookie_name
{
return Some(SessionToken::from_encoded(value.trim().to_string()));
}
}
None
}