Skip to main content

hyperstack_server/websocket/
rate_limiter.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6use tracing::{debug, warn};
7
8/// Rate limit window configuration
9#[derive(Debug, Clone, Copy)]
10pub struct RateLimitWindow {
11    /// Maximum number of requests allowed in the window
12    pub max_requests: u32,
13    /// Window duration
14    pub window_duration: Duration,
15    /// Burst allowance (extra requests allowed temporarily)
16    pub burst: u32,
17}
18
19impl RateLimitWindow {
20    /// Create a new rate limit window
21    pub fn new(max_requests: u32, window_duration: Duration) -> Self {
22        Self {
23            max_requests,
24            window_duration,
25            burst: 0,
26        }
27    }
28
29    /// Add burst allowance
30    pub fn with_burst(mut self, burst: u32) -> Self {
31        self.burst = burst;
32        self
33    }
34}
35
36impl Default for RateLimitWindow {
37    fn default() -> Self {
38        Self {
39            max_requests: 100,
40            window_duration: Duration::from_secs(60),
41            burst: 10,
42        }
43    }
44}
45
46/// Rate limit result
47#[derive(Debug, Clone)]
48pub enum RateLimitResult {
49    /// Request is allowed
50    Allowed { remaining: u32, reset_at: Instant },
51    /// Request is denied due to rate limiting
52    Denied { retry_after: Duration, limit: u32 },
53}
54
55/// A single rate limit bucket using sliding window algorithm
56#[derive(Debug)]
57struct RateLimitBucket {
58    /// Request timestamps in the current window
59    requests: Vec<Instant>,
60    /// Window configuration
61    window: RateLimitWindow,
62}
63
64impl RateLimitBucket {
65    fn new(window: RateLimitWindow) -> Self {
66        Self {
67            requests: Vec::with_capacity((window.max_requests + window.burst) as usize),
68            window,
69        }
70    }
71
72    fn prune_expired(&mut self, now: Instant) {
73        let cutoff = now - self.window.window_duration;
74        self.requests.retain(|&t| t > cutoff);
75    }
76
77    /// Check if a request is allowed and record it
78    fn check_and_record(&mut self, now: Instant) -> RateLimitResult {
79        self.prune_expired(now);
80
81        let limit = self.window.max_requests + self.window.burst;
82        let current_count = self.requests.len() as u32;
83
84        if current_count >= limit {
85            // Calculate retry after time
86            if let Some(oldest) = self.requests.first() {
87                let retry_after =
88                    (*oldest + self.window.window_duration).saturating_duration_since(now);
89                RateLimitResult::Denied {
90                    retry_after,
91                    limit: self.window.max_requests,
92                }
93            } else {
94                RateLimitResult::Denied {
95                    retry_after: self.window.window_duration,
96                    limit: self.window.max_requests,
97                }
98            }
99        } else {
100            self.requests.push(now);
101            let reset_at = now + self.window.window_duration;
102            RateLimitResult::Allowed {
103                remaining: limit - current_count - 1,
104                reset_at,
105            }
106        }
107    }
108}
109
110/// Rate limiter configuration per key type
111#[derive(Debug, Clone)]
112pub struct RateLimiterConfig {
113    /// Rate limit for handshake attempts per IP
114    pub handshake_per_ip: RateLimitWindow,
115    /// Rate limit for connection attempts per subject
116    pub connections_per_subject: RateLimitWindow,
117    /// Rate limit for connection attempts per metering key
118    pub connections_per_metering_key: RateLimitWindow,
119    /// Rate limit for subscription requests per connection
120    pub subscriptions_per_connection: RateLimitWindow,
121    /// Rate limit for messages per connection
122    pub messages_per_connection: RateLimitWindow,
123    /// Rate limit for snapshot requests per connection
124    pub snapshots_per_connection: RateLimitWindow,
125    /// Enable rate limiting (can be disabled for testing)
126    pub enabled: bool,
127}
128
129impl Default for RateLimiterConfig {
130    fn default() -> Self {
131        Self {
132            handshake_per_ip: RateLimitWindow::new(60, Duration::from_secs(60)).with_burst(10),
133            connections_per_subject: RateLimitWindow::new(30, Duration::from_secs(60))
134                .with_burst(5),
135            connections_per_metering_key: RateLimitWindow::new(100, Duration::from_secs(60))
136                .with_burst(20),
137            subscriptions_per_connection: RateLimitWindow::new(120, Duration::from_secs(60))
138                .with_burst(10),
139            messages_per_connection: RateLimitWindow::new(1000, Duration::from_secs(60))
140                .with_burst(100),
141            snapshots_per_connection: RateLimitWindow::new(30, Duration::from_secs(60))
142                .with_burst(5),
143            enabled: true,
144        }
145    }
146}
147
148impl RateLimiterConfig {
149    /// Load configuration from environment variables
150    pub fn from_env() -> Self {
151        let mut config = Self::default();
152
153        // Handshake rate limit
154        if let (Ok(max), Ok(secs)) = (
155            std::env::var("HYPERSTACK_RATE_LIMIT_HANDSHAKE_PER_IP_MAX"),
156            std::env::var("HYPERSTACK_RATE_LIMIT_HANDSHAKE_PER_IP_WINDOW_SECS"),
157        ) {
158            if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
159                config.handshake_per_ip = RateLimitWindow::new(max, Duration::from_secs(secs));
160            }
161        }
162
163        // Connections per subject
164        if let (Ok(max), Ok(secs)) = (
165            std::env::var("HYPERSTACK_RATE_LIMIT_CONNECTIONS_PER_SUBJECT_MAX"),
166            std::env::var("HYPERSTACK_RATE_LIMIT_CONNECTIONS_PER_SUBJECT_WINDOW_SECS"),
167        ) {
168            if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
169                config.connections_per_subject =
170                    RateLimitWindow::new(max, Duration::from_secs(secs));
171            }
172        }
173
174        // Connections per metering key
175        if let (Ok(max), Ok(secs)) = (
176            std::env::var("HYPERSTACK_RATE_LIMIT_CONNECTIONS_PER_METERING_KEY_MAX"),
177            std::env::var("HYPERSTACK_RATE_LIMIT_CONNECTIONS_PER_METERING_KEY_WINDOW_SECS"),
178        ) {
179            if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
180                config.connections_per_metering_key =
181                    RateLimitWindow::new(max, Duration::from_secs(secs));
182            }
183        }
184
185        // Subscriptions per connection
186        if let (Ok(max), Ok(secs)) = (
187            std::env::var("HYPERSTACK_RATE_LIMIT_SUBSCRIPTIONS_PER_CONNECTION_MAX"),
188            std::env::var("HYPERSTACK_RATE_LIMIT_SUBSCRIPTIONS_PER_CONNECTION_WINDOW_SECS"),
189        ) {
190            if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
191                config.subscriptions_per_connection =
192                    RateLimitWindow::new(max, Duration::from_secs(secs));
193            }
194        }
195
196        // Messages per connection
197        if let (Ok(max), Ok(secs)) = (
198            std::env::var("HYPERSTACK_RATE_LIMIT_MESSAGES_PER_CONNECTION_MAX"),
199            std::env::var("HYPERSTACK_RATE_LIMIT_MESSAGES_PER_CONNECTION_WINDOW_SECS"),
200        ) {
201            if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
202                config.messages_per_connection =
203                    RateLimitWindow::new(max, Duration::from_secs(secs));
204            }
205        }
206
207        // Snapshots per connection
208        if let (Ok(max), Ok(secs)) = (
209            std::env::var("HYPERSTACK_RATE_LIMIT_SNAPSHOTS_PER_CONNECTION_MAX"),
210            std::env::var("HYPERSTACK_RATE_LIMIT_SNAPSHOTS_PER_CONNECTION_WINDOW_SECS"),
211        ) {
212            if let (Ok(max), Ok(secs)) = (max.parse(), secs.parse()) {
213                config.snapshots_per_connection =
214                    RateLimitWindow::new(max, Duration::from_secs(secs));
215            }
216        }
217
218        // Enable/disable
219        if let Ok(enabled) = std::env::var("HYPERSTACK_RATE_LIMITING_ENABLED") {
220            config.enabled = enabled.parse().unwrap_or(true);
221        }
222
223        config
224    }
225
226    /// Disable rate limiting (useful for testing)
227    pub fn disabled() -> Self {
228        Self {
229            enabled: false,
230            ..Default::default()
231        }
232    }
233}
234
235/// Multi-tenant rate limiter with per-key tracking
236#[derive(Debug)]
237pub struct WebSocketRateLimiter {
238    config: RateLimiterConfig,
239    /// Per-IP handshake rate limits
240    ip_buckets: Arc<RwLock<HashMap<String, RateLimitBucket>>>,
241    /// Per-subject connection rate limits
242    subject_buckets: Arc<RwLock<HashMap<String, RateLimitBucket>>>,
243    /// Per-metering-key connection rate limits
244    metering_key_buckets: Arc<RwLock<HashMap<String, RateLimitBucket>>>,
245    /// Per-connection subscription rate limits
246    subscription_buckets: Arc<RwLock<HashMap<uuid::Uuid, RateLimitBucket>>>,
247    /// Per-connection message rate limits
248    message_buckets: Arc<RwLock<HashMap<uuid::Uuid, RateLimitBucket>>>,
249    /// Per-connection snapshot rate limits
250    snapshot_buckets: Arc<RwLock<HashMap<uuid::Uuid, RateLimitBucket>>>,
251}
252
253impl WebSocketRateLimiter {
254    /// Create a new rate limiter with the given configuration
255    pub fn new(config: RateLimiterConfig) -> Self {
256        Self {
257            config,
258            ip_buckets: Arc::new(RwLock::new(HashMap::new())),
259            subject_buckets: Arc::new(RwLock::new(HashMap::new())),
260            metering_key_buckets: Arc::new(RwLock::new(HashMap::new())),
261            subscription_buckets: Arc::new(RwLock::new(HashMap::new())),
262            message_buckets: Arc::new(RwLock::new(HashMap::new())),
263            snapshot_buckets: Arc::new(RwLock::new(HashMap::new())),
264        }
265    }
266
267    /// Check if handshake is allowed from the given IP
268    pub async fn check_handshake(&self, addr: SocketAddr) -> RateLimitResult {
269        if !self.config.enabled {
270            return RateLimitResult::Allowed {
271                remaining: u32::MAX,
272                reset_at: Instant::now() + Duration::from_secs(60),
273            };
274        }
275
276        let ip = addr.ip().to_string();
277        let mut buckets = self.ip_buckets.write().await;
278        let bucket = buckets
279            .entry(ip.clone())
280            .or_insert_with(|| RateLimitBucket::new(self.config.handshake_per_ip));
281
282        let result = bucket.check_and_record(Instant::now());
283
284        match &result {
285            RateLimitResult::Denied { retry_after, limit } => {
286                warn!(
287                    ip = %ip,
288                    retry_after_secs = retry_after.as_secs(),
289                    limit = limit,
290                    "Rate limit exceeded for handshake"
291                );
292            }
293            RateLimitResult::Allowed { remaining, .. } => {
294                debug!(
295                    ip = %ip,
296                    remaining = remaining,
297                    "Handshake rate limit check passed"
298                );
299            }
300        }
301
302        result
303    }
304
305    /// Check if connection is allowed for the given subject
306    pub async fn check_connection_for_subject(&self, subject: &str) -> RateLimitResult {
307        if !self.config.enabled {
308            return RateLimitResult::Allowed {
309                remaining: u32::MAX,
310                reset_at: Instant::now() + Duration::from_secs(60),
311            };
312        }
313
314        let mut buckets = self.subject_buckets.write().await;
315        let bucket = buckets
316            .entry(subject.to_string())
317            .or_insert_with(|| RateLimitBucket::new(self.config.connections_per_subject));
318
319        bucket.check_and_record(Instant::now())
320    }
321
322    /// Check if connection is allowed for the given metering key
323    pub async fn check_connection_for_metering_key(&self, metering_key: &str) -> RateLimitResult {
324        if !self.config.enabled {
325            return RateLimitResult::Allowed {
326                remaining: u32::MAX,
327                reset_at: Instant::now() + Duration::from_secs(60),
328            };
329        }
330
331        let mut buckets = self.metering_key_buckets.write().await;
332        let bucket = buckets
333            .entry(metering_key.to_string())
334            .or_insert_with(|| RateLimitBucket::new(self.config.connections_per_metering_key));
335
336        bucket.check_and_record(Instant::now())
337    }
338
339    /// Check if subscription is allowed for the given connection
340    pub async fn check_subscription(&self, client_id: uuid::Uuid) -> RateLimitResult {
341        if !self.config.enabled {
342            return RateLimitResult::Allowed {
343                remaining: u32::MAX,
344                reset_at: Instant::now() + Duration::from_secs(60),
345            };
346        }
347
348        let mut buckets = self.subscription_buckets.write().await;
349        let bucket = buckets
350            .entry(client_id)
351            .or_insert_with(|| RateLimitBucket::new(self.config.subscriptions_per_connection));
352
353        bucket.check_and_record(Instant::now())
354    }
355
356    /// Check if message is allowed for the given connection
357    pub async fn check_message(&self, client_id: uuid::Uuid) -> RateLimitResult {
358        if !self.config.enabled {
359            return RateLimitResult::Allowed {
360                remaining: u32::MAX,
361                reset_at: Instant::now() + Duration::from_secs(60),
362            };
363        }
364
365        let mut buckets = self.message_buckets.write().await;
366        let bucket = buckets
367            .entry(client_id)
368            .or_insert_with(|| RateLimitBucket::new(self.config.messages_per_connection));
369
370        bucket.check_and_record(Instant::now())
371    }
372
373    /// Check if snapshot is allowed for the given connection
374    pub async fn check_snapshot(&self, client_id: uuid::Uuid) -> RateLimitResult {
375        if !self.config.enabled {
376            return RateLimitResult::Allowed {
377                remaining: u32::MAX,
378                reset_at: Instant::now() + Duration::from_secs(60),
379            };
380        }
381
382        let mut buckets = self.snapshot_buckets.write().await;
383        let bucket = buckets
384            .entry(client_id)
385            .or_insert_with(|| RateLimitBucket::new(self.config.snapshots_per_connection));
386
387        bucket.check_and_record(Instant::now())
388    }
389
390    /// Clean up stale buckets to prevent memory growth
391    pub async fn cleanup_stale_buckets(&self) {
392        let now = Instant::now();
393
394        // Clean up IP buckets
395        {
396            let mut buckets = self.ip_buckets.write().await;
397            buckets.retain(|_, bucket| {
398                bucket.prune_expired(now);
399                !bucket.requests.is_empty()
400            });
401        }
402
403        // Clean up subject buckets
404        {
405            let mut buckets = self.subject_buckets.write().await;
406            buckets.retain(|_, bucket| {
407                bucket.prune_expired(now);
408                !bucket.requests.is_empty()
409            });
410        }
411
412        // Clean up metering key buckets
413        {
414            let mut buckets = self.metering_key_buckets.write().await;
415            buckets.retain(|_, bucket| {
416                bucket.prune_expired(now);
417                !bucket.requests.is_empty()
418            });
419        }
420
421        // Clean up connection-specific buckets for disconnected clients
422        // These should be explicitly removed when clients disconnect
423    }
424
425    /// Remove all rate limit buckets for a disconnected client
426    pub async fn remove_client_buckets(&self, client_id: uuid::Uuid) {
427        let mut sub_buckets = self.subscription_buckets.write().await;
428        sub_buckets.remove(&client_id);
429        drop(sub_buckets);
430
431        let mut msg_buckets = self.message_buckets.write().await;
432        msg_buckets.remove(&client_id);
433        drop(msg_buckets);
434
435        let mut snap_buckets = self.snapshot_buckets.write().await;
436        snap_buckets.remove(&client_id);
437    }
438
439    /// Start a background task to periodically clean up stale buckets
440    pub fn start_cleanup_task(&self) {
441        let limiter = self.clone();
442        tokio::spawn(async move {
443            let mut interval = tokio::time::interval(Duration::from_secs(60));
444            loop {
445                interval.tick().await;
446                limiter.cleanup_stale_buckets().await;
447            }
448        });
449    }
450}
451
452impl Clone for WebSocketRateLimiter {
453    fn clone(&self) -> Self {
454        Self {
455            config: self.config.clone(),
456            ip_buckets: Arc::clone(&self.ip_buckets),
457            subject_buckets: Arc::clone(&self.subject_buckets),
458            metering_key_buckets: Arc::clone(&self.metering_key_buckets),
459            subscription_buckets: Arc::clone(&self.subscription_buckets),
460            message_buckets: Arc::clone(&self.message_buckets),
461            snapshot_buckets: Arc::clone(&self.snapshot_buckets),
462        }
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    fn test_config() -> RateLimiterConfig {
471        RateLimiterConfig {
472            enabled: true,
473            handshake_per_ip: RateLimitWindow::new(60, Duration::from_secs(60)).with_burst(10),
474            connections_per_subject: RateLimitWindow::new(30, Duration::from_secs(60))
475                .with_burst(5),
476            connections_per_metering_key: RateLimitWindow::new(100, Duration::from_secs(60))
477                .with_burst(20),
478            subscriptions_per_connection: RateLimitWindow::new(120, Duration::from_secs(60))
479                .with_burst(10),
480            messages_per_connection: RateLimitWindow::new(1000, Duration::from_secs(60))
481                .with_burst(100),
482            snapshots_per_connection: RateLimitWindow::new(30, Duration::from_secs(60))
483                .with_burst(5),
484        }
485    }
486
487    #[tokio::test]
488    async fn test_rate_limiter_allows_within_limit() {
489        let config = RateLimiterConfig {
490            handshake_per_ip: RateLimitWindow::new(5, Duration::from_secs(60)),
491            ..test_config()
492        };
493        let limiter = WebSocketRateLimiter::new(config);
494
495        let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
496
497        // Should allow first 5 requests
498        for i in 0..5 {
499            let result = limiter.check_handshake(addr).await;
500            match result {
501                RateLimitResult::Allowed { remaining, .. } => {
502                    assert_eq!(
503                        remaining,
504                        4 - i,
505                        "Request {} should have {} remaining",
506                        i,
507                        4 - i
508                    );
509                }
510                RateLimitResult::Denied { .. } => {
511                    panic!("Request {} should be allowed", i);
512                }
513            }
514        }
515    }
516
517    #[tokio::test]
518    async fn test_rate_limiter_denies_over_limit() {
519        let config = RateLimiterConfig {
520            handshake_per_ip: RateLimitWindow::new(2, Duration::from_secs(60)),
521            ..test_config()
522        };
523        let limiter = WebSocketRateLimiter::new(config);
524
525        let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
526
527        // First 2 should be allowed
528        limiter.check_handshake(addr).await;
529        limiter.check_handshake(addr).await;
530
531        // Third should be denied
532        let result = limiter.check_handshake(addr).await;
533        assert!(
534            matches!(result, RateLimitResult::Denied { .. }),
535            "Third request should be denied"
536        );
537    }
538
539    #[tokio::test]
540    async fn test_rate_limiter_with_burst() {
541        let config = RateLimiterConfig {
542            handshake_per_ip: RateLimitWindow::new(2, Duration::from_secs(60)).with_burst(2),
543            ..test_config()
544        };
545        let limiter = WebSocketRateLimiter::new(config);
546
547        let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
548
549        // First 4 should be allowed (2 base + 2 burst)
550        for i in 0..4 {
551            let result = limiter.check_handshake(addr).await;
552            assert!(
553                matches!(result, RateLimitResult::Allowed { .. }),
554                "Request {} should be allowed with burst",
555                i
556            );
557        }
558
559        // Fifth should be denied
560        let result = limiter.check_handshake(addr).await;
561        assert!(
562            matches!(result, RateLimitResult::Denied { .. }),
563            "Fifth request should be denied"
564        );
565    }
566
567    #[tokio::test]
568    async fn test_rate_limiter_disabled() {
569        let limiter = WebSocketRateLimiter::new(RateLimiterConfig::disabled());
570
571        let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
572
573        // Should allow unlimited when disabled
574        for _ in 0..100 {
575            let result = limiter.check_handshake(addr).await;
576            assert!(
577                matches!(result, RateLimitResult::Allowed { .. }),
578                "Should be allowed when disabled"
579            );
580        }
581    }
582
583    #[tokio::test]
584    async fn test_subject_rate_limiting() {
585        let config = RateLimiterConfig {
586            connections_per_subject: RateLimitWindow::new(3, Duration::from_secs(60)),
587            ..test_config()
588        };
589        let limiter = WebSocketRateLimiter::new(config);
590
591        // First 3 connections allowed
592        for i in 0..3 {
593            let result = limiter.check_connection_for_subject("user-123").await;
594            assert!(
595                matches!(result, RateLimitResult::Allowed { remaining, .. } if remaining == 2 - i),
596                "Connection {} should be allowed",
597                i
598            );
599        }
600
601        // Fourth denied
602        let result = limiter.check_connection_for_subject("user-123").await;
603        assert!(
604            matches!(result, RateLimitResult::Denied { .. }),
605            "Fourth connection should be denied"
606        );
607
608        // Different subject should still work
609        let result = limiter.check_connection_for_subject("user-456").await;
610        assert!(
611            matches!(result, RateLimitResult::Allowed { .. }),
612            "Different subject should be allowed"
613        );
614    }
615
616    #[tokio::test]
617    async fn test_cleanup_stale_buckets_removes_expired_buckets() {
618        let limiter = WebSocketRateLimiter::new(test_config());
619        let stale_request = Instant::now() - Duration::from_secs(600);
620
621        {
622            let mut buckets = limiter.ip_buckets.write().await;
623            let mut bucket = RateLimitBucket::new(limiter.config.handshake_per_ip);
624            bucket.requests.push(stale_request);
625            buckets.insert("127.0.0.1".to_string(), bucket);
626        }
627
628        {
629            let mut buckets = limiter.subject_buckets.write().await;
630            let mut bucket = RateLimitBucket::new(limiter.config.connections_per_subject);
631            bucket.requests.push(stale_request);
632            buckets.insert("user-123".to_string(), bucket);
633        }
634
635        {
636            let mut buckets = limiter.metering_key_buckets.write().await;
637            let mut bucket = RateLimitBucket::new(limiter.config.connections_per_metering_key);
638            bucket.requests.push(stale_request);
639            buckets.insert("meter-123".to_string(), bucket);
640        }
641
642        limiter.cleanup_stale_buckets().await;
643
644        assert!(limiter.ip_buckets.read().await.is_empty());
645        assert!(limiter.subject_buckets.read().await.is_empty());
646        assert!(limiter.metering_key_buckets.read().await.is_empty());
647    }
648
649    #[tokio::test]
650    async fn test_remove_client_buckets_clears_connection_specific_state() {
651        let limiter = WebSocketRateLimiter::new(test_config());
652        let client_id = uuid::Uuid::new_v4();
653
654        let _ = limiter.check_subscription(client_id).await;
655        let _ = limiter.check_message(client_id).await;
656        let _ = limiter.check_snapshot(client_id).await;
657
658        assert!(limiter
659            .subscription_buckets
660            .read()
661            .await
662            .contains_key(&client_id));
663        assert!(limiter
664            .message_buckets
665            .read()
666            .await
667            .contains_key(&client_id));
668        assert!(limiter
669            .snapshot_buckets
670            .read()
671            .await
672            .contains_key(&client_id));
673
674        limiter.remove_client_buckets(client_id).await;
675
676        assert!(!limiter
677            .subscription_buckets
678            .read()
679            .await
680            .contains_key(&client_id));
681        assert!(!limiter
682            .message_buckets
683            .read()
684            .await
685            .contains_key(&client_id));
686        assert!(!limiter
687            .snapshot_buckets
688            .read()
689            .await
690            .contains_key(&client_id));
691    }
692}