1use std::sync::Arc;
14
15use async_trait::async_trait;
16use chrono::{DateTime, Utc};
17use dashmap::DashMap;
18use serde::Deserialize;
19use tracing::{debug, info, warn};
20
21#[derive(Debug, Clone, Deserialize)]
27pub struct TokenRevocationConfig {
28 #[serde(default)]
30 pub enabled: bool,
31
32 #[serde(default = "default_backend")]
34 pub backend: String,
35
36 #[serde(default = "default_true")]
38 pub require_jti: bool,
39
40 #[serde(default)]
44 pub fail_open: bool,
45
46 pub redis_url: Option<String>,
48}
49
50fn default_backend() -> String {
51 "memory".into()
52}
53const fn default_true() -> bool {
54 true
55}
56
57#[async_trait]
65pub trait RevocationStore: Send + Sync {
66 async fn is_revoked(&self, jti: &str) -> Result<bool, RevocationError>;
68
69 async fn revoke(&self, jti: &str, ttl_secs: u64) -> Result<(), RevocationError>;
72
73 async fn revoke_all_for_user(&self, sub: &str) -> Result<u64, RevocationError>;
76}
77
78#[derive(Debug, thiserror::Error)]
80#[non_exhaustive]
81pub enum RevocationError {
82 #[error("revocation store error: {0}")]
84 Backend(String),
85}
86
87pub struct InMemoryRevocationStore {
93 entries: DashMap<String, (String, DateTime<Utc>)>,
95}
96
97impl InMemoryRevocationStore {
98 #[must_use]
100 pub fn new() -> Self {
101 Self {
102 entries: DashMap::new(),
103 }
104 }
105
106 pub fn cleanup_expired(&self) {
108 let now = Utc::now();
109 self.entries.retain(|_, (_, exp)| *exp > now);
110 }
111}
112
113impl Default for InMemoryRevocationStore {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119#[async_trait]
123impl RevocationStore for InMemoryRevocationStore {
124 async fn is_revoked(&self, jti: &str) -> Result<bool, RevocationError> {
125 if let Some(entry) = self.entries.get(jti) {
126 let (_, expires_at) = entry.value();
127 if *expires_at > Utc::now() {
128 return Ok(true);
129 }
130 drop(entry);
132 self.entries.remove(jti);
133 }
134 Ok(false)
135 }
136
137 async fn revoke(&self, jti: &str, ttl_secs: u64) -> Result<(), RevocationError> {
138 let expires_at = Utc::now() + chrono::Duration::seconds(ttl_secs.cast_signed());
139 self.entries.insert(jti.to_string(), (String::new(), expires_at));
141 Ok(())
142 }
143
144 async fn revoke_all_for_user(&self, sub: &str) -> Result<u64, RevocationError> {
145 let mut count = 0u64;
148 for entry in &self.entries {
149 let (s, _) = entry.value();
150 if s == sub {
151 count += 1;
152 }
153 }
154 Ok(count)
157 }
158}
159
160#[cfg(feature = "redis-rate-limiting")]
169pub struct RedisRevocationStore {
170 client: redis::Client,
171 key_prefix: String,
172}
173
174#[cfg(feature = "redis-rate-limiting")]
175impl RedisRevocationStore {
176 pub fn new(redis_url: &str) -> Result<Self, RevocationError> {
182 let client = redis::Client::open(redis_url)
183 .map_err(|e| RevocationError::Backend(format!("Redis connection error: {e}")))?;
184 Ok(Self {
185 client,
186 key_prefix: "fraiseql:revoked:".into(),
187 })
188 }
189}
190
191#[cfg(feature = "redis-rate-limiting")]
192#[async_trait]
196impl RevocationStore for RedisRevocationStore {
197 async fn is_revoked(&self, jti: &str) -> Result<bool, RevocationError> {
198 use redis::AsyncCommands;
199 let mut conn = self
200 .client
201 .get_multiplexed_async_connection()
202 .await
203 .map_err(|e| RevocationError::Backend(format!("Redis: {e}")))?;
204 let key = format!("{}{jti}", self.key_prefix);
205 let exists: bool = conn
206 .exists(&key)
207 .await
208 .map_err(|e| RevocationError::Backend(format!("Redis EXISTS: {e}")))?;
209 Ok(exists)
210 }
211
212 async fn revoke(&self, jti: &str, ttl_secs: u64) -> Result<(), RevocationError> {
213 use redis::AsyncCommands;
214 let mut conn = self
215 .client
216 .get_multiplexed_async_connection()
217 .await
218 .map_err(|e| RevocationError::Backend(format!("Redis: {e}")))?;
219 let key = format!("{}{jti}", self.key_prefix);
220 let _: () = conn
221 .set_ex(&key, "1", ttl_secs)
222 .await
223 .map_err(|e| RevocationError::Backend(format!("Redis SET EX: {e}")))?;
224 Ok(())
225 }
226
227 async fn revoke_all_for_user(&self, sub: &str) -> Result<u64, RevocationError> {
228 let mut conn = self
229 .client
230 .get_multiplexed_async_connection()
231 .await
232 .map_err(|e| RevocationError::Backend(format!("Redis: {e}")))?;
233 let pattern = format!("{}user:{sub}:*", self.key_prefix);
236 let keys: Vec<String> = redis::cmd("KEYS")
237 .arg(&pattern)
238 .query_async(&mut conn)
239 .await
240 .map_err(|e| RevocationError::Backend(format!("Redis KEYS: {e}")))?;
241 let count = keys.len() as u64;
242 if !keys.is_empty() {
243 let _: () = redis::cmd("DEL")
244 .arg(&keys)
245 .query_async(&mut conn)
246 .await
247 .map_err(|e| RevocationError::Backend(format!("Redis DEL: {e}")))?;
248 }
249 Ok(count)
250 }
251}
252
253pub struct TokenRevocationManager {
259 store: Arc<dyn RevocationStore>,
260 require_jti: bool,
261 fail_open: bool,
262}
263
264impl TokenRevocationManager {
265 #[must_use]
267 pub fn new(store: Arc<dyn RevocationStore>, require_jti: bool, fail_open: bool) -> Self {
268 Self {
269 store,
270 require_jti,
271 fail_open,
272 }
273 }
274
275 pub async fn check_token(&self, jti: Option<&str>) -> Result<(), TokenRejection> {
286 let jti = match jti {
287 Some(j) if !j.is_empty() => j,
288 _ => {
289 if self.require_jti {
290 return Err(TokenRejection::MissingJti);
291 }
292 return Ok(());
294 },
295 };
296
297 match self.store.is_revoked(jti).await {
298 Ok(true) => Err(TokenRejection::Revoked),
299 Ok(false) => Ok(()),
300 Err(e) => {
301 warn!(error = %e, jti = %jti, "Revocation store check failed");
302 if self.fail_open {
303 debug!("fail_open=true — allowing request despite store error");
304 Ok(())
305 } else {
306 Err(TokenRejection::StoreUnavailable)
307 }
308 },
309 }
310 }
311
312 pub async fn revoke(&self, jti: &str, ttl_secs: u64) -> Result<(), RevocationError> {
318 self.store.revoke(jti, ttl_secs).await
319 }
320
321 pub async fn revoke_all_for_user(&self, sub: &str) -> Result<u64, RevocationError> {
327 self.store.revoke_all_for_user(sub).await
328 }
329
330 #[must_use]
332 pub const fn require_jti(&self) -> bool {
333 self.require_jti
334 }
335}
336
337impl std::fmt::Debug for TokenRevocationManager {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 f.debug_struct("TokenRevocationManager")
340 .field("require_jti", &self.require_jti)
341 .field("fail_open", &self.fail_open)
342 .finish_non_exhaustive()
343 }
344}
345
346#[derive(Debug, Clone, PartialEq, Eq)]
348#[non_exhaustive]
349pub enum TokenRejection {
350 Revoked,
352 MissingJti,
354 StoreUnavailable,
356}
357
358pub fn revocation_manager_from_schema(
364 schema: &fraiseql_core::schema::CompiledSchema,
365) -> Option<Arc<TokenRevocationManager>> {
366 let security = schema.security.as_ref()?;
367 let revocation_val = security.additional.get("token_revocation")?;
368 let config: TokenRevocationConfig = serde_json::from_value(revocation_val.clone())
369 .map_err(|e| {
370 warn!(error = %e, "Failed to parse security.token_revocation config");
371 })
372 .ok()?;
373
374 if !config.enabled {
375 return None;
376 }
377
378 let store: Arc<dyn RevocationStore> = match config.backend.as_str() {
379 #[cfg(feature = "redis-rate-limiting")]
380 "redis" => {
381 let url = config.redis_url.as_deref().unwrap_or("redis://localhost:6379");
382 match RedisRevocationStore::new(url) {
383 Ok(s) => {
384 info!(backend = "redis", "Token revocation store initialized");
385 Arc::new(s)
386 },
387 Err(e) => {
388 warn!(error = %e, "Failed to init Redis revocation store — falling back to in-memory");
389 Arc::new(InMemoryRevocationStore::new())
390 },
391 }
392 },
393 #[cfg(not(feature = "redis-rate-limiting"))]
394 "redis" => {
395 warn!(
396 "token_revocation.backend = \"redis\" but the `redis-rate-limiting` feature is \
397 not compiled in. Falling back to in-memory."
398 );
399 Arc::new(InMemoryRevocationStore::new())
400 },
401 "memory" | "env" => {
402 info!(backend = "memory", "Token revocation store initialized (in-memory)");
403 Arc::new(InMemoryRevocationStore::new())
404 },
405 other => {
406 warn!(backend = %other, "Unknown revocation backend — falling back to in-memory");
407 Arc::new(InMemoryRevocationStore::new())
408 },
409 };
410
411 Some(Arc::new(TokenRevocationManager::new(
412 store,
413 config.require_jti,
414 config.fail_open,
415 )))
416}
417
418#[cfg(test)]
423mod tests {
424 #![allow(clippy::unwrap_used)] use super::*;
427
428 fn memory_store() -> Arc<dyn RevocationStore> {
429 Arc::new(InMemoryRevocationStore::new())
430 }
431
432 #[tokio::test]
433 async fn revoke_then_check_is_revoked() {
434 let store = memory_store();
435 store.revoke("jti-1", 3600).await.unwrap();
436 assert!(store.is_revoked("jti-1").await.unwrap());
437 }
438
439 #[tokio::test]
440 async fn non_revoked_jti_passes() {
441 let store = memory_store();
442 assert!(!store.is_revoked("jti-unknown").await.unwrap());
443 }
444
445 #[tokio::test]
446 async fn expired_entry_not_revoked() {
447 let store = InMemoryRevocationStore::new();
448 store.revoke("jti-expired", 0).await.unwrap();
450 assert!(!store.is_revoked("jti-expired").await.unwrap());
452 }
453
454 #[tokio::test]
455 async fn cleanup_removes_expired() {
456 let store = InMemoryRevocationStore::new();
457 store.revoke("jti-a", 0).await.unwrap();
458 store.revoke("jti-b", 3600).await.unwrap();
459 store.cleanup_expired();
460 assert_eq!(store.entries.len(), 1);
462 }
463
464 #[tokio::test]
465 async fn manager_rejects_revoked_token() {
466 let store = memory_store();
467 store.revoke("jti-x", 3600).await.unwrap();
468 let mgr = TokenRevocationManager::new(store, true, false);
469 assert_eq!(mgr.check_token(Some("jti-x")).await, Err(TokenRejection::Revoked));
470 }
471
472 #[tokio::test]
473 async fn manager_allows_non_revoked_token() {
474 let mgr = TokenRevocationManager::new(memory_store(), true, false);
475 mgr.check_token(Some("jti-ok"))
476 .await
477 .unwrap_or_else(|e| panic!("expected Ok for non-revoked token: {e:?}"));
478 }
479
480 #[tokio::test]
481 async fn manager_rejects_missing_jti_when_required() {
482 let mgr = TokenRevocationManager::new(memory_store(), true, false);
483 assert_eq!(mgr.check_token(None).await, Err(TokenRejection::MissingJti));
484 }
485
486 #[tokio::test]
487 async fn manager_allows_missing_jti_when_not_required() {
488 let mgr = TokenRevocationManager::new(memory_store(), false, false);
489 assert!(
490 mgr.check_token(None).await.is_ok(),
491 "missing jti should be allowed when jti is not required"
492 );
493 }
494
495 #[tokio::test]
496 async fn manager_allows_empty_jti_when_not_required() {
497 let mgr = TokenRevocationManager::new(memory_store(), false, false);
498 assert!(
499 mgr.check_token(Some("")).await.is_ok(),
500 "empty jti should be allowed when jti is not required"
501 );
502 }
503}