use chrono::Utc;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use mockforge_core::Error;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RevokedToken {
pub token_id: String,
pub user_id: Option<String>,
pub revoked_at: i64,
pub reason: String,
pub expires_at: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct TokenRevocationStore {
revoked_tokens: Arc<RwLock<HashMap<String, RevokedToken>>>,
user_revoked_tokens: Arc<RwLock<HashMap<String, Vec<String>>>>,
}
impl TokenRevocationStore {
pub fn new() -> Self {
Self {
revoked_tokens: Arc::new(RwLock::new(HashMap::new())),
user_revoked_tokens: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn revoke_token(
&self,
token_id: String,
user_id: Option<String>,
reason: String,
expires_at: Option<i64>,
) {
let revoked = RevokedToken {
token_id: token_id.clone(),
user_id: user_id.clone(),
revoked_at: Utc::now().timestamp(),
reason,
expires_at,
};
let mut tokens = self.revoked_tokens.write().await;
tokens.insert(token_id.clone(), revoked);
if let Some(uid) = user_id {
let mut user_tokens = self.user_revoked_tokens.write().await;
user_tokens.entry(uid).or_insert_with(Vec::new).push(token_id);
}
}
pub async fn revoke_user_tokens(&self, user_id: String, reason: String) {
let user_tokens = self.user_revoked_tokens.write().await;
if let Some(token_ids) = user_tokens.get(&user_id) {
let mut tokens = self.revoked_tokens.write().await;
for token_id in token_ids {
if let Some(revoked) = tokens.get_mut(token_id) {
revoked.revoked_at = Utc::now().timestamp();
revoked.reason = reason.clone();
}
}
}
}
pub async fn is_revoked(&self, token_id: &str) -> Option<RevokedToken> {
let tokens = self.revoked_tokens.read().await;
tokens.get(token_id).cloned()
}
pub async fn get_revocation_status(&self, token_id: &str) -> Option<RevokedToken> {
self.is_revoked(token_id).await
}
pub async fn cleanup_expired(&self) {
let now = Utc::now().timestamp();
let mut tokens = self.revoked_tokens.write().await;
tokens.retain(|_, revoked| revoked.expires_at.is_none_or(|exp| exp > now));
}
}
impl Default for TokenRevocationStore {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct KeyRotationState {
active_keys: Arc<RwLock<HashMap<String, ActiveKey>>>,
grace_period_seconds: i64,
}
#[derive(Debug, Clone)]
pub struct ActiveKey {
pub kid: String,
pub created_at: i64,
pub inactive_at: Option<i64>,
pub is_primary: bool,
}
impl KeyRotationState {
pub fn new(grace_period_seconds: i64) -> Self {
Self {
active_keys: Arc::new(RwLock::new(HashMap::new())),
grace_period_seconds,
}
}
pub async fn add_key(&self, kid: String, is_primary: bool) {
let mut keys = self.active_keys.write().await;
keys.insert(
kid.clone(),
ActiveKey {
kid,
created_at: Utc::now().timestamp(),
inactive_at: None,
is_primary,
},
);
}
pub async fn rotate_key(&self, new_kid: String) -> Result<(), Error> {
let mut keys = self.active_keys.write().await;
for key in keys.values_mut() {
key.is_primary = false;
key.inactive_at = Some(Utc::now().timestamp() + self.grace_period_seconds);
}
keys.insert(
new_kid.clone(),
ActiveKey {
kid: new_kid,
created_at: Utc::now().timestamp(),
inactive_at: None,
is_primary: true,
},
);
Ok(())
}
pub async fn get_active_keys(&self) -> Vec<ActiveKey> {
let now = Utc::now().timestamp();
let keys = self.active_keys.read().await;
keys.values()
.filter(|key| key.inactive_at.is_none_or(|inactive_at| inactive_at > now))
.cloned()
.collect()
}
pub async fn get_primary_key(&self) -> Option<ActiveKey> {
let keys = self.active_keys.read().await;
keys.values().find(|key| key.is_primary).cloned()
}
pub async fn cleanup_old_keys(&self) {
let now = Utc::now().timestamp();
let mut keys = self.active_keys.write().await;
keys.retain(|_, key| key.inactive_at.is_none_or(|inactive_at| inactive_at > now));
}
}
#[derive(Debug, Clone)]
pub struct ClockSkewState {
skew_seconds: Arc<RwLock<i64>>,
apply_to_issuance: bool,
apply_to_validation: bool,
}
impl ClockSkewState {
pub fn new() -> Self {
Self {
skew_seconds: Arc::new(RwLock::new(0)),
apply_to_issuance: true,
apply_to_validation: true,
}
}
pub async fn set_skew(&self, skew_seconds: i64) {
let mut skew = self.skew_seconds.write().await;
*skew = skew_seconds;
}
pub async fn get_skew(&self) -> i64 {
let skew = self.skew_seconds.read().await;
*skew
}
pub async fn get_adjusted_time(&self) -> i64 {
let skew = self.skew_seconds.read().await;
Utc::now().timestamp() + *skew
}
pub async fn apply_issuance_skew(&self, timestamp: i64) -> i64 {
if self.apply_to_issuance {
let skew = self.skew_seconds.read().await;
timestamp + *skew
} else {
timestamp
}
}
pub async fn apply_validation_skew(&self, timestamp: i64) -> i64 {
if self.apply_to_validation {
let skew = self.skew_seconds.read().await;
timestamp - *skew
} else {
timestamp
}
}
}
impl Default for ClockSkewState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TokenLifecycleManager {
pub revocation: TokenRevocationStore,
pub key_rotation: KeyRotationState,
pub clock_skew: ClockSkewState,
}
impl TokenLifecycleManager {
pub fn new(grace_period_seconds: i64) -> Self {
Self {
revocation: TokenRevocationStore::new(),
key_rotation: KeyRotationState::new(grace_period_seconds),
clock_skew: ClockSkewState::new(),
}
}
}
impl Default for TokenLifecycleManager {
fn default() -> Self {
Self::new(3600) }
}
pub fn extract_token_id(token: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
format!("{:x}", hasher.finalize())
}