1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6use tracing::{debug, warn};
7
8#[derive(Debug, Clone, Copy)]
10pub struct RateLimitWindow {
11 pub max_requests: u32,
13 pub window_duration: Duration,
15 pub burst: u32,
17}
18
19impl RateLimitWindow {
20 pub fn new(max_requests: u32, window_duration: Duration) -> Self {
22 Self {
23 max_requests,
24 window_duration,
25 burst: 0,
26 }
27 }
28
29 pub fn with_burst(mut self, burst: u32) -> Self {
31 self.burst = burst;
32 self
33 }
34}
35
36impl Default for RateLimitWindow {
37 fn default() -> Self {
38 Self {
39 max_requests: 100,
40 window_duration: Duration::from_secs(60),
41 burst: 10,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub enum RateLimitResult {
49 Allowed { remaining: u32, reset_at: Instant },
51 Denied { retry_after: Duration, limit: u32 },
53}
54
55#[derive(Debug)]
57struct RateLimitBucket {
58 requests: Vec<Instant>,
60 window: RateLimitWindow,
62}
63
64impl RateLimitBucket {
65 fn new(window: RateLimitWindow) -> Self {
66 Self {
67 requests: Vec::with_capacity((window.max_requests + window.burst) as usize),
68 window,
69 }
70 }
71
72 fn prune_expired(&mut self, now: Instant) {
73 let cutoff = now - self.window.window_duration;
74 self.requests.retain(|&t| t > cutoff);
75 }
76
77 fn check_and_record(&mut self, now: Instant) -> RateLimitResult {
79 self.prune_expired(now);
80
81 let limit = self.window.max_requests + self.window.burst;
82 let current_count = self.requests.len() as u32;
83
84 if current_count >= limit {
85 if let Some(oldest) = self.requests.first() {
87 let retry_after =
88 (*oldest + self.window.window_duration).saturating_duration_since(now);
89 RateLimitResult::Denied {
90 retry_after,
91 limit: self.window.max_requests,
92 }
93 } else {
94 RateLimitResult::Denied {
95 retry_after: self.window.window_duration,
96 limit: self.window.max_requests,
97 }
98 }
99 } else {
100 self.requests.push(now);
101 let reset_at = now + self.window.window_duration;
102 RateLimitResult::Allowed {
103 remaining: limit - current_count - 1,
104 reset_at,
105 }
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct RateLimiterConfig {
113 pub handshake_per_ip: RateLimitWindow,
115 pub connections_per_subject: RateLimitWindow,
117 pub connections_per_metering_key: RateLimitWindow,
119 pub subscriptions_per_connection: RateLimitWindow,
121 pub messages_per_connection: RateLimitWindow,
123 pub snapshots_per_connection: RateLimitWindow,
125 pub enabled: bool,
127}
128
129impl Default for RateLimiterConfig {
130 fn default() -> Self {
131 Self {
132 handshake_per_ip: RateLimitWindow::new(60, Duration::from_secs(60)).with_burst(10),
133 connections_per_subject: RateLimitWindow::new(30, Duration::from_secs(60))
134 .with_burst(5),
135 connections_per_metering_key: RateLimitWindow::new(100, Duration::from_secs(60))
136 .with_burst(20),
137 subscriptions_per_connection: RateLimitWindow::new(120, Duration::from_secs(60))
138 .with_burst(10),
139 messages_per_connection: RateLimitWindow::new(1000, Duration::from_secs(60))
140 .with_burst(100),
141 snapshots_per_connection: RateLimitWindow::new(30, Duration::from_secs(60))
142 .with_burst(5),
143 enabled: true,
144 }
145 }
146}
147
148impl RateLimiterConfig {
149 pub fn from_env() -> Self {
151 let mut config = Self::default();
152
153 if let (Ok(max), Ok(secs)) = (
155 std::env::var("HYPERSTACK_RATE_LIMIT_HANDSHAKE_PER_IP_MAX"),
156 std::env::var("HYPERSTACK_RATE_LIMIT_HANDSHAKE_PER_IP_WINDOW_SECS"),
157 ) {
158 if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
159 config.handshake_per_ip = RateLimitWindow::new(max, Duration::from_secs(secs));
160 }
161 }
162
163 if let (Ok(max), Ok(secs)) = (
165 std::env::var("HYPERSTACK_RATE_LIMIT_CONNECTIONS_PER_SUBJECT_MAX"),
166 std::env::var("HYPERSTACK_RATE_LIMIT_CONNECTIONS_PER_SUBJECT_WINDOW_SECS"),
167 ) {
168 if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
169 config.connections_per_subject =
170 RateLimitWindow::new(max, Duration::from_secs(secs));
171 }
172 }
173
174 if let (Ok(max), Ok(secs)) = (
176 std::env::var("HYPERSTACK_RATE_LIMIT_CONNECTIONS_PER_METERING_KEY_MAX"),
177 std::env::var("HYPERSTACK_RATE_LIMIT_CONNECTIONS_PER_METERING_KEY_WINDOW_SECS"),
178 ) {
179 if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
180 config.connections_per_metering_key =
181 RateLimitWindow::new(max, Duration::from_secs(secs));
182 }
183 }
184
185 if let (Ok(max), Ok(secs)) = (
187 std::env::var("HYPERSTACK_RATE_LIMIT_SUBSCRIPTIONS_PER_CONNECTION_MAX"),
188 std::env::var("HYPERSTACK_RATE_LIMIT_SUBSCRIPTIONS_PER_CONNECTION_WINDOW_SECS"),
189 ) {
190 if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
191 config.subscriptions_per_connection =
192 RateLimitWindow::new(max, Duration::from_secs(secs));
193 }
194 }
195
196 if let (Ok(max), Ok(secs)) = (
198 std::env::var("HYPERSTACK_RATE_LIMIT_MESSAGES_PER_CONNECTION_MAX"),
199 std::env::var("HYPERSTACK_RATE_LIMIT_MESSAGES_PER_CONNECTION_WINDOW_SECS"),
200 ) {
201 if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
202 config.messages_per_connection =
203 RateLimitWindow::new(max, Duration::from_secs(secs));
204 }
205 }
206
207 if let (Ok(max), Ok(secs)) = (
209 std::env::var("HYPERSTACK_RATE_LIMIT_SNAPSHOTS_PER_CONNECTION_MAX"),
210 std::env::var("HYPERSTACK_RATE_LIMIT_SNAPSHOTS_PER_CONNECTION_WINDOW_SECS"),
211 ) {
212 if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
213 config.snapshots_per_connection =
214 RateLimitWindow::new(max, Duration::from_secs(secs));
215 }
216 }
217
218 if let Ok(enabled) = std::env::var("HYPERSTACK_RATE_LIMITING_ENABLED") {
220 config.enabled = enabled.parse().unwrap_or(true);
221 }
222
223 config
224 }
225
226 pub fn disabled() -> Self {
228 Self {
229 enabled: false,
230 ..Default::default()
231 }
232 }
233}
234
235#[derive(Debug)]
237pub struct WebSocketRateLimiter {
238 config: RateLimiterConfig,
239 ip_buckets: Arc<RwLock<HashMap<String, RateLimitBucket>>>,
241 subject_buckets: Arc<RwLock<HashMap<String, RateLimitBucket>>>,
243 metering_key_buckets: Arc<RwLock<HashMap<String, RateLimitBucket>>>,
245 subscription_buckets: Arc<RwLock<HashMap<uuid::Uuid, RateLimitBucket>>>,
247 message_buckets: Arc<RwLock<HashMap<uuid::Uuid, RateLimitBucket>>>,
249 snapshot_buckets: Arc<RwLock<HashMap<uuid::Uuid, RateLimitBucket>>>,
251}
252
253impl WebSocketRateLimiter {
254 pub fn new(config: RateLimiterConfig) -> Self {
256 Self {
257 config,
258 ip_buckets: Arc::new(RwLock::new(HashMap::new())),
259 subject_buckets: Arc::new(RwLock::new(HashMap::new())),
260 metering_key_buckets: Arc::new(RwLock::new(HashMap::new())),
261 subscription_buckets: Arc::new(RwLock::new(HashMap::new())),
262 message_buckets: Arc::new(RwLock::new(HashMap::new())),
263 snapshot_buckets: Arc::new(RwLock::new(HashMap::new())),
264 }
265 }
266
267 pub async fn check_handshake(&self, addr: SocketAddr) -> RateLimitResult {
269 if !self.config.enabled {
270 return RateLimitResult::Allowed {
271 remaining: u32::MAX,
272 reset_at: Instant::now() + Duration::from_secs(60),
273 };
274 }
275
276 let ip = addr.ip().to_string();
277 let mut buckets = self.ip_buckets.write().await;
278 let bucket = buckets
279 .entry(ip.clone())
280 .or_insert_with(|| RateLimitBucket::new(self.config.handshake_per_ip));
281
282 let result = bucket.check_and_record(Instant::now());
283
284 match &result {
285 RateLimitResult::Denied { retry_after, limit } => {
286 warn!(
287 ip = %ip,
288 retry_after_secs = retry_after.as_secs(),
289 limit = limit,
290 "Rate limit exceeded for handshake"
291 );
292 }
293 RateLimitResult::Allowed { remaining, .. } => {
294 debug!(
295 ip = %ip,
296 remaining = remaining,
297 "Handshake rate limit check passed"
298 );
299 }
300 }
301
302 result
303 }
304
305 pub async fn check_connection_for_subject(&self, subject: &str) -> RateLimitResult {
307 if !self.config.enabled {
308 return RateLimitResult::Allowed {
309 remaining: u32::MAX,
310 reset_at: Instant::now() + Duration::from_secs(60),
311 };
312 }
313
314 let mut buckets = self.subject_buckets.write().await;
315 let bucket = buckets
316 .entry(subject.to_string())
317 .or_insert_with(|| RateLimitBucket::new(self.config.connections_per_subject));
318
319 bucket.check_and_record(Instant::now())
320 }
321
322 pub async fn check_connection_for_metering_key(&self, metering_key: &str) -> RateLimitResult {
324 if !self.config.enabled {
325 return RateLimitResult::Allowed {
326 remaining: u32::MAX,
327 reset_at: Instant::now() + Duration::from_secs(60),
328 };
329 }
330
331 let mut buckets = self.metering_key_buckets.write().await;
332 let bucket = buckets
333 .entry(metering_key.to_string())
334 .or_insert_with(|| RateLimitBucket::new(self.config.connections_per_metering_key));
335
336 bucket.check_and_record(Instant::now())
337 }
338
339 pub async fn check_subscription(&self, client_id: uuid::Uuid) -> RateLimitResult {
341 if !self.config.enabled {
342 return RateLimitResult::Allowed {
343 remaining: u32::MAX,
344 reset_at: Instant::now() + Duration::from_secs(60),
345 };
346 }
347
348 let mut buckets = self.subscription_buckets.write().await;
349 let bucket = buckets
350 .entry(client_id)
351 .or_insert_with(|| RateLimitBucket::new(self.config.subscriptions_per_connection));
352
353 bucket.check_and_record(Instant::now())
354 }
355
356 pub async fn check_message(&self, client_id: uuid::Uuid) -> RateLimitResult {
358 if !self.config.enabled {
359 return RateLimitResult::Allowed {
360 remaining: u32::MAX,
361 reset_at: Instant::now() + Duration::from_secs(60),
362 };
363 }
364
365 let mut buckets = self.message_buckets.write().await;
366 let bucket = buckets
367 .entry(client_id)
368 .or_insert_with(|| RateLimitBucket::new(self.config.messages_per_connection));
369
370 bucket.check_and_record(Instant::now())
371 }
372
373 pub async fn check_snapshot(&self, client_id: uuid::Uuid) -> RateLimitResult {
375 if !self.config.enabled {
376 return RateLimitResult::Allowed {
377 remaining: u32::MAX,
378 reset_at: Instant::now() + Duration::from_secs(60),
379 };
380 }
381
382 let mut buckets = self.snapshot_buckets.write().await;
383 let bucket = buckets
384 .entry(client_id)
385 .or_insert_with(|| RateLimitBucket::new(self.config.snapshots_per_connection));
386
387 bucket.check_and_record(Instant::now())
388 }
389
390 pub async fn cleanup_stale_buckets(&self) {
392 let now = Instant::now();
393
394 {
396 let mut buckets = self.ip_buckets.write().await;
397 buckets.retain(|_, bucket| {
398 bucket.prune_expired(now);
399 !bucket.requests.is_empty()
400 });
401 }
402
403 {
405 let mut buckets = self.subject_buckets.write().await;
406 buckets.retain(|_, bucket| {
407 bucket.prune_expired(now);
408 !bucket.requests.is_empty()
409 });
410 }
411
412 {
414 let mut buckets = self.metering_key_buckets.write().await;
415 buckets.retain(|_, bucket| {
416 bucket.prune_expired(now);
417 !bucket.requests.is_empty()
418 });
419 }
420
421 }
424
425 pub async fn remove_client_buckets(&self, client_id: uuid::Uuid) {
427 let mut sub_buckets = self.subscription_buckets.write().await;
428 sub_buckets.remove(&client_id);
429 drop(sub_buckets);
430
431 let mut msg_buckets = self.message_buckets.write().await;
432 msg_buckets.remove(&client_id);
433 drop(msg_buckets);
434
435 let mut snap_buckets = self.snapshot_buckets.write().await;
436 snap_buckets.remove(&client_id);
437 }
438
439 pub fn start_cleanup_task(&self) {
441 let limiter = self.clone();
442 tokio::spawn(async move {
443 let mut interval = tokio::time::interval(Duration::from_secs(60));
444 loop {
445 interval.tick().await;
446 limiter.cleanup_stale_buckets().await;
447 }
448 });
449 }
450}
451
452impl Clone for WebSocketRateLimiter {
453 fn clone(&self) -> Self {
454 Self {
455 config: self.config.clone(),
456 ip_buckets: Arc::clone(&self.ip_buckets),
457 subject_buckets: Arc::clone(&self.subject_buckets),
458 metering_key_buckets: Arc::clone(&self.metering_key_buckets),
459 subscription_buckets: Arc::clone(&self.subscription_buckets),
460 message_buckets: Arc::clone(&self.message_buckets),
461 snapshot_buckets: Arc::clone(&self.snapshot_buckets),
462 }
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 fn test_config() -> RateLimiterConfig {
471 RateLimiterConfig {
472 enabled: true,
473 handshake_per_ip: RateLimitWindow::new(60, Duration::from_secs(60)).with_burst(10),
474 connections_per_subject: RateLimitWindow::new(30, Duration::from_secs(60))
475 .with_burst(5),
476 connections_per_metering_key: RateLimitWindow::new(100, Duration::from_secs(60))
477 .with_burst(20),
478 subscriptions_per_connection: RateLimitWindow::new(120, Duration::from_secs(60))
479 .with_burst(10),
480 messages_per_connection: RateLimitWindow::new(1000, Duration::from_secs(60))
481 .with_burst(100),
482 snapshots_per_connection: RateLimitWindow::new(30, Duration::from_secs(60))
483 .with_burst(5),
484 }
485 }
486
487 #[tokio::test]
488 async fn test_rate_limiter_allows_within_limit() {
489 let config = RateLimiterConfig {
490 handshake_per_ip: RateLimitWindow::new(5, Duration::from_secs(60)),
491 ..test_config()
492 };
493 let limiter = WebSocketRateLimiter::new(config);
494
495 let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
496
497 for i in 0..5 {
499 let result = limiter.check_handshake(addr).await;
500 match result {
501 RateLimitResult::Allowed { remaining, .. } => {
502 assert_eq!(
503 remaining,
504 4 - i,
505 "Request {} should have {} remaining",
506 i,
507 4 - i
508 );
509 }
510 RateLimitResult::Denied { .. } => {
511 panic!("Request {} should be allowed", i);
512 }
513 }
514 }
515 }
516
517 #[tokio::test]
518 async fn test_rate_limiter_denies_over_limit() {
519 let config = RateLimiterConfig {
520 handshake_per_ip: RateLimitWindow::new(2, Duration::from_secs(60)),
521 ..test_config()
522 };
523 let limiter = WebSocketRateLimiter::new(config);
524
525 let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
526
527 limiter.check_handshake(addr).await;
529 limiter.check_handshake(addr).await;
530
531 let result = limiter.check_handshake(addr).await;
533 assert!(
534 matches!(result, RateLimitResult::Denied { .. }),
535 "Third request should be denied"
536 );
537 }
538
539 #[tokio::test]
540 async fn test_rate_limiter_with_burst() {
541 let config = RateLimiterConfig {
542 handshake_per_ip: RateLimitWindow::new(2, Duration::from_secs(60)).with_burst(2),
543 ..test_config()
544 };
545 let limiter = WebSocketRateLimiter::new(config);
546
547 let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
548
549 for i in 0..4 {
551 let result = limiter.check_handshake(addr).await;
552 assert!(
553 matches!(result, RateLimitResult::Allowed { .. }),
554 "Request {} should be allowed with burst",
555 i
556 );
557 }
558
559 let result = limiter.check_handshake(addr).await;
561 assert!(
562 matches!(result, RateLimitResult::Denied { .. }),
563 "Fifth request should be denied"
564 );
565 }
566
567 #[tokio::test]
568 async fn test_rate_limiter_disabled() {
569 let limiter = WebSocketRateLimiter::new(RateLimiterConfig::disabled());
570
571 let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
572
573 for _ in 0..100 {
575 let result = limiter.check_handshake(addr).await;
576 assert!(
577 matches!(result, RateLimitResult::Allowed { .. }),
578 "Should be allowed when disabled"
579 );
580 }
581 }
582
583 #[tokio::test]
584 async fn test_subject_rate_limiting() {
585 let config = RateLimiterConfig {
586 connections_per_subject: RateLimitWindow::new(3, Duration::from_secs(60)),
587 ..test_config()
588 };
589 let limiter = WebSocketRateLimiter::new(config);
590
591 for i in 0..3 {
593 let result = limiter.check_connection_for_subject("user-123").await;
594 assert!(
595 matches!(result, RateLimitResult::Allowed { remaining, .. } if remaining == 2 - i),
596 "Connection {} should be allowed",
597 i
598 );
599 }
600
601 let result = limiter.check_connection_for_subject("user-123").await;
603 assert!(
604 matches!(result, RateLimitResult::Denied { .. }),
605 "Fourth connection should be denied"
606 );
607
608 let result = limiter.check_connection_for_subject("user-456").await;
610 assert!(
611 matches!(result, RateLimitResult::Allowed { .. }),
612 "Different subject should be allowed"
613 );
614 }
615
616 #[tokio::test]
617 async fn test_cleanup_stale_buckets_removes_expired_buckets() {
618 let limiter = WebSocketRateLimiter::new(test_config());
619 let stale_request = Instant::now() - Duration::from_secs(600);
620
621 {
622 let mut buckets = limiter.ip_buckets.write().await;
623 let mut bucket = RateLimitBucket::new(limiter.config.handshake_per_ip);
624 bucket.requests.push(stale_request);
625 buckets.insert("127.0.0.1".to_string(), bucket);
626 }
627
628 {
629 let mut buckets = limiter.subject_buckets.write().await;
630 let mut bucket = RateLimitBucket::new(limiter.config.connections_per_subject);
631 bucket.requests.push(stale_request);
632 buckets.insert("user-123".to_string(), bucket);
633 }
634
635 {
636 let mut buckets = limiter.metering_key_buckets.write().await;
637 let mut bucket = RateLimitBucket::new(limiter.config.connections_per_metering_key);
638 bucket.requests.push(stale_request);
639 buckets.insert("meter-123".to_string(), bucket);
640 }
641
642 limiter.cleanup_stale_buckets().await;
643
644 assert!(limiter.ip_buckets.read().await.is_empty());
645 assert!(limiter.subject_buckets.read().await.is_empty());
646 assert!(limiter.metering_key_buckets.read().await.is_empty());
647 }
648
649 #[tokio::test]
650 async fn test_remove_client_buckets_clears_connection_specific_state() {
651 let limiter = WebSocketRateLimiter::new(test_config());
652 let client_id = uuid::Uuid::new_v4();
653
654 let _ = limiter.check_subscription(client_id).await;
655 let _ = limiter.check_message(client_id).await;
656 let _ = limiter.check_snapshot(client_id).await;
657
658 assert!(limiter
659 .subscription_buckets
660 .read()
661 .await
662 .contains_key(&client_id));
663 assert!(limiter
664 .message_buckets
665 .read()
666 .await
667 .contains_key(&client_id));
668 assert!(limiter
669 .snapshot_buckets
670 .read()
671 .await
672 .contains_key(&client_id));
673
674 limiter.remove_client_buckets(client_id).await;
675
676 assert!(!limiter
677 .subscription_buckets
678 .read()
679 .await
680 .contains_key(&client_id));
681 assert!(!limiter
682 .message_buckets
683 .read()
684 .await
685 .contains_key(&client_id));
686 assert!(!limiter
687 .snapshot_buckets
688 .read()
689 .await
690 .contains_key(&client_id));
691 }
692}