Skip to main content

hyperstack_auth/
revocation.rs

1//! Token revocation support
2//!
3//! Provides functionality to revoke tokens before their natural expiry.
4//! Revoked tokens are tracked by their JWT ID (jti) claim.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime, UNIX_EPOCH};
9use tokio::sync::RwLock;
10
11/// A revoked token entry with expiration tracking
12#[derive(Debug, Clone)]
13struct RevokedEntry {
14    expires_at: u64,
15}
16
17/// Token revocation list with automatic cleanup
18#[derive(Clone)]
19pub struct TokenRevocationList {
20    /// Revoked JWT IDs keyed by JTI with expiry tracking
21    revoked: Arc<RwLock<HashMap<String, RevokedEntry>>>,
22    /// Fallback retention window when the token expiry is unavailable
23    max_age: Duration,
24}
25
26impl TokenRevocationList {
27    /// Create a new empty revocation list
28    pub fn new() -> Self {
29        Self {
30            revoked: Arc::new(RwLock::new(HashMap::new())),
31            max_age: Duration::from_secs(86400), // 24 hours default
32        }
33    }
34
35    /// Set the maximum age of revocation entries
36    pub fn with_max_age(mut self, max_age: Duration) -> Self {
37        self.max_age = max_age;
38        self
39    }
40
41    /// Revoke a token by its JTI using `max_age` as a fallback expiry.
42    pub async fn revoke(&self, jti: impl Into<String>) {
43        let expires_at = current_unix_timestamp().saturating_add(self.max_age.as_secs());
44        self.revoke_until(jti, expires_at).await;
45    }
46
47    /// Revoke a token by its JTI until the token naturally expires.
48    pub async fn revoke_until(&self, jti: impl Into<String>, expires_at: u64) {
49        let mut revoked = self.revoked.write().await;
50        revoked.insert(jti.into(), RevokedEntry { expires_at });
51    }
52
53    /// Check if a token is revoked
54    pub async fn is_revoked(&self, jti: &str) -> bool {
55        let revoked = self.revoked.read().await;
56        revoked.contains_key(jti)
57    }
58
59    /// Remove a token from the revocation list
60    pub async fn unrevoke(&self, jti: &str) {
61        let mut revoked = self.revoked.write().await;
62        revoked.remove(jti);
63    }
64
65    /// Get the number of revoked tokens
66    pub async fn len(&self) -> usize {
67        let revoked = self.revoked.read().await;
68        revoked.len()
69    }
70
71    /// Check if the revocation list is empty
72    pub async fn is_empty(&self) -> bool {
73        let revoked = self.revoked.read().await;
74        revoked.is_empty()
75    }
76
77    /// Clear all revoked tokens
78    pub async fn clear(&self) {
79        let mut revoked = self.revoked.write().await;
80        revoked.clear();
81    }
82
83    /// Clean up old revocation entries (should be called periodically)
84    pub async fn cleanup_expired(&self, now: u64) -> usize {
85        let mut revoked = self.revoked.write().await;
86        let before = revoked.len();
87        revoked.retain(|_, entry| entry.expires_at > now);
88        before - revoked.len()
89    }
90}
91
92fn current_unix_timestamp() -> u64 {
93    SystemTime::now()
94        .duration_since(UNIX_EPOCH)
95        .expect("time should not be before epoch")
96        .as_secs()
97}
98
99impl Default for TokenRevocationList {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105/// Revocation checker trait for integration with verifiers
106#[async_trait::async_trait]
107pub trait RevocationChecker: Send + Sync {
108    /// Check if a token is revoked
109    async fn is_revoked(&self, jti: &str) -> bool;
110}
111
112#[async_trait::async_trait]
113impl RevocationChecker for TokenRevocationList {
114    async fn is_revoked(&self, jti: &str) -> bool {
115        self.is_revoked(jti).await
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use std::time::Duration;
123
124    #[tokio::test]
125    async fn test_revoke_and_check() {
126        let list = TokenRevocationList::new();
127
128        assert!(!list.is_revoked("token-1").await);
129
130        list.revoke("token-1").await;
131        assert!(list.is_revoked("token-1").await);
132
133        list.unrevoke("token-1").await;
134        assert!(!list.is_revoked("token-1").await);
135    }
136
137    #[tokio::test]
138    async fn test_multiple_tokens() {
139        let list = TokenRevocationList::new();
140
141        list.revoke("token-1").await;
142        list.revoke("token-2").await;
143
144        assert!(list.is_revoked("token-1").await);
145        assert!(list.is_revoked("token-2").await);
146        assert!(!list.is_revoked("token-3").await);
147
148        assert_eq!(list.len().await, 2);
149    }
150
151    #[tokio::test]
152    async fn test_clear() {
153        let list = TokenRevocationList::new();
154
155        list.revoke("token-1").await;
156        list.revoke("token-2").await;
157
158        list.clear().await;
159
160        assert!(list.is_empty().await);
161        assert!(!list.is_revoked("token-1").await);
162    }
163
164    #[tokio::test]
165    async fn test_cleanup_expired_removes_expired_entries() {
166        let list = TokenRevocationList::new().with_max_age(Duration::from_secs(60));
167
168        list.revoke_until("expired-token", 100).await;
169        list.revoke_until("active-token", 200).await;
170
171        let removed = list.cleanup_expired(150).await;
172
173        assert_eq!(removed, 1);
174        assert!(!list.is_revoked("expired-token").await);
175        assert!(list.is_revoked("active-token").await);
176    }
177}