use crate::authn::ids::{DeviceId, TenantId, UserId};
use crate::session::data::{AuthState, SessionData};
use axess_identity::define_id;
use axess_rng::SecureRng;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{DateTime, Duration, Utc};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
define_id! {
pub RefreshTokenId
}
define_id! {
pub TokenFamilyId
}
#[derive(Debug, Clone)]
pub struct RefreshToken {
pub id: RefreshTokenId,
pub user_id: UserId,
pub tenant_id: TenantId,
pub token_hash: String,
pub issued_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub revoked: bool,
pub device_info: Option<String>,
pub family_id: Option<TokenFamilyId>,
pub device_id: Option<DeviceId>,
}
#[derive(Debug, Clone)]
pub struct RefreshTokenConfig {
pub ttl: Duration,
pub max_per_user: usize,
pub rotation: bool,
pub hash_pepper: Option<Vec<u8>>,
}
impl Default for RefreshTokenConfig {
fn default() -> Self {
Self {
ttl: Duration::days(30),
max_per_user: 10,
rotation: true,
hash_pepper: None,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum RefreshError<E: std::error::Error + Send + Sync + 'static> {
#[error("refresh token not found")]
NotFound,
#[error("refresh token expired")]
Expired,
#[error("refresh token revoked")]
Revoked,
#[error("device mismatch")]
DeviceMismatch,
#[error("account is not active")]
AccountInactive,
#[error("store error: {0}")]
Store(#[source] E),
}
pub trait RefreshTokenStore: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
fn store_token(
&self,
token: &RefreshToken,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn find_token(
&self,
token_hash: &str,
) -> impl std::future::Future<Output = Result<Option<RefreshToken>, Self::Error>> + Send;
fn revoke_token(
&self,
token_id: &RefreshTokenId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn revoke_user_tokens(
&self,
user_id: &UserId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn active_tokens(
&self,
user_id: &UserId,
) -> impl std::future::Future<Output = Result<Vec<RefreshToken>, Self::Error>> + Send;
fn revoke_family(
&self,
user_id: &UserId,
family_id: &TokenFamilyId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn issue_with_eviction(
&self,
evict_ids: &[RefreshTokenId],
new_token: &RefreshToken,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn rotate_token(
&self,
parent_id: &RefreshTokenId,
new_token: &RefreshToken,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn on_token_compromise(
&self,
user_id: &UserId,
family_id: &TokenFamilyId,
compromised_devices: &[(TenantId, DeviceId)],
) -> impl std::future::Future<Output = ()> + Send {
tracing::warn!(
%user_id,
%family_id,
device_count = compromised_devices.len(),
"refresh token compromise detected (rotated-out token was reused); \
override on_token_compromise + call cascade_revoke_devices to revoke linked devices"
);
async {}
}
}
fn hash_token(plaintext: &str, pepper: Option<&[u8]>) -> String {
match pepper {
Some(pep) if !pep.is_empty() => {
use hmac::Mac;
let mut mac = crate::hmac::new_signer(pep);
mac.update(plaintext.as_bytes());
URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes())
}
_ => {
let digest = Sha256::digest(plaintext.as_bytes());
URL_SAFE_NO_PAD.encode(digest)
}
}
}
fn generate_token_value(rng: &impl SecureRng) -> String {
let mut bytes = [0u8; 32];
rng.fill_bytes(&mut bytes);
URL_SAFE_NO_PAD.encode(bytes)
}
pub struct IssueRequest<'a> {
pub user_id: &'a UserId,
pub tenant_id: &'a TenantId,
pub device_info: Option<String>,
pub family_id: Option<TokenFamilyId>,
pub device_id: Option<DeviceId>,
}
pub async fn issue_refresh_token<S: RefreshTokenStore>(
req: IssueRequest<'_>,
config: &RefreshTokenConfig,
store: &S,
rng: &impl SecureRng,
now: DateTime<Utc>,
) -> Result<(String, RefreshToken), RefreshError<S::Error>> {
let active = store
.active_tokens(req.user_id)
.await
.map_err(RefreshError::Store)?;
let evict_ids: Vec<RefreshTokenId> = if active.len() >= config.max_per_user {
let to_revoke = active.len() - config.max_per_user + 1;
active.iter().take(to_revoke).map(|t| t.id).collect()
} else {
Vec::new()
};
let (plaintext, record) = build_refresh_token(req, config, rng, now);
store
.issue_with_eviction(&evict_ids, &record)
.await
.map_err(RefreshError::Store)?;
Ok((plaintext, record))
}
pub async fn refresh_session<S: RefreshTokenStore>(
plaintext: &str,
store: &S,
config: &RefreshTokenConfig,
rng: &impl SecureRng,
now: DateTime<Utc>,
device_info: Option<&str>,
) -> Result<(SessionData, Option<(String, RefreshToken)>), RefreshError<S::Error>> {
let token_hash = hash_token(plaintext, config.hash_pepper.as_deref());
let record = store
.find_token(&token_hash)
.await
.map_err(RefreshError::Store)?
.ok_or(RefreshError::NotFound)?;
if record.revoked {
if let Some(ref fid) = record.family_id {
tracing::warn!(
family_id = %fid,
user_id = %record.user_id,
"revoked refresh token reused; revoking entire family (compromise signal)"
);
let compromised_devices = collect_family_device_targets(store, fid, &record).await;
store
.revoke_family(&record.user_id, fid)
.await
.map_err(RefreshError::Store)?;
store
.on_token_compromise(&record.user_id, fid, &compromised_devices)
.await;
}
return Err(RefreshError::Revoked);
}
if now >= record.expires_at {
return Err(RefreshError::Expired);
}
if let Some(ref stored_device) = record.device_info {
let Some(current) = device_info else {
return Err(RefreshError::DeviceMismatch);
};
let stored_hash = Sha256::digest(stored_device.as_bytes());
let current_hash = Sha256::digest(current.as_bytes());
if !bool::from(stored_hash.ct_eq(¤t_hash)) {
return Err(RefreshError::DeviceMismatch);
}
}
let session = SessionData {
version: crate::session::data::SESSION_DATA_VERSION,
auth_state: AuthState::Authenticated {
user_id: record.user_id,
tenant_id: record.tenant_id,
authn_time: now,
factors_completed: Vec::new(),
},
fingerprint: None,
device_id: record.device_id,
custom: serde_json::Value::default(),
};
let new_token = if config.rotation {
let (new_plaintext, new_record) = build_refresh_token(
IssueRequest {
user_id: &record.user_id,
tenant_id: &record.tenant_id,
device_info: record.device_info.clone(),
family_id: record.family_id, device_id: record.device_id, },
config,
rng,
now,
);
store
.rotate_token(&record.id, &new_record)
.await
.map_err(RefreshError::Store)?;
Some((new_plaintext, new_record))
} else {
None
};
Ok((session, new_token))
}
pub async fn refresh_session_with_status_check<S, FStat, Fut>(
plaintext: &str,
store: &S,
config: &RefreshTokenConfig,
rng: &impl SecureRng,
now: DateTime<Utc>,
device_info: Option<&str>,
status_check: FStat,
) -> Result<(SessionData, Option<(String, RefreshToken)>), RefreshError<S::Error>>
where
S: RefreshTokenStore,
FStat: FnOnce(&UserId) -> Fut,
Fut: std::future::Future<Output = bool>,
{
let token_hash = hash_token(plaintext, config.hash_pepper.as_deref());
let record = store
.find_token(&token_hash)
.await
.map_err(RefreshError::Store)?
.ok_or(RefreshError::NotFound)?;
if record.revoked {
return refresh_session(plaintext, store, config, rng, now, device_info).await;
}
if now >= record.expires_at {
return Err(RefreshError::Expired);
}
if !status_check(&record.user_id).await {
tracing::warn!(
user_id = %record.user_id,
"refresh refused; caller-supplied status check returned false"
);
return Err(RefreshError::AccountInactive);
}
refresh_session(plaintext, store, config, rng, now, device_info).await
}
async fn collect_family_device_targets<S: RefreshTokenStore>(
store: &S,
family_id: &TokenFamilyId,
seen_record: &RefreshToken,
) -> Vec<(TenantId, DeviceId)> {
let mut targets: Vec<(TenantId, DeviceId)> = Vec::new();
if let Some(did) = seen_record.device_id.as_ref() {
targets.push((seen_record.tenant_id, *did));
}
match store.active_tokens(&seen_record.user_id).await {
Ok(active) => {
for token in active {
if token.family_id.as_ref() != Some(family_id) {
continue;
}
let Some(did) = token.device_id else { continue };
let pair = (token.tenant_id, did);
if !targets.contains(&pair) {
targets.push(pair);
}
}
}
Err(e) => {
tracing::warn!(
error = %e,
family_id = %family_id,
user_id = %seen_record.user_id,
"failed to enumerate active siblings for device cascade; \
family revocation will proceed with the seen-record device only"
);
}
}
targets
}
fn build_refresh_token(
req: IssueRequest<'_>,
config: &RefreshTokenConfig,
rng: &impl SecureRng,
now: DateTime<Utc>,
) -> (String, RefreshToken) {
let plaintext = generate_token_value(rng);
let token_hash = hash_token(&plaintext, config.hash_pepper.as_deref());
let mut id_bytes = [0u8; 16];
rng.fill_bytes(&mut id_bytes);
let id = RefreshTokenId::try_new(uuid::Uuid::from_bytes(id_bytes).to_string())
.expect("UUID is non-empty");
let effective_family = req.family_id.unwrap_or_else(|| {
TokenFamilyId::try_new(id.to_string()).expect("UUID-derived family id is non-empty")
});
let record = RefreshToken {
id,
user_id: *req.user_id,
tenant_id: *req.tenant_id,
token_hash,
issued_at: now,
expires_at: now + config.ttl,
revoked: false,
device_info: req.device_info,
family_id: Some(effective_family),
device_id: req.device_id,
};
(plaintext, record)
}
pub async fn revoke_refresh_token<S: RefreshTokenStore>(
plaintext: &str,
store: &S,
config: &RefreshTokenConfig,
) -> Result<(), RefreshError<S::Error>> {
let token_hash = hash_token(plaintext, config.hash_pepper.as_deref());
let record = store
.find_token(&token_hash)
.await
.map_err(RefreshError::Store)?
.ok_or(RefreshError::NotFound)?;
store
.revoke_token(&record.id)
.await
.map_err(RefreshError::Store)?;
Ok(())
}
#[cfg(test)]
mod atomic_rotation;
#[cfg(test)]
mod basics;
#[cfg(all(test, feature = "device"))]
mod device_cascade;
#[cfg(test)]
mod status_check;
#[cfg(test)]
mod test_support;
#[cfg(test)]
mod refresh_unit_tests {
use super::*;
use crate::testing::mock_random::MockRng;
#[test]
fn hash_token_no_pepper_is_sha256() {
let h1 = hash_token("test-token-value", None);
let h2 = hash_token("test-token-value", None);
assert_eq!(h1, h2, "hash must be deterministic");
assert!(!h1.is_empty());
}
#[test]
fn hash_token_with_pepper_differs_from_without() {
let without = hash_token("test-token", None);
let with = hash_token("test-token", Some(b"my-secret-pepper"));
assert_ne!(without, with, "peppered hash must differ from unpeppered");
}
#[test]
fn hash_token_empty_pepper_falls_back_to_sha256() {
let none_pepper = hash_token("token", None);
let empty_pepper = hash_token("token", Some(b""));
assert_eq!(
none_pepper, empty_pepper,
"empty pepper must fall back to SHA-256 path"
);
}
#[test]
fn hash_token_different_inputs_different_hashes() {
let h1 = hash_token("token-a", None);
let h2 = hash_token("token-b", None);
assert_ne!(h1, h2);
}
#[test]
fn hash_token_different_peppers_different_hashes() {
let h1 = hash_token("same-token", Some(b"pepper-1"));
let h2 = hash_token("same-token", Some(b"pepper-2"));
assert_ne!(h1, h2);
}
#[test]
fn generate_token_value_produces_nonempty_base64() {
let rng = MockRng::new(42);
let token = generate_token_value(&rng);
assert!(!token.is_empty());
assert_eq!(token.len(), 43, "32 bytes → 43 base64url chars");
}
#[test]
fn generate_token_value_is_deterministic() {
let t1 = generate_token_value(&MockRng::new(99));
let t2 = generate_token_value(&MockRng::new(99));
assert_eq!(t1, t2);
}
#[test]
fn generate_token_value_different_seeds_differ() {
let t1 = generate_token_value(&MockRng::new(1));
let t2 = generate_token_value(&MockRng::new(2));
assert_ne!(t1, t2);
}
#[test]
fn default_config_has_safe_defaults() {
let cfg = RefreshTokenConfig::default();
assert_eq!(cfg.ttl, Duration::days(30));
assert_eq!(cfg.max_per_user, 10);
assert!(cfg.rotation, "rotation must be on by default");
assert!(cfg.hash_pepper.is_none(), "no pepper by default");
}
#[test]
fn refresh_error_display_per_variant() {
use std::io;
let variants: Vec<(RefreshError<io::Error>, &str)> = vec![
(RefreshError::NotFound, "not found"),
(RefreshError::Expired, "expired"),
(RefreshError::Revoked, "revoked"),
(RefreshError::DeviceMismatch, "device mismatch"),
(RefreshError::AccountInactive, "not active"),
(RefreshError::Store(io::Error::other("boom")), "store error"),
];
for (err, expected_substr) in variants {
let msg = err.to_string();
assert!(
msg.contains(expected_substr),
"RefreshError display for {:?} must contain {expected_substr:?}, got {msg:?}",
std::mem::discriminant(&err)
);
}
}
}