use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use a2a_protocol_types::error::{A2aError, A2aResult};
use tokio::sync::RwLock;
use crate::call_context::CallContext;
use crate::interceptor::ServerInterceptor;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_window: u64,
pub window_secs: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_window: 100,
window_secs: 60,
}
}
}
struct CallerBucket {
window_start: AtomicU64,
count: AtomicU64,
}
pub struct RateLimitInterceptor {
config: RateLimitConfig,
buckets: RwLock<HashMap<String, CallerBucket>>,
check_count: AtomicU64,
}
const CLEANUP_INTERVAL: u64 = 256;
impl std::fmt::Debug for RateLimitInterceptor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimitInterceptor")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl RateLimitInterceptor {
#[must_use]
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: RwLock::new(HashMap::new()),
check_count: AtomicU64::new(0),
}
}
fn caller_key(ctx: &CallContext) -> String {
if let Some(identity) = ctx.caller_identity() {
return identity.to_owned();
}
if let Some(xff) = ctx.http_headers().get("x-forwarded-for") {
if let Some(ip) = xff.split(',').next() {
return ip.trim().to_string();
}
}
"anonymous".to_string()
}
const fn window_number(&self, now_secs: u64) -> u64 {
now_secs / self.config.window_secs
}
async fn cleanup_stale_buckets(&self) {
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let current_window = self.window_number(now_secs);
let mut buckets = self.buckets.write().await;
buckets.retain(|_, bucket| {
bucket.window_start.load(Ordering::Relaxed) >= current_window.saturating_sub(1)
});
}
#[allow(clippy::too_many_lines)]
async fn check(&self, key: &str) -> A2aResult<()> {
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let current_window = self.window_number(now_secs);
let count = self.check_count.fetch_add(1, Ordering::Relaxed);
if count > 0 && count.is_multiple_of(CLEANUP_INTERVAL) {
self.cleanup_stale_buckets().await;
}
{
let buckets = self.buckets.read().await;
if let Some(bucket) = buckets.get(key) {
loop {
let bucket_window = bucket.window_start.load(Ordering::Acquire);
if bucket_window == current_window {
let count = bucket.count.fetch_add(1, Ordering::Relaxed) + 1;
if count > self.config.requests_per_window {
return Err(A2aError::internal(format!(
"rate limit exceeded: {} requests per {} seconds",
self.config.requests_per_window, self.config.window_secs
)));
}
return Ok(());
}
if bucket
.window_start
.compare_exchange(
bucket_window,
current_window,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
bucket.count.store(1, Ordering::Release);
return Ok(());
}
}
}
}
let mut buckets = self.buckets.write().await;
if let Some(bucket) = buckets.get(key) {
let bucket_window = bucket.window_start.load(Ordering::Acquire);
if bucket_window == current_window {
let count = bucket.count.fetch_add(1, Ordering::Relaxed) + 1;
if count > self.config.requests_per_window {
return Err(A2aError::internal(format!(
"rate limit exceeded: {} requests per {} seconds",
self.config.requests_per_window, self.config.window_secs
)));
}
} else {
bucket.window_start.store(current_window, Ordering::Release);
bucket.count.store(1, Ordering::Release);
}
return Ok(());
}
buckets.insert(
key.to_string(),
CallerBucket {
window_start: AtomicU64::new(current_window),
count: AtomicU64::new(1),
},
);
drop(buckets);
Ok(())
}
}
impl ServerInterceptor for RateLimitInterceptor {
fn before<'a>(
&'a self,
ctx: &'a CallContext,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
let key = Self::caller_key(ctx);
self.check(&key).await
})
}
fn after<'a>(
&'a self,
_ctx: &'a CallContext,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async { Ok(()) })
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn make_ctx(identity: Option<&str>) -> CallContext {
let mut ctx = CallContext::new("message/send");
if let Some(id) = identity {
ctx = ctx.with_caller_identity(id.to_owned());
}
ctx
}
#[tokio::test]
async fn allows_requests_within_limit() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 5,
window_secs: 60,
});
let ctx = make_ctx(Some("user-1"));
for _ in 0..5 {
assert!(limiter.before(&ctx).await.is_ok());
}
}
#[tokio::test]
async fn rejects_requests_over_limit() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 3,
window_secs: 60,
});
let ctx = make_ctx(Some("user-2"));
for _ in 0..3 {
assert!(limiter.before(&ctx).await.is_ok());
}
let result = limiter.before(&ctx).await;
assert!(result.is_err());
}
#[tokio::test]
async fn different_callers_have_separate_limits() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 2,
window_secs: 60,
});
let ctx_a = make_ctx(Some("alice"));
let ctx_b = make_ctx(Some("bob"));
assert!(limiter.before(&ctx_a).await.is_ok());
assert!(limiter.before(&ctx_a).await.is_ok());
assert!(limiter.before(&ctx_a).await.is_err());
assert!(limiter.before(&ctx_b).await.is_ok());
assert!(limiter.before(&ctx_b).await.is_ok());
}
#[tokio::test]
async fn anonymous_fallback_when_no_identity() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 1,
window_secs: 60,
});
let ctx = make_ctx(None);
assert!(limiter.before(&ctx).await.is_ok());
assert!(limiter.before(&ctx).await.is_err());
}
#[tokio::test]
async fn uses_x_forwarded_for_when_no_identity() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 1,
window_secs: 60,
});
let mut headers = HashMap::new();
headers.insert(
"x-forwarded-for".to_string(),
"10.0.0.1, 10.0.0.2".to_string(),
);
let ctx = CallContext::new("message/send").with_http_headers(headers);
assert!(limiter.before(&ctx).await.is_ok());
assert!(limiter.before(&ctx).await.is_err());
}
#[tokio::test]
async fn concurrent_rate_limit_checks() {
use std::sync::Arc;
let limiter = Arc::new(RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 100,
window_secs: 60,
}));
let mut handles = Vec::new();
for _ in 0..200 {
let lim = Arc::clone(&limiter);
handles.push(tokio::spawn(async move {
let ctx =
CallContext::new("message/send").with_caller_identity("concurrent-user".into());
lim.before(&ctx).await
}));
}
let mut ok_count = 0;
let mut err_count = 0;
for handle in handles {
match handle.await.unwrap() {
Ok(()) => ok_count += 1,
Err(_) => err_count += 1,
}
}
assert_eq!(ok_count, 100, "expected 100 allowed, got {ok_count}");
assert_eq!(err_count, 100, "expected 100 rejected, got {err_count}");
}
#[tokio::test]
async fn stale_bucket_cleanup() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 10,
window_secs: 60,
});
let ctx_a = make_ctx(Some("stale-a"));
let ctx_b = make_ctx(Some("stale-b"));
assert!(limiter.before(&ctx_a).await.is_ok());
assert!(limiter.before(&ctx_b).await.is_ok());
assert_eq!(limiter.buckets.read().await.len(), 2);
limiter.cleanup_stale_buckets().await;
assert_eq!(
limiter.buckets.read().await.len(),
2,
"current-window buckets should not be evicted"
);
}
#[test]
fn debug_format_includes_config() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 42,
window_secs: 10,
});
let debug = format!("{limiter:?}");
assert!(
debug.contains("RateLimitInterceptor"),
"Debug output should contain struct name"
);
assert!(
debug.contains("config"),
"Debug output should contain config field"
);
}
#[test]
fn default_config_values() {
let config = RateLimitConfig::default();
assert_eq!(config.requests_per_window, 100);
assert_eq!(config.window_secs, 60);
}
#[tokio::test]
async fn after_hook_is_noop() {
let limiter = RateLimitInterceptor::new(RateLimitConfig::default());
let ctx = make_ctx(Some("user"));
let result = limiter.after(&ctx).await;
assert_eq!(result.unwrap(), (), "after hook should return Ok(())");
}
#[test]
fn window_number_correctness() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 10,
window_secs: 60,
});
assert_eq!(limiter.window_number(0), 0);
assert_eq!(limiter.window_number(59), 0);
assert_eq!(limiter.window_number(60), 1);
assert_eq!(limiter.window_number(120), 2);
assert_eq!(limiter.window_number(61), 1);
}
#[tokio::test]
async fn cleanup_stale_buckets_removes_old_entries() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 100,
window_secs: 60,
});
{
let mut buckets = limiter.buckets.write().await;
buckets.insert(
"ancient-user".to_string(),
CallerBucket {
window_start: AtomicU64::new(0), count: AtomicU64::new(5),
},
);
}
assert_eq!(limiter.buckets.read().await.len(), 1);
limiter.cleanup_stale_buckets().await;
assert_eq!(
limiter.buckets.read().await.len(),
0,
"ancient bucket should be evicted"
);
}
#[tokio::test]
async fn check_triggers_cleanup_at_interval() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 10000,
window_secs: 60,
});
{
let mut buckets = limiter.buckets.write().await;
buckets.insert(
"stale-for-cleanup".to_string(),
CallerBucket {
window_start: AtomicU64::new(0),
count: AtomicU64::new(1),
},
);
}
limiter
.check_count
.store(CLEANUP_INTERVAL, Ordering::Relaxed);
let ctx = make_ctx(Some("cleanup-trigger-user"));
assert!(limiter.before(&ctx).await.is_ok());
let buckets = limiter.buckets.read().await;
let has_stale = buckets.contains_key("stale-for-cleanup");
drop(buckets);
assert!(
!has_stale,
"stale bucket should be cleaned up after CLEANUP_INTERVAL checks"
);
}
#[tokio::test]
async fn slow_path_double_check_same_window() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 2,
window_secs: 60,
});
let ctx = make_ctx(Some("race-user"));
assert!(limiter.before(&ctx).await.is_ok());
assert!(limiter.before(&ctx).await.is_ok());
assert!(limiter.before(&ctx).await.is_err());
}
#[tokio::test]
async fn slow_path_double_check_stale_window() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 10,
window_secs: 60,
});
let key = "slow-path-stale";
{
let mut buckets = limiter.buckets.write().await;
buckets.insert(
key.to_string(),
CallerBucket {
window_start: AtomicU64::new(1), count: AtomicU64::new(5),
},
);
}
let result = limiter.check(key).await;
assert!(
result.is_ok(),
"slow-path stale-window reset should succeed"
);
assert_eq!(
limiter
.buckets
.read()
.await
.get(key)
.expect("bucket should exist")
.count
.load(Ordering::Relaxed),
1,
"count should be reset to 1 after window advance"
);
}
#[tokio::test]
async fn slow_path_rate_limit_exceeded() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 1,
window_secs: 60,
});
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let current_window = limiter.window_number(now_secs);
let key = "slow-path-exceeded";
{
let mut buckets = limiter.buckets.write().await;
buckets.insert(
key.to_string(),
CallerBucket {
window_start: AtomicU64::new(current_window),
count: AtomicU64::new(1), },
);
}
let result = limiter.check(key).await;
assert!(
result.is_err(),
"slow-path should reject when count exceeds limit"
);
}
#[tokio::test]
async fn fast_path_rate_limit_exceeded() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 2,
window_secs: 60,
});
let ctx = make_ctx(Some("fast-path-user"));
assert!(limiter.before(&ctx).await.is_ok());
assert!(limiter.before(&ctx).await.is_ok());
let result = limiter.before(&ctx).await;
assert!(
result.is_err(),
"fast-path should reject when count exceeds limit"
);
let err = result.unwrap_err();
assert!(
err.to_string().contains("rate limit exceeded"),
"error message should mention rate limit exceeded, got: {err}"
);
}
#[tokio::test]
async fn fast_path_window_advancement_resets_count() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 1,
window_secs: 60,
});
let key = "fast-path-window-advance";
{
let mut buckets = limiter.buckets.write().await;
buckets.insert(
key.to_string(),
CallerBucket {
window_start: AtomicU64::new(1), count: AtomicU64::new(999),
},
);
}
let result = limiter.check(key).await;
assert_eq!(
result.unwrap(),
(),
"fast-path window advance should return Ok(())"
);
assert_eq!(
limiter
.buckets
.read()
.await
.get(key)
.expect("bucket should exist")
.count
.load(Ordering::Relaxed),
1,
"count should be reset to 1 after window advance"
);
}
#[tokio::test]
async fn cleanup_does_not_run_on_first_call() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 10000,
window_secs: 60,
});
{
let mut buckets = limiter.buckets.write().await;
buckets.insert(
"stale-first-call".to_string(),
CallerBucket {
window_start: AtomicU64::new(0),
count: AtomicU64::new(1),
},
);
}
let ctx = make_ctx(Some("first-caller"));
assert!(limiter.before(&ctx).await.is_ok());
assert!(
limiter
.buckets
.read()
.await
.contains_key("stale-first-call"),
"stale bucket should not be cleaned up on the very first call"
);
}
#[tokio::test]
async fn x_forwarded_for_single_ip() {
let limiter = RateLimitInterceptor::new(RateLimitConfig {
requests_per_window: 1,
window_secs: 60,
});
let mut headers = HashMap::new();
headers.insert("x-forwarded-for".to_string(), "192.168.1.1".to_string());
let ctx = CallContext::new("message/send").with_http_headers(headers);
assert!(limiter.before(&ctx).await.is_ok());
assert!(limiter.before(&ctx).await.is_err());
}
}