1use super::subscription::Subscription;
2use crate::compression::CompressedPayload;
3use crate::websocket::auth::{AuthContext, AuthDeny};
4use crate::websocket::rate_limiter::{RateLimitResult, WebSocketRateLimiter};
5use bytes::Bytes;
6use dashmap::DashMap;
7use futures_util::stream::SplitSink;
8use futures_util::SinkExt;
9use hyperstack_auth::Limits;
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::net::TcpStream;
15use tokio::sync::{mpsc, RwLock};
16use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
17use tokio_util::sync::CancellationToken;
18use tracing::{debug, info, warn};
19use uuid::Uuid;
20
21pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SendError {
26 ClientNotFound,
28 ClientBackpressured,
30 ClientDisconnected,
32}
33
34impl std::fmt::Display for SendError {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 SendError::ClientNotFound => write!(f, "client not found"),
38 SendError::ClientBackpressured => write!(f, "client backpressured and disconnected"),
39 SendError::ClientDisconnected => write!(f, "client disconnected"),
40 }
41 }
42}
43
44impl std::error::Error for SendError {}
45
46#[derive(Debug)]
48struct EgressTracker {
49 bytes_this_minute: u64,
51 window_start: SystemTime,
53}
54
55#[derive(Debug)]
57struct MessageRateTracker {
58 messages_this_minute: u32,
59 window_start: SystemTime,
60}
61
62impl MessageRateTracker {
63 fn new() -> Self {
64 Self {
65 messages_this_minute: 0,
66 window_start: SystemTime::now(),
67 }
68 }
69
70 fn maybe_reset_window(&mut self) {
71 let now = SystemTime::now();
72 if now.duration_since(self.window_start).unwrap_or_default() >= Duration::from_secs(60) {
73 self.messages_this_minute = 0;
74 self.window_start = now;
75 }
76 }
77
78 fn record_message(&mut self, limit: u32) -> bool {
79 self.maybe_reset_window();
80 if self.messages_this_minute + 1 > limit {
81 false
82 } else {
83 self.messages_this_minute += 1;
84 true
85 }
86 }
87
88 fn current_usage(&mut self) -> u32 {
89 self.maybe_reset_window();
90 self.messages_this_minute
91 }
92}
93
94impl EgressTracker {
95 fn new() -> Self {
96 Self {
97 bytes_this_minute: 0,
98 window_start: SystemTime::now(),
99 }
100 }
101
102 fn maybe_reset_window(&mut self) {
104 let now = SystemTime::now();
105 if now.duration_since(self.window_start).unwrap_or_default() >= Duration::from_secs(60) {
106 self.bytes_this_minute = 0;
107 self.window_start = now;
108 }
109 }
110
111 fn record_bytes(&mut self, bytes: usize, limit: u64) -> bool {
113 self.maybe_reset_window();
114 let bytes_u64 = bytes as u64;
115 if self.bytes_this_minute + bytes_u64 > limit {
116 false
117 } else {
118 self.bytes_this_minute += bytes_u64;
119 true
120 }
121 }
122
123 fn current_usage(&mut self) -> u64 {
125 self.maybe_reset_window();
126 self.bytes_this_minute
127 }
128}
129
130#[derive(Debug)]
132pub struct ClientInfo {
133 pub id: Uuid,
134 pub subscription: Option<Subscription>,
135 pub last_seen: SystemTime,
136 pub sender: mpsc::Sender<Message>,
137 subscriptions: Arc<RwLock<HashMap<String, CancellationToken>>>,
138 pub auth_context: Option<AuthContext>,
140 pub remote_addr: SocketAddr,
142 egress_tracker: std::sync::Mutex<EgressTracker>,
144 message_rate_tracker: std::sync::Mutex<MessageRateTracker>,
146}
147
148impl ClientInfo {
149 pub fn new(
150 id: Uuid,
151 sender: mpsc::Sender<Message>,
152 auth_context: Option<AuthContext>,
153 remote_addr: SocketAddr,
154 ) -> Self {
155 Self {
156 id,
157 subscription: None,
158 last_seen: SystemTime::now(),
159 sender,
160 subscriptions: Arc::new(RwLock::new(HashMap::new())),
161 auth_context,
162 remote_addr,
163 egress_tracker: std::sync::Mutex::new(EgressTracker::new()),
164 message_rate_tracker: std::sync::Mutex::new(MessageRateTracker::new()),
165 }
166 }
167
168 pub fn record_egress(&self, bytes: usize) -> Option<u64> {
170 if let Ok(mut tracker) = self.egress_tracker.lock() {
171 if let Some(ref ctx) = self.auth_context {
172 if let Some(limit) = ctx.limits.max_bytes_per_minute {
173 if tracker.record_bytes(bytes, limit) {
174 return Some(tracker.current_usage());
175 } else {
176 return None; }
178 }
179 }
180 return Some(tracker.current_usage());
182 }
183 None
184 }
185
186 pub fn record_inbound_message(&self) -> Option<u32> {
188 if let Ok(mut tracker) = self.message_rate_tracker.lock() {
189 if let Some(ref ctx) = self.auth_context {
190 if let Some(limit) = ctx.limits.max_messages_per_minute {
191 if tracker.record_message(limit) {
192 return Some(tracker.current_usage());
193 } else {
194 return None;
195 }
196 }
197 }
198
199 return Some(tracker.current_usage());
200 }
201
202 None
203 }
204
205 pub fn update_last_seen(&mut self) {
206 self.last_seen = SystemTime::now();
207 }
208
209 pub fn is_stale(&self, timeout: Duration) -> bool {
210 self.last_seen.elapsed().unwrap_or(Duration::MAX) > timeout
211 }
212
213 pub async fn add_subscription(&self, sub_key: String, token: CancellationToken) -> bool {
214 let mut subs = self.subscriptions.write().await;
215 if let Some(old_token) = subs.insert(sub_key.clone(), token) {
216 old_token.cancel();
217 debug!("Replaced existing subscription: {}", sub_key);
218 false
219 } else {
220 true
221 }
222 }
223
224 pub async fn remove_subscription(&self, sub_key: &str) -> bool {
225 let mut subs = self.subscriptions.write().await;
226 if let Some(token) = subs.remove(sub_key) {
227 token.cancel();
228 debug!("Cancelled subscription: {}", sub_key);
229 true
230 } else {
231 debug!("Subscription not found for cancellation: {}", sub_key);
232 false
233 }
234 }
235
236 pub async fn cancel_all_subscriptions(&self) {
237 let subs = self.subscriptions.read().await;
238 for (sub_key, token) in subs.iter() {
239 token.cancel();
240 debug!("Cancelled subscription on disconnect: {}", sub_key);
241 }
242 }
243
244 pub async fn subscription_count(&self) -> usize {
245 self.subscriptions.read().await.len()
246 }
247}
248
249#[derive(Debug, Clone)]
254pub struct RateLimitConfig {
255 pub max_connections_per_ip: Option<usize>,
257 pub max_connections_per_metering_key: Option<usize>,
259 pub max_connections_per_origin: Option<usize>,
261 pub client_timeout: Duration,
263 pub message_queue_size: usize,
265 pub max_reconnect_attempts: Option<u32>,
267 pub message_rate_window: Duration,
269 pub egress_rate_window: Duration,
271 pub default_limits: Option<Limits>,
274}
275
276impl Default for RateLimitConfig {
277 fn default() -> Self {
278 Self {
279 max_connections_per_ip: None,
280 max_connections_per_metering_key: None,
281 max_connections_per_origin: None,
282 client_timeout: Duration::from_secs(300),
283 message_queue_size: 512,
284 max_reconnect_attempts: None,
285 message_rate_window: Duration::from_secs(60),
286 egress_rate_window: Duration::from_secs(60),
287 default_limits: None,
288 }
289 }
290}
291
292impl RateLimitConfig {
293 pub fn from_env() -> Self {
308 let mut config = Self::default();
309
310 if let Ok(val) = std::env::var("HYPERSTACK_WS_MAX_CONNECTIONS_PER_IP") {
311 if let Ok(max) = val.parse() {
312 config.max_connections_per_ip = Some(max);
313 }
314 }
315
316 if let Ok(val) = std::env::var("HYPERSTACK_WS_MAX_CONNECTIONS_PER_METERING_KEY") {
317 if let Ok(max) = val.parse() {
318 config.max_connections_per_metering_key = Some(max);
319 }
320 }
321
322 if let Ok(val) = std::env::var("HYPERSTACK_WS_MAX_CONNECTIONS_PER_ORIGIN") {
323 if let Ok(max) = val.parse() {
324 config.max_connections_per_origin = Some(max);
325 }
326 }
327
328 if let Ok(val) = std::env::var("HYPERSTACK_WS_CLIENT_TIMEOUT_SECS") {
329 if let Ok(secs) = val.parse() {
330 config.client_timeout = Duration::from_secs(secs);
331 }
332 }
333
334 if let Ok(val) = std::env::var("HYPERSTACK_WS_MESSAGE_QUEUE_SIZE") {
335 if let Ok(size) = val.parse() {
336 config.message_queue_size = size;
337 }
338 }
339
340 if let Ok(val) = std::env::var("HYPERSTACK_WS_RATE_LIMIT_WINDOW_SECS") {
341 if let Ok(secs) = val.parse() {
342 config.message_rate_window = Duration::from_secs(secs);
343 config.egress_rate_window = Duration::from_secs(secs);
344 }
345 }
346
347 let mut default_limits = Limits::default();
349 let mut has_default_limits = false;
350
351 if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_CONNECTIONS") {
352 if let Ok(max) = val.parse() {
353 default_limits.max_connections = Some(max);
354 has_default_limits = true;
355 }
356 }
357
358 if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_SUBSCRIPTIONS") {
359 if let Ok(max) = val.parse() {
360 default_limits.max_subscriptions = Some(max);
361 has_default_limits = true;
362 }
363 }
364
365 if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_SNAPSHOT_ROWS") {
366 if let Ok(max) = val.parse() {
367 default_limits.max_snapshot_rows = Some(max);
368 has_default_limits = true;
369 }
370 }
371
372 if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_MESSAGES_PER_MINUTE") {
373 if let Ok(max) = val.parse() {
374 default_limits.max_messages_per_minute = Some(max);
375 has_default_limits = true;
376 }
377 }
378
379 if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_BYTES_PER_MINUTE") {
380 if let Ok(max) = val.parse() {
381 default_limits.max_bytes_per_minute = Some(max);
382 has_default_limits = true;
383 }
384 }
385
386 if has_default_limits {
387 config.default_limits = Some(default_limits);
388 }
389
390 config
391 }
392
393 pub fn with_max_connections_per_ip(mut self, max: usize) -> Self {
395 self.max_connections_per_ip = Some(max);
396 self
397 }
398
399 pub fn with_timeout(mut self, timeout: Duration) -> Self {
401 self.client_timeout = timeout;
402 self
403 }
404
405 pub fn with_message_queue_size(mut self, size: usize) -> Self {
407 self.message_queue_size = size;
408 self
409 }
410
411 pub fn with_rate_limit_window(mut self, window: Duration) -> Self {
413 self.message_rate_window = window;
414 self.egress_rate_window = window;
415 self
416 }
417
418 pub fn with_default_limits(mut self, limits: Limits) -> Self {
423 self.default_limits = Some(limits);
424 self
425 }
426}
427
428#[derive(Clone)]
437pub struct ClientManager {
438 clients: Arc<DashMap<Uuid, ClientInfo>>,
439 rate_limit_config: RateLimitConfig,
440 rate_limiter: Option<Arc<WebSocketRateLimiter>>,
442}
443
444impl ClientManager {
445 pub fn new() -> Self {
446 Self::with_config(RateLimitConfig::default())
447 }
448
449 pub fn with_config(config: RateLimitConfig) -> Self {
451 Self {
452 clients: Arc::new(DashMap::new()),
453 rate_limit_config: config,
454 rate_limiter: None,
455 }
456 }
457
458 pub fn from_env() -> Self {
462 Self::with_config(RateLimitConfig::from_env())
463 }
464
465 pub fn with_timeout(mut self, timeout: Duration) -> Self {
467 self.rate_limit_config.client_timeout = timeout;
468 self
469 }
470
471 pub fn with_message_queue_size(mut self, queue_size: usize) -> Self {
473 self.rate_limit_config.message_queue_size = queue_size;
474 self
475 }
476
477 pub fn with_max_connections_per_ip(mut self, max: usize) -> Self {
479 self.rate_limit_config.max_connections_per_ip = Some(max);
480 self
481 }
482
483 pub fn with_rate_limit_window(mut self, window: Duration) -> Self {
485 self.rate_limit_config.message_rate_window = window;
486 self.rate_limit_config.egress_rate_window = window;
487 self
488 }
489
490 pub fn with_default_limits(mut self, limits: Limits) -> Self {
495 self.rate_limit_config.default_limits = Some(limits);
496 self
497 }
498
499 pub fn with_rate_limiter(mut self, rate_limiter: Arc<WebSocketRateLimiter>) -> Self {
501 self.rate_limiter = Some(rate_limiter);
502 self
503 }
504
505 pub fn rate_limiter(&self) -> Option<&WebSocketRateLimiter> {
507 self.rate_limiter.as_ref().map(|r| r.as_ref())
508 }
509
510 pub fn rate_limit_config(&self) -> &RateLimitConfig {
512 &self.rate_limit_config
513 }
514
515 pub fn add_client(
521 &self,
522 client_id: Uuid,
523 mut ws_sender: WebSocketSender,
524 auth_context: Option<AuthContext>,
525 remote_addr: SocketAddr,
526 ) {
527 let (client_tx, mut client_rx) =
528 mpsc::channel::<Message>(self.rate_limit_config.message_queue_size);
529 let client_info = ClientInfo::new(client_id, client_tx, auth_context, remote_addr);
530
531 let clients_ref = self.clients.clone();
532 tokio::spawn(async move {
533 while let Some(message) = client_rx.recv().await {
534 if let Err(e) = ws_sender.send(message).await {
535 warn!("Failed to send message to client {}: {}", client_id, e);
536 break;
537 }
538 }
539 clients_ref.remove(&client_id);
540 debug!("WebSocket sender task for client {} stopped", client_id);
541 });
542
543 self.clients.insert(client_id, client_info);
544 info!("Client {} registered from {}", client_id, remote_addr);
545 }
546
547 pub fn remove_client(&self, client_id: Uuid) {
549 if self.clients.remove(&client_id).is_some() {
550 info!("Client {} removed", client_id);
551 }
552 }
553
554 pub fn update_client_auth(&self, client_id: Uuid, auth_context: AuthContext) -> bool {
558 if let Some(mut client) = self.clients.get_mut(&client_id) {
559 client.auth_context = Some(auth_context);
560 debug!("Updated auth context for client {}", client_id);
561 true
562 } else {
563 false
564 }
565 }
566
567 pub fn check_and_remove_expired(&self, client_id: Uuid) -> bool {
572 if let Some(client) = self.clients.get(&client_id) {
573 if let Some(ref ctx) = client.auth_context {
574 let now = std::time::SystemTime::now()
575 .duration_since(std::time::UNIX_EPOCH)
576 .unwrap_or_default()
577 .as_secs();
578 if ctx.expires_at <= now {
579 warn!(
580 "Client {} token expired (expired at {}), disconnecting",
581 client_id, ctx.expires_at
582 );
583 self.clients.remove(&client_id);
584 return true;
585 }
586 }
587 }
588 false
589 }
590
591 pub fn client_count(&self) -> usize {
596 self.clients.len()
597 }
598
599 pub fn send_to_client(&self, client_id: Uuid, data: Arc<Bytes>) -> Result<(), SendError> {
608 if self.check_and_remove_expired(client_id) {
610 return Err(SendError::ClientDisconnected);
611 }
612
613 if let Some(client) = self.clients.get(&client_id) {
615 if client.record_egress(data.len()).is_none() {
616 warn!("Client {} exceeded egress limit, disconnecting", client_id);
617 self.clients.remove(&client_id);
618 return Err(SendError::ClientDisconnected);
619 }
620 } else {
621 return Err(SendError::ClientNotFound);
622 }
623
624 let sender = {
625 let client = self
626 .clients
627 .get(&client_id)
628 .ok_or(SendError::ClientNotFound)?;
629 client.sender.clone()
630 };
631
632 let msg = Message::Binary((*data).clone());
633 match sender.try_send(msg) {
634 Ok(()) => Ok(()),
635 Err(mpsc::error::TrySendError::Full(_)) => {
636 warn!(
637 "Client {} backpressured (queue full), disconnecting",
638 client_id
639 );
640 self.clients.remove(&client_id);
641 Err(SendError::ClientBackpressured)
642 }
643 Err(mpsc::error::TrySendError::Closed(_)) => {
644 debug!("Client {} channel closed", client_id);
645 self.clients.remove(&client_id);
646 Err(SendError::ClientDisconnected)
647 }
648 }
649 }
650
651 pub async fn send_to_client_async(
660 &self,
661 client_id: Uuid,
662 data: Arc<Bytes>,
663 ) -> Result<(), SendError> {
664 if self.check_and_remove_expired(client_id) {
666 return Err(SendError::ClientDisconnected);
667 }
668
669 if let Some(client) = self.clients.get(&client_id) {
671 if client.record_egress(data.len()).is_none() {
672 warn!("Client {} exceeded egress limit, disconnecting", client_id);
673 self.clients.remove(&client_id);
674 return Err(SendError::ClientDisconnected);
675 }
676 } else {
677 return Err(SendError::ClientNotFound);
678 }
679
680 let sender = {
681 let client = self
682 .clients
683 .get(&client_id)
684 .ok_or(SendError::ClientNotFound)?;
685 client.sender.clone()
686 };
687
688 let msg = Message::Binary((*data).clone());
689 sender
690 .send(msg)
691 .await
692 .map_err(|_| SendError::ClientDisconnected)
693 }
694
695 pub async fn send_text_to_client(
700 &self,
701 client_id: Uuid,
702 text: String,
703 ) -> Result<(), SendError> {
704 if self.check_and_remove_expired(client_id) {
706 return Err(SendError::ClientDisconnected);
707 }
708
709 let sender = {
710 let client = self
711 .clients
712 .get(&client_id)
713 .ok_or(SendError::ClientNotFound)?;
714 client.sender.clone()
715 };
716
717 let msg = Message::Text(text.into());
718 sender
719 .send(msg)
720 .await
721 .map_err(|_| SendError::ClientDisconnected)
722 }
723
724 pub async fn send_compressed_async(
729 &self,
730 client_id: Uuid,
731 payload: CompressedPayload,
732 ) -> Result<(), SendError> {
733 if self.check_and_remove_expired(client_id) {
735 return Err(SendError::ClientDisconnected);
736 }
737
738 let (sender, bytes_to_record) = {
739 let client = self
740 .clients
741 .get(&client_id)
742 .ok_or(SendError::ClientNotFound)?;
743
744 let bytes = match &payload {
745 CompressedPayload::Compressed(bytes) => bytes.len(),
746 CompressedPayload::Uncompressed(bytes) => bytes.len(),
747 };
748
749 (client.sender.clone(), bytes)
750 };
751
752 if let Some(client) = self.clients.get(&client_id) {
754 if client.record_egress(bytes_to_record).is_none() {
755 warn!("Client {} exceeded egress limit, disconnecting", client_id);
756 self.clients.remove(&client_id);
757 return Err(SendError::ClientDisconnected);
758 }
759 }
760
761 let msg = match payload {
762 CompressedPayload::Compressed(bytes) => Message::Binary(bytes),
763 CompressedPayload::Uncompressed(bytes) => Message::Binary(bytes),
764 };
765 sender
766 .send(msg)
767 .await
768 .map_err(|_| SendError::ClientDisconnected)
769 }
770
771 pub fn update_subscription(&self, client_id: Uuid, subscription: Subscription) -> bool {
773 if let Some(mut client) = self.clients.get_mut(&client_id) {
774 client.subscription = Some(subscription);
775 client.update_last_seen();
776 debug!("Updated subscription for client {}", client_id);
777 true
778 } else {
779 warn!(
780 "Failed to update subscription for unknown client {}",
781 client_id
782 );
783 false
784 }
785 }
786
787 pub fn update_client_last_seen(&self, client_id: Uuid) {
789 if let Some(mut client) = self.clients.get_mut(&client_id) {
790 client.update_last_seen();
791 }
792 }
793
794 #[allow(clippy::result_large_err)]
796 pub fn check_inbound_message_allowed(&self, client_id: Uuid) -> Result<(), AuthDeny> {
797 if self.check_and_remove_expired(client_id) {
798 return Err(AuthDeny::new(
799 crate::websocket::auth::AuthErrorCode::TokenExpired,
800 "Authentication token expired",
801 ));
802 }
803
804 let Some(client) = self.clients.get(&client_id) else {
805 return Err(AuthDeny::new(
806 crate::websocket::auth::AuthErrorCode::InternalError,
807 "Client not found",
808 ));
809 };
810
811 if client.record_inbound_message().is_some() {
812 Ok(())
813 } else {
814 self.clients.remove(&client_id);
815 Err(AuthDeny::rate_limited(
816 self.rate_limit_config.message_rate_window,
817 "inbound websocket messages",
818 )
819 .with_context(format!(
820 "client {} exceeded the inbound message budget",
821 client_id
822 )))
823 }
824 }
825
826 pub fn get_subscription(&self, client_id: Uuid) -> Option<Subscription> {
828 self.clients
829 .get(&client_id)
830 .and_then(|c| c.subscription.clone())
831 }
832
833 pub fn has_client(&self, client_id: Uuid) -> bool {
835 self.clients.contains_key(&client_id)
836 }
837
838 pub async fn add_client_subscription(
839 &self,
840 client_id: Uuid,
841 sub_key: String,
842 token: CancellationToken,
843 ) -> bool {
844 if let Some(client) = self.clients.get(&client_id) {
845 client.add_subscription(sub_key, token).await
846 } else {
847 false
848 }
849 }
850
851 pub async fn remove_client_subscription(&self, client_id: Uuid, sub_key: &str) -> bool {
852 if let Some(client) = self.clients.get(&client_id) {
853 client.remove_subscription(sub_key).await
854 } else {
855 false
856 }
857 }
858
859 pub async fn cancel_all_client_subscriptions(&self, client_id: Uuid) {
860 if let Some(client) = self.clients.get(&client_id) {
861 client.cancel_all_subscriptions().await;
862 }
863 }
864
865 pub fn cleanup_stale_clients(&self) -> usize {
867 let timeout = self.rate_limit_config.client_timeout;
868 let mut stale_clients = Vec::new();
869
870 for entry in self.clients.iter() {
871 if entry.value().is_stale(timeout) {
872 stale_clients.push(*entry.key());
873 }
874 }
875
876 let removed_count = stale_clients.len();
877 for client_id in stale_clients {
878 self.clients.remove(&client_id);
879 info!("Removed stale client {}", client_id);
880 }
881
882 removed_count
883 }
884
885 pub fn start_cleanup_task(&self) {
887 let client_manager = self.clone();
888
889 tokio::spawn(async move {
890 let mut interval = tokio::time::interval(Duration::from_secs(30));
891
892 loop {
893 interval.tick().await;
894 let removed = client_manager.cleanup_stale_clients();
895 if removed > 0 {
896 info!("Cleaned up {} stale clients", removed);
897 }
898 }
899 });
900 }
901
902 pub async fn check_connection_allowed(
910 &self,
911 remote_addr: SocketAddr,
912 auth_context: &Option<AuthContext>,
913 ) -> Result<(), AuthDeny> {
914 if let Some(ref rate_limiter) = self.rate_limiter {
916 match rate_limiter.check_handshake(remote_addr).await {
918 RateLimitResult::Allowed { .. } => {}
919 RateLimitResult::Denied { retry_after, limit } => {
920 return Err(AuthDeny::rate_limited(retry_after, "websocket handshakes")
921 .with_context(format!(
922 "handshake rate limit of {} per minute exceeded for {}",
923 limit, remote_addr
924 )));
925 }
926 }
927
928 if let Some(ref ctx) = auth_context {
930 match rate_limiter
931 .check_connection_for_subject(&ctx.subject)
932 .await
933 {
934 RateLimitResult::Allowed { .. } => {}
935 RateLimitResult::Denied { retry_after, limit } => {
936 return Err(AuthDeny::rate_limited(retry_after, "websocket connections")
937 .with_context(format!(
938 "connection rate limit for subject {} of {} per minute exceeded",
939 ctx.subject, limit
940 )));
941 }
942 }
943
944 match rate_limiter
946 .check_connection_for_metering_key(&ctx.metering_key)
947 .await
948 {
949 RateLimitResult::Allowed { .. } => {}
950 RateLimitResult::Denied { retry_after, limit } => {
951 return Err(AuthDeny::rate_limited(
952 retry_after,
953 "metered websocket connections",
954 )
955 .with_context(format!(
956 "connection rate limit for metering key {} of {} per minute exceeded",
957 ctx.metering_key, limit
958 )));
959 }
960 }
961 }
962 }
963
964 if let Some(max_per_ip) = self.rate_limit_config.max_connections_per_ip {
966 let current_ip_connections = self.count_connections_for_ip(&remote_addr);
967 if current_ip_connections >= max_per_ip {
968 return Err(AuthDeny::connection_limit_exceeded(
969 &format!("ip {}", remote_addr.ip()),
970 current_ip_connections,
971 max_per_ip,
972 ));
973 }
974 }
975
976 if let Some(ctx) = auth_context {
977 let max_connections = ctx.limits.max_connections.or_else(|| {
979 self.rate_limit_config
980 .default_limits
981 .as_ref()
982 .and_then(|l| l.max_connections)
983 });
984 if let Some(max_connections) = max_connections {
985 let current_connections = self.count_connections_for_subject(&ctx.subject);
986 if current_connections >= max_connections as usize {
987 return Err(AuthDeny::connection_limit_exceeded(
988 &format!("subject {}", ctx.subject),
989 current_connections,
990 max_connections as usize,
991 ));
992 }
993 }
994
995 if let Some(max_per_metering_key) =
997 self.rate_limit_config.max_connections_per_metering_key
998 {
999 let current_metering_connections =
1000 self.count_connections_for_metering_key(&ctx.metering_key);
1001 if current_metering_connections >= max_per_metering_key {
1002 return Err(AuthDeny::connection_limit_exceeded(
1003 &format!("metering key {}", ctx.metering_key),
1004 current_metering_connections,
1005 max_per_metering_key,
1006 ));
1007 }
1008 }
1009
1010 if let Some(max_per_origin) = self.rate_limit_config.max_connections_per_origin {
1012 if let Some(ref origin) = ctx.origin {
1013 let current_origin_connections = self.count_connections_for_origin(origin);
1014 if current_origin_connections >= max_per_origin {
1015 return Err(AuthDeny::connection_limit_exceeded(
1016 &format!("origin {}", origin),
1017 current_origin_connections,
1018 max_per_origin,
1019 ));
1020 }
1021 }
1022 }
1023 }
1024 Ok(())
1025 }
1026
1027 fn count_connections_for_ip(&self, remote_addr: &SocketAddr) -> usize {
1029 let ip = remote_addr.ip();
1030 self.clients
1031 .iter()
1032 .filter(|entry| entry.value().remote_addr.ip() == ip)
1033 .count()
1034 }
1035
1036 fn count_connections_for_subject(&self, subject: &str) -> usize {
1038 self.clients
1039 .iter()
1040 .filter(|entry| {
1041 entry
1042 .value()
1043 .auth_context
1044 .as_ref()
1045 .map(|ctx| ctx.subject == subject)
1046 .unwrap_or(false)
1047 })
1048 .count()
1049 }
1050
1051 fn count_connections_for_metering_key(&self, metering_key: &str) -> usize {
1053 self.clients
1054 .iter()
1055 .filter(|entry| {
1056 entry
1057 .value()
1058 .auth_context
1059 .as_ref()
1060 .map(|ctx| ctx.metering_key == metering_key)
1061 .unwrap_or(false)
1062 })
1063 .count()
1064 }
1065
1066 fn count_connections_for_origin(&self, origin: &str) -> usize {
1068 self.clients
1069 .iter()
1070 .filter(|entry| {
1071 entry
1072 .value()
1073 .auth_context
1074 .as_ref()
1075 .and_then(|ctx| ctx.origin.as_ref())
1076 .map(|o| o == origin)
1077 .unwrap_or(false)
1078 })
1079 .count()
1080 }
1081
1082 pub async fn check_subscription_allowed(&self, client_id: Uuid) -> Result<(), AuthDeny> {
1086 if let Some(client) = self.clients.get(&client_id) {
1087 let current_subs = client.subscription_count().await;
1088
1089 if let Some(ref ctx) = client.auth_context {
1091 let max_subs = ctx.limits.max_subscriptions.or_else(|| {
1092 self.rate_limit_config
1093 .default_limits
1094 .as_ref()
1095 .and_then(|l| l.max_subscriptions)
1096 });
1097 if let Some(max_subs) = max_subs {
1098 if current_subs >= max_subs as usize {
1099 return Err(AuthDeny::new(
1100 crate::websocket::auth::AuthErrorCode::SubscriptionLimitExceeded,
1101 format!(
1102 "Subscription limit exceeded: {} of {} subscriptions for client {}",
1103 current_subs, max_subs, client_id
1104 ),
1105 )
1106 .with_suggested_action(
1107 "Unsubscribe from an existing view before creating another subscription",
1108 ));
1109 }
1110 }
1111 }
1112 }
1113 Ok(())
1114 }
1115
1116 pub fn get_metering_key(&self, client_id: Uuid) -> Option<String> {
1118 self.clients.get(&client_id).and_then(|client| {
1119 client
1120 .auth_context
1121 .as_ref()
1122 .map(|ctx| ctx.metering_key.clone())
1123 })
1124 }
1125
1126 pub fn get_auth_context(&self, client_id: Uuid) -> Option<AuthContext> {
1128 self.clients
1129 .get(&client_id)
1130 .and_then(|client| client.auth_context.clone())
1131 }
1132
1133 #[allow(clippy::result_large_err)]
1137 pub fn check_snapshot_allowed(
1138 &self,
1139 client_id: Uuid,
1140 requested_rows: u32,
1141 ) -> Result<(), AuthDeny> {
1142 if let Some(client) = self.clients.get(&client_id) {
1143 if let Some(ref ctx) = client.auth_context {
1144 let max_rows = ctx.limits.max_snapshot_rows.or_else(|| {
1145 self.rate_limit_config
1146 .default_limits
1147 .as_ref()
1148 .and_then(|l| l.max_snapshot_rows)
1149 });
1150 if let Some(max_rows) = max_rows {
1151 if requested_rows > max_rows {
1152 return Err(AuthDeny::new(
1153 crate::websocket::auth::AuthErrorCode::SnapshotLimitExceeded,
1154 format!(
1155 "Snapshot limit exceeded: requested {} rows, max allowed is {} for client {}",
1156 requested_rows, max_rows, client_id
1157 ),
1158 )
1159 .with_suggested_action(
1160 "Request fewer rows or lower the snapshotLimit on the subscription",
1161 ));
1162 }
1163 }
1164 }
1165 }
1166 Ok(())
1167 }
1168}
1169
1170impl Default for ClientManager {
1171 fn default() -> Self {
1172 Self::new()
1173 }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178 use super::*;
1179 use crate::websocket::auth::AuthContext;
1180 use hyperstack_auth::{KeyClass, Limits};
1181 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
1182
1183 fn create_test_auth_context(subject: &str, limits: Limits) -> AuthContext {
1184 AuthContext {
1185 subject: subject.to_string(),
1186 issuer: "test-issuer".to_string(),
1187 key_class: KeyClass::Publishable,
1188 metering_key: format!("meter-{}", subject),
1189 deployment_id: None,
1190 expires_at: u64::MAX,
1191 scope: "read".to_string(),
1192 limits,
1193 plan: None,
1194 origin: None,
1195 client_ip: None,
1196 jti: uuid::Uuid::new_v4().to_string(),
1197 }
1198 }
1199
1200 fn create_test_socket_addr(ip: &str) -> SocketAddr {
1201 SocketAddr::new(
1202 ip.parse::<IpAddr>()
1203 .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
1204 12345,
1205 )
1206 }
1207
1208 #[test]
1209 fn test_egress_tracker_basic() {
1210 let mut tracker = EgressTracker::new();
1211
1212 assert!(tracker.record_bytes(500, 1000));
1214 assert_eq!(tracker.current_usage(), 500);
1215
1216 assert!(tracker.record_bytes(400, 1000));
1218 assert_eq!(tracker.current_usage(), 900);
1219
1220 assert!(!tracker.record_bytes(200, 1000));
1222 assert_eq!(tracker.current_usage(), 900); }
1224
1225 #[test]
1226 fn test_egress_tracker_window_reset() {
1227 let mut tracker = EgressTracker::new();
1228
1229 assert!(tracker.record_bytes(100, 100));
1231 assert!(!tracker.record_bytes(1, 100));
1232
1233 tracker.bytes_this_minute = 0;
1235 tracker.window_start = SystemTime::now() - Duration::from_secs(61);
1236
1237 assert!(tracker.record_bytes(50, 100));
1239 }
1240
1241 #[test]
1242 fn test_message_rate_tracker_basic() {
1243 let mut tracker = MessageRateTracker::new();
1244
1245 assert!(tracker.record_message(2));
1246 assert_eq!(tracker.current_usage(), 1);
1247
1248 assert!(tracker.record_message(2));
1249 assert_eq!(tracker.current_usage(), 2);
1250
1251 assert!(!tracker.record_message(2));
1252 assert_eq!(tracker.current_usage(), 2);
1253 }
1254
1255 #[tokio::test]
1256 async fn test_client_inbound_message_limit() {
1257 let (tx, _rx) = mpsc::channel(1);
1258 let client = ClientInfo::new(
1259 Uuid::new_v4(),
1260 tx,
1261 Some(create_test_auth_context(
1262 "user-1",
1263 Limits {
1264 max_messages_per_minute: Some(2),
1265 ..Default::default()
1266 },
1267 )),
1268 create_test_socket_addr("127.0.0.1"),
1269 );
1270
1271 assert_eq!(client.record_inbound_message(), Some(1));
1272 assert_eq!(client.record_inbound_message(), Some(2));
1273 assert_eq!(client.record_inbound_message(), None);
1274 }
1275
1276 #[tokio::test]
1277 async fn test_no_limits() {
1278 let manager = ClientManager::new();
1279 let addr = create_test_socket_addr("127.0.0.1");
1280
1281 assert!(manager.check_connection_allowed(addr, &None).await.is_ok());
1283
1284 let auth_context = create_test_auth_context("test", Limits::default());
1286 assert!(manager
1287 .check_connection_allowed(addr, &Some(auth_context))
1288 .await
1289 .is_ok());
1290 }
1291
1292 #[tokio::test]
1293 async fn test_per_subject_connection_limit() {
1294 let manager = ClientManager::new();
1295
1296 let limits = Limits {
1297 max_connections: Some(2),
1298 ..Default::default()
1299 };
1300
1301 let auth_context = create_test_auth_context("user-1", limits);
1302 let addr = create_test_socket_addr("127.0.0.1");
1303
1304 assert!(manager
1306 .check_connection_allowed(addr, &Some(auth_context.clone()))
1307 .await
1308 .is_ok());
1309 }
1310
1311 #[tokio::test]
1312 async fn test_per_ip_connection_limit() {
1313 let manager = ClientManager::new().with_max_connections_per_ip(2);
1314 let addr = create_test_socket_addr("192.168.1.1");
1315
1316 assert!(manager.check_connection_allowed(addr, &None).await.is_ok());
1318 }
1319
1320 #[test]
1322 fn rate_limit_config_default() {
1323 let config = RateLimitConfig::default();
1324 assert!(config.max_connections_per_ip.is_none());
1325 assert_eq!(config.client_timeout, Duration::from_secs(300));
1326 assert_eq!(config.message_queue_size, 512);
1327 assert!(config.max_reconnect_attempts.is_none());
1328 assert_eq!(config.message_rate_window, Duration::from_secs(60));
1329 assert_eq!(config.egress_rate_window, Duration::from_secs(60));
1330 }
1331
1332 #[test]
1333 fn rate_limit_config_builder_methods() {
1334 let config = RateLimitConfig::default()
1335 .with_max_connections_per_ip(10)
1336 .with_timeout(Duration::from_secs(600))
1337 .with_message_queue_size(1024)
1338 .with_rate_limit_window(Duration::from_secs(120));
1339
1340 assert_eq!(config.max_connections_per_ip, Some(10));
1341 assert_eq!(config.client_timeout, Duration::from_secs(600));
1342 assert_eq!(config.message_queue_size, 1024);
1343 assert_eq!(config.message_rate_window, Duration::from_secs(120));
1344 assert_eq!(config.egress_rate_window, Duration::from_secs(120));
1345 }
1346
1347 #[tokio::test]
1348 async fn client_manager_with_config() {
1349 let config = RateLimitConfig::default()
1350 .with_max_connections_per_ip(5)
1351 .with_timeout(Duration::from_secs(120))
1352 .with_message_queue_size(256);
1353
1354 let manager = ClientManager::with_config(config);
1355 let addr = create_test_socket_addr("10.0.0.1");
1356
1357 assert_eq!(manager.rate_limit_config().max_connections_per_ip, Some(5));
1359 assert_eq!(
1360 manager.rate_limit_config().client_timeout,
1361 Duration::from_secs(120)
1362 );
1363 assert_eq!(manager.rate_limit_config().message_queue_size, 256);
1364
1365 assert!(manager.check_connection_allowed(addr, &None).await.is_ok());
1367 }
1368
1369 #[tokio::test]
1370 async fn client_manager_builder_pattern() {
1371 let manager = ClientManager::new()
1372 .with_max_connections_per_ip(10)
1373 .with_timeout(Duration::from_secs(180))
1374 .with_message_queue_size(1024)
1375 .with_rate_limit_window(Duration::from_secs(90));
1376
1377 assert_eq!(manager.rate_limit_config().max_connections_per_ip, Some(10));
1378 assert_eq!(
1379 manager.rate_limit_config().client_timeout,
1380 Duration::from_secs(180)
1381 );
1382 assert_eq!(manager.rate_limit_config().message_queue_size, 1024);
1383 assert_eq!(
1384 manager.rate_limit_config().message_rate_window,
1385 Duration::from_secs(90)
1386 );
1387 }
1388
1389 #[tokio::test]
1391 async fn connection_limit_enforcement_with_actual_clients() {
1392 let manager = ClientManager::new().with_max_connections_per_ip(2);
1393 let addr1 = create_test_socket_addr("192.168.1.1");
1394 let addr2 = create_test_socket_addr("192.168.1.2");
1395
1396 let auth1 = create_test_auth_context("user-1", Limits::default());
1398 assert!(manager
1399 .check_connection_allowed(addr1, &Some(auth1.clone()))
1400 .await
1401 .is_ok());
1402
1403 let auth2 = create_test_auth_context("user-2", Limits::default());
1408 assert!(manager
1409 .check_connection_allowed(addr1, &Some(auth2.clone()))
1410 .await
1411 .is_ok());
1412
1413 let auth3 = create_test_auth_context("user-3", Limits::default());
1415 assert!(manager
1416 .check_connection_allowed(addr2, &Some(auth3.clone()))
1417 .await
1418 .is_ok());
1419 }
1420
1421 #[tokio::test]
1423 async fn subscription_limit_enforcement() {
1424 let manager = ClientManager::new();
1425 let addr = create_test_socket_addr("127.0.0.1");
1426
1427 let auth = create_test_auth_context(
1429 "user-1",
1430 Limits {
1431 max_subscriptions: Some(2),
1432 ..Default::default()
1433 },
1434 );
1435
1436 assert!(manager
1438 .check_connection_allowed(addr, &Some(auth.clone()))
1439 .await
1440 .is_ok());
1441
1442 assert_eq!(auth.limits.max_subscriptions, Some(2));
1445 }
1446
1447 #[tokio::test]
1449 async fn snapshot_limit_enforcement() {
1450 let manager = ClientManager::new();
1451 let addr = create_test_socket_addr("127.0.0.1");
1452
1453 let auth = create_test_auth_context(
1454 "user-1",
1455 Limits {
1456 max_snapshot_rows: Some(1000),
1457 ..Default::default()
1458 },
1459 );
1460
1461 assert!(manager
1462 .check_connection_allowed(addr, &Some(auth.clone()))
1463 .await
1464 .is_ok());
1465
1466 }
1469
1470 #[tokio::test]
1472 async fn test_rate_limiter_integration() {
1473 use crate::websocket::rate_limiter::{RateLimiterConfig, WebSocketRateLimiter};
1474
1475 let rate_limiter = Arc::new(WebSocketRateLimiter::new(RateLimiterConfig::default()));
1476 let manager = ClientManager::new().with_rate_limiter(rate_limiter);
1477 let addr = create_test_socket_addr("127.0.0.1");
1478
1479 let auth = create_test_auth_context("user-1", Limits::default());
1481 assert!(manager
1482 .check_connection_allowed(addr, &Some(auth))
1483 .await
1484 .is_ok());
1485 }
1486}