1use std::collections::HashMap;
43use std::future::Future;
44use std::pin::Pin;
45use std::sync::atomic::{AtomicU64, Ordering};
46use std::time::{SystemTime, UNIX_EPOCH};
47
48use a2a_protocol_types::error::{A2aError, A2aResult};
49use tokio::sync::RwLock;
50
51use crate::call_context::CallContext;
52use crate::interceptor::ServerInterceptor;
53
54#[derive(Debug, Clone)]
56pub struct RateLimitConfig {
57 pub requests_per_window: u64,
59
60 pub window_secs: u64,
62}
63
64impl Default for RateLimitConfig {
65 fn default() -> Self {
66 Self {
67 requests_per_window: 100,
68 window_secs: 60,
69 }
70 }
71}
72
73struct CallerBucket {
75 window_start: AtomicU64,
77 count: AtomicU64,
79}
80
81pub struct RateLimitInterceptor {
91 config: RateLimitConfig,
92 buckets: RwLock<HashMap<String, CallerBucket>>,
93 check_count: AtomicU64,
95}
96
97const CLEANUP_INTERVAL: u64 = 256;
99
100impl std::fmt::Debug for RateLimitInterceptor {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("RateLimitInterceptor")
103 .field("config", &self.config)
104 .finish_non_exhaustive()
105 }
106}
107
108impl RateLimitInterceptor {
109 #[must_use]
111 pub fn new(config: RateLimitConfig) -> Self {
112 Self {
113 config,
114 buckets: RwLock::new(HashMap::new()),
115 check_count: AtomicU64::new(0),
116 }
117 }
118
119 fn caller_key(ctx: &CallContext) -> String {
121 if let Some(identity) = ctx.caller_identity() {
122 return identity.to_owned();
123 }
124 if let Some(xff) = ctx.http_headers().get("x-forwarded-for") {
125 if let Some(ip) = xff.split(',').next() {
127 return ip.trim().to_string();
128 }
129 }
130 "anonymous".to_string()
131 }
132
133 const fn window_number(&self, now_secs: u64) -> u64 {
135 now_secs / self.config.window_secs
136 }
137
138 async fn cleanup_stale_buckets(&self) {
143 let now_secs = SystemTime::now()
144 .duration_since(UNIX_EPOCH)
145 .unwrap_or_default()
146 .as_secs();
147 let current_window = self.window_number(now_secs);
148
149 let mut buckets = self.buckets.write().await;
150 buckets.retain(|_, bucket| {
151 bucket.window_start.load(Ordering::Relaxed) >= current_window.saturating_sub(1)
152 });
153 }
154
155 #[allow(clippy::too_many_lines)]
157 async fn check(&self, key: &str) -> A2aResult<()> {
158 let now_secs = SystemTime::now()
159 .duration_since(UNIX_EPOCH)
160 .unwrap_or_default()
161 .as_secs();
162 let current_window = self.window_number(now_secs);
163
164 let count = self.check_count.fetch_add(1, Ordering::Relaxed);
166 if count > 0 && count.is_multiple_of(CLEANUP_INTERVAL) {
167 self.cleanup_stale_buckets().await;
168 }
169
170 {
172 let buckets = self.buckets.read().await;
173 if let Some(bucket) = buckets.get(key) {
174 loop {
178 let bucket_window = bucket.window_start.load(Ordering::Acquire);
179 if bucket_window == current_window {
180 let count = bucket.count.fetch_add(1, Ordering::Relaxed) + 1;
181 if count > self.config.requests_per_window {
182 return Err(A2aError::internal(format!(
183 "rate limit exceeded: {} requests per {} seconds",
184 self.config.requests_per_window, self.config.window_secs
185 )));
186 }
187 return Ok(());
188 }
189 if bucket
193 .window_start
194 .compare_exchange(
195 bucket_window,
196 current_window,
197 Ordering::AcqRel,
198 Ordering::Acquire,
199 )
200 .is_ok()
201 {
202 bucket.count.store(1, Ordering::Release);
203 return Ok(());
204 }
205 }
207 }
208 }
209
210 let mut buckets = self.buckets.write().await;
212 if let Some(bucket) = buckets.get(key) {
214 let bucket_window = bucket.window_start.load(Ordering::Acquire);
215 if bucket_window == current_window {
216 let count = bucket.count.fetch_add(1, Ordering::Relaxed) + 1;
217 if count > self.config.requests_per_window {
218 return Err(A2aError::internal(format!(
219 "rate limit exceeded: {} requests per {} seconds",
220 self.config.requests_per_window, self.config.window_secs
221 )));
222 }
223 } else {
224 bucket.window_start.store(current_window, Ordering::Release);
225 bucket.count.store(1, Ordering::Release);
226 }
227 return Ok(());
228 }
229 buckets.insert(
230 key.to_string(),
231 CallerBucket {
232 window_start: AtomicU64::new(current_window),
233 count: AtomicU64::new(1),
234 },
235 );
236 drop(buckets);
237 Ok(())
238 }
239}
240
241impl ServerInterceptor for RateLimitInterceptor {
242 fn before<'a>(
243 &'a self,
244 ctx: &'a CallContext,
245 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
246 Box::pin(async move {
247 let key = Self::caller_key(ctx);
248 self.check(&key).await
249 })
250 }
251
252 fn after<'a>(
253 &'a self,
254 _ctx: &'a CallContext,
255 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
256 Box::pin(async { Ok(()) })
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use std::collections::HashMap;
264
265 fn make_ctx(identity: Option<&str>) -> CallContext {
266 let mut ctx = CallContext::new("message/send");
267 if let Some(id) = identity {
268 ctx = ctx.with_caller_identity(id.to_owned());
269 }
270 ctx
271 }
272
273 #[tokio::test]
274 async fn allows_requests_within_limit() {
275 let limiter = RateLimitInterceptor::new(RateLimitConfig {
276 requests_per_window: 5,
277 window_secs: 60,
278 });
279 let ctx = make_ctx(Some("user-1"));
280 for _ in 0..5 {
281 assert!(limiter.before(&ctx).await.is_ok());
282 }
283 }
284
285 #[tokio::test]
286 async fn rejects_requests_over_limit() {
287 let limiter = RateLimitInterceptor::new(RateLimitConfig {
288 requests_per_window: 3,
289 window_secs: 60,
290 });
291 let ctx = make_ctx(Some("user-2"));
292 for _ in 0..3 {
293 assert!(limiter.before(&ctx).await.is_ok());
294 }
295 let result = limiter.before(&ctx).await;
296 assert!(result.is_err());
297 }
298
299 #[tokio::test]
300 async fn different_callers_have_separate_limits() {
301 let limiter = RateLimitInterceptor::new(RateLimitConfig {
302 requests_per_window: 2,
303 window_secs: 60,
304 });
305 let ctx_a = make_ctx(Some("alice"));
306 let ctx_b = make_ctx(Some("bob"));
307
308 assert!(limiter.before(&ctx_a).await.is_ok());
309 assert!(limiter.before(&ctx_a).await.is_ok());
310 assert!(limiter.before(&ctx_a).await.is_err()); assert!(limiter.before(&ctx_b).await.is_ok());
314 assert!(limiter.before(&ctx_b).await.is_ok());
315 }
316
317 #[tokio::test]
318 async fn anonymous_fallback_when_no_identity() {
319 let limiter = RateLimitInterceptor::new(RateLimitConfig {
320 requests_per_window: 1,
321 window_secs: 60,
322 });
323 let ctx = make_ctx(None);
324 assert!(limiter.before(&ctx).await.is_ok());
325 assert!(limiter.before(&ctx).await.is_err());
326 }
327
328 #[tokio::test]
329 async fn uses_x_forwarded_for_when_no_identity() {
330 let limiter = RateLimitInterceptor::new(RateLimitConfig {
331 requests_per_window: 1,
332 window_secs: 60,
333 });
334 let mut headers = HashMap::new();
335 headers.insert(
336 "x-forwarded-for".to_string(),
337 "10.0.0.1, 10.0.0.2".to_string(),
338 );
339 let ctx = CallContext::new("message/send").with_http_headers(headers);
340 assert!(limiter.before(&ctx).await.is_ok());
341 assert!(limiter.before(&ctx).await.is_err());
342 }
343
344 #[tokio::test]
345 async fn concurrent_rate_limit_checks() {
346 use std::sync::Arc;
347
348 let limiter = Arc::new(RateLimitInterceptor::new(RateLimitConfig {
349 requests_per_window: 100,
350 window_secs: 60,
351 }));
352
353 let mut handles = Vec::new();
355 for _ in 0..200 {
356 let lim = Arc::clone(&limiter);
357 handles.push(tokio::spawn(async move {
358 let ctx =
359 CallContext::new("message/send").with_caller_identity("concurrent-user".into());
360 lim.before(&ctx).await
361 }));
362 }
363
364 let mut ok_count = 0;
365 let mut err_count = 0;
366 for handle in handles {
367 match handle.await.unwrap() {
368 Ok(()) => ok_count += 1,
369 Err(_) => err_count += 1,
370 }
371 }
372
373 assert_eq!(ok_count, 100, "expected 100 allowed, got {ok_count}");
375 assert_eq!(err_count, 100, "expected 100 rejected, got {err_count}");
376 }
377
378 #[tokio::test]
379 async fn stale_bucket_cleanup() {
380 let limiter = RateLimitInterceptor::new(RateLimitConfig {
381 requests_per_window: 10,
382 window_secs: 60,
383 });
384
385 let ctx_a = make_ctx(Some("stale-a"));
387 let ctx_b = make_ctx(Some("stale-b"));
388 assert!(limiter.before(&ctx_a).await.is_ok());
389 assert!(limiter.before(&ctx_b).await.is_ok());
390
391 assert_eq!(limiter.buckets.read().await.len(), 2);
392
393 limiter.cleanup_stale_buckets().await;
395 assert_eq!(
396 limiter.buckets.read().await.len(),
397 2,
398 "current-window buckets should not be evicted"
399 );
400 }
401
402 #[test]
403 fn debug_format_includes_config() {
404 let limiter = RateLimitInterceptor::new(RateLimitConfig {
405 requests_per_window: 42,
406 window_secs: 10,
407 });
408 let debug = format!("{limiter:?}");
409 assert!(
410 debug.contains("RateLimitInterceptor"),
411 "Debug output should contain struct name"
412 );
413 assert!(
414 debug.contains("config"),
415 "Debug output should contain config field"
416 );
417 }
418
419 #[test]
421 fn default_config_values() {
422 let config = RateLimitConfig::default();
423 assert_eq!(config.requests_per_window, 100);
424 assert_eq!(config.window_secs, 60);
425 }
426
427 #[tokio::test]
429 async fn after_hook_is_noop() {
430 let limiter = RateLimitInterceptor::new(RateLimitConfig::default());
431 let ctx = make_ctx(Some("user"));
432 let result = limiter.after(&ctx).await;
433 assert_eq!(result.unwrap(), (), "after hook should return Ok(())");
434 }
435
436 #[test]
437 fn window_number_correctness() {
438 let limiter = RateLimitInterceptor::new(RateLimitConfig {
439 requests_per_window: 10,
440 window_secs: 60,
441 });
442
443 assert_eq!(limiter.window_number(0), 0);
445 assert_eq!(limiter.window_number(59), 0);
447 assert_eq!(limiter.window_number(60), 1);
449 assert_eq!(limiter.window_number(120), 2);
451 assert_eq!(limiter.window_number(61), 1);
453 }
454
455 #[tokio::test]
456 async fn cleanup_stale_buckets_removes_old_entries() {
457 let limiter = RateLimitInterceptor::new(RateLimitConfig {
458 requests_per_window: 100,
459 window_secs: 60,
460 });
461
462 {
464 let mut buckets = limiter.buckets.write().await;
465 buckets.insert(
466 "ancient-user".to_string(),
467 CallerBucket {
468 window_start: AtomicU64::new(0), count: AtomicU64::new(5),
470 },
471 );
472 }
473 assert_eq!(limiter.buckets.read().await.len(), 1);
474
475 limiter.cleanup_stale_buckets().await;
477 assert_eq!(
478 limiter.buckets.read().await.len(),
479 0,
480 "ancient bucket should be evicted"
481 );
482 }
483
484 #[tokio::test]
485 async fn check_triggers_cleanup_at_interval() {
486 let limiter = RateLimitInterceptor::new(RateLimitConfig {
487 requests_per_window: 10000,
488 window_secs: 60,
489 });
490
491 {
493 let mut buckets = limiter.buckets.write().await;
494 buckets.insert(
495 "stale-for-cleanup".to_string(),
496 CallerBucket {
497 window_start: AtomicU64::new(0),
498 count: AtomicU64::new(1),
499 },
500 );
501 }
502
503 limiter
506 .check_count
507 .store(CLEANUP_INTERVAL, Ordering::Relaxed);
508
509 let ctx = make_ctx(Some("cleanup-trigger-user"));
510 assert!(limiter.before(&ctx).await.is_ok());
512
513 let buckets = limiter.buckets.read().await;
515 let has_stale = buckets.contains_key("stale-for-cleanup");
516 drop(buckets);
517 assert!(
518 !has_stale,
519 "stale bucket should be cleaned up after CLEANUP_INTERVAL checks"
520 );
521 }
522
523 #[tokio::test]
524 async fn slow_path_double_check_same_window() {
525 let limiter = RateLimitInterceptor::new(RateLimitConfig {
529 requests_per_window: 2,
530 window_secs: 60,
531 });
532
533 let ctx = make_ctx(Some("race-user"));
534 assert!(limiter.before(&ctx).await.is_ok());
536 assert!(limiter.before(&ctx).await.is_ok());
538 assert!(limiter.before(&ctx).await.is_err());
540 }
541
542 #[tokio::test]
545 async fn slow_path_double_check_stale_window() {
546 let limiter = RateLimitInterceptor::new(RateLimitConfig {
547 requests_per_window: 10,
548 window_secs: 60,
549 });
550
551 let key = "slow-path-stale";
554 {
555 let mut buckets = limiter.buckets.write().await;
556 buckets.insert(
557 key.to_string(),
558 CallerBucket {
559 window_start: AtomicU64::new(1), count: AtomicU64::new(5),
561 },
562 );
563 }
564
565 let result = limiter.check(key).await;
569 assert!(
570 result.is_ok(),
571 "slow-path stale-window reset should succeed"
572 );
573
574 assert_eq!(
576 limiter
577 .buckets
578 .read()
579 .await
580 .get(key)
581 .expect("bucket should exist")
582 .count
583 .load(Ordering::Relaxed),
584 1,
585 "count should be reset to 1 after window advance"
586 );
587 }
588
589 #[tokio::test]
592 async fn slow_path_rate_limit_exceeded() {
593 let limiter = RateLimitInterceptor::new(RateLimitConfig {
594 requests_per_window: 1,
595 window_secs: 60,
596 });
597
598 let now_secs = SystemTime::now()
599 .duration_since(UNIX_EPOCH)
600 .unwrap()
601 .as_secs();
602 let current_window = limiter.window_number(now_secs);
603
604 let key = "slow-path-exceeded";
606 {
607 let mut buckets = limiter.buckets.write().await;
608 buckets.insert(
609 key.to_string(),
610 CallerBucket {
611 window_start: AtomicU64::new(current_window),
612 count: AtomicU64::new(1), },
614 );
615 }
616
617 let result = limiter.check(key).await;
620 assert!(
621 result.is_err(),
622 "slow-path should reject when count exceeds limit"
623 );
624 }
625
626 #[tokio::test]
628 async fn fast_path_rate_limit_exceeded() {
629 let limiter = RateLimitInterceptor::new(RateLimitConfig {
630 requests_per_window: 2,
631 window_secs: 60,
632 });
633
634 let ctx = make_ctx(Some("fast-path-user"));
636 assert!(limiter.before(&ctx).await.is_ok());
637 assert!(limiter.before(&ctx).await.is_ok());
638 let result = limiter.before(&ctx).await;
640 assert!(
641 result.is_err(),
642 "fast-path should reject when count exceeds limit"
643 );
644 let err = result.unwrap_err();
645 assert!(
646 err.to_string().contains("rate limit exceeded"),
647 "error message should mention rate limit exceeded, got: {err}"
648 );
649 }
650
651 #[tokio::test]
654 async fn fast_path_window_advancement_resets_count() {
655 let limiter = RateLimitInterceptor::new(RateLimitConfig {
656 requests_per_window: 1,
657 window_secs: 60,
658 });
659
660 let key = "fast-path-window-advance";
661 {
663 let mut buckets = limiter.buckets.write().await;
664 buckets.insert(
665 key.to_string(),
666 CallerBucket {
667 window_start: AtomicU64::new(1), count: AtomicU64::new(999),
669 },
670 );
671 }
672
673 let result = limiter.check(key).await;
676 assert_eq!(
677 result.unwrap(),
678 (),
679 "fast-path window advance should return Ok(())"
680 );
681
682 assert_eq!(
683 limiter
684 .buckets
685 .read()
686 .await
687 .get(key)
688 .expect("bucket should exist")
689 .count
690 .load(Ordering::Relaxed),
691 1,
692 "count should be reset to 1 after window advance"
693 );
694 }
695
696 #[tokio::test]
702 async fn cleanup_does_not_run_on_first_call() {
703 let limiter = RateLimitInterceptor::new(RateLimitConfig {
704 requests_per_window: 10000,
705 window_secs: 60,
706 });
707
708 {
710 let mut buckets = limiter.buckets.write().await;
711 buckets.insert(
712 "stale-first-call".to_string(),
713 CallerBucket {
714 window_start: AtomicU64::new(0),
715 count: AtomicU64::new(1),
716 },
717 );
718 }
719
720 let ctx = make_ctx(Some("first-caller"));
723 assert!(limiter.before(&ctx).await.is_ok());
724
725 assert!(
727 limiter
728 .buckets
729 .read()
730 .await
731 .contains_key("stale-first-call"),
732 "stale bucket should not be cleaned up on the very first call"
733 );
734 }
735
736 #[tokio::test]
738 async fn x_forwarded_for_single_ip() {
739 let limiter = RateLimitInterceptor::new(RateLimitConfig {
740 requests_per_window: 1,
741 window_secs: 60,
742 });
743 let mut headers = HashMap::new();
744 headers.insert("x-forwarded-for".to_string(), "192.168.1.1".to_string());
745 let ctx = CallContext::new("message/send").with_http_headers(headers);
746 assert!(limiter.before(&ctx).await.is_ok());
747 assert!(limiter.before(&ctx).await.is_err());
749 }
750}