use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use async_trait::async_trait;
use super::{AuditEvent, AuditSink, RateLimiter};
#[derive(Debug)]
pub struct RateLimitedAuditSink {
inner: Arc<dyn AuditSink>,
limiter: Arc<dyn RateLimiter>,
dropped_total: AtomicU64,
}
impl RateLimitedAuditSink {
#[must_use]
pub fn new(inner: Arc<dyn AuditSink>, limiter: Arc<dyn RateLimiter>) -> Self {
Self {
inner,
limiter,
dropped_total: AtomicU64::new(0),
}
}
#[must_use]
pub fn dropped_total(&self) -> u64 {
self.dropped_total.load(Ordering::Relaxed)
}
}
#[async_trait]
impl AuditSink for RateLimitedAuditSink {
async fn record_failure(&self, event: AuditEvent) {
let key = event.rate_limit_key();
if self.limiter.allow(&key).await {
self.inner.record_failure(event).await;
} else {
self.dropped_total.fetch_add(1, Ordering::Relaxed);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::audit::{
AuditEvent, MemoryAuditSink, MemoryRateLimiter, RateLimiter, VerifyErrorKind,
};
use std::collections::BTreeMap;
use std::time::Duration;
use time::OffsetDateTime;
fn fixture(source_id: &str) -> AuditEvent {
AuditEvent {
kind: VerifyErrorKind::SignatureInvalid,
occurred_at: OffsetDateTime::UNIX_EPOCH,
source_id: source_id.to_owned(),
client_id_hint: None,
kid_hint: None,
metadata: BTreeMap::new(),
}
}
#[tokio::test]
async fn admitted_event_reaches_inner_sink() {
let memory = Arc::new(MemoryAuditSink::new());
let limiter: Arc<dyn RateLimiter> =
Arc::new(MemoryRateLimiter::new(10, Duration::from_secs(60)));
let limited =
RateLimitedAuditSink::new(memory.clone() as Arc<dyn AuditSink>, limiter);
limited.record_failure(fixture("rcw::k1")).await;
assert_eq!(memory.len(), 1);
assert_eq!(limited.dropped_total(), 0);
}
#[tokio::test]
async fn dropped_total_only_counts_actual_drops() {
let memory = Arc::new(MemoryAuditSink::new());
let limiter: Arc<dyn RateLimiter> =
Arc::new(MemoryRateLimiter::new(1, Duration::from_secs(60)));
let limited =
RateLimitedAuditSink::new(memory.clone() as Arc<dyn AuditSink>, limiter);
limited.record_failure(fixture("rcw::k1")).await;
limited.record_failure(fixture("rcw::k1")).await;
limited.record_failure(fixture("rcw::k1")).await;
assert_eq!(memory.len(), 1);
assert_eq!(limited.dropped_total(), 2);
limited.record_failure(fixture("rcw::k2")).await;
assert_eq!(memory.len(), 2);
assert_eq!(limited.dropped_total(), 2);
}
}