Skip to main content

fraiseql_server/
token_revocation.rs

1//! Token revocation — reject JWTs whose `jti` claim has been revoked.
2//!
3//! After JWT signature verification succeeds, the server checks the token's
4//! `jti` (JWT ID) claim against a revocation store.  If the `jti` is present,
5//! the token is rejected with 401.
6//!
7//! Two production backends: Redis (recommended) and PostgreSQL (fallback).
8//! An in-memory backend is provided for testing and single-instance dev.
9//!
10//! Revoked JTIs expire automatically when the JWT's `exp` claim passes, keeping
11//! the store bounded.
12
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use chrono::{DateTime, Utc};
17use dashmap::DashMap;
18use serde::Deserialize;
19use tracing::{debug, info, warn};
20
21// ───────────────────────────────────────────────────────────────
22// Configuration
23// ───────────────────────────────────────────────────────────────
24
25/// Token revocation configuration embedded in the compiled schema.
26#[derive(Debug, Clone, Deserialize)]
27pub struct TokenRevocationConfig {
28    /// Whether token revocation is enabled.
29    #[serde(default)]
30    pub enabled: bool,
31
32    /// Storage backend: `"redis"` or `"postgres"` or `"memory"`.
33    #[serde(default = "default_backend")]
34    pub backend: String,
35
36    /// Reject JWTs that lack a `jti` claim when revocation is enabled.
37    #[serde(default = "default_true")]
38    pub require_jti: bool,
39
40    /// If the revocation store is unreachable:
41    /// - `false` (default): reject the request (fail-closed)
42    /// - `true`: allow the request (fail-open)
43    #[serde(default)]
44    pub fail_open: bool,
45
46    /// Redis URL (inherited from `[fraiseql.redis]` if not set here).
47    pub redis_url: Option<String>,
48}
49
50fn default_backend() -> String {
51    "memory".into()
52}
53const fn default_true() -> bool {
54    true
55}
56
57// ───────────────────────────────────────────────────────────────
58// Trait
59// ───────────────────────────────────────────────────────────────
60
61/// Revocation store abstraction.
62// Reason: used as dyn Trait (Arc<dyn RevocationStore>); async_trait ensures Send bounds and
63// dyn-compatibility async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
64#[async_trait]
65pub trait RevocationStore: Send + Sync {
66    /// Check if a JTI has been revoked.
67    async fn is_revoked(&self, jti: &str) -> Result<bool, RevocationError>;
68
69    /// Revoke a single JTI.  `ttl_secs` is the remaining JWT lifetime —
70    /// the store should auto-expire the entry after this duration.
71    async fn revoke(&self, jti: &str, ttl_secs: u64) -> Result<(), RevocationError>;
72
73    /// Revoke all tokens for a user (by `sub` claim).
74    /// Returns the number of tokens revoked.
75    async fn revoke_all_for_user(&self, sub: &str) -> Result<u64, RevocationError>;
76}
77
78/// Revocation store error.
79#[derive(Debug, thiserror::Error)]
80#[non_exhaustive]
81pub enum RevocationError {
82    /// Backend is unreachable or returned an error.
83    #[error("revocation store error: {0}")]
84    Backend(String),
85}
86
87// ───────────────────────────────────────────────────────────────
88// In-memory backend
89// ───────────────────────────────────────────────────────────────
90
91/// In-memory revocation store for testing and single-instance dev.
92pub struct InMemoryRevocationStore {
93    /// Map of JTI → (sub, `expires_at`).
94    entries: DashMap<String, (String, DateTime<Utc>)>,
95}
96
97impl InMemoryRevocationStore {
98    /// Create a new, empty in-memory revocation store.
99    #[must_use]
100    pub fn new() -> Self {
101        Self {
102            entries: DashMap::new(),
103        }
104    }
105
106    /// Remove expired entries.
107    pub fn cleanup_expired(&self) {
108        let now = Utc::now();
109        self.entries.retain(|_, (_, exp)| *exp > now);
110    }
111}
112
113impl Default for InMemoryRevocationStore {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119// Reason: RevocationStore is defined with #[async_trait]; all implementations must match
120// its transformed method signatures to satisfy the trait contract
121// async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
122#[async_trait]
123impl RevocationStore for InMemoryRevocationStore {
124    async fn is_revoked(&self, jti: &str) -> Result<bool, RevocationError> {
125        if let Some(entry) = self.entries.get(jti) {
126            let (_, expires_at) = entry.value();
127            if *expires_at > Utc::now() {
128                return Ok(true);
129            }
130            // Expired — remove lazily.
131            drop(entry);
132            self.entries.remove(jti);
133        }
134        Ok(false)
135    }
136
137    async fn revoke(&self, jti: &str, ttl_secs: u64) -> Result<(), RevocationError> {
138        let expires_at = Utc::now() + chrono::Duration::seconds(ttl_secs.cast_signed());
139        // We store an empty sub — single-JTI revocation doesn't need sub.
140        self.entries.insert(jti.to_string(), (String::new(), expires_at));
141        Ok(())
142    }
143
144    async fn revoke_all_for_user(&self, sub: &str) -> Result<u64, RevocationError> {
145        // For in-memory, we can't revoke unknown future tokens.
146        // We count and mark all entries matching this sub.
147        let mut count = 0u64;
148        for entry in &self.entries {
149            let (s, _) = entry.value();
150            if s == sub {
151                count += 1;
152            }
153        }
154        // In practice, revoke_all_for_user requires a list of known JTIs for the user.
155        // The in-memory store doesn't track sub → JTI mappings beyond what's already stored.
156        Ok(count)
157    }
158}
159
160// ───────────────────────────────────────────────────────────────
161// Redis backend (optional)
162// ───────────────────────────────────────────────────────────────
163
164/// Redis-backed JWT revocation store.
165///
166/// Stores revoked JTI claims in Redis with automatic TTL-based expiry.
167/// Requires the `redis-rate-limiting` feature.
168#[cfg(feature = "redis-rate-limiting")]
169pub struct RedisRevocationStore {
170    client:     redis::Client,
171    key_prefix: String,
172}
173
174#[cfg(feature = "redis-rate-limiting")]
175impl RedisRevocationStore {
176    /// Create a new Redis-backed revocation store.
177    ///
178    /// # Errors
179    ///
180    /// Returns error if the Redis URL is invalid.
181    pub fn new(redis_url: &str) -> Result<Self, RevocationError> {
182        let client = redis::Client::open(redis_url)
183            .map_err(|e| RevocationError::Backend(format!("Redis connection error: {e}")))?;
184        Ok(Self {
185            client,
186            key_prefix: "fraiseql:revoked:".into(),
187        })
188    }
189}
190
191#[cfg(feature = "redis-rate-limiting")]
192// Reason: RevocationStore is defined with #[async_trait]; all implementations must match
193// its transformed method signatures to satisfy the trait contract
194// async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
195#[async_trait]
196impl RevocationStore for RedisRevocationStore {
197    async fn is_revoked(&self, jti: &str) -> Result<bool, RevocationError> {
198        use redis::AsyncCommands;
199        let mut conn = self
200            .client
201            .get_multiplexed_async_connection()
202            .await
203            .map_err(|e| RevocationError::Backend(format!("Redis: {e}")))?;
204        let key = format!("{}{jti}", self.key_prefix);
205        let exists: bool = conn
206            .exists(&key)
207            .await
208            .map_err(|e| RevocationError::Backend(format!("Redis EXISTS: {e}")))?;
209        Ok(exists)
210    }
211
212    async fn revoke(&self, jti: &str, ttl_secs: u64) -> Result<(), RevocationError> {
213        use redis::AsyncCommands;
214        let mut conn = self
215            .client
216            .get_multiplexed_async_connection()
217            .await
218            .map_err(|e| RevocationError::Backend(format!("Redis: {e}")))?;
219        let key = format!("{}{jti}", self.key_prefix);
220        let _: () = conn
221            .set_ex(&key, "1", ttl_secs)
222            .await
223            .map_err(|e| RevocationError::Backend(format!("Redis SET EX: {e}")))?;
224        Ok(())
225    }
226
227    async fn revoke_all_for_user(&self, sub: &str) -> Result<u64, RevocationError> {
228        let mut conn = self
229            .client
230            .get_multiplexed_async_connection()
231            .await
232            .map_err(|e| RevocationError::Backend(format!("Redis: {e}")))?;
233        // Scan for keys matching the user pattern.
234        // User-keyed entries use prefix: fraiseql:revoked:user:{sub}:*
235        let pattern = format!("{}user:{sub}:*", self.key_prefix);
236        let keys: Vec<String> = redis::cmd("KEYS")
237            .arg(&pattern)
238            .query_async(&mut conn)
239            .await
240            .map_err(|e| RevocationError::Backend(format!("Redis KEYS: {e}")))?;
241        let count = keys.len() as u64;
242        if !keys.is_empty() {
243            let _: () = redis::cmd("DEL")
244                .arg(&keys)
245                .query_async(&mut conn)
246                .await
247                .map_err(|e| RevocationError::Backend(format!("Redis DEL: {e}")))?;
248        }
249        Ok(count)
250    }
251}
252
253// ───────────────────────────────────────────────────────────────
254// Token Revocation Manager
255// ───────────────────────────────────────────────────────────────
256
257/// High-level token revocation manager wrapping a backend store.
258pub struct TokenRevocationManager {
259    store:       Arc<dyn RevocationStore>,
260    require_jti: bool,
261    fail_open:   bool,
262}
263
264impl TokenRevocationManager {
265    /// Create a new revocation manager.
266    #[must_use]
267    pub fn new(store: Arc<dyn RevocationStore>, require_jti: bool, fail_open: bool) -> Self {
268        Self {
269            store,
270            require_jti,
271            fail_open,
272        }
273    }
274
275    /// Check if a token should be rejected.
276    ///
277    /// Returns `Ok(())` if the token is allowed, or an error reason if rejected.
278    ///
279    /// # Errors
280    ///
281    /// Returns `TokenRejection::MissingJti` if JTI is required but absent.
282    /// Returns `TokenRejection::Revoked` if the token has been revoked.
283    /// Returns `TokenRejection::StoreUnavailable` if the revocation store is unreachable and
284    /// `fail_open` is false.
285    pub async fn check_token(&self, jti: Option<&str>) -> Result<(), TokenRejection> {
286        let jti = match jti {
287            Some(j) if !j.is_empty() => j,
288            _ => {
289                if self.require_jti {
290                    return Err(TokenRejection::MissingJti);
291                }
292                // No JTI and not required — allow through.
293                return Ok(());
294            },
295        };
296
297        match self.store.is_revoked(jti).await {
298            Ok(true) => Err(TokenRejection::Revoked),
299            Ok(false) => Ok(()),
300            Err(e) => {
301                warn!(error = %e, jti = %jti, "Revocation store check failed");
302                if self.fail_open {
303                    debug!("fail_open=true — allowing request despite store error");
304                    Ok(())
305                } else {
306                    Err(TokenRejection::StoreUnavailable)
307                }
308            },
309        }
310    }
311
312    /// Revoke a single token by JTI.
313    ///
314    /// # Errors
315    ///
316    /// Returns `RevocationError` if the underlying revocation store operation fails.
317    pub async fn revoke(&self, jti: &str, ttl_secs: u64) -> Result<(), RevocationError> {
318        self.store.revoke(jti, ttl_secs).await
319    }
320
321    /// Revoke all tokens for a user.
322    ///
323    /// # Errors
324    ///
325    /// Returns `RevocationError` if the underlying revocation store operation fails.
326    pub async fn revoke_all_for_user(&self, sub: &str) -> Result<u64, RevocationError> {
327        self.store.revoke_all_for_user(sub).await
328    }
329
330    /// Whether JTI is required.
331    #[must_use]
332    pub const fn require_jti(&self) -> bool {
333        self.require_jti
334    }
335}
336
337impl std::fmt::Debug for TokenRevocationManager {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        f.debug_struct("TokenRevocationManager")
340            .field("require_jti", &self.require_jti)
341            .field("fail_open", &self.fail_open)
342            .finish_non_exhaustive()
343    }
344}
345
346/// Why a token was rejected.
347#[derive(Debug, Clone, PartialEq, Eq)]
348#[non_exhaustive]
349pub enum TokenRejection {
350    /// Token has been revoked.
351    Revoked,
352    /// Token lacks a `jti` claim and `require_jti` is enabled.
353    MissingJti,
354    /// Revocation store is unavailable and `fail_open` is false.
355    StoreUnavailable,
356}
357
358// ───────────────────────────────────────────────────────────────
359// Builder from compiled schema
360// ───────────────────────────────────────────────────────────────
361
362/// Build a `TokenRevocationManager` from the compiled schema's `security.token_revocation` JSON.
363pub fn revocation_manager_from_schema(
364    schema: &fraiseql_core::schema::CompiledSchema,
365) -> Option<Arc<TokenRevocationManager>> {
366    let security = schema.security.as_ref()?;
367    let revocation_val = security.additional.get("token_revocation")?;
368    let config: TokenRevocationConfig = serde_json::from_value(revocation_val.clone())
369        .map_err(|e| {
370            warn!(error = %e, "Failed to parse security.token_revocation config");
371        })
372        .ok()?;
373
374    if !config.enabled {
375        return None;
376    }
377
378    let store: Arc<dyn RevocationStore> = match config.backend.as_str() {
379        #[cfg(feature = "redis-rate-limiting")]
380        "redis" => {
381            let url = config.redis_url.as_deref().unwrap_or("redis://localhost:6379");
382            match RedisRevocationStore::new(url) {
383                Ok(s) => {
384                    info!(backend = "redis", "Token revocation store initialized");
385                    Arc::new(s)
386                },
387                Err(e) => {
388                    warn!(error = %e, "Failed to init Redis revocation store — falling back to in-memory");
389                    Arc::new(InMemoryRevocationStore::new())
390                },
391            }
392        },
393        #[cfg(not(feature = "redis-rate-limiting"))]
394        "redis" => {
395            warn!(
396                "token_revocation.backend = \"redis\" but the `redis-rate-limiting` feature is \
397                 not compiled in. Falling back to in-memory."
398            );
399            Arc::new(InMemoryRevocationStore::new())
400        },
401        "memory" | "env" => {
402            info!(backend = "memory", "Token revocation store initialized (in-memory)");
403            Arc::new(InMemoryRevocationStore::new())
404        },
405        other => {
406            warn!(backend = %other, "Unknown revocation backend — falling back to in-memory");
407            Arc::new(InMemoryRevocationStore::new())
408        },
409    };
410
411    Some(Arc::new(TokenRevocationManager::new(
412        store,
413        config.require_jti,
414        config.fail_open,
415    )))
416}
417
418// ───────────────────────────────────────────────────────────────
419// Tests
420// ───────────────────────────────────────────────────────────────
421
422#[cfg(test)]
423mod tests {
424    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
425
426    use super::*;
427
428    fn memory_store() -> Arc<dyn RevocationStore> {
429        Arc::new(InMemoryRevocationStore::new())
430    }
431
432    #[tokio::test]
433    async fn revoke_then_check_is_revoked() {
434        let store = memory_store();
435        store.revoke("jti-1", 3600).await.unwrap();
436        assert!(store.is_revoked("jti-1").await.unwrap());
437    }
438
439    #[tokio::test]
440    async fn non_revoked_jti_passes() {
441        let store = memory_store();
442        assert!(!store.is_revoked("jti-unknown").await.unwrap());
443    }
444
445    #[tokio::test]
446    async fn expired_entry_not_revoked() {
447        let store = InMemoryRevocationStore::new();
448        // Insert with 0-second TTL → already expired.
449        store.revoke("jti-expired", 0).await.unwrap();
450        // Should not be considered revoked (TTL elapsed).
451        assert!(!store.is_revoked("jti-expired").await.unwrap());
452    }
453
454    #[tokio::test]
455    async fn cleanup_removes_expired() {
456        let store = InMemoryRevocationStore::new();
457        store.revoke("jti-a", 0).await.unwrap();
458        store.revoke("jti-b", 3600).await.unwrap();
459        store.cleanup_expired();
460        // jti-a expired, jti-b still valid.
461        assert_eq!(store.entries.len(), 1);
462    }
463
464    #[tokio::test]
465    async fn manager_rejects_revoked_token() {
466        let store = memory_store();
467        store.revoke("jti-x", 3600).await.unwrap();
468        let mgr = TokenRevocationManager::new(store, true, false);
469        assert_eq!(mgr.check_token(Some("jti-x")).await, Err(TokenRejection::Revoked));
470    }
471
472    #[tokio::test]
473    async fn manager_allows_non_revoked_token() {
474        let mgr = TokenRevocationManager::new(memory_store(), true, false);
475        mgr.check_token(Some("jti-ok"))
476            .await
477            .unwrap_or_else(|e| panic!("expected Ok for non-revoked token: {e:?}"));
478    }
479
480    #[tokio::test]
481    async fn manager_rejects_missing_jti_when_required() {
482        let mgr = TokenRevocationManager::new(memory_store(), true, false);
483        assert_eq!(mgr.check_token(None).await, Err(TokenRejection::MissingJti));
484    }
485
486    #[tokio::test]
487    async fn manager_allows_missing_jti_when_not_required() {
488        let mgr = TokenRevocationManager::new(memory_store(), false, false);
489        assert!(
490            mgr.check_token(None).await.is_ok(),
491            "missing jti should be allowed when jti is not required"
492        );
493    }
494
495    #[tokio::test]
496    async fn manager_allows_empty_jti_when_not_required() {
497        let mgr = TokenRevocationManager::new(memory_store(), false, false);
498        assert!(
499            mgr.check_token(Some("")).await.is_ok(),
500            "empty jti should be allowed when jti is not required"
501        );
502    }
503}