#![cfg(feature = "dpop-replay-cache")]
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use lru::LruCache;
use thiserror::Error;
use tokio::sync::Mutex;
pub const DEFAULT_TTL_SECS: u64 = 60;
pub const DEFAULT_MAX_SIZE: usize = 10_000;
pub const ENV_TTL_SECS: &str = "SOLID_POD_DPOP_REPLAY_TTL_SECS";
pub const ENV_MAX_SIZE: &str = "SOLID_POD_DPOP_REPLAY_MAX_SIZE";
#[derive(Debug, Error)]
pub enum ReplayError {
#[error("DPoP jti already used within TTL window ({ttl:?})")]
Replayed { ttl: Duration },
}
#[derive(Debug, Clone)]
pub struct DpopReplayCache {
inner: Arc<Mutex<LruCacheInner>>,
ttl: Duration,
max_size: usize,
}
#[derive(Debug)]
struct LruCacheInner {
entries: LruCache<String, Instant>,
}
impl DpopReplayCache {
pub fn from_env() -> Self {
let ttl_secs = std::env::var(ENV_TTL_SECS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(DEFAULT_TTL_SECS);
let max_size = std::env::var(ENV_MAX_SIZE)
.ok()
.and_then(|s| s.parse::<usize>().ok())
.filter(|n| *n > 0)
.unwrap_or(DEFAULT_MAX_SIZE);
Self::with_config(Duration::from_secs(ttl_secs), max_size)
}
pub fn with_config(ttl: Duration, max_size: usize) -> Self {
let cap = NonZeroUsize::new(max_size.max(1))
.expect("max_size clamped to >= 1 above");
Self {
inner: Arc::new(Mutex::new(LruCacheInner {
entries: LruCache::new(cap),
})),
ttl,
max_size: max_size.max(1),
}
}
pub fn ttl(&self) -> Duration {
self.ttl
}
pub fn max_size(&self) -> usize {
self.max_size
}
pub async fn len(&self) -> usize {
self.inner.lock().await.entries.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
pub async fn check_and_record(&self, jti: &str) -> Result<(), ReplayError> {
let now = Instant::now();
let mut guard = self.inner.lock().await;
if let Some(first_seen) = guard.entries.peek(jti).copied() {
let age = now.saturating_duration_since(first_seen);
if age < self.ttl {
return Err(ReplayError::Replayed { ttl: self.ttl });
}
}
guard.entries.put(jti.to_string(), now);
Ok(())
}
pub async fn evict_expired(&self) -> usize {
let now = Instant::now();
let mut guard = self.inner.lock().await;
let expired: Vec<String> = guard
.entries
.iter()
.filter_map(|(jti, seen)| {
if now.saturating_duration_since(*seen) >= self.ttl {
Some(jti.clone())
} else {
None
}
})
.collect();
let removed = expired.len();
for jti in expired {
guard.entries.pop(&jti);
}
removed
}
pub fn spawn_evictor(self, period: Duration) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(period);
ticker.tick().await;
loop {
ticker.tick().await;
let _ = self.evict_expired().await;
}
})
}
}
impl Default for DpopReplayCache {
fn default() -> Self {
Self::with_config(Duration::from_secs(DEFAULT_TTL_SECS), DEFAULT_MAX_SIZE)
}
}
#[derive(Debug, Default)]
pub struct ReplayRejectedCounter {
value: std::sync::atomic::AtomicU64,
}
impl ReplayRejectedCounter {
pub const fn new() -> Self {
Self {
value: std::sync::atomic::AtomicU64::new(0),
}
}
pub fn increment(&self) {
self.value
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn get(&self) -> u64 {
self.value.load(std::sync::atomic::Ordering::Relaxed)
}
}
pub static DPOP_REPLAY_REJECTED_TOTAL: ReplayRejectedCounter =
ReplayRejectedCounter::new();
use std::sync::Mutex as StdMutex;
use std::time::SystemTime;
pub const JTI_DEFAULT_CAPACITY: usize = 10_000;
pub const JTI_DEFAULT_TTL: Duration = Duration::from_secs(5 * 60);
#[derive(Debug, Clone)]
pub struct JtiReplayCache {
inner: Arc<StdMutex<JtiInner>>,
ttl: Duration,
capacity: usize,
}
#[derive(Debug)]
struct JtiInner {
entries: LruCache<String, SystemTime>,
}
impl JtiReplayCache {
pub fn new(capacity: usize, ttl: Duration) -> Self {
let cap = NonZeroUsize::new(capacity.max(1))
.expect("capacity clamped to >= 1 above");
Self {
inner: Arc::new(StdMutex::new(JtiInner {
entries: LruCache::new(cap),
})),
ttl,
capacity: capacity.max(1),
}
}
pub fn ttl(&self) -> Duration {
self.ttl
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn len(&self) -> usize {
self.inner
.lock()
.unwrap_or_else(|e| e.into_inner())
.entries
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn check_and_insert(
&self,
jti: &str,
now: SystemTime,
) -> Result<(), ReplayError> {
let mut guard = self
.inner
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(&first_seen) = guard.entries.peek(jti) {
let age = now
.duration_since(first_seen)
.unwrap_or(Duration::ZERO);
if age < self.ttl {
return Err(ReplayError::Replayed { ttl: self.ttl });
}
}
guard.entries.put(jti.to_string(), now);
Ok(())
}
}
impl Default for JtiReplayCache {
fn default() -> Self {
Self::new(JTI_DEFAULT_CAPACITY, JTI_DEFAULT_TTL)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn first_seen_jti_is_accepted() {
let cache = DpopReplayCache::with_config(Duration::from_secs(60), 16);
assert!(cache.check_and_record("jti-1").await.is_ok());
assert_eq!(cache.len().await, 1);
}
#[tokio::test]
async fn replay_within_ttl_is_rejected() {
let cache = DpopReplayCache::with_config(Duration::from_secs(60), 16);
cache.check_and_record("jti-1").await.unwrap();
let err = cache.check_and_record("jti-1").await.unwrap_err();
assert!(matches!(err, ReplayError::Replayed { .. }));
assert_eq!(cache.len().await, 1);
}
#[tokio::test]
async fn default_config_matches_constants() {
let cache = DpopReplayCache::default();
assert_eq!(cache.ttl(), Duration::from_secs(DEFAULT_TTL_SECS));
assert_eq!(cache.max_size(), DEFAULT_MAX_SIZE);
}
#[tokio::test]
async fn max_size_clamped_to_at_least_one() {
let cache = DpopReplayCache::with_config(Duration::from_secs(1), 0);
assert_eq!(cache.max_size(), 1);
}
#[test]
fn counter_increments() {
let c = ReplayRejectedCounter::new();
assert_eq!(c.get(), 0);
c.increment();
c.increment();
assert_eq!(c.get(), 2);
}
}