use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use uuid::Uuid;
use crate::error::{FusekiError, FusekiResult};
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum RotationError {
#[error("Refresh token not found")]
TokenNotFound,
#[error("Refresh token has expired")]
TokenExpired,
#[error("Refresh token has already been rotated (possible replay attack)")]
TokenAlreadyRotated,
#[error("Refresh token has been revoked")]
TokenRevoked,
}
impl From<RotationError> for FusekiError {
fn from(e: RotationError) -> Self {
FusekiError::authentication(e.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefreshToken {
pub token: String,
pub user_id: String,
pub family_id: Uuid,
pub issued_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub rotated: bool,
pub revoked: bool,
pub previous_token_hash: Option<String>,
pub generation: u32,
}
impl RefreshToken {
pub fn is_valid(&self) -> bool {
!self.rotated && !self.revoked && Utc::now() < self.expires_at
}
pub fn is_expired(&self) -> bool {
Utc::now() >= self.expires_at
}
}
fn generate_token_string() -> FusekiResult<String> {
use scirs2_core::random::SecureRandom;
let mut secure = SecureRandom::new();
let bytes = secure.random_bytes(32);
Ok(URL_SAFE_NO_PAD.encode(&bytes))
}
fn sha256_hex(input: &str) -> String {
let hash = Sha256::digest(input.as_bytes());
hex::encode(hash)
}
#[derive(Clone, Debug)]
pub struct RefreshTokenStore {
tokens: Arc<RwLock<HashMap<String, RefreshToken>>>,
family_index: Arc<RwLock<HashMap<Uuid, Vec<String>>>>,
user_index: Arc<RwLock<HashMap<String, Vec<String>>>>,
}
impl RefreshTokenStore {
pub fn new() -> Self {
Self {
tokens: Arc::new(RwLock::new(HashMap::new())),
family_index: Arc::new(RwLock::new(HashMap::new())),
user_index: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn issue(&self, user_id: &str, ttl: Duration) -> FusekiResult<RefreshToken> {
self.issue_in_family(user_id, ttl, Uuid::new_v4(), None, 0)
.await
}
async fn issue_in_family(
&self,
user_id: &str,
ttl: Duration,
family_id: Uuid,
previous_token_hash: Option<String>,
generation: u32,
) -> FusekiResult<RefreshToken> {
let token_str = generate_token_string()?;
let now = Utc::now();
let token = RefreshToken {
token: token_str.clone(),
user_id: user_id.to_string(),
family_id,
issued_at: now,
expires_at: now + ttl,
rotated: false,
revoked: false,
previous_token_hash,
generation,
};
{
let mut tokens = self.tokens.write().await;
tokens.insert(token_str.clone(), token.clone());
}
{
let mut family = self.family_index.write().await;
family.entry(family_id).or_default().push(token_str.clone());
}
{
let mut user_idx = self.user_index.write().await;
user_idx
.entry(user_id.to_string())
.or_default()
.push(token_str.clone());
}
debug!(
user_id = %user_id,
family_id = %family_id,
generation = generation,
expires_at = %token.expires_at,
"Issued refresh token",
);
Ok(token)
}
pub async fn rotate(
&self,
old_token_str: &str,
ttl: Duration,
) -> Result<RefreshToken, RotationError> {
let (family_id, user_id, generation, already_rotated, revoked, expired) = {
let tokens = self.tokens.read().await;
match tokens.get(old_token_str) {
None => return Err(RotationError::TokenNotFound),
Some(t) => (
t.family_id,
t.user_id.clone(),
t.generation,
t.rotated,
t.revoked,
t.is_expired(),
),
}
};
if revoked {
return Err(RotationError::TokenRevoked);
}
if expired {
return Err(RotationError::TokenExpired);
}
if already_rotated {
warn!(
family_id = %family_id,
user_id = %user_id,
"Replay attack detected: presented already-rotated refresh token; \
cascade-revoking all tokens in family",
);
self.cascade_revoke_family(family_id).await;
return Err(RotationError::TokenAlreadyRotated);
}
{
let mut tokens = self.tokens.write().await;
if let Some(t) = tokens.get_mut(old_token_str) {
t.rotated = true;
}
}
let prev_hash = sha256_hex(old_token_str);
let new_token = self
.issue_in_family(&user_id, ttl, family_id, Some(prev_hash), generation + 1)
.await
.map_err(|e| {
warn!("Failed to issue successor token during rotation: {e}");
RotationError::TokenRevoked
})?;
info!(
user_id = %user_id,
family_id = %family_id,
new_generation = new_token.generation,
"Refresh token rotated successfully",
);
Ok(new_token)
}
pub async fn validate(&self, token_str: &str) -> Option<RefreshToken> {
let tokens = self.tokens.read().await;
tokens
.get(token_str)
.and_then(|t| if t.is_valid() { Some(t.clone()) } else { None })
}
pub async fn revoke(&self, token_str: &str) {
let mut tokens = self.tokens.write().await;
if let Some(t) = tokens.get_mut(token_str) {
t.revoked = true;
debug!(
token_prefix = &token_str[..8.min(token_str.len())],
"Refresh token revoked"
);
}
}
pub async fn revoke_all_for_user(&self, user_id: &str) -> usize {
let token_strings: Vec<String> = {
let user_idx = self.user_index.read().await;
user_idx.get(user_id).cloned().unwrap_or_default()
};
let count = token_strings.len();
if count == 0 {
debug!(user_id = %user_id, "revoke_all_for_user: no tokens found");
return 0;
}
{
let mut tokens = self.tokens.write().await;
for token_str in &token_strings {
if let Some(t) = tokens.get_mut(token_str) {
t.revoked = true;
}
}
}
info!(user_id = %user_id, count = count, "Revoked all refresh tokens for user");
count
}
async fn cascade_revoke_family(&self, family_id: Uuid) {
let token_strings: Vec<String> = {
let family = self.family_index.read().await;
family.get(&family_id).cloned().unwrap_or_default()
};
let count = token_strings.len();
{
let mut tokens = self.tokens.write().await;
for token_str in &token_strings {
if let Some(t) = tokens.get_mut(token_str) {
t.revoked = true;
}
}
}
warn!(
family_id = %family_id,
count = count,
"Cascade-revoked all tokens in family due to replay attack",
);
}
pub async fn cleanup_expired(&self, grace_period: Duration) -> usize {
let cutoff = Utc::now() - grace_period;
let mut removed = 0usize;
let stale_tokens: Vec<String> = {
let tokens = self.tokens.read().await;
tokens
.iter()
.filter(|(_, t)| t.expires_at < cutoff)
.map(|(k, _)| k.clone())
.collect()
};
{
let mut tokens = self.tokens.write().await;
for key in &stale_tokens {
tokens.remove(key);
removed += 1;
}
}
if removed > 0 {
let stale_set: std::collections::HashSet<&str> =
stale_tokens.iter().map(String::as_str).collect();
{
let mut family = self.family_index.write().await;
for vec in family.values_mut() {
vec.retain(|t| !stale_set.contains(t.as_str()));
}
family.retain(|_, v| !v.is_empty());
}
{
let mut user_idx = self.user_index.write().await;
for vec in user_idx.values_mut() {
vec.retain(|t| !stale_set.contains(t.as_str()));
}
user_idx.retain(|_, v| !v.is_empty());
}
debug!(removed = removed, "Cleaned up expired refresh tokens");
}
removed
}
pub async fn len(&self) -> usize {
self.tokens.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.tokens.read().await.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
const DEFAULT_TTL: Duration = Duration::hours(1);
const SHORT_TTL: Duration = Duration::milliseconds(50);
#[tokio::test]
async fn test_issue_returns_valid_token() {
let store = RefreshTokenStore::new();
let token = store.issue("alice", DEFAULT_TTL).await.unwrap();
assert_eq!(token.user_id, "alice");
assert!(!token.token.is_empty());
assert_eq!(token.generation, 0);
assert!(token.previous_token_hash.is_none());
assert!(!token.rotated);
assert!(!token.revoked);
assert!(token.is_valid());
}
#[tokio::test]
async fn test_validate_active_token_succeeds() {
let store = RefreshTokenStore::new();
let issued = store.issue("bob", DEFAULT_TTL).await.unwrap();
let found = store.validate(&issued.token).await;
assert!(found.is_some());
assert_eq!(found.unwrap().user_id, "bob");
}
#[tokio::test]
async fn test_validate_nonexistent_token_returns_none() {
let store = RefreshTokenStore::new();
let found = store.validate("does-not-exist").await;
assert!(found.is_none());
}
#[tokio::test]
async fn test_issued_token_has_correct_expiry() {
let store = RefreshTokenStore::new();
let ttl = Duration::seconds(300);
let before = Utc::now();
let token = store.issue("carol", ttl).await.unwrap();
let after = Utc::now();
assert!(token.expires_at >= before + ttl);
assert!(token.expires_at <= after + ttl);
}
#[tokio::test]
async fn test_rotate_produces_new_valid_token() {
let store = RefreshTokenStore::new();
let old = store.issue("alice", DEFAULT_TTL).await.unwrap();
let new_token = store.rotate(&old.token, DEFAULT_TTL).await.unwrap();
assert_ne!(new_token.token, old.token);
assert_eq!(new_token.user_id, "alice");
assert_eq!(new_token.family_id, old.family_id);
assert_eq!(new_token.generation, 1);
assert!(new_token.is_valid());
}
#[tokio::test]
async fn test_rotate_invalidates_old_token() {
let store = RefreshTokenStore::new();
let old = store.issue("alice", DEFAULT_TTL).await.unwrap();
let _new_token = store.rotate(&old.token, DEFAULT_TTL).await.unwrap();
let found = store.validate(&old.token).await;
assert!(found.is_none());
}
#[tokio::test]
async fn test_rotate_sets_previous_token_hash() {
let store = RefreshTokenStore::new();
let old = store.issue("alice", DEFAULT_TTL).await.unwrap();
let expected_hash = sha256_hex(&old.token);
let new_token = store.rotate(&old.token, DEFAULT_TTL).await.unwrap();
assert_eq!(
new_token.previous_token_hash.as_deref(),
Some(expected_hash.as_str())
);
}
#[tokio::test]
async fn test_chained_rotation_increments_generation() {
let store = RefreshTokenStore::new();
let t0 = store.issue("alice", DEFAULT_TTL).await.unwrap();
let t1 = store.rotate(&t0.token, DEFAULT_TTL).await.unwrap();
let t2 = store.rotate(&t1.token, DEFAULT_TTL).await.unwrap();
let t3 = store.rotate(&t2.token, DEFAULT_TTL).await.unwrap();
assert_eq!(t3.generation, 3);
assert_eq!(t3.family_id, t0.family_id);
}
#[tokio::test]
async fn test_rotate_nonexistent_token_returns_not_found() {
let store = RefreshTokenStore::new();
let result = store.rotate("no-such-token", DEFAULT_TTL).await;
assert_eq!(result.unwrap_err(), RotationError::TokenNotFound);
}
#[tokio::test]
async fn test_replay_attack_detected_on_rotated_token() {
let store = RefreshTokenStore::new();
let old = store.issue("alice", DEFAULT_TTL).await.unwrap();
let _new_token = store.rotate(&old.token, DEFAULT_TTL).await.unwrap();
let result = store.rotate(&old.token, DEFAULT_TTL).await;
assert_eq!(result.unwrap_err(), RotationError::TokenAlreadyRotated);
}
#[tokio::test]
async fn test_replay_attack_cascade_revokes_all_family_tokens() {
let store = RefreshTokenStore::new();
let t0 = store.issue("alice", DEFAULT_TTL).await.unwrap();
let t1 = store.rotate(&t0.token, DEFAULT_TTL).await.unwrap();
let t2 = store.rotate(&t1.token, DEFAULT_TTL).await.unwrap();
let _ = store.rotate(&t0.token, DEFAULT_TTL).await;
assert!(store.validate(&t0.token).await.is_none());
assert!(store.validate(&t1.token).await.is_none());
assert!(store.validate(&t2.token).await.is_none());
}
#[tokio::test]
async fn test_replay_attack_does_not_affect_other_user_tokens() {
let store = RefreshTokenStore::new();
let alice_t0 = store.issue("alice", DEFAULT_TTL).await.unwrap();
let _alice_t1 = store.rotate(&alice_t0.token, DEFAULT_TTL).await.unwrap();
let bob_token = store.issue("bob", DEFAULT_TTL).await.unwrap();
let _ = store.rotate(&alice_t0.token, DEFAULT_TTL).await;
assert!(store.validate(&bob_token.token).await.is_some());
}
#[tokio::test]
async fn test_expired_token_fails_validate() {
let store = RefreshTokenStore::new();
let token = store.issue("alice", SHORT_TTL).await.unwrap();
sleep(std::time::Duration::from_millis(100)).await;
assert!(store.validate(&token.token).await.is_none());
}
#[tokio::test]
async fn test_expired_token_fails_rotation() {
let store = RefreshTokenStore::new();
let token = store.issue("alice", SHORT_TTL).await.unwrap();
sleep(std::time::Duration::from_millis(100)).await;
let result = store.rotate(&token.token, DEFAULT_TTL).await;
assert_eq!(result.unwrap_err(), RotationError::TokenExpired);
}
#[tokio::test]
async fn test_revoke_single_token() {
let store = RefreshTokenStore::new();
let token = store.issue("alice", DEFAULT_TTL).await.unwrap();
store.revoke(&token.token).await;
assert!(store.validate(&token.token).await.is_none());
}
#[tokio::test]
async fn test_revoked_token_cannot_be_rotated() {
let store = RefreshTokenStore::new();
let token = store.issue("alice", DEFAULT_TTL).await.unwrap();
store.revoke(&token.token).await;
let result = store.rotate(&token.token, DEFAULT_TTL).await;
assert_eq!(result.unwrap_err(), RotationError::TokenRevoked);
}
#[tokio::test]
async fn test_revoke_all_for_user_invalidates_all_sessions() {
let store = RefreshTokenStore::new();
let s1 = store.issue("alice", DEFAULT_TTL).await.unwrap();
let s2 = store.issue("alice", DEFAULT_TTL).await.unwrap();
let s3 = store.issue("alice", DEFAULT_TTL).await.unwrap();
let count = store.revoke_all_for_user("alice").await;
assert_eq!(count, 3);
assert!(store.validate(&s1.token).await.is_none());
assert!(store.validate(&s2.token).await.is_none());
assert!(store.validate(&s3.token).await.is_none());
}
#[tokio::test]
async fn test_revoke_all_for_user_does_not_affect_other_users() {
let store = RefreshTokenStore::new();
let _alice_token = store.issue("alice", DEFAULT_TTL).await.unwrap();
let bob_token = store.issue("bob", DEFAULT_TTL).await.unwrap();
store.revoke_all_for_user("alice").await;
assert!(store.validate(&bob_token.token).await.is_some());
}
#[tokio::test]
async fn test_revoke_all_for_nonexistent_user_returns_zero() {
let store = RefreshTokenStore::new();
let count = store.revoke_all_for_user("nobody").await;
assert_eq!(count, 0);
}
#[tokio::test]
async fn test_concurrent_rotation_only_one_succeeds() {
use std::sync::atomic::{AtomicUsize, Ordering};
let store = Arc::new(RefreshTokenStore::new());
let token = store.issue("alice", DEFAULT_TTL).await.unwrap();
let successes = Arc::new(AtomicUsize::new(0));
let replay_errors = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let store_clone = Arc::clone(&store);
let token_str = token.token.clone();
let successes_clone = Arc::clone(&successes);
let replay_clone = Arc::clone(&replay_errors);
handles.push(tokio::spawn(async move {
match store_clone.rotate(&token_str, DEFAULT_TTL).await {
Ok(_) => {
successes_clone.fetch_add(1, Ordering::Relaxed);
}
Err(RotationError::TokenAlreadyRotated) => {
replay_clone.fetch_add(1, Ordering::Relaxed);
}
Err(_) => {}
}
}));
}
for h in handles {
h.await.unwrap();
}
let total_successes = successes.load(Ordering::Relaxed);
assert_eq!(
total_successes, 1,
"Expected exactly one successful rotation, got {total_successes}"
);
}
#[tokio::test]
async fn test_concurrent_issue_multiple_users() {
let store = Arc::new(RefreshTokenStore::new());
let mut handles = vec![];
for i in 0..20 {
let store_clone = Arc::clone(&store);
handles.push(tokio::spawn(async move {
let user = format!("user_{i}");
store_clone.issue(&user, DEFAULT_TTL).await.unwrap()
}));
}
let mut tokens = vec![];
for h in handles {
tokens.push(h.await.unwrap());
}
assert_eq!(store.len().await, 20);
let token_set: std::collections::HashSet<&str> =
tokens.iter().map(|t| t.token.as_str()).collect();
assert_eq!(token_set.len(), 20);
}
#[tokio::test]
async fn test_cleanup_removes_expired_tokens() {
let store = RefreshTokenStore::new();
store.issue("alice", SHORT_TTL).await.unwrap();
store.issue("alice", DEFAULT_TTL).await.unwrap();
sleep(std::time::Duration::from_millis(100)).await;
let removed = store.cleanup_expired(Duration::zero()).await;
assert_eq!(
removed, 1,
"Expected exactly one expired token to be removed"
);
assert_eq!(store.len().await, 1);
}
#[tokio::test]
async fn test_cleanup_with_grace_period_keeps_recently_expired() {
let store = RefreshTokenStore::new();
store.issue("alice", SHORT_TTL).await.unwrap();
sleep(std::time::Duration::from_millis(100)).await;
let removed = store.cleanup_expired(Duration::hours(1)).await;
assert_eq!(
removed, 0,
"Grace period should protect recently expired tokens"
);
assert_eq!(store.len().await, 1);
}
}