fraiseql_core/security/oidc/
replay_cache.rs1use std::{
23 sync::atomic::{AtomicU64, Ordering},
24 time::Duration,
25};
26
27use async_trait::async_trait;
28use tracing::warn;
29
30#[derive(Debug, thiserror::Error)]
36#[non_exhaustive]
37pub enum ReplayCacheError {
38 #[error("JWT token has already been used (jti replay detected)")]
40 Replayed,
41 #[error("Replay cache backend error: {0}")]
43 Backend(String),
44}
45
46#[derive(Debug, Clone, Copy, Default)]
52#[non_exhaustive]
53pub enum FailurePolicy {
54 #[default]
57 FailOpen,
58 FailClosed,
61}
62
63#[async_trait]
69pub trait ReplayCacheBackend: Send + Sync {
70 async fn check_and_record(&self, jti: &str, ttl: Duration) -> Result<(), ReplayCacheError>;
83}
84
85static JWT_REPLAY_REJECTED_TOTAL: AtomicU64 = AtomicU64::new(0);
91static JWT_REPLAY_CACHE_ERRORS_TOTAL: AtomicU64 = AtomicU64::new(0);
93
94#[must_use]
96pub fn jwt_replay_rejected_total() -> u64 {
97 JWT_REPLAY_REJECTED_TOTAL.load(Ordering::Relaxed)
98}
99
100#[must_use]
102pub fn jwt_replay_cache_errors_total() -> u64 {
103 JWT_REPLAY_CACHE_ERRORS_TOTAL.load(Ordering::Relaxed)
104}
105
106pub struct ReplayCache {
111 backend: Box<dyn ReplayCacheBackend>,
112 policy: FailurePolicy,
113}
114
115impl ReplayCache {
116 #[must_use]
118 pub fn new(backend: impl ReplayCacheBackend + 'static) -> Self {
119 Self {
120 backend: Box::new(backend),
121 policy: FailurePolicy::FailOpen,
122 }
123 }
124
125 #[must_use]
127 pub const fn with_policy(mut self, policy: FailurePolicy) -> Self {
128 self.policy = policy;
129 self
130 }
131
132 pub async fn check_and_record(&self, jti: &str, ttl: Duration) -> Result<(), ReplayCacheError> {
139 match self.backend.check_and_record(jti, ttl).await {
140 Ok(()) => Ok(()),
141 Err(ReplayCacheError::Replayed) => {
142 JWT_REPLAY_REJECTED_TOTAL.fetch_add(1, Ordering::Relaxed);
143 Err(ReplayCacheError::Replayed)
144 },
145 Err(ReplayCacheError::Backend(msg)) => {
146 JWT_REPLAY_CACHE_ERRORS_TOTAL.fetch_add(1, Ordering::Relaxed);
147 match self.policy {
148 FailurePolicy::FailOpen => {
149 warn!(
150 error = %msg,
151 "JWT replay cache backend error — failing open (token accepted). \
152 Replay protection is degraded while the backend is unavailable."
153 );
154 Ok(())
155 },
156 FailurePolicy::FailClosed => Err(ReplayCacheError::Backend(msg)),
157 }
158 },
159 }
160 }
161}
162
163pub struct MemoryReplayCache {
175 store: dashmap::DashMap<String, std::time::Instant>,
176}
177
178impl MemoryReplayCache {
179 #[must_use]
181 pub fn new() -> Self {
182 Self {
183 store: dashmap::DashMap::new(),
184 }
185 }
186}
187
188impl Default for MemoryReplayCache {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194#[async_trait]
195impl ReplayCacheBackend for MemoryReplayCache {
196 async fn check_and_record(&self, jti: &str, ttl: Duration) -> Result<(), ReplayCacheError> {
197 let now = std::time::Instant::now();
198 let expiry = now + ttl;
199
200 if let Some(existing) = self.store.get(jti) {
202 if *existing > now {
203 return Err(ReplayCacheError::Replayed);
205 }
206 drop(existing);
207 }
208
209 self.store.insert(jti.to_string(), expiry);
211 Ok(())
212 }
213}
214
215#[cfg(feature = "jwt-replay")]
225pub struct RedisReplayCache {
226 pool: redis::aio::ConnectionManager,
227 key_prefix: String,
228}
229
230#[cfg(feature = "jwt-replay")]
231impl RedisReplayCache {
232 pub async fn new(redis_url: &str) -> Result<Self, ReplayCacheError> {
238 Self::with_prefix(redis_url, "fraiseql:jti:").await
239 }
240
241 pub async fn with_prefix(redis_url: &str, key_prefix: &str) -> Result<Self, ReplayCacheError> {
247 let client = redis::Client::open(redis_url)
248 .map_err(|e| ReplayCacheError::Backend(format!("invalid Redis URL: {e}")))?;
249 let pool = client
250 .get_connection_manager()
251 .await
252 .map_err(|e| ReplayCacheError::Backend(format!("Redis connection failed: {e}")))?;
253 Ok(Self {
254 pool,
255 key_prefix: key_prefix.to_string(),
256 })
257 }
258
259 fn key(&self, jti: &str) -> String {
260 format!("{}{}", self.key_prefix, jti)
261 }
262}
263
264#[cfg(feature = "jwt-replay")]
265#[async_trait]
266impl ReplayCacheBackend for RedisReplayCache {
267 async fn check_and_record(&self, jti: &str, ttl: Duration) -> Result<(), ReplayCacheError> {
268 use redis::AsyncCommands;
269
270 let key = self.key(jti);
271 let ttl_secs = ttl.as_secs().max(1);
272 let mut conn = self.pool.clone();
273
274 let was_set: bool = conn
277 .set_options(
278 &key,
279 1u8,
280 redis::SetOptions::default()
281 .conditional_set(redis::ExistenceCheck::NX)
282 .with_expiration(redis::SetExpiry::EX(ttl_secs)),
283 )
284 .await
285 .map_err(|e| ReplayCacheError::Backend(format!("Redis SET NX failed: {e}")))?;
286
287 if was_set {
288 Ok(())
289 } else {
290 Err(ReplayCacheError::Replayed)
291 }
292 }
293}
294
295#[cfg(test)]
300mod tests {
301 #![allow(clippy::unwrap_used)] use super::*;
304
305 #[tokio::test]
306 async fn test_first_use_accepted() {
307 let cache = ReplayCache::new(MemoryReplayCache::new());
308 let result = cache.check_and_record("jti-abc", Duration::from_secs(900)).await;
309 assert!(result.is_ok(), "first use should be accepted");
310 }
311
312 #[tokio::test]
313 async fn test_replay_rejected() {
314 let cache = ReplayCache::new(MemoryReplayCache::new());
315 cache.check_and_record("jti-abc", Duration::from_secs(900)).await.unwrap();
316 let result = cache.check_and_record("jti-abc", Duration::from_secs(900)).await;
317 assert!(
318 matches!(result, Err(ReplayCacheError::Replayed)),
319 "second use of same jti should be rejected"
320 );
321 }
322
323 #[tokio::test]
324 async fn test_different_jtis_accepted_independently() {
325 let cache = ReplayCache::new(MemoryReplayCache::new());
326 cache.check_and_record("jti-1", Duration::from_secs(900)).await.unwrap();
327 let result = cache.check_and_record("jti-2", Duration::from_secs(900)).await;
328 assert!(result.is_ok(), "different jti should be accepted");
329 }
330
331 #[tokio::test]
332 async fn test_fail_open_policy_on_backend_error() {
333 struct AlwaysErrorBackend;
334
335 #[async_trait]
336 impl ReplayCacheBackend for AlwaysErrorBackend {
337 async fn check_and_record(
338 &self,
339 _jti: &str,
340 _ttl: Duration,
341 ) -> Result<(), ReplayCacheError> {
342 Err(ReplayCacheError::Backend("simulated error".to_string()))
343 }
344 }
345
346 let cache = ReplayCache::new(AlwaysErrorBackend).with_policy(FailurePolicy::FailOpen);
347 let result = cache.check_and_record("jti-xyz", Duration::from_secs(900)).await;
348 assert!(result.is_ok(), "fail-open should accept on backend error");
349 }
350
351 #[tokio::test]
352 async fn test_fail_closed_policy_on_backend_error() {
353 struct AlwaysErrorBackend;
354
355 #[async_trait]
356 impl ReplayCacheBackend for AlwaysErrorBackend {
357 async fn check_and_record(
358 &self,
359 _jti: &str,
360 _ttl: Duration,
361 ) -> Result<(), ReplayCacheError> {
362 Err(ReplayCacheError::Backend("simulated error".to_string()))
363 }
364 }
365
366 let cache = ReplayCache::new(AlwaysErrorBackend).with_policy(FailurePolicy::FailClosed);
367 let result = cache.check_and_record("jti-xyz", Duration::from_secs(900)).await;
368 assert!(result.is_err(), "fail-closed should reject on backend error");
369 }
370
371 #[tokio::test]
372 async fn test_replay_counter_increments() {
373 let before = jwt_replay_rejected_total();
374 let cache = ReplayCache::new(MemoryReplayCache::new());
375 cache.check_and_record("jti-counter", Duration::from_secs(900)).await.unwrap();
376 let _ = cache.check_and_record("jti-counter", Duration::from_secs(900)).await;
377 let after = jwt_replay_rejected_total();
378 assert!(after >= before + 1, "replay counter should have incremented");
379 }
380}