use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::auth::session::Session;
use crate::auth::session::meta::SessionMeta;
use crate::auth::session::store::SessionStore;
use crate::auth::session::token::SessionToken;
use crate::db::Database;
use crate::{Error, Result};
use super::claims::Claims;
use super::config::JwtSessionsConfig;
use super::decoder::JwtDecoder;
use super::encoder::JwtEncoder;
use super::tokens::TokenPair;
const AUD_ACCESS: &str = "access";
const AUD_REFRESH: &str = "refresh";
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before UNIX epoch")
.as_secs()
}
#[derive(Clone)]
pub struct JwtSessionService {
inner: Arc<Inner>,
}
struct Inner {
store: SessionStore,
encoder: JwtEncoder,
decoder: JwtDecoder,
config: JwtSessionsConfig,
}
impl JwtSessionService {
pub fn new(db: Database, config: JwtSessionsConfig) -> Result<Self> {
if config.signing_secret.is_empty() {
return Err(Error::internal("jwt: signing_secret must be set"));
}
let encoder = JwtEncoder::from_config(&config);
let decoder = JwtDecoder::from_config(&config);
let store_cfg = crate::auth::session::cookie::CookieSessionsConfig {
session_ttl_secs: config.refresh_ttl_secs,
touch_interval_secs: config.touch_interval_secs,
max_sessions_per_user: config.max_per_user.max(1),
cookie_name: String::new(),
validate_fingerprint: false,
cookie: Default::default(),
};
let store = SessionStore::new(db, store_cfg);
Ok(Self {
inner: Arc::new(Inner {
store,
encoder,
decoder,
config,
}),
})
}
pub fn encoder(&self) -> &JwtEncoder {
&self.inner.encoder
}
pub fn decoder(&self) -> &JwtDecoder {
&self.inner.decoder
}
pub fn config(&self) -> &JwtSessionsConfig {
&self.inner.config
}
#[cfg(any(test, feature = "test-helpers"))]
pub fn store(&self) -> &SessionStore {
&self.inner.store
}
#[cfg(not(any(test, feature = "test-helpers")))]
pub(crate) fn store(&self) -> &SessionStore {
&self.inner.store
}
pub fn layer(&self) -> super::middleware::JwtLayer {
super::middleware::JwtLayer::from_service(self.clone())
}
pub async fn authenticate(&self, user_id: &str, meta: &SessionMeta) -> Result<TokenPair> {
let (raw, token) = self.inner.store.create(meta, user_id, None).await?;
self.mint_pair(&raw.user_id, &token.expose())
}
pub async fn rotate(&self, refresh_token: &str) -> Result<TokenPair> {
let claims: Claims = self.inner.decoder.decode(refresh_token)?;
if claims.aud.as_deref() != Some(AUD_REFRESH) {
return Err(Error::unauthorized("unauthorized").with_code("auth:aud_mismatch"));
}
let jti = claims.jti.as_deref().ok_or_else(|| {
Error::unauthorized("unauthorized").with_code("auth:session_not_found")
})?;
let old_token = SessionToken::from_raw(jti).ok_or_else(|| {
Error::unauthorized("unauthorized").with_code("auth:session_not_found")
})?;
let raw = self
.inner
.store
.read_by_token_hash(&old_token.hash())
.await?
.ok_or_else(|| {
Error::unauthorized("unauthorized").with_code("auth:session_not_found")
})?;
let new_token = SessionToken::generate();
self.inner
.store
.rotate_token_to(&raw.id, &new_token)
.await?;
self.mint_pair(&raw.user_id, &new_token.expose())
}
pub async fn logout(&self, access_token: &str) -> Result<()> {
let claims: Claims = self.inner.decoder.decode(access_token)?;
if claims.aud.as_deref() != Some(AUD_ACCESS) {
return Err(Error::unauthorized("unauthorized").with_code("auth:aud_mismatch"));
}
let jti = claims.jti.as_deref().ok_or_else(|| {
Error::unauthorized("unauthorized").with_code("auth:session_not_found")
})?;
let token = SessionToken::from_raw(jti).ok_or_else(|| {
Error::unauthorized("unauthorized").with_code("auth:session_not_found")
})?;
if let Some(raw) = self.inner.store.read_by_token_hash(&token.hash()).await? {
self.inner.store.destroy(&raw.id).await?;
}
Ok(())
}
pub async fn list(&self, user_id: &str) -> Result<Vec<Session>> {
let raws = self.inner.store.list_for_user(user_id).await?;
Ok(raws.into_iter().map(Session::from).collect())
}
pub async fn revoke(&self, user_id: &str, id: &str) -> Result<()> {
let row = self.inner.store.read(id).await?.ok_or_else(|| {
Error::not_found("session not found").with_code("auth:session_not_found")
})?;
if row.user_id != user_id {
return Err(Error::not_found("session not found").with_code("auth:session_not_found"));
}
self.inner.store.destroy(id).await
}
pub async fn revoke_all(&self, user_id: &str) -> Result<()> {
self.inner.store.destroy_all_for_user(user_id).await
}
pub async fn revoke_all_except(&self, user_id: &str, keep_id: &str) -> Result<()> {
self.inner.store.destroy_all_except(user_id, keep_id).await
}
pub async fn cleanup_expired(&self) -> Result<u64> {
self.inner.store.cleanup_expired().await
}
fn mint_pair(&self, user_id: &str, jti: &str) -> Result<TokenPair> {
let now = now_secs();
let access_exp = now + self.inner.config.access_ttl_secs;
let refresh_exp = now + self.inner.config.refresh_ttl_secs;
let access = Claims::new()
.with_sub(user_id)
.with_aud(AUD_ACCESS)
.with_jti(jti)
.with_exp(access_exp)
.with_iat_now();
let access = if let Some(ref iss) = self.inner.config.issuer {
access.with_iss(iss)
} else {
access
};
let refresh = Claims::new()
.with_sub(user_id)
.with_aud(AUD_REFRESH)
.with_jti(jti)
.with_exp(refresh_exp)
.with_iat_now();
let refresh = if let Some(ref iss) = self.inner.config.issuer {
refresh.with_iss(iss)
} else {
refresh
};
Ok(TokenPair {
access_token: self.inner.encoder.encode(&access)?,
refresh_token: self.inner.encoder.encode(&refresh)?,
access_expires_at: access_exp,
refresh_expires_at: refresh_exp,
})
}
}