use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use tokio::sync::Mutex;
use super::interceptor::{A2aDelegationContext, A2aError, A2aInterceptor, InterceptorDecision};
#[derive(Debug, Clone)]
struct TokenBucket {
tokens: f64,
last_refill: Instant,
}
impl TokenBucket {
fn new(capacity: u32) -> Self {
Self { tokens: f64::from(capacity), last_refill: Instant::now() }
}
fn try_consume(&mut self, rps: u32, burst: u32) -> bool {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.last_refill = now;
self.tokens = (self.tokens + elapsed * f64::from(rps)).min(f64::from(burst));
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitInterceptor {
pub rps: u32,
pub burst: u32,
buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
}
impl RateLimitInterceptor {
pub fn new(rps: u32, burst: u32) -> Self {
Self { rps, burst, buckets: Arc::new(Mutex::new(HashMap::new())) }
}
fn bucket_key(ctx: &A2aDelegationContext) -> String {
ctx.caller_id.clone().unwrap_or_else(|| "__global__".to_string())
}
}
#[async_trait]
impl A2aInterceptor for RateLimitInterceptor {
async fn before_delegation(
&self,
ctx: &mut A2aDelegationContext,
) -> Result<InterceptorDecision, A2aError> {
let key = Self::bucket_key(ctx);
let mut buckets = self.buckets.lock().await;
let bucket = buckets.entry(key).or_insert_with(|| TokenBucket::new(self.burst));
if bucket.try_consume(self.rps, self.burst) {
Ok(InterceptorDecision::Continue)
} else {
Ok(InterceptorDecision::Reject {
code: -32002,
message: "rate limit exceeded".to_string(),
})
}
}
async fn after_delegation(
&self,
_ctx: &A2aDelegationContext,
_response: &mut serde_json::Value,
) -> Result<(), A2aError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn make_ctx(caller_id: Option<&str>) -> A2aDelegationContext {
A2aDelegationContext {
method: "tasks/send".to_string(),
params: serde_json::json!({}),
caller_id: caller_id.map(String::from),
metadata: HashMap::new(),
}
}
#[tokio::test]
async fn test_first_request_allowed() {
let limiter = RateLimitInterceptor::new(10, 10);
let mut ctx = make_ctx(Some("client-1"));
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
}
#[tokio::test]
async fn test_burst_capacity_allows_multiple_requests() {
let limiter = RateLimitInterceptor::new(1, 5);
let mut ctx = make_ctx(Some("client-1"));
for _ in 0..5 {
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
}
}
#[tokio::test]
async fn test_exceeding_burst_rejects() {
let limiter = RateLimitInterceptor::new(1, 3);
let mut ctx = make_ctx(Some("client-1"));
for _ in 0..3 {
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
}
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
match decision {
InterceptorDecision::Reject { code, message } => {
assert_eq!(code, -32002);
assert_eq!(message, "rate limit exceeded");
}
_ => panic!("expected Reject"),
}
}
#[tokio::test]
async fn test_per_client_isolation() {
let limiter = RateLimitInterceptor::new(1, 2);
let mut ctx1 = make_ctx(Some("client-1"));
for _ in 0..2 {
limiter.before_delegation(&mut ctx1).await.unwrap();
}
let decision = limiter.before_delegation(&mut ctx1).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Reject { .. }));
let mut ctx2 = make_ctx(Some("client-2"));
let decision = limiter.before_delegation(&mut ctx2).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
}
#[tokio::test]
async fn test_no_caller_id_uses_global_bucket() {
let limiter = RateLimitInterceptor::new(1, 2);
let mut ctx = make_ctx(None);
for _ in 0..2 {
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
}
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Reject { .. }));
}
#[tokio::test]
async fn test_tokens_refill_over_time() {
let limiter = RateLimitInterceptor::new(100, 1);
let mut ctx = make_ctx(Some("client-1"));
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Reject { .. }));
tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
}
#[tokio::test]
async fn test_after_delegation_is_noop() {
let limiter = RateLimitInterceptor::new(10, 10);
let ctx = A2aDelegationContext {
method: "tasks/send".to_string(),
params: serde_json::json!({}),
caller_id: Some("client-1".to_string()),
metadata: HashMap::new(),
};
let mut response = serde_json::json!({"result": "ok"});
let result = limiter.after_delegation(&ctx, &mut response).await;
assert!(result.is_ok());
assert_eq!(response, serde_json::json!({"result": "ok"}));
}
#[tokio::test]
async fn test_rejection_code_is_minus_32002() {
let limiter = RateLimitInterceptor::new(1, 0);
let mut ctx = make_ctx(Some("client-1"));
let decision = limiter.before_delegation(&mut ctx).await.unwrap();
match decision {
InterceptorDecision::Reject { code, message } => {
assert_eq!(code, -32002);
assert_eq!(message, "rate limit exceeded");
}
_ => panic!("expected Reject with burst=0"),
}
}
}