use crate::lease::LeaseAcquisition;
use crate::lease::RenewalLease;
use crate::session::SessionFamilyId;
use crate::session::SessionFamilyRecord;
use crate::session::SessionId;
use crate::session::SessionLookup;
use crate::session::SessionRecord;
use crate::session::SessionRefreshRecord;
use crate::session::SessionTouch;
use crate::tokens::RefreshTokenHash;
use crate::tokens::RefreshTokenHashRef;
pub type RepositoryResult<T> = std::result::Result<T, RepositoryError>;
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum RepositoryError {
#[error("session not found")]
SessionNotFound,
#[error("session family not found")]
SessionFamilyNotFound,
#[error("concurrent session update detected")]
Conflict,
#[error("invalid persisted session state")]
InvalidState,
#[error("{message}")]
Backend {
message: String,
},
}
impl RepositoryError {
#[must_use]
pub fn backend(message: impl Into<String>) -> Self {
Self::Backend {
message: message.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CreateSession {
pub session: SessionRecord,
pub refresh_token_hash: RefreshTokenHash,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RotateRefreshToken {
pub session_id: SessionId,
pub family: SessionFamilyRecord,
pub lease: RenewalLease,
pub previous_refresh_token_hash: RefreshTokenHash,
pub next_refresh_token_hash: RefreshTokenHash,
pub next_session: SessionRecord,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RotateRefreshTokenOutcome {
Rotated,
SessionMissing,
LeaseUnavailable,
RefreshTokenMismatch,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RevokeSessionScope {
CurrentSession,
SessionFamily,
}
pub trait SessionRepository: Send + Sync {
fn create_session(
&self,
input: CreateSession,
) -> impl std::future::Future<Output = RepositoryResult<()>> + Send;
fn find_session_by_refresh_token_hash<'a>(
&'a self,
refresh_token_hash: RefreshTokenHashRef<'a>,
) -> impl std::future::Future<Output = RepositoryResult<Option<SessionLookup>>> + Send + 'a;
fn find_session(
&self,
session_id: SessionId,
) -> impl std::future::Future<Output = RepositoryResult<Option<SessionRecord>>> + Send;
fn find_family(
&self,
family_id: SessionFamilyId,
) -> impl std::future::Future<Output = RepositoryResult<Option<SessionFamilyRecord>>> + Send;
fn find_refresh_record(
&self,
session_id: SessionId,
) -> impl std::future::Future<Output = RepositoryResult<Option<SessionRefreshRecord>>> + Send;
fn try_acquire_renewal_lease(
&self,
session_id: SessionId,
lease: RenewalLease,
) -> impl std::future::Future<Output = RepositoryResult<LeaseAcquisition>> + Send;
fn rotate_refresh_token(
&self,
input: RotateRefreshToken,
) -> impl std::future::Future<Output = RepositoryResult<RotateRefreshTokenOutcome>> + Send;
fn revoke_session(
&self,
session_id: SessionId,
scope: RevokeSessionScope,
) -> impl std::future::Future<Output = RepositoryResult<()>> + Send;
fn revoke_family(
&self,
family_id: SessionFamilyId,
) -> impl std::future::Future<Output = RepositoryResult<()>> + Send;
fn touch_session(
&self,
touch: SessionTouch,
) -> impl std::future::Future<Output = RepositoryResult<()>> + Send;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lease::{LeaseAcquisition, LeaseId, LeaseTtl, RenewalLease};
use crate::session::{Session, SessionFamilyRecord};
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone, Copy, Default)]
struct ContractRepository;
impl SessionRepository for ContractRepository {
async fn create_session(&self, _input: CreateSession) -> RepositoryResult<()> {
Ok(())
}
async fn find_session_by_refresh_token_hash<'a>(
&'a self,
_refresh_token_hash: RefreshTokenHashRef<'a>,
) -> RepositoryResult<Option<SessionLookup>> {
Ok(None)
}
async fn find_session(
&self,
_session_id: SessionId,
) -> RepositoryResult<Option<SessionRecord>> {
Ok(None)
}
async fn find_family(
&self,
_family_id: SessionFamilyId,
) -> RepositoryResult<Option<SessionFamilyRecord>> {
Ok(None)
}
async fn find_refresh_record(
&self,
_session_id: SessionId,
) -> RepositoryResult<Option<SessionRefreshRecord>> {
Ok(None)
}
async fn try_acquire_renewal_lease(
&self,
_session_id: SessionId,
lease: RenewalLease,
) -> RepositoryResult<LeaseAcquisition> {
Ok(LeaseAcquisition::Acquired(lease))
}
async fn rotate_refresh_token(
&self,
_input: RotateRefreshToken,
) -> RepositoryResult<RotateRefreshTokenOutcome> {
Ok(RotateRefreshTokenOutcome::Rotated)
}
async fn revoke_session(
&self,
_session_id: SessionId,
_scope: RevokeSessionScope,
) -> RepositoryResult<()> {
Ok(())
}
async fn revoke_family(&self, _family_id: SessionFamilyId) -> RepositoryResult<()> {
Ok(())
}
async fn touch_session(&self, _touch: SessionTouch) -> RepositoryResult<()> {
Ok(())
}
}
fn assert_session_repository<T: SessionRepository>(_repository: &T) {}
fn sample_time() -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_secs(1_000)
}
fn sample_hash(value: &str) -> RefreshTokenHash {
match RefreshTokenHash::new(value) {
Ok(hash) => hash,
Err(error) => panic!("expected valid refresh-token hash: {error}"),
}
}
fn sample_session() -> SessionRecord {
let now = sample_time();
Session::new(
SessionFamilyId::new(),
"subject-123",
now,
now + Duration::from_secs(3_600),
)
}
fn sample_lease(session_id: SessionId) -> RenewalLease {
RenewalLease::from_ttl(
session_id,
LeaseId::new(),
sample_time(),
LeaseTtl::new(Duration::from_secs(30)),
)
}
#[test]
fn backend_error_constructor_keeps_message() {
let error = RepositoryError::backend("safe backend summary");
assert_eq!(
error,
RepositoryError::Backend {
message: String::from("safe backend summary"),
}
);
}
#[tokio::test]
async fn repository_trait_contracts_are_callable() {
let repository = ContractRepository;
assert_session_repository(&repository);
let session = sample_session();
let family =
SessionFamilyRecord::new(session.family_id, session.subject_id.clone(), sample_time());
let lease = sample_lease(session.session_id);
let create_input = CreateSession {
session: session.clone(),
refresh_token_hash: sample_hash("active-refresh-hash"),
};
let rotate_input = RotateRefreshToken {
session_id: session.session_id,
family,
lease,
previous_refresh_token_hash: sample_hash("previous-refresh-hash"),
next_refresh_token_hash: sample_hash("next-refresh-hash"),
next_session: session
.clone()
.touched(sample_time() + Duration::from_secs(10)),
};
let touch = SessionTouch::new(session.session_id, sample_time() + Duration::from_secs(20));
assert_eq!(repository.create_session(create_input).await, Ok(()));
assert!(matches!(
repository
.find_session_by_refresh_token_hash("lookup-refresh-hash")
.await,
Ok(None)
));
assert!(matches!(
repository.find_session(session.session_id).await,
Ok(None)
));
assert!(matches!(
repository.find_family(session.family_id).await,
Ok(None)
));
assert!(matches!(
repository.find_refresh_record(session.session_id).await,
Ok(None)
));
assert_eq!(
repository
.try_acquire_renewal_lease(session.session_id, lease)
.await,
Ok(LeaseAcquisition::Acquired(lease))
);
assert_eq!(
repository.rotate_refresh_token(rotate_input).await,
Ok(RotateRefreshTokenOutcome::Rotated)
);
assert_eq!(
repository
.revoke_session(session.session_id, RevokeSessionScope::CurrentSession)
.await,
Ok(())
);
assert_eq!(repository.revoke_family(session.family_id).await, Ok(()));
assert_eq!(repository.touch_session(touch).await, Ok(()));
}
}