use super::internal::{
DEFAULT_PREFIX, ValkeyStoreError, registry_key, revocation_session_channel,
revocation_user_channel,
};
use crate::session::{id::SessionId, store::SessionRegistry};
use axess_clock::{Clock, SystemClock};
use fred::prelude::*;
use std::sync::Arc;
use std::time::Duration;
use tracing::debug;
#[derive(Clone)]
pub struct ValkeySessionRegistry {
pub(super) client: Client,
prefix: Arc<str>,
ttl: Duration,
clock: Arc<dyn Clock>,
}
impl ValkeySessionRegistry {
pub fn new(client: Client) -> Self {
Self {
client,
prefix: DEFAULT_PREFIX.into(),
ttl: Duration::from_secs(24 * 60 * 60),
clock: Arc::new(SystemClock),
}
}
pub fn with_options(client: Client, prefix: impl Into<Arc<str>>, ttl: Duration) -> Self {
Self {
client,
prefix: prefix.into(),
ttl,
clock: Arc::new(SystemClock),
}
}
pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
self.clock = clock;
self
}
}
impl SessionRegistry for ValkeySessionRegistry {
type Error = ValkeyStoreError;
async fn register(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> Result<(), Self::Error> {
let key = registry_key(&self.prefix, user_id.to_string().as_str());
let sid_str = session_id.to_string();
let ttl_secs = self.ttl.as_secs().min(i64::MAX as u64) as i64;
let score = self.clock.now().timestamp_millis() as f64;
let pipeline = self.client.pipeline();
pipeline
.zadd::<(), _, _>(&key, None, None, false, false, (score, sid_str))
.await?;
pipeline.expire::<(), _>(&key, ttl_secs, None).await?;
pipeline.all::<()>().await?;
debug!(
user_id = %user_id,
session_id = %session_id,
"session registered in valkey registry"
);
Ok(())
}
async fn is_valid(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> Result<bool, Self::Error> {
let key = registry_key(&self.prefix, user_id.to_string().as_str());
let sid_str = session_id.to_string();
let score: Option<f64> = self.client.zscore(&key, &sid_str).await?;
Ok(score.is_some())
}
async fn invalidate_user(
&self,
user_id: &crate::authn::ids::UserId,
) -> Result<(), Self::Error> {
let key = registry_key(&self.prefix, user_id.to_string().as_str());
self.client.del::<(), _>(&key).await?;
let channel = revocation_user_channel(&self.prefix, user_id.to_string().as_str());
let _: Result<(), _> = self.client.publish(&channel, "1").await;
debug!(
user_id = %user_id,
"all sessions invalidated for user in valkey registry"
);
Ok(())
}
async fn invalidate_session(
&self,
user_id: &crate::authn::ids::UserId,
session_id: &SessionId,
) -> Result<(), Self::Error> {
let key = registry_key(&self.prefix, user_id.to_string().as_str());
let sid_str = session_id.to_string();
self.client.zrem::<(), _, _>(&key, &sid_str).await?;
let channel = revocation_session_channel(&self.prefix, &sid_str);
let _: Result<(), _> = self.client.publish(&channel, "1").await;
debug!(
user_id = %user_id,
session_id = %session_id,
"session removed from valkey registry"
);
Ok(())
}
async fn active_sessions(
&self,
user_id: &crate::authn::ids::UserId,
) -> Result<Vec<SessionId>, Self::Error> {
let key = registry_key(&self.prefix, user_id.to_string().as_str());
let members: Vec<String> = self
.client
.zrange(&key, 0i64, -1i64, None, false, None, false)
.await?;
Ok(members.into_iter().filter_map(|s| s.parse().ok()).collect())
}
async fn watch_revocation(&self, user_id: &crate::authn::ids::UserId, session_id: &SessionId) {
let user_channel = revocation_user_channel(&self.prefix, user_id.to_string().as_str());
let session_channel = revocation_session_channel(&self.prefix, &session_id.to_string());
let subscriber = self.client.clone_new();
if subscriber.init().await.is_err() {
std::future::pending::<()>().await;
return;
}
let mut messages = subscriber.message_rx();
if subscriber
.subscribe(vec![user_channel.clone(), session_channel.clone()])
.await
.is_err()
{
let _ = subscriber.quit().await;
std::future::pending::<()>().await;
return;
}
if let Ok(false) = self.is_valid(user_id, session_id).await {
let _ = subscriber.quit().await;
return;
}
let _ = messages.recv().await;
let _ = subscriber.quit().await;
}
}
use crate::health::{HealthCheck, HealthStatus};
use std::future::Future;
use std::pin::Pin;
const HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(2);
impl HealthCheck for ValkeySessionRegistry {
fn check(&self) -> Pin<Box<dyn Future<Output = HealthStatus> + Send + '_>> {
Box::pin(async {
match tokio::time::timeout(HEALTH_CHECK_TIMEOUT, self.client.ping::<String>(None)).await
{
Ok(Ok(_)) => HealthStatus::Healthy,
Ok(Err(e)) => HealthStatus::Unhealthy(format!("valkey PING failed: {e}")),
Err(_) => HealthStatus::Unhealthy("valkey PING timeout (2s)".into()),
}
})
}
}