#[cfg(any(test, feature = "memory"))]
use crate::health::{HealthCheck, HealthStatus};
use crate::session::{data::SessionData, id::SessionId};
#[cfg(any(test, feature = "memory"))]
use crate::store::{MemoryStore, Store};
#[cfg(any(test, feature = "memory"))]
use dashmap::DashMap;
use std::future::Future;
use std::pin::Pin;
#[cfg(any(test, feature = "memory"))]
use std::sync::Arc;
#[cfg(any(test, feature = "memory"))]
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
pub trait SessionStore: Send + Sync + Clone + 'static {
type Error: std::error::Error + Send + Sync + 'static;
fn load(
&self,
id: &SessionId,
) -> impl std::future::Future<Output = Result<Option<SessionData>, Self::Error>> + Send;
fn save(
&self,
id: &SessionId,
data: &SessionData,
ttl: Duration,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn delete(
&self,
id: &SessionId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn cycle(
&self,
old_id: &SessionId,
new_id: &SessionId,
data: &SessionData,
ttl: Duration,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn prune_expired(&self) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send;
fn find_sessions_for_user(
&self,
user_id: &crate::authn::ids::UserId,
limit: usize,
) -> impl std::future::Future<Output = Result<Vec<(SessionId, SessionData)>, Self::Error>> + Send
{
let _ = (user_id, limit);
std::future::ready(Ok(Vec::new()))
}
}
pub trait SessionRegistry: Send + Sync + Clone + 'static {
type Error: std::error::Error + Send + Sync + 'static;
fn register(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn is_valid(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> impl std::future::Future<Output = Result<bool, Self::Error>> + Send;
fn invalidate_user(
&self,
user_id: &crate::authn::ids::UserId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn invalidate_session(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn active_sessions(
&self,
user_id: &crate::authn::ids::UserId,
) -> impl std::future::Future<Output = Result<Vec<SessionId>, Self::Error>> + Send;
fn watch_revocation(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> impl std::future::Future<Output = ()> + Send {
let _ = (user_id, session_id);
std::future::pending()
}
}
pub trait SessionRevoker: Send + Sync + 'static {
fn invalidate_user<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn invalidate_session<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
session_id: &'a SessionId,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
}
pub trait SessionRegistryHandle: SessionRevoker {
fn is_valid<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
session_id: &'a SessionId,
) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>>;
fn register<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
session_id: &'a SessionId,
) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>>;
fn active_sessions<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
) -> Pin<Box<dyn Future<Output = Vec<SessionId>> + Send + 'a>>;
}
pub struct SessionRegistryAdapter<T: SessionRegistry>(pub T);
impl<T: SessionRegistry + 'static> SessionRevoker for SessionRegistryAdapter<T> {
fn invalidate_user<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
Box::pin(async move {
if let Err(e) = self.0.invalidate_user(user_id).await {
tracing::warn!(
user_id = %user_id,
error = %e,
"session registry invalidate_user failed; user sessions may remain active"
);
}
})
}
fn invalidate_session<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
session_id: &'a SessionId,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
Box::pin(async move {
if let Err(e) = self.0.invalidate_session(user_id, session_id).await {
tracing::warn!(
user_id = %user_id,
error = %e,
"session registry invalidate_session failed"
);
}
})
}
}
impl<T: SessionRegistry + 'static> SessionRegistryHandle for SessionRegistryAdapter<T> {
fn is_valid<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
session_id: &'a SessionId,
) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
Box::pin(async move {
match self.0.is_valid(user_id, session_id).await {
Ok(valid) => valid,
Err(e) => {
tracing::warn!(
user_id = %user_id,
error = %e,
"session registry is_valid check failed; failing closed"
);
false
}
}
})
}
fn register<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
session_id: &'a SessionId,
) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
Box::pin(async move {
match self.0.register(user_id, session_id).await {
Ok(()) => true,
Err(e) => {
tracing::warn!(
user_id = %user_id,
error = %e,
"session registry register failed; caller must clear the session \
to avoid an Authenticated-but-untracked outcome"
);
false
}
}
})
}
fn active_sessions<'a>(
&'a self,
user_id: &'a crate::authn::ids::UserId,
) -> Pin<Box<dyn Future<Output = Vec<SessionId>> + Send + 'a>> {
Box::pin(async move {
match self.0.active_sessions(user_id).await {
Ok(sessions) => sessions,
Err(e) => {
tracing::error!(
user_id = %user_id,
error = %e,
"session registry active_sessions failed; concurrent session limit disabled for this request"
);
Vec::new()
}
}
})
}
}
#[cfg(any(test, feature = "memory"))]
#[derive(Clone)]
pub struct MemorySessionStore {
inner: MemoryStore<SessionId, SessionData>,
write_count: Arc<AtomicU64>,
}
#[cfg(any(test, feature = "memory"))]
impl Default for MemorySessionStore {
fn default() -> Self {
Self {
inner: MemoryStore::new(),
write_count: Arc::new(AtomicU64::new(0)),
}
}
}
#[cfg(any(test, feature = "memory"))]
impl MemorySessionStore {
pub fn new() -> Self {
Self::default()
}
pub fn with_clock(mut self, clock: Arc<dyn axess_clock::Clock>) -> Self {
self.inner = self.inner.with_clock(clock);
self
}
pub fn purge_expired(&self) {
self.inner.prune_expired_sync();
}
fn maybe_auto_purge(&self) {
let count = self.write_count.fetch_add(1, Ordering::Relaxed);
if count.is_multiple_of(1024) && !self.inner.is_empty() {
self.purge_expired();
}
}
}
#[cfg(any(test, feature = "memory"))]
pub type MemoryStoreError = std::convert::Infallible;
#[cfg(any(test, feature = "memory"))]
impl SessionStore for MemorySessionStore {
type Error = MemoryStoreError;
async fn load(&self, id: &SessionId) -> Result<Option<SessionData>, Self::Error> {
self.inner.get(id).await
}
async fn save(
&self,
id: &SessionId,
data: &SessionData,
ttl: Duration,
) -> Result<(), Self::Error> {
self.inner.put(id, data, ttl).await?;
self.maybe_auto_purge();
Ok(())
}
async fn delete(&self, id: &SessionId) -> Result<(), Self::Error> {
self.inner.delete(id).await
}
async fn cycle(
&self,
old_id: &SessionId,
new_id: &SessionId,
data: &SessionData,
ttl: Duration,
) -> Result<(), Self::Error> {
self.inner.delete(old_id).await?;
self.inner.put(new_id, data, ttl).await?;
self.maybe_auto_purge();
Ok(())
}
async fn prune_expired(&self) -> Result<u64, Self::Error> {
self.inner.prune_expired().await
}
}
#[cfg(any(test, feature = "memory"))]
impl HealthCheck for MemorySessionStore {
fn check(&self) -> Pin<Box<dyn Future<Output = HealthStatus> + Send + '_>> {
Box::pin(async { HealthStatus::Healthy })
}
}
#[cfg(any(test, feature = "memory"))]
#[derive(Clone, Default)]
pub struct MemorySessionRegistry {
valid: Arc<DashMap<Arc<str>, Vec<SessionId>>>,
}
#[cfg(any(test, feature = "memory"))]
impl MemorySessionRegistry {
pub fn new() -> Self {
Self::default()
}
}
#[cfg(any(test, feature = "memory"))]
pub type MemoryRegistryError = std::convert::Infallible;
#[cfg(any(test, feature = "memory"))]
impl SessionRegistry for MemorySessionRegistry {
type Error = MemoryRegistryError;
async fn register(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> Result<(), Self::Error> {
let mut entry = self
.valid
.entry(Arc::from(user_id.to_string()))
.or_default();
if !entry.contains(session_id) {
entry.push(*session_id);
}
Ok(())
}
async fn is_valid(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> Result<bool, Self::Error> {
Ok(self
.valid
.get(user_id.to_string().as_str())
.is_some_and(|sessions| sessions.contains(session_id)))
}
async fn invalidate_user(
&self,
user_id: &crate::authn::ids::UserId,
) -> Result<(), Self::Error> {
self.valid.remove(user_id.to_string().as_str());
Ok(())
}
async fn invalidate_session(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> Result<(), Self::Error> {
if let Some(mut sessions) = self.valid.get_mut(user_id.to_string().as_str()) {
sessions.retain(|s| s != session_id);
}
Ok(())
}
async fn active_sessions(
&self,
user_id: &crate::authn::ids::UserId,
) -> Result<Vec<SessionId>, Self::Error> {
Ok(self
.valid
.get(user_id.to_string().as_str())
.map(|sessions| sessions.clone())
.unwrap_or_default())
}
}
#[cfg(any(test, feature = "memory"))]
impl HealthCheck for MemorySessionRegistry {
fn check(&self) -> Pin<Box<dyn Future<Output = HealthStatus> + Send + '_>> {
Box::pin(async { HealthStatus::Healthy })
}
}
#[cfg(test)]
mod memory_store_clock_tests;
#[cfg(test)]
mod memory_registry_tests;