use crate::auth::jwt_issuer::{JwtIssuer, RefreshTokenClaims, TokenPair, TokenSubject, TokenType};
use crate::auth::storage::RefreshTokenStore;
use crate::error::{Result, TidewayError};
use async_trait::async_trait;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
#[async_trait]
pub trait UserLoader: Send + Sync {
type User: Send + Sync;
async fn load_user(&self, user_id: &str) -> Result<Option<Self::User>>;
fn user_email(&self, user: &Self::User) -> Option<String>;
fn user_name(&self, user: &Self::User) -> Option<String>;
async fn custom_claims(&self, _user: &Self::User) -> Option<serde_json::Value> {
None
}
}
pub struct TokenRefreshFlow<S, L>
where
S: RefreshTokenStore,
L: UserLoader,
{
issuer: JwtIssuer,
store: S,
user_loader: L,
decoding_key: DecodingKey,
validation: Validation,
}
impl<S, L> TokenRefreshFlow<S, L>
where
S: RefreshTokenStore,
L: UserLoader,
{
pub fn new(issuer: JwtIssuer, store: S, user_loader: L, secret: &[u8]) -> Self {
let mut validation = Validation::new(issuer.algorithm());
validation.set_issuer(&[issuer.issuer()]);
if let Some(aud) = issuer.audience() {
validation.set_audience(&[aud]);
}
Self {
decoding_key: DecodingKey::from_secret(secret),
issuer,
store,
user_loader,
validation,
}
}
pub fn with_rsa_public_key(
issuer: JwtIssuer,
store: S,
user_loader: L,
public_key_pem: &[u8],
) -> Result<Self> {
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[issuer.issuer()]);
if let Some(aud) = issuer.audience() {
validation.set_audience(&[aud]);
}
let decoding_key = DecodingKey::from_rsa_pem(public_key_pem)
.map_err(|e| TidewayError::Internal(format!("Invalid RSA public key: {}", e)))?;
Ok(Self {
issuer,
store,
user_loader,
decoding_key,
validation,
})
}
pub async fn refresh(&self, refresh_token: &str) -> Result<TokenPair> {
let claims =
decode::<RefreshTokenClaims>(refresh_token, &self.decoding_key, &self.validation)
.map_err(|e| {
tracing::warn!(
target: "auth.token.invalid",
error = %e,
"Invalid refresh token presented"
);
TidewayError::Unauthorized(format!("Invalid refresh token: {}", e))
})?
.claims;
let user_id = &claims.standard.sub;
let family = &claims.family;
if claims.token_type != TokenType::Refresh {
tracing::warn!(
target: "auth.token.invalid",
user_id = %user_id,
token_type = ?claims.token_type,
"Wrong token type used for refresh"
);
return Err(TidewayError::Unauthorized("Invalid token type".into()));
}
if self.store.is_family_revoked(family).await? {
tracing::warn!(
target: "auth.token.invalid",
user_id = %user_id,
family = %family,
"Attempted use of revoked token family"
);
return Err(TidewayError::Unauthorized("Token has been revoked".into()));
}
let user = match self.user_loader.load_user(user_id).await? {
Some(u) => u,
None => {
tracing::warn!(
target: "auth.token.invalid",
user_id = %user_id,
family = %family,
"Token refresh failed: user not found or disabled"
);
return Err(TidewayError::Unauthorized("User not found".into()));
}
};
let new_generation = claims
.generation
.checked_add(1)
.ok_or_else(|| TidewayError::Unauthorized("Invalid refresh token generation".into()))?;
let advanced = self
.store
.compare_and_swap_family_generation(family, claims.generation, new_generation)
.await?;
if !advanced {
if self.store.is_family_revoked(family).await? {
tracing::warn!(
target: "auth.token.invalid",
user_id = %user_id,
family = %family,
"Attempted use of revoked token family"
);
return Err(TidewayError::Unauthorized("Token has been revoked".into()));
}
let stored_gen = self.store.get_family_generation(family).await?.unwrap_or(0);
if claims.generation < stored_gen {
tracing::error!(
target: "auth.token.reuse_detected",
user_id = %user_id,
family = %family,
presented_generation = claims.generation,
expected_generation = stored_gen,
"SECURITY: Refresh token reuse detected - possible token theft"
);
self.store.revoke_family(family).await?;
return Err(TidewayError::Unauthorized("Token reuse detected".into()));
}
tracing::warn!(
target: "auth.token.invalid",
user_id = %user_id,
family = %family,
presented_generation = claims.generation,
stored_generation = stored_gen,
"Refresh token generation mismatch"
);
return Err(TidewayError::Unauthorized(
"Refresh token already used".into(),
));
}
let email = self.user_loader.user_email(&user);
let name = self.user_loader.user_name(&user);
let custom_claims = self.user_loader.custom_claims(&user).await;
let (access_token, expires_in) = {
let mut subject = TokenSubject::new(user_id);
if let Some(ref e) = email {
subject = subject.with_email(e);
}
if let Some(ref n) = name {
subject = subject.with_name(n);
}
if let Some(custom) = custom_claims {
self.issuer
.issue_access_token(subject.with_custom(custom))?
} else {
self.issuer.issue_access_token(subject)?
}
};
let family = claims.family.clone();
let new_refresh_token = self.issuer.rotate_refresh_token(&claims)?;
tracing::info!(
target: "auth.token.refresh",
user_id = %user_id,
family = %family,
generation = new_generation,
"Token refreshed successfully"
);
Ok(TokenPair {
access_token,
refresh_token: new_refresh_token,
expires_in,
token_type: "Bearer",
family,
})
}
pub async fn revoke(&self, refresh_token: &str) -> Result<()> {
let claims =
decode::<RefreshTokenClaims>(refresh_token, &self.decoding_key, &self.validation)
.map_err(|e| TidewayError::Unauthorized(format!("Invalid refresh token: {}", e)))?
.claims;
self.store.revoke_family(&claims.family).await?;
tracing::info!(
target: "auth.token.revoked",
user_id = %claims.standard.sub,
family = %claims.family,
"Token family revoked (logout)"
);
Ok(())
}
pub async fn revoke_all(&self, user_id: &str) -> Result<()> {
self.store.revoke_all_for_user(user_id).await?;
tracing::warn!(
target: "auth.token.revoke_all",
user_id = %user_id,
"All tokens revoked for user (security event)"
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::jwt_issuer::JwtIssuerConfig;
use crate::auth::storage::token::test::InMemoryRefreshTokenStore;
struct TestUserLoader;
struct TestUser {
email: String,
name: String,
}
#[async_trait]
impl UserLoader for TestUserLoader {
type User = TestUser;
async fn load_user(&self, _user_id: &str) -> Result<Option<Self::User>> {
Ok(Some(TestUser {
email: "test@example.com".to_string(),
name: "Test User".to_string(),
}))
}
fn user_email(&self, user: &Self::User) -> Option<String> {
Some(user.email.clone())
}
fn user_name(&self, user: &Self::User) -> Option<String> {
Some(user.name.clone())
}
}
#[tokio::test]
async fn test_refresh_flow() {
let secret = b"test-secret-key-32-bytes-long!!";
let issuer = JwtIssuer::new(JwtIssuerConfig::with_secret(
String::from_utf8_lossy(secret).to_string(),
"test-app",
))
.unwrap();
let store = InMemoryRefreshTokenStore::new();
let user_loader = TestUserLoader;
let flow = TokenRefreshFlow::new(issuer.clone(), store, user_loader, secret);
let subject = TokenSubject::new("user-123");
let initial = issuer.issue(subject, false).unwrap();
let refreshed = flow.refresh(&initial.refresh_token).await.unwrap();
assert!(!refreshed.access_token.is_empty());
assert!(!refreshed.refresh_token.is_empty());
assert_ne!(refreshed.refresh_token, initial.refresh_token);
}
#[tokio::test]
async fn test_reuse_detection() {
let secret = b"test-secret-key-32-bytes-long!!";
let issuer = JwtIssuer::new(JwtIssuerConfig::with_secret(
String::from_utf8_lossy(secret).to_string(),
"test-app",
))
.unwrap();
let store = InMemoryRefreshTokenStore::new();
let user_loader = TestUserLoader;
let flow = TokenRefreshFlow::new(issuer.clone(), store, user_loader, secret);
let subject = TokenSubject::new("user-123");
let initial = issuer.issue(subject, false).unwrap();
let _refreshed = flow.refresh(&initial.refresh_token).await.unwrap();
let result = flow.refresh(&initial.refresh_token).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_revoke_token() {
let secret = b"test-secret-key-32-bytes-long!!";
let issuer = JwtIssuer::new(JwtIssuerConfig::with_secret(
String::from_utf8_lossy(secret).to_string(),
"test-app",
))
.unwrap();
let store = InMemoryRefreshTokenStore::new();
let user_loader = TestUserLoader;
let flow = TokenRefreshFlow::new(issuer.clone(), store, user_loader, secret);
let subject = TokenSubject::new("user-123");
let initial = issuer.issue(subject, false).unwrap();
let refreshed = flow.refresh(&initial.refresh_token).await.unwrap();
flow.revoke(&refreshed.refresh_token).await.unwrap();
let result = flow.refresh(&refreshed.refresh_token).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("revoked"));
}
#[tokio::test]
async fn test_revoke_all_for_user() {
let secret = b"test-secret-key-32-bytes-long!!";
let issuer = JwtIssuer::new(JwtIssuerConfig::with_secret(
String::from_utf8_lossy(secret).to_string(),
"test-app",
))
.unwrap();
let store = InMemoryRefreshTokenStore::new();
let user_loader = TestUserLoader;
let subject = TokenSubject::new("user-123");
let initial = issuer.issue(subject, false).unwrap();
store
.associate_family_with_user(&initial.family, "user-123")
.await
.unwrap();
let flow = TokenRefreshFlow::new(issuer.clone(), store, user_loader, secret);
let refreshed = flow.refresh(&initial.refresh_token).await.unwrap();
flow.revoke_all("user-123").await.unwrap();
let result = flow.refresh(&refreshed.refresh_token).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_invalid_token_rejected() {
let secret = b"test-secret-key-32-bytes-long!!";
let issuer = JwtIssuer::new(JwtIssuerConfig::with_secret(
String::from_utf8_lossy(secret).to_string(),
"test-app",
))
.unwrap();
let store = InMemoryRefreshTokenStore::new();
let user_loader = TestUserLoader;
let flow = TokenRefreshFlow::new(issuer, store, user_loader, secret);
let result = flow.refresh("not-a-valid-token").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_access_token_rejected_for_refresh() {
let secret = b"test-secret-key-32-bytes-long!!";
let issuer = JwtIssuer::new(JwtIssuerConfig::with_secret(
String::from_utf8_lossy(secret).to_string(),
"test-app",
))
.unwrap();
let store = InMemoryRefreshTokenStore::new();
let user_loader = TestUserLoader;
let flow = TokenRefreshFlow::new(issuer.clone(), store, user_loader, secret);
let subject = TokenSubject::new("user-123");
let tokens = issuer.issue(subject, false).unwrap();
let result = flow.refresh(&tokens.access_token).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_wrong_secret_rejected() {
let secret = b"test-secret-key-32-bytes-long!!";
let wrong_secret = b"wrong-secret-key-32-bytes-long!";
let issuer = JwtIssuer::new(JwtIssuerConfig::with_secret(
String::from_utf8_lossy(secret).to_string(),
"test-app",
))
.unwrap();
let store = InMemoryRefreshTokenStore::new();
let user_loader = TestUserLoader;
let flow = TokenRefreshFlow::new(issuer.clone(), store, user_loader, wrong_secret);
let subject = TokenSubject::new("user-123");
let tokens = issuer.issue(subject, false).unwrap();
let result = flow.refresh(&tokens.refresh_token).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_chained_refresh() {
let secret = b"test-secret-key-32-bytes-long!!";
let issuer = JwtIssuer::new(JwtIssuerConfig::with_secret(
String::from_utf8_lossy(secret).to_string(),
"test-app",
))
.unwrap();
let store = InMemoryRefreshTokenStore::new();
let user_loader = TestUserLoader;
let flow = TokenRefreshFlow::new(issuer.clone(), store, user_loader, secret);
let subject = TokenSubject::new("user-123");
let initial = issuer.issue(subject, false).unwrap();
let refresh1 = flow.refresh(&initial.refresh_token).await.unwrap();
let refresh2 = flow.refresh(&refresh1.refresh_token).await.unwrap();
let refresh3 = flow.refresh(&refresh2.refresh_token).await.unwrap();
assert_ne!(initial.refresh_token, refresh1.refresh_token);
assert_ne!(refresh1.refresh_token, refresh2.refresh_token);
assert_ne!(refresh2.refresh_token, refresh3.refresh_token);
assert_eq!(initial.family, refresh1.family);
assert_eq!(refresh1.family, refresh2.family);
assert_eq!(refresh2.family, refresh3.family);
}
#[tokio::test]
async fn test_concurrent_refresh_same_token_only_one_succeeds() {
let secret = b"test-secret-key-32-bytes-long!!";
let issuer = JwtIssuer::new(JwtIssuerConfig::with_secret(
String::from_utf8_lossy(secret).to_string(),
"test-app",
))
.unwrap();
let store = InMemoryRefreshTokenStore::new();
let user_loader = TestUserLoader;
let flow = TokenRefreshFlow::new(issuer.clone(), store, user_loader, secret);
let subject = TokenSubject::new("user-123");
let initial = issuer.issue(subject, false).unwrap();
let (r1, r2) = tokio::join!(
flow.refresh(&initial.refresh_token),
flow.refresh(&initial.refresh_token)
);
let success_count = [r1.is_ok(), r2.is_ok()]
.into_iter()
.filter(|ok| *ok)
.count();
assert_eq!(
success_count, 1,
"exactly one concurrent refresh should succeed"
);
}
}