Skip to main content

fraiseql_core/security/oidc/
replay_cache.rs

1//! JWT replay prevention via a jti (JWT ID) replay cache.
2//!
3//! Each validated JWT token carries a unique `jti` claim. This module stores
4//! seen `jti` values and rejects any token whose `jti` has been seen before,
5//! preventing stolen-token replay attacks within the token's remaining validity
6//! window.
7//!
8//! # Backends
9//!
10//! - **Redis** (`jwt-replay` feature): distributed, survives server restarts.
11//! - **Memory** (always available): single-process, resets on restart; suitable for testing or
12//!   single-instance deployments.
13//!
14//! # Failure policy
15//!
16//! When Redis is unavailable, behavior is controlled by [`FailurePolicy`]:
17//! - [`FailurePolicy::FailOpen`] (default): accept the token and log a warning. Prevents auth
18//!   outages during Redis downtime at the cost of reduced replay protection.
19//! - [`FailurePolicy::FailClosed`]: reject the token. Maximum security, but any Redis hiccup will
20//!   cause auth failures.
21
22use std::{
23    sync::atomic::{AtomicU64, Ordering},
24    time::Duration,
25};
26
27use async_trait::async_trait;
28use tracing::warn;
29
30// ============================================================================
31// Error type
32// ============================================================================
33
34/// Error returned by [`ReplayCacheBackend::check_and_record`].
35#[derive(Debug, thiserror::Error)]
36#[non_exhaustive]
37pub enum ReplayCacheError {
38    /// The `jti` was already seen — this is a replayed token.
39    #[error("JWT token has already been used (jti replay detected)")]
40    Replayed,
41    /// The backend returned an unexpected error.
42    #[error("Replay cache backend error: {0}")]
43    Backend(String),
44}
45
46// ============================================================================
47// Failure policy
48// ============================================================================
49
50/// Policy controlling what happens when the replay-cache backend is unavailable.
51#[derive(Debug, Clone, Copy, Default)]
52#[non_exhaustive]
53pub enum FailurePolicy {
54    /// Accept the token and log a warning. Prevents auth outages during backend
55    /// downtime at the cost of reduced replay protection during the outage.
56    #[default]
57    FailOpen,
58    /// Reject the token. Maximum security, but any backend hiccup causes auth
59    /// failures.
60    FailClosed,
61}
62
63// ============================================================================
64// Backend trait
65// ============================================================================
66
67/// A backend that stores and checks seen JWT IDs.
68#[async_trait]
69pub trait ReplayCacheBackend: Send + Sync {
70    /// Check whether `jti` has been seen before, and record it if not.
71    ///
72    /// # Returns
73    ///
74    /// - `Ok(())` if this is the first time the `jti` has been seen (accepted).
75    /// - `Err(ReplayCacheError::Replayed)` if the `jti` was already stored.
76    /// - `Err(ReplayCacheError::Backend(_))` on a transient backend error.
77    ///
78    /// # Errors
79    ///
80    /// Returns [`ReplayCacheError::Replayed`] when replay is detected.
81    /// Returns [`ReplayCacheError::Backend`] on storage failure.
82    async fn check_and_record(&self, jti: &str, ttl: Duration) -> Result<(), ReplayCacheError>;
83}
84
85// ============================================================================
86// ReplayCache — the public façade
87// ============================================================================
88
89/// Global counter for JWT replay rejections.
90static JWT_REPLAY_REJECTED_TOTAL: AtomicU64 = AtomicU64::new(0);
91/// Global counter for replay-cache backend errors (Redis failures, etc.).
92static JWT_REPLAY_CACHE_ERRORS_TOTAL: AtomicU64 = AtomicU64::new(0);
93
94/// Return the cumulative count of JWT replay rejections since process start.
95#[must_use]
96pub fn jwt_replay_rejected_total() -> u64 {
97    JWT_REPLAY_REJECTED_TOTAL.load(Ordering::Relaxed)
98}
99
100/// Return the cumulative count of replay-cache backend errors since process start.
101#[must_use]
102pub fn jwt_replay_cache_errors_total() -> u64 {
103    JWT_REPLAY_CACHE_ERRORS_TOTAL.load(Ordering::Relaxed)
104}
105
106/// JWT replay prevention cache.
107///
108/// Wraps a [`ReplayCacheBackend`] with a configurable [`FailurePolicy`] and
109/// Prometheus-compatible counters.
110pub struct ReplayCache {
111    backend: Box<dyn ReplayCacheBackend>,
112    policy:  FailurePolicy,
113}
114
115impl ReplayCache {
116    /// Create a new `ReplayCache` wrapping the given backend.
117    #[must_use]
118    pub fn new(backend: impl ReplayCacheBackend + 'static) -> Self {
119        Self {
120            backend: Box::new(backend),
121            policy:  FailurePolicy::FailOpen,
122        }
123    }
124
125    /// Set the failure policy for backend errors.
126    #[must_use]
127    pub const fn with_policy(mut self, policy: FailurePolicy) -> Self {
128        self.policy = policy;
129        self
130    }
131
132    /// Check and record the given `jti` with the given TTL.
133    ///
134    /// # Errors
135    ///
136    /// Returns `Err(ReplayCacheError::Replayed)` when replay is detected.
137    /// Backend errors are handled according to the configured [`FailurePolicy`].
138    pub async fn check_and_record(&self, jti: &str, ttl: Duration) -> Result<(), ReplayCacheError> {
139        match self.backend.check_and_record(jti, ttl).await {
140            Ok(()) => Ok(()),
141            Err(ReplayCacheError::Replayed) => {
142                JWT_REPLAY_REJECTED_TOTAL.fetch_add(1, Ordering::Relaxed);
143                Err(ReplayCacheError::Replayed)
144            },
145            Err(ReplayCacheError::Backend(msg)) => {
146                JWT_REPLAY_CACHE_ERRORS_TOTAL.fetch_add(1, Ordering::Relaxed);
147                match self.policy {
148                    FailurePolicy::FailOpen => {
149                        warn!(
150                            error = %msg,
151                            "JWT replay cache backend error — failing open (token accepted). \
152                             Replay protection is degraded while the backend is unavailable."
153                        );
154                        Ok(())
155                    },
156                    FailurePolicy::FailClosed => Err(ReplayCacheError::Backend(msg)),
157                }
158            },
159        }
160    }
161}
162
163// ============================================================================
164// In-memory backend (always compiled in; useful for tests + single-process)
165// ============================================================================
166
167/// In-memory JWT replay cache backend.
168///
169/// Uses a `DashMap` for lock-free concurrent access. TTL is enforced by storing
170/// the expiry timestamp alongside each entry and lazily evicting on lookup.
171///
172/// **Not distributed**: each process has its own cache. Use the Redis backend
173/// for multi-instance deployments.
174pub struct MemoryReplayCache {
175    store: dashmap::DashMap<String, std::time::Instant>,
176}
177
178impl MemoryReplayCache {
179    /// Create a new in-memory replay cache.
180    #[must_use]
181    pub fn new() -> Self {
182        Self {
183            store: dashmap::DashMap::new(),
184        }
185    }
186}
187
188impl Default for MemoryReplayCache {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194#[async_trait]
195impl ReplayCacheBackend for MemoryReplayCache {
196    async fn check_and_record(&self, jti: &str, ttl: Duration) -> Result<(), ReplayCacheError> {
197        let now = std::time::Instant::now();
198        let expiry = now + ttl;
199
200        // Remove expired entry if present (lazy eviction).
201        if let Some(existing) = self.store.get(jti) {
202            if *existing > now {
203                // Still valid — this is a replay.
204                return Err(ReplayCacheError::Replayed);
205            }
206            drop(existing);
207        }
208
209        // Insert (or re-insert after expiry).
210        self.store.insert(jti.to_string(), expiry);
211        Ok(())
212    }
213}
214
215// ============================================================================
216// Redis backend (compiled in with `jwt-replay` feature)
217// ============================================================================
218
219/// Redis-backed JWT replay cache backend.
220///
221/// Uses `SET key 1 EX {ttl_secs} NX` (SET if Not eXists) to atomically record
222/// a `jti` and detect replays: if the key was not set because it already existed,
223/// the token is a replay.
224#[cfg(feature = "jwt-replay")]
225pub struct RedisReplayCache {
226    pool:       redis::aio::ConnectionManager,
227    key_prefix: String,
228}
229
230#[cfg(feature = "jwt-replay")]
231impl RedisReplayCache {
232    /// Connect to Redis and create the replay cache.
233    ///
234    /// # Errors
235    ///
236    /// Returns an error if the Redis URL is invalid or the connection fails.
237    pub async fn new(redis_url: &str) -> Result<Self, ReplayCacheError> {
238        Self::with_prefix(redis_url, "fraiseql:jti:").await
239    }
240
241    /// Connect to Redis with a custom key prefix (useful for multi-tenant isolation).
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if the Redis URL is invalid or the connection fails.
246    pub async fn with_prefix(redis_url: &str, key_prefix: &str) -> Result<Self, ReplayCacheError> {
247        let client = redis::Client::open(redis_url)
248            .map_err(|e| ReplayCacheError::Backend(format!("invalid Redis URL: {e}")))?;
249        let pool = client
250            .get_connection_manager()
251            .await
252            .map_err(|e| ReplayCacheError::Backend(format!("Redis connection failed: {e}")))?;
253        Ok(Self {
254            pool,
255            key_prefix: key_prefix.to_string(),
256        })
257    }
258
259    fn key(&self, jti: &str) -> String {
260        format!("{}{}", self.key_prefix, jti)
261    }
262}
263
264#[cfg(feature = "jwt-replay")]
265#[async_trait]
266impl ReplayCacheBackend for RedisReplayCache {
267    async fn check_and_record(&self, jti: &str, ttl: Duration) -> Result<(), ReplayCacheError> {
268        use redis::AsyncCommands;
269
270        let key = self.key(jti);
271        let ttl_secs = ttl.as_secs().max(1);
272        let mut conn = self.pool.clone();
273
274        // SET key 1 EX ttl_secs NX
275        // Returns true if the key was set (first use), false if it already existed (replay).
276        let was_set: bool = conn
277            .set_options(
278                &key,
279                1u8,
280                redis::SetOptions::default()
281                    .conditional_set(redis::ExistenceCheck::NX)
282                    .with_expiration(redis::SetExpiry::EX(ttl_secs)),
283            )
284            .await
285            .map_err(|e| ReplayCacheError::Backend(format!("Redis SET NX failed: {e}")))?;
286
287        if was_set {
288            Ok(())
289        } else {
290            Err(ReplayCacheError::Replayed)
291        }
292    }
293}
294
295// ============================================================================
296// Tests
297// ============================================================================
298
299#[cfg(test)]
300mod tests {
301    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
302
303    use super::*;
304
305    #[tokio::test]
306    async fn test_first_use_accepted() {
307        let cache = ReplayCache::new(MemoryReplayCache::new());
308        let result = cache.check_and_record("jti-abc", Duration::from_secs(900)).await;
309        assert!(result.is_ok(), "first use should be accepted");
310    }
311
312    #[tokio::test]
313    async fn test_replay_rejected() {
314        let cache = ReplayCache::new(MemoryReplayCache::new());
315        cache.check_and_record("jti-abc", Duration::from_secs(900)).await.unwrap();
316        let result = cache.check_and_record("jti-abc", Duration::from_secs(900)).await;
317        assert!(
318            matches!(result, Err(ReplayCacheError::Replayed)),
319            "second use of same jti should be rejected"
320        );
321    }
322
323    #[tokio::test]
324    async fn test_different_jtis_accepted_independently() {
325        let cache = ReplayCache::new(MemoryReplayCache::new());
326        cache.check_and_record("jti-1", Duration::from_secs(900)).await.unwrap();
327        let result = cache.check_and_record("jti-2", Duration::from_secs(900)).await;
328        assert!(result.is_ok(), "different jti should be accepted");
329    }
330
331    #[tokio::test]
332    async fn test_fail_open_policy_on_backend_error() {
333        struct AlwaysErrorBackend;
334
335        #[async_trait]
336        impl ReplayCacheBackend for AlwaysErrorBackend {
337            async fn check_and_record(
338                &self,
339                _jti: &str,
340                _ttl: Duration,
341            ) -> Result<(), ReplayCacheError> {
342                Err(ReplayCacheError::Backend("simulated error".to_string()))
343            }
344        }
345
346        let cache = ReplayCache::new(AlwaysErrorBackend).with_policy(FailurePolicy::FailOpen);
347        let result = cache.check_and_record("jti-xyz", Duration::from_secs(900)).await;
348        assert!(result.is_ok(), "fail-open should accept on backend error");
349    }
350
351    #[tokio::test]
352    async fn test_fail_closed_policy_on_backend_error() {
353        struct AlwaysErrorBackend;
354
355        #[async_trait]
356        impl ReplayCacheBackend for AlwaysErrorBackend {
357            async fn check_and_record(
358                &self,
359                _jti: &str,
360                _ttl: Duration,
361            ) -> Result<(), ReplayCacheError> {
362                Err(ReplayCacheError::Backend("simulated error".to_string()))
363            }
364        }
365
366        let cache = ReplayCache::new(AlwaysErrorBackend).with_policy(FailurePolicy::FailClosed);
367        let result = cache.check_and_record("jti-xyz", Duration::from_secs(900)).await;
368        assert!(result.is_err(), "fail-closed should reject on backend error");
369    }
370
371    #[tokio::test]
372    async fn test_replay_counter_increments() {
373        let before = jwt_replay_rejected_total();
374        let cache = ReplayCache::new(MemoryReplayCache::new());
375        cache.check_and_record("jti-counter", Duration::from_secs(900)).await.unwrap();
376        let _ = cache.check_and_record("jti-counter", Duration::from_secs(900)).await;
377        let after = jwt_replay_rejected_total();
378        assert!(after >= before + 1, "replay counter should have incremented");
379    }
380}