Skip to main content

hyperstack_server/websocket/
client_manager.rs

1use super::subscription::Subscription;
2use crate::compression::CompressedPayload;
3use crate::websocket::auth::{AuthContext, AuthDeny};
4use crate::websocket::rate_limiter::{RateLimitResult, WebSocketRateLimiter};
5use bytes::Bytes;
6use dashmap::DashMap;
7use futures_util::stream::SplitSink;
8use futures_util::SinkExt;
9use hyperstack_auth::Limits;
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::net::TcpStream;
15use tokio::sync::{mpsc, RwLock};
16use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
17use tokio_util::sync::CancellationToken;
18use tracing::{debug, info, warn};
19use uuid::Uuid;
20
21pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
22
23/// Error type for send operations
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SendError {
26    /// Client not found in registry
27    ClientNotFound,
28    /// Client's message queue is full - client was disconnected
29    ClientBackpressured,
30    /// Client's channel is closed - client was disconnected
31    ClientDisconnected,
32}
33
34impl std::fmt::Display for SendError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            SendError::ClientNotFound => write!(f, "client not found"),
38            SendError::ClientBackpressured => write!(f, "client backpressured and disconnected"),
39            SendError::ClientDisconnected => write!(f, "client disconnected"),
40        }
41    }
42}
43
44impl std::error::Error for SendError {}
45
46/// Egress tracking for a client
47#[derive(Debug)]
48struct EgressTracker {
49    /// Bytes sent in the current minute window
50    bytes_this_minute: u64,
51    /// Start of the current minute window
52    window_start: SystemTime,
53}
54
55/// Inbound message-rate tracking for a client
56#[derive(Debug)]
57struct MessageRateTracker {
58    messages_this_minute: u32,
59    window_start: SystemTime,
60}
61
62impl MessageRateTracker {
63    fn new() -> Self {
64        Self {
65            messages_this_minute: 0,
66            window_start: SystemTime::now(),
67        }
68    }
69
70    fn maybe_reset_window(&mut self) {
71        let now = SystemTime::now();
72        if now.duration_since(self.window_start).unwrap_or_default() >= Duration::from_secs(60) {
73            self.messages_this_minute = 0;
74            self.window_start = now;
75        }
76    }
77
78    fn record_message(&mut self, limit: u32) -> bool {
79        self.maybe_reset_window();
80        if self.messages_this_minute + 1 > limit {
81            false
82        } else {
83            self.messages_this_minute += 1;
84            true
85        }
86    }
87
88    fn current_usage(&mut self) -> u32 {
89        self.maybe_reset_window();
90        self.messages_this_minute
91    }
92}
93
94impl EgressTracker {
95    fn new() -> Self {
96        Self {
97            bytes_this_minute: 0,
98            window_start: SystemTime::now(),
99        }
100    }
101
102    /// Check if we need to reset the window (new minute)
103    fn maybe_reset_window(&mut self) {
104        let now = SystemTime::now();
105        if now.duration_since(self.window_start).unwrap_or_default() >= Duration::from_secs(60) {
106            self.bytes_this_minute = 0;
107            self.window_start = now;
108        }
109    }
110
111    /// Record bytes sent, returning true if within limit
112    fn record_bytes(&mut self, bytes: usize, limit: u64) -> bool {
113        self.maybe_reset_window();
114        let bytes_u64 = bytes as u64;
115        if self.bytes_this_minute + bytes_u64 > limit {
116            false
117        } else {
118            self.bytes_this_minute += bytes_u64;
119            true
120        }
121    }
122
123    /// Get current usage
124    fn current_usage(&mut self) -> u64 {
125        self.maybe_reset_window();
126        self.bytes_this_minute
127    }
128}
129
130/// Information about a connected client
131#[derive(Debug)]
132pub struct ClientInfo {
133    pub id: Uuid,
134    pub subscription: Option<Subscription>,
135    pub last_seen: SystemTime,
136    pub sender: mpsc::Sender<Message>,
137    subscriptions: Arc<RwLock<HashMap<String, CancellationToken>>>,
138    /// Authentication context for this client
139    pub auth_context: Option<AuthContext>,
140    /// Client's IP address for rate limiting
141    pub remote_addr: SocketAddr,
142    /// Egress tracking for rate limiting
143    egress_tracker: std::sync::Mutex<EgressTracker>,
144    /// Inbound message-rate tracking for rate limiting
145    message_rate_tracker: std::sync::Mutex<MessageRateTracker>,
146}
147
148impl ClientInfo {
149    pub fn new(
150        id: Uuid,
151        sender: mpsc::Sender<Message>,
152        auth_context: Option<AuthContext>,
153        remote_addr: SocketAddr,
154    ) -> Self {
155        Self {
156            id,
157            subscription: None,
158            last_seen: SystemTime::now(),
159            sender,
160            subscriptions: Arc::new(RwLock::new(HashMap::new())),
161            auth_context,
162            remote_addr,
163            egress_tracker: std::sync::Mutex::new(EgressTracker::new()),
164            message_rate_tracker: std::sync::Mutex::new(MessageRateTracker::new()),
165        }
166    }
167
168    /// Record bytes sent, returning true if within limit
169    pub fn record_egress(&self, bytes: usize) -> Option<u64> {
170        if let Ok(mut tracker) = self.egress_tracker.lock() {
171            if let Some(ref ctx) = self.auth_context {
172                if let Some(limit) = ctx.limits.max_bytes_per_minute {
173                    if tracker.record_bytes(bytes, limit) {
174                        return Some(tracker.current_usage());
175                    } else {
176                        return None; // Limit exceeded
177                    }
178                }
179            }
180            // No limit set, return current usage (0)
181            return Some(tracker.current_usage());
182        }
183        None
184    }
185
186    /// Record an inbound client message, returning true if within limit.
187    pub fn record_inbound_message(&self) -> Option<u32> {
188        if let Ok(mut tracker) = self.message_rate_tracker.lock() {
189            if let Some(ref ctx) = self.auth_context {
190                if let Some(limit) = ctx.limits.max_messages_per_minute {
191                    if tracker.record_message(limit) {
192                        return Some(tracker.current_usage());
193                    } else {
194                        return None;
195                    }
196                }
197            }
198
199            return Some(tracker.current_usage());
200        }
201
202        None
203    }
204
205    pub fn update_last_seen(&mut self) {
206        self.last_seen = SystemTime::now();
207    }
208
209    pub fn is_stale(&self, timeout: Duration) -> bool {
210        self.last_seen.elapsed().unwrap_or(Duration::MAX) > timeout
211    }
212
213    pub async fn add_subscription(&self, sub_key: String, token: CancellationToken) -> bool {
214        let mut subs = self.subscriptions.write().await;
215        if let Some(old_token) = subs.insert(sub_key.clone(), token) {
216            old_token.cancel();
217            debug!("Replaced existing subscription: {}", sub_key);
218            false
219        } else {
220            true
221        }
222    }
223
224    pub async fn remove_subscription(&self, sub_key: &str) -> bool {
225        let mut subs = self.subscriptions.write().await;
226        if let Some(token) = subs.remove(sub_key) {
227            token.cancel();
228            debug!("Cancelled subscription: {}", sub_key);
229            true
230        } else {
231            debug!("Subscription not found for cancellation: {}", sub_key);
232            false
233        }
234    }
235
236    pub async fn cancel_all_subscriptions(&self) {
237        let subs = self.subscriptions.read().await;
238        for (sub_key, token) in subs.iter() {
239            token.cancel();
240            debug!("Cancelled subscription on disconnect: {}", sub_key);
241        }
242    }
243
244    pub async fn subscription_count(&self) -> usize {
245        self.subscriptions.read().await.len()
246    }
247}
248
249/// Configuration for rate limiting in ClientManager
250///
251/// These settings control various rate limits at the connection level.
252/// Per-subject limits are controlled via AuthContext.Limits.
253#[derive(Debug, Clone)]
254pub struct RateLimitConfig {
255    /// Global maximum connections per IP address
256    pub max_connections_per_ip: Option<usize>,
257    /// Global maximum connections per metering key
258    pub max_connections_per_metering_key: Option<usize>,
259    /// Global maximum connections per origin
260    pub max_connections_per_origin: Option<usize>,
261    /// Default connection timeout for stale client cleanup
262    pub client_timeout: Duration,
263    /// Message queue size per client
264    pub message_queue_size: usize,
265    /// Maximum reconnect attempts per client (optional global default)
266    pub max_reconnect_attempts: Option<u32>,
267    /// Rate limit window duration for message counting
268    pub message_rate_window: Duration,
269    /// Rate limit window duration for egress tracking
270    pub egress_rate_window: Duration,
271    /// Default limits applied when auth token doesn't specify limits
272    /// These act as server-wide fallback limits for all connections
273    pub default_limits: Option<Limits>,
274}
275
276impl Default for RateLimitConfig {
277    fn default() -> Self {
278        Self {
279            max_connections_per_ip: None,
280            max_connections_per_metering_key: None,
281            max_connections_per_origin: None,
282            client_timeout: Duration::from_secs(300),
283            message_queue_size: 512,
284            max_reconnect_attempts: None,
285            message_rate_window: Duration::from_secs(60),
286            egress_rate_window: Duration::from_secs(60),
287            default_limits: None,
288        }
289    }
290}
291
292impl RateLimitConfig {
293    /// Load configuration from environment variables
294    ///
295    /// Environment variables:
296    /// - `HYPERSTACK_WS_MAX_CONNECTIONS_PER_IP` - Max connections per IP (default: unlimited)
297    /// - `HYPERSTACK_WS_MAX_CONNECTIONS_PER_METERING_KEY` - Max connections per metering key (default: unlimited)
298    /// - `HYPERSTACK_WS_MAX_CONNECTIONS_PER_ORIGIN` - Max connections per origin (default: unlimited)
299    /// - `HYPERSTACK_WS_CLIENT_TIMEOUT_SECS` - Client timeout in seconds (default: 300)
300    /// - `HYPERSTACK_WS_MESSAGE_QUEUE_SIZE` - Message queue size per client (default: 512)
301    /// - `HYPERSTACK_WS_RATE_LIMIT_WINDOW_SECS` - Rate limit window in seconds (default: 60)
302    /// - `HYPERSTACK_WS_DEFAULT_MAX_CONNECTIONS` - Default max connections per subject (fallback when token has no limit)
303    /// - `HYPERSTACK_WS_DEFAULT_MAX_SUBSCRIPTIONS` - Default max subscriptions per connection (fallback when token has no limit)
304    /// - `HYPERSTACK_WS_DEFAULT_MAX_SNAPSHOT_ROWS` - Default max snapshot rows per request (fallback when token has no limit)
305    /// - `HYPERSTACK_WS_DEFAULT_MAX_MESSAGES_PER_MINUTE` - Default max messages per minute (fallback when token has no limit)
306    /// - `HYPERSTACK_WS_DEFAULT_MAX_BYTES_PER_MINUTE` - Default max bytes per minute (fallback when token has no limit)
307    pub fn from_env() -> Self {
308        let mut config = Self::default();
309
310        if let Ok(val) = std::env::var("HYPERSTACK_WS_MAX_CONNECTIONS_PER_IP") {
311            if let Ok(max) = val.parse() {
312                config.max_connections_per_ip = Some(max);
313            }
314        }
315
316        if let Ok(val) = std::env::var("HYPERSTACK_WS_MAX_CONNECTIONS_PER_METERING_KEY") {
317            if let Ok(max) = val.parse() {
318                config.max_connections_per_metering_key = Some(max);
319            }
320        }
321
322        if let Ok(val) = std::env::var("HYPERSTACK_WS_MAX_CONNECTIONS_PER_ORIGIN") {
323            if let Ok(max) = val.parse() {
324                config.max_connections_per_origin = Some(max);
325            }
326        }
327
328        if let Ok(val) = std::env::var("HYPERSTACK_WS_CLIENT_TIMEOUT_SECS") {
329            if let Ok(secs) = val.parse() {
330                config.client_timeout = Duration::from_secs(secs);
331            }
332        }
333
334        if let Ok(val) = std::env::var("HYPERSTACK_WS_MESSAGE_QUEUE_SIZE") {
335            if let Ok(size) = val.parse() {
336                config.message_queue_size = size;
337            }
338        }
339
340        if let Ok(val) = std::env::var("HYPERSTACK_WS_RATE_LIMIT_WINDOW_SECS") {
341            if let Ok(secs) = val.parse() {
342                config.message_rate_window = Duration::from_secs(secs);
343                config.egress_rate_window = Duration::from_secs(secs);
344            }
345        }
346
347        // Load default limits from environment (fallback when auth token doesn't specify limits)
348        let mut default_limits = Limits::default();
349        let mut has_default_limits = false;
350
351        if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_CONNECTIONS") {
352            if let Ok(max) = val.parse() {
353                default_limits.max_connections = Some(max);
354                has_default_limits = true;
355            }
356        }
357
358        if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_SUBSCRIPTIONS") {
359            if let Ok(max) = val.parse() {
360                default_limits.max_subscriptions = Some(max);
361                has_default_limits = true;
362            }
363        }
364
365        if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_SNAPSHOT_ROWS") {
366            if let Ok(max) = val.parse() {
367                default_limits.max_snapshot_rows = Some(max);
368                has_default_limits = true;
369            }
370        }
371
372        if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_MESSAGES_PER_MINUTE") {
373            if let Ok(max) = val.parse() {
374                default_limits.max_messages_per_minute = Some(max);
375                has_default_limits = true;
376            }
377        }
378
379        if let Ok(val) = std::env::var("HYPERSTACK_WS_DEFAULT_MAX_BYTES_PER_MINUTE") {
380            if let Ok(max) = val.parse() {
381                default_limits.max_bytes_per_minute = Some(max);
382                has_default_limits = true;
383            }
384        }
385
386        if has_default_limits {
387            config.default_limits = Some(default_limits);
388        }
389
390        config
391    }
392
393    /// Set the maximum connections per IP
394    pub fn with_max_connections_per_ip(mut self, max: usize) -> Self {
395        self.max_connections_per_ip = Some(max);
396        self
397    }
398
399    /// Set the client timeout
400    pub fn with_timeout(mut self, timeout: Duration) -> Self {
401        self.client_timeout = timeout;
402        self
403    }
404
405    /// Set the message queue size
406    pub fn with_message_queue_size(mut self, size: usize) -> Self {
407        self.message_queue_size = size;
408        self
409    }
410
411    /// Set the rate limit window (applies to both message and egress windows)
412    pub fn with_rate_limit_window(mut self, window: Duration) -> Self {
413        self.message_rate_window = window;
414        self.egress_rate_window = window;
415        self
416    }
417
418    /// Set default limits applied when auth token doesn't specify limits
419    ///
420    /// These limits act as server-wide fallbacks for connections
421    /// where the authentication token doesn't include explicit limits.
422    pub fn with_default_limits(mut self, limits: Limits) -> Self {
423        self.default_limits = Some(limits);
424        self
425    }
426}
427
428/// Manages all connected WebSocket clients using lock-free DashMap.
429///
430/// Key design decisions:
431/// - Uses DashMap for lock-free concurrent access to client registry
432/// - Uses try_send instead of send to never block on slow clients
433/// - Disconnects clients that are backpressured (queue full) to prevent cascade failures
434/// - All public methods are non-blocking or use fine-grained per-key locks
435/// - Supports configurable rate limiting per IP, subject, and global defaults
436#[derive(Clone)]
437pub struct ClientManager {
438    clients: Arc<DashMap<Uuid, ClientInfo>>,
439    rate_limit_config: RateLimitConfig,
440    /// Optional WebSocket rate limiter for granular rate control
441    rate_limiter: Option<Arc<WebSocketRateLimiter>>,
442}
443
444impl ClientManager {
445    pub fn new() -> Self {
446        Self::with_config(RateLimitConfig::default())
447    }
448
449    /// Create a new ClientManager with the given rate limit configuration
450    pub fn with_config(config: RateLimitConfig) -> Self {
451        Self {
452            clients: Arc::new(DashMap::new()),
453            rate_limit_config: config,
454            rate_limiter: None,
455        }
456    }
457
458    /// Load configuration from environment variables
459    ///
460    /// See `RateLimitConfig::from_env` for supported variables.
461    pub fn from_env() -> Self {
462        Self::with_config(RateLimitConfig::from_env())
463    }
464
465    /// Set the client timeout for stale client cleanup
466    pub fn with_timeout(mut self, timeout: Duration) -> Self {
467        self.rate_limit_config.client_timeout = timeout;
468        self
469    }
470
471    /// Set the message queue size per client
472    pub fn with_message_queue_size(mut self, queue_size: usize) -> Self {
473        self.rate_limit_config.message_queue_size = queue_size;
474        self
475    }
476
477    /// Set a global limit on connections per IP address
478    pub fn with_max_connections_per_ip(mut self, max: usize) -> Self {
479        self.rate_limit_config.max_connections_per_ip = Some(max);
480        self
481    }
482
483    /// Set the rate limit window duration
484    pub fn with_rate_limit_window(mut self, window: Duration) -> Self {
485        self.rate_limit_config.message_rate_window = window;
486        self.rate_limit_config.egress_rate_window = window;
487        self
488    }
489
490    /// Set default limits applied when auth token doesn't specify limits
491    ///
492    /// These limits act as server-wide fallbacks for connections
493    /// where the authentication token doesn't include explicit limits.
494    pub fn with_default_limits(mut self, limits: Limits) -> Self {
495        self.rate_limit_config.default_limits = Some(limits);
496        self
497    }
498
499    /// Set a WebSocket rate limiter for granular rate control
500    pub fn with_rate_limiter(mut self, rate_limiter: Arc<WebSocketRateLimiter>) -> Self {
501        self.rate_limiter = Some(rate_limiter);
502        self
503    }
504
505    /// Get the rate limiter if configured
506    pub fn rate_limiter(&self) -> Option<&WebSocketRateLimiter> {
507        self.rate_limiter.as_ref().map(|r| r.as_ref())
508    }
509
510    /// Get the current rate limit configuration
511    pub fn rate_limit_config(&self) -> &RateLimitConfig {
512        &self.rate_limit_config
513    }
514
515    /// Add a new client connection.
516    ///
517    /// Spawns a dedicated sender task for this client that reads from its mpsc channel
518    /// and writes to the WebSocket. If the WebSocket write fails, the client is automatically
519    /// removed from the registry.
520    pub fn add_client(
521        &self,
522        client_id: Uuid,
523        mut ws_sender: WebSocketSender,
524        auth_context: Option<AuthContext>,
525        remote_addr: SocketAddr,
526    ) {
527        let (client_tx, mut client_rx) =
528            mpsc::channel::<Message>(self.rate_limit_config.message_queue_size);
529        let client_info = ClientInfo::new(client_id, client_tx, auth_context, remote_addr);
530
531        let clients_ref = self.clients.clone();
532        tokio::spawn(async move {
533            while let Some(message) = client_rx.recv().await {
534                if let Err(e) = ws_sender.send(message).await {
535                    warn!("Failed to send message to client {}: {}", client_id, e);
536                    break;
537                }
538            }
539            clients_ref.remove(&client_id);
540            debug!("WebSocket sender task for client {} stopped", client_id);
541        });
542
543        self.clients.insert(client_id, client_info);
544        info!("Client {} registered from {}", client_id, remote_addr);
545    }
546
547    /// Remove a client from the registry.
548    pub fn remove_client(&self, client_id: Uuid) {
549        if self.clients.remove(&client_id).is_some() {
550            info!("Client {} removed", client_id);
551        }
552    }
553
554    /// Update the auth context for a client.
555    ///
556    /// Used for in-band auth refresh without reconnecting.
557    pub fn update_client_auth(&self, client_id: Uuid, auth_context: AuthContext) -> bool {
558        if let Some(mut client) = self.clients.get_mut(&client_id) {
559            client.auth_context = Some(auth_context);
560            debug!("Updated auth context for client {}", client_id);
561            true
562        } else {
563            false
564        }
565    }
566
567    /// Check if a client's token has expired.
568    ///
569    /// Returns true if the client has an auth context and it has expired.
570    /// If expired, the client is removed from the registry.
571    pub fn check_and_remove_expired(&self, client_id: Uuid) -> bool {
572        if let Some(client) = self.clients.get(&client_id) {
573            if let Some(ref ctx) = client.auth_context {
574                let now = std::time::SystemTime::now()
575                    .duration_since(std::time::UNIX_EPOCH)
576                    .unwrap_or_default()
577                    .as_secs();
578                if ctx.expires_at <= now {
579                    warn!(
580                        "Client {} token expired (expired at {}), disconnecting",
581                        client_id, ctx.expires_at
582                    );
583                    self.clients.remove(&client_id);
584                    return true;
585                }
586            }
587        }
588        false
589    }
590
591    /// Get the current number of connected clients.
592    ///
593    /// This is lock-free and returns an approximate count (may be slightly stale
594    /// under high concurrency, which is fine for max_clients checks).
595    pub fn client_count(&self) -> usize {
596        self.clients.len()
597    }
598
599    /// Send data to a specific client (non-blocking).
600    ///
601    /// This method NEVER blocks. If the client's queue is full, the client is
602    /// considered too slow and is disconnected to prevent cascade failures.
603    /// Use this for live streaming updates.
604    ///
605    /// For initial snapshots where you expect to send many messages at once,
606    /// use `send_to_client_async` instead which will wait for queue space.
607    pub fn send_to_client(&self, client_id: Uuid, data: Arc<Bytes>) -> Result<(), SendError> {
608        // Check if client token has expired before sending
609        if self.check_and_remove_expired(client_id) {
610            return Err(SendError::ClientDisconnected);
611        }
612
613        // Check egress limits
614        if let Some(client) = self.clients.get(&client_id) {
615            if client.record_egress(data.len()).is_none() {
616                warn!("Client {} exceeded egress limit, disconnecting", client_id);
617                self.clients.remove(&client_id);
618                return Err(SendError::ClientDisconnected);
619            }
620        } else {
621            return Err(SendError::ClientNotFound);
622        }
623
624        let sender = {
625            let client = self
626                .clients
627                .get(&client_id)
628                .ok_or(SendError::ClientNotFound)?;
629            client.sender.clone()
630        };
631
632        let msg = Message::Binary((*data).clone());
633        match sender.try_send(msg) {
634            Ok(()) => Ok(()),
635            Err(mpsc::error::TrySendError::Full(_)) => {
636                warn!(
637                    "Client {} backpressured (queue full), disconnecting",
638                    client_id
639                );
640                self.clients.remove(&client_id);
641                Err(SendError::ClientBackpressured)
642            }
643            Err(mpsc::error::TrySendError::Closed(_)) => {
644                debug!("Client {} channel closed", client_id);
645                self.clients.remove(&client_id);
646                Err(SendError::ClientDisconnected)
647            }
648        }
649    }
650
651    /// Send data to a specific client (async, waits for queue space).
652    ///
653    /// This method will wait if the client's queue is full, allowing the client
654    /// time to catch up. Use this for initial snapshots where you need to send
655    /// many messages at once.
656    ///
657    /// For live streaming updates, use `send_to_client` instead which will
658    /// disconnect slow clients rather than blocking.
659    pub async fn send_to_client_async(
660        &self,
661        client_id: Uuid,
662        data: Arc<Bytes>,
663    ) -> Result<(), SendError> {
664        // Check if client token has expired before sending
665        if self.check_and_remove_expired(client_id) {
666            return Err(SendError::ClientDisconnected);
667        }
668
669        // Check egress limits
670        if let Some(client) = self.clients.get(&client_id) {
671            if client.record_egress(data.len()).is_none() {
672                warn!("Client {} exceeded egress limit, disconnecting", client_id);
673                self.clients.remove(&client_id);
674                return Err(SendError::ClientDisconnected);
675            }
676        } else {
677            return Err(SendError::ClientNotFound);
678        }
679
680        let sender = {
681            let client = self
682                .clients
683                .get(&client_id)
684                .ok_or(SendError::ClientNotFound)?;
685            client.sender.clone()
686        };
687
688        let msg = Message::Binary((*data).clone());
689        sender
690            .send(msg)
691            .await
692            .map_err(|_| SendError::ClientDisconnected)
693    }
694
695    /// Send a text message to a specific client (async).
696    ///
697    /// This method sends a text message directly to the client's WebSocket.
698    /// Used for control messages like auth refresh responses.
699    pub async fn send_text_to_client(
700        &self,
701        client_id: Uuid,
702        text: String,
703    ) -> Result<(), SendError> {
704        // Check if client token has expired before sending
705        if self.check_and_remove_expired(client_id) {
706            return Err(SendError::ClientDisconnected);
707        }
708
709        let sender = {
710            let client = self
711                .clients
712                .get(&client_id)
713                .ok_or(SendError::ClientNotFound)?;
714            client.sender.clone()
715        };
716
717        let msg = Message::Text(text.into());
718        sender
719            .send(msg)
720            .await
721            .map_err(|_| SendError::ClientDisconnected)
722    }
723
724    /// Send a potentially compressed payload to a client (async).
725    ///
726    /// Compressed payloads are sent as binary frames (raw gzip).
727    /// Uncompressed payloads are sent as text frames (JSON).
728    pub async fn send_compressed_async(
729        &self,
730        client_id: Uuid,
731        payload: CompressedPayload,
732    ) -> Result<(), SendError> {
733        // Check if client token has expired before sending
734        if self.check_and_remove_expired(client_id) {
735            return Err(SendError::ClientDisconnected);
736        }
737
738        let (sender, bytes_to_record) = {
739            let client = self
740                .clients
741                .get(&client_id)
742                .ok_or(SendError::ClientNotFound)?;
743
744            let bytes = match &payload {
745                CompressedPayload::Compressed(bytes) => bytes.len(),
746                CompressedPayload::Uncompressed(bytes) => bytes.len(),
747            };
748
749            (client.sender.clone(), bytes)
750        };
751
752        // Check egress limits
753        if let Some(client) = self.clients.get(&client_id) {
754            if client.record_egress(bytes_to_record).is_none() {
755                warn!("Client {} exceeded egress limit, disconnecting", client_id);
756                self.clients.remove(&client_id);
757                return Err(SendError::ClientDisconnected);
758            }
759        }
760
761        let msg = match payload {
762            CompressedPayload::Compressed(bytes) => Message::Binary(bytes),
763            CompressedPayload::Uncompressed(bytes) => Message::Binary(bytes),
764        };
765        sender
766            .send(msg)
767            .await
768            .map_err(|_| SendError::ClientDisconnected)
769    }
770
771    /// Update the subscription for a client.
772    pub fn update_subscription(&self, client_id: Uuid, subscription: Subscription) -> bool {
773        if let Some(mut client) = self.clients.get_mut(&client_id) {
774            client.subscription = Some(subscription);
775            client.update_last_seen();
776            debug!("Updated subscription for client {}", client_id);
777            true
778        } else {
779            warn!(
780                "Failed to update subscription for unknown client {}",
781                client_id
782            );
783            false
784        }
785    }
786
787    /// Update the last_seen timestamp for a client.
788    pub fn update_client_last_seen(&self, client_id: Uuid) {
789        if let Some(mut client) = self.clients.get_mut(&client_id) {
790            client.update_last_seen();
791        }
792    }
793
794    /// Check whether an inbound message is allowed for a client.
795    #[allow(clippy::result_large_err)]
796    pub fn check_inbound_message_allowed(&self, client_id: Uuid) -> Result<(), AuthDeny> {
797        if self.check_and_remove_expired(client_id) {
798            return Err(AuthDeny::new(
799                crate::websocket::auth::AuthErrorCode::TokenExpired,
800                "Authentication token expired",
801            ));
802        }
803
804        let Some(client) = self.clients.get(&client_id) else {
805            return Err(AuthDeny::new(
806                crate::websocket::auth::AuthErrorCode::InternalError,
807                "Client not found",
808            ));
809        };
810
811        if client.record_inbound_message().is_some() {
812            Ok(())
813        } else {
814            self.clients.remove(&client_id);
815            Err(AuthDeny::rate_limited(
816                self.rate_limit_config.message_rate_window,
817                "inbound websocket messages",
818            )
819            .with_context(format!(
820                "client {} exceeded the inbound message budget",
821                client_id
822            )))
823        }
824    }
825
826    /// Get the subscription for a client.
827    pub fn get_subscription(&self, client_id: Uuid) -> Option<Subscription> {
828        self.clients
829            .get(&client_id)
830            .and_then(|c| c.subscription.clone())
831    }
832
833    /// Check if a client exists.
834    pub fn has_client(&self, client_id: Uuid) -> bool {
835        self.clients.contains_key(&client_id)
836    }
837
838    pub async fn add_client_subscription(
839        &self,
840        client_id: Uuid,
841        sub_key: String,
842        token: CancellationToken,
843    ) -> bool {
844        if let Some(client) = self.clients.get(&client_id) {
845            client.add_subscription(sub_key, token).await
846        } else {
847            false
848        }
849    }
850
851    pub async fn remove_client_subscription(&self, client_id: Uuid, sub_key: &str) -> bool {
852        if let Some(client) = self.clients.get(&client_id) {
853            client.remove_subscription(sub_key).await
854        } else {
855            false
856        }
857    }
858
859    pub async fn cancel_all_client_subscriptions(&self, client_id: Uuid) {
860        if let Some(client) = self.clients.get(&client_id) {
861            client.cancel_all_subscriptions().await;
862        }
863    }
864
865    /// Remove stale clients that haven't been seen within the timeout period.
866    pub fn cleanup_stale_clients(&self) -> usize {
867        let timeout = self.rate_limit_config.client_timeout;
868        let mut stale_clients = Vec::new();
869
870        for entry in self.clients.iter() {
871            if entry.value().is_stale(timeout) {
872                stale_clients.push(*entry.key());
873            }
874        }
875
876        let removed_count = stale_clients.len();
877        for client_id in stale_clients {
878            self.clients.remove(&client_id);
879            info!("Removed stale client {}", client_id);
880        }
881
882        removed_count
883    }
884
885    /// Start a background task that periodically cleans up stale clients.
886    pub fn start_cleanup_task(&self) {
887        let client_manager = self.clone();
888
889        tokio::spawn(async move {
890            let mut interval = tokio::time::interval(Duration::from_secs(30));
891
892            loop {
893                interval.tick().await;
894                let removed = client_manager.cleanup_stale_clients();
895                if removed > 0 {
896                    info!("Cleaned up {} stale clients", removed);
897                }
898            }
899        });
900    }
901
902    /// ENFORCEMENT HOOKS
903    ///
904    /// These methods provide hooks for enforcing limits based on auth context.
905    /// They check limits before allowing operations and return errors if limits are exceeded.
906    /// Check if a connection is allowed for the given auth context.
907    ///
908    /// Returns Ok(()) if the connection is allowed, or an error with a reason if not.
909    pub async fn check_connection_allowed(
910        &self,
911        remote_addr: SocketAddr,
912        auth_context: &Option<AuthContext>,
913    ) -> Result<(), AuthDeny> {
914        // Check rate limiter first if configured
915        if let Some(ref rate_limiter) = self.rate_limiter {
916            // Check handshake rate limit for IP
917            match rate_limiter.check_handshake(remote_addr).await {
918                RateLimitResult::Allowed { .. } => {}
919                RateLimitResult::Denied { retry_after, limit } => {
920                    return Err(AuthDeny::rate_limited(retry_after, "websocket handshakes")
921                        .with_context(format!(
922                            "handshake rate limit of {} per minute exceeded for {}",
923                            limit, remote_addr
924                        )));
925                }
926            }
927
928            // Check connection rate limit for subject
929            if let Some(ref ctx) = auth_context {
930                match rate_limiter
931                    .check_connection_for_subject(&ctx.subject)
932                    .await
933                {
934                    RateLimitResult::Allowed { .. } => {}
935                    RateLimitResult::Denied { retry_after, limit } => {
936                        return Err(AuthDeny::rate_limited(retry_after, "websocket connections")
937                            .with_context(format!(
938                                "connection rate limit for subject {} of {} per minute exceeded",
939                                ctx.subject, limit
940                            )));
941                    }
942                }
943
944                // Check connection rate limit for metering key
945                match rate_limiter
946                    .check_connection_for_metering_key(&ctx.metering_key)
947                    .await
948                {
949                    RateLimitResult::Allowed { .. } => {}
950                    RateLimitResult::Denied { retry_after, limit } => {
951                        return Err(AuthDeny::rate_limited(
952                            retry_after,
953                            "metered websocket connections",
954                        )
955                        .with_context(format!(
956                            "connection rate limit for metering key {} of {} per minute exceeded",
957                            ctx.metering_key, limit
958                        )));
959                    }
960                }
961            }
962        }
963
964        // Check global per-IP connection limit
965        if let Some(max_per_ip) = self.rate_limit_config.max_connections_per_ip {
966            let current_ip_connections = self.count_connections_for_ip(&remote_addr);
967            if current_ip_connections >= max_per_ip {
968                return Err(AuthDeny::connection_limit_exceeded(
969                    &format!("ip {}", remote_addr.ip()),
970                    current_ip_connections,
971                    max_per_ip,
972                ));
973            }
974        }
975
976        if let Some(ctx) = auth_context {
977            // Check max connections per subject (use token limits, fallback to default limits)
978            let max_connections = ctx.limits.max_connections.or_else(|| {
979                self.rate_limit_config
980                    .default_limits
981                    .as_ref()
982                    .and_then(|l| l.max_connections)
983            });
984            if let Some(max_connections) = max_connections {
985                let current_connections = self.count_connections_for_subject(&ctx.subject);
986                if current_connections >= max_connections as usize {
987                    return Err(AuthDeny::connection_limit_exceeded(
988                        &format!("subject {}", ctx.subject),
989                        current_connections,
990                        max_connections as usize,
991                    ));
992                }
993            }
994
995            // Check global max connections per metering key
996            if let Some(max_per_metering_key) =
997                self.rate_limit_config.max_connections_per_metering_key
998            {
999                let current_metering_connections =
1000                    self.count_connections_for_metering_key(&ctx.metering_key);
1001                if current_metering_connections >= max_per_metering_key {
1002                    return Err(AuthDeny::connection_limit_exceeded(
1003                        &format!("metering key {}", ctx.metering_key),
1004                        current_metering_connections,
1005                        max_per_metering_key,
1006                    ));
1007                }
1008            }
1009
1010            // Check global max connections per origin
1011            if let Some(max_per_origin) = self.rate_limit_config.max_connections_per_origin {
1012                if let Some(ref origin) = ctx.origin {
1013                    let current_origin_connections = self.count_connections_for_origin(origin);
1014                    if current_origin_connections >= max_per_origin {
1015                        return Err(AuthDeny::connection_limit_exceeded(
1016                            &format!("origin {}", origin),
1017                            current_origin_connections,
1018                            max_per_origin,
1019                        ));
1020                    }
1021                }
1022            }
1023        }
1024        Ok(())
1025    }
1026
1027    /// Count connections from a specific IP address
1028    fn count_connections_for_ip(&self, remote_addr: &SocketAddr) -> usize {
1029        let ip = remote_addr.ip();
1030        self.clients
1031            .iter()
1032            .filter(|entry| entry.value().remote_addr.ip() == ip)
1033            .count()
1034    }
1035
1036    /// Count connections for a specific subject
1037    fn count_connections_for_subject(&self, subject: &str) -> usize {
1038        self.clients
1039            .iter()
1040            .filter(|entry| {
1041                entry
1042                    .value()
1043                    .auth_context
1044                    .as_ref()
1045                    .map(|ctx| ctx.subject == subject)
1046                    .unwrap_or(false)
1047            })
1048            .count()
1049    }
1050
1051    /// Count connections for a specific metering key
1052    fn count_connections_for_metering_key(&self, metering_key: &str) -> usize {
1053        self.clients
1054            .iter()
1055            .filter(|entry| {
1056                entry
1057                    .value()
1058                    .auth_context
1059                    .as_ref()
1060                    .map(|ctx| ctx.metering_key == metering_key)
1061                    .unwrap_or(false)
1062            })
1063            .count()
1064    }
1065
1066    /// Count connections for a specific origin
1067    fn count_connections_for_origin(&self, origin: &str) -> usize {
1068        self.clients
1069            .iter()
1070            .filter(|entry| {
1071                entry
1072                    .value()
1073                    .auth_context
1074                    .as_ref()
1075                    .and_then(|ctx| ctx.origin.as_ref())
1076                    .map(|o| o == origin)
1077                    .unwrap_or(false)
1078            })
1079            .count()
1080    }
1081
1082    /// Check if a subscription is allowed for the given client.
1083    ///
1084    /// Returns Ok(()) if the subscription is allowed, or an error with a reason if not.
1085    pub async fn check_subscription_allowed(&self, client_id: Uuid) -> Result<(), AuthDeny> {
1086        if let Some(client) = self.clients.get(&client_id) {
1087            let current_subs = client.subscription_count().await;
1088
1089            // Check max subscriptions per connection (use token limits, fallback to default limits)
1090            if let Some(ref ctx) = client.auth_context {
1091                let max_subs = ctx.limits.max_subscriptions.or_else(|| {
1092                    self.rate_limit_config
1093                        .default_limits
1094                        .as_ref()
1095                        .and_then(|l| l.max_subscriptions)
1096                });
1097                if let Some(max_subs) = max_subs {
1098                    if current_subs >= max_subs as usize {
1099                        return Err(AuthDeny::new(
1100                            crate::websocket::auth::AuthErrorCode::SubscriptionLimitExceeded,
1101                            format!(
1102                                "Subscription limit exceeded: {} of {} subscriptions for client {}",
1103                                current_subs, max_subs, client_id
1104                            ),
1105                        )
1106                        .with_suggested_action(
1107                            "Unsubscribe from an existing view before creating another subscription",
1108                        ));
1109                    }
1110                }
1111            }
1112        }
1113        Ok(())
1114    }
1115
1116    /// Get metering key for a client
1117    pub fn get_metering_key(&self, client_id: Uuid) -> Option<String> {
1118        self.clients.get(&client_id).and_then(|client| {
1119            client
1120                .auth_context
1121                .as_ref()
1122                .map(|ctx| ctx.metering_key.clone())
1123        })
1124    }
1125
1126    /// Get auth context for a client.
1127    pub fn get_auth_context(&self, client_id: Uuid) -> Option<AuthContext> {
1128        self.clients
1129            .get(&client_id)
1130            .and_then(|client| client.auth_context.clone())
1131    }
1132
1133    /// Check if a snapshot request is allowed (based on max_snapshot_rows limit)
1134    ///
1135    /// Uses token limits if available, falls back to default limits from RateLimitConfig.
1136    #[allow(clippy::result_large_err)]
1137    pub fn check_snapshot_allowed(
1138        &self,
1139        client_id: Uuid,
1140        requested_rows: u32,
1141    ) -> Result<(), AuthDeny> {
1142        if let Some(client) = self.clients.get(&client_id) {
1143            if let Some(ref ctx) = client.auth_context {
1144                let max_rows = ctx.limits.max_snapshot_rows.or_else(|| {
1145                    self.rate_limit_config
1146                        .default_limits
1147                        .as_ref()
1148                        .and_then(|l| l.max_snapshot_rows)
1149                });
1150                if let Some(max_rows) = max_rows {
1151                    if requested_rows > max_rows {
1152                        return Err(AuthDeny::new(
1153                            crate::websocket::auth::AuthErrorCode::SnapshotLimitExceeded,
1154                            format!(
1155                                "Snapshot limit exceeded: requested {} rows, max allowed is {} for client {}",
1156                                requested_rows, max_rows, client_id
1157                            ),
1158                        )
1159                        .with_suggested_action(
1160                            "Request fewer rows or lower the snapshotLimit on the subscription",
1161                        ));
1162                    }
1163                }
1164            }
1165        }
1166        Ok(())
1167    }
1168}
1169
1170impl Default for ClientManager {
1171    fn default() -> Self {
1172        Self::new()
1173    }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178    use super::*;
1179    use crate::websocket::auth::AuthContext;
1180    use hyperstack_auth::{KeyClass, Limits};
1181    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
1182
1183    fn create_test_auth_context(subject: &str, limits: Limits) -> AuthContext {
1184        AuthContext {
1185            subject: subject.to_string(),
1186            issuer: "test-issuer".to_string(),
1187            key_class: KeyClass::Publishable,
1188            metering_key: format!("meter-{}", subject),
1189            deployment_id: None,
1190            expires_at: u64::MAX,
1191            scope: "read".to_string(),
1192            limits,
1193            plan: None,
1194            origin: None,
1195            client_ip: None,
1196            jti: uuid::Uuid::new_v4().to_string(),
1197        }
1198    }
1199
1200    fn create_test_socket_addr(ip: &str) -> SocketAddr {
1201        SocketAddr::new(
1202            ip.parse::<IpAddr>()
1203                .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
1204            12345,
1205        )
1206    }
1207
1208    #[test]
1209    fn test_egress_tracker_basic() {
1210        let mut tracker = EgressTracker::new();
1211
1212        // Should allow bytes within limit
1213        assert!(tracker.record_bytes(500, 1000));
1214        assert_eq!(tracker.current_usage(), 500);
1215
1216        // Should allow more bytes within limit
1217        assert!(tracker.record_bytes(400, 1000));
1218        assert_eq!(tracker.current_usage(), 900);
1219
1220        // Should reject bytes over limit
1221        assert!(!tracker.record_bytes(200, 1000));
1222        assert_eq!(tracker.current_usage(), 900); // Usage shouldn't increase
1223    }
1224
1225    #[test]
1226    fn test_egress_tracker_window_reset() {
1227        let mut tracker = EgressTracker::new();
1228
1229        // Use up the limit
1230        assert!(tracker.record_bytes(100, 100));
1231        assert!(!tracker.record_bytes(1, 100));
1232
1233        // Reset the window
1234        tracker.bytes_this_minute = 0;
1235        tracker.window_start = SystemTime::now() - Duration::from_secs(61);
1236
1237        // Should allow after window reset
1238        assert!(tracker.record_bytes(50, 100));
1239    }
1240
1241    #[test]
1242    fn test_message_rate_tracker_basic() {
1243        let mut tracker = MessageRateTracker::new();
1244
1245        assert!(tracker.record_message(2));
1246        assert_eq!(tracker.current_usage(), 1);
1247
1248        assert!(tracker.record_message(2));
1249        assert_eq!(tracker.current_usage(), 2);
1250
1251        assert!(!tracker.record_message(2));
1252        assert_eq!(tracker.current_usage(), 2);
1253    }
1254
1255    #[tokio::test]
1256    async fn test_client_inbound_message_limit() {
1257        let (tx, _rx) = mpsc::channel(1);
1258        let client = ClientInfo::new(
1259            Uuid::new_v4(),
1260            tx,
1261            Some(create_test_auth_context(
1262                "user-1",
1263                Limits {
1264                    max_messages_per_minute: Some(2),
1265                    ..Default::default()
1266                },
1267            )),
1268            create_test_socket_addr("127.0.0.1"),
1269        );
1270
1271        assert_eq!(client.record_inbound_message(), Some(1));
1272        assert_eq!(client.record_inbound_message(), Some(2));
1273        assert_eq!(client.record_inbound_message(), None);
1274    }
1275
1276    #[tokio::test]
1277    async fn test_no_limits() {
1278        let manager = ClientManager::new();
1279        let addr = create_test_socket_addr("127.0.0.1");
1280
1281        // No auth context - should succeed
1282        assert!(manager.check_connection_allowed(addr, &None).await.is_ok());
1283
1284        // Auth context with no limits - should succeed
1285        let auth_context = create_test_auth_context("test", Limits::default());
1286        assert!(manager
1287            .check_connection_allowed(addr, &Some(auth_context))
1288            .await
1289            .is_ok());
1290    }
1291
1292    #[tokio::test]
1293    async fn test_per_subject_connection_limit() {
1294        let manager = ClientManager::new();
1295
1296        let limits = Limits {
1297            max_connections: Some(2),
1298            ..Default::default()
1299        };
1300
1301        let auth_context = create_test_auth_context("user-1", limits);
1302        let addr = create_test_socket_addr("127.0.0.1");
1303
1304        // First connection should succeed (no clients yet)
1305        assert!(manager
1306            .check_connection_allowed(addr, &Some(auth_context.clone()))
1307            .await
1308            .is_ok());
1309    }
1310
1311    #[tokio::test]
1312    async fn test_per_ip_connection_limit() {
1313        let manager = ClientManager::new().with_max_connections_per_ip(2);
1314        let addr = create_test_socket_addr("192.168.1.1");
1315
1316        // Should succeed when no connections from that IP
1317        assert!(manager.check_connection_allowed(addr, &None).await.is_ok());
1318    }
1319
1320    // Tests for RateLimitConfig
1321    #[test]
1322    fn rate_limit_config_default() {
1323        let config = RateLimitConfig::default();
1324        assert!(config.max_connections_per_ip.is_none());
1325        assert_eq!(config.client_timeout, Duration::from_secs(300));
1326        assert_eq!(config.message_queue_size, 512);
1327        assert!(config.max_reconnect_attempts.is_none());
1328        assert_eq!(config.message_rate_window, Duration::from_secs(60));
1329        assert_eq!(config.egress_rate_window, Duration::from_secs(60));
1330    }
1331
1332    #[test]
1333    fn rate_limit_config_builder_methods() {
1334        let config = RateLimitConfig::default()
1335            .with_max_connections_per_ip(10)
1336            .with_timeout(Duration::from_secs(600))
1337            .with_message_queue_size(1024)
1338            .with_rate_limit_window(Duration::from_secs(120));
1339
1340        assert_eq!(config.max_connections_per_ip, Some(10));
1341        assert_eq!(config.client_timeout, Duration::from_secs(600));
1342        assert_eq!(config.message_queue_size, 1024);
1343        assert_eq!(config.message_rate_window, Duration::from_secs(120));
1344        assert_eq!(config.egress_rate_window, Duration::from_secs(120));
1345    }
1346
1347    #[tokio::test]
1348    async fn client_manager_with_config() {
1349        let config = RateLimitConfig::default()
1350            .with_max_connections_per_ip(5)
1351            .with_timeout(Duration::from_secs(120))
1352            .with_message_queue_size(256);
1353
1354        let manager = ClientManager::with_config(config);
1355        let addr = create_test_socket_addr("10.0.0.1");
1356
1357        // Check that the configuration was applied
1358        assert_eq!(manager.rate_limit_config().max_connections_per_ip, Some(5));
1359        assert_eq!(
1360            manager.rate_limit_config().client_timeout,
1361            Duration::from_secs(120)
1362        );
1363        assert_eq!(manager.rate_limit_config().message_queue_size, 256);
1364
1365        // Should allow when under limit
1366        assert!(manager.check_connection_allowed(addr, &None).await.is_ok());
1367    }
1368
1369    #[tokio::test]
1370    async fn client_manager_builder_pattern() {
1371        let manager = ClientManager::new()
1372            .with_max_connections_per_ip(10)
1373            .with_timeout(Duration::from_secs(180))
1374            .with_message_queue_size(1024)
1375            .with_rate_limit_window(Duration::from_secs(90));
1376
1377        assert_eq!(manager.rate_limit_config().max_connections_per_ip, Some(10));
1378        assert_eq!(
1379            manager.rate_limit_config().client_timeout,
1380            Duration::from_secs(180)
1381        );
1382        assert_eq!(manager.rate_limit_config().message_queue_size, 1024);
1383        assert_eq!(
1384            manager.rate_limit_config().message_rate_window,
1385            Duration::from_secs(90)
1386        );
1387    }
1388
1389    // Integration test: Connection limits are enforced
1390    #[tokio::test]
1391    async fn connection_limit_enforcement_with_actual_clients() {
1392        let manager = ClientManager::new().with_max_connections_per_ip(2);
1393        let addr1 = create_test_socket_addr("192.168.1.1");
1394        let addr2 = create_test_socket_addr("192.168.1.2");
1395
1396        // First connection from IP1 should succeed
1397        let auth1 = create_test_auth_context("user-1", Limits::default());
1398        assert!(manager
1399            .check_connection_allowed(addr1, &Some(auth1.clone()))
1400            .await
1401            .is_ok());
1402
1403        // Simulate adding a client (we can't easily do this without a real WebSocket,
1404        // but we can verify the check logic works)
1405
1406        // Same IP, different auth context - should still count toward IP limit
1407        let auth2 = create_test_auth_context("user-2", Limits::default());
1408        assert!(manager
1409            .check_connection_allowed(addr1, &Some(auth2.clone()))
1410            .await
1411            .is_ok());
1412
1413        // Different IP - should succeed regardless
1414        let auth3 = create_test_auth_context("user-3", Limits::default());
1415        assert!(manager
1416            .check_connection_allowed(addr2, &Some(auth3.clone()))
1417            .await
1418            .is_ok());
1419    }
1420
1421    // Test subscription limit enforcement
1422    #[tokio::test]
1423    async fn subscription_limit_enforcement() {
1424        let manager = ClientManager::new();
1425        let addr = create_test_socket_addr("127.0.0.1");
1426
1427        // Create auth context with subscription limit
1428        let auth = create_test_auth_context(
1429            "user-1",
1430            Limits {
1431                max_subscriptions: Some(2),
1432                ..Default::default()
1433            },
1434        );
1435
1436        // Check should pass initially
1437        assert!(manager
1438            .check_connection_allowed(addr, &Some(auth.clone()))
1439            .await
1440            .is_ok());
1441
1442        // Note: We can't easily test the full subscription flow without a real connection,
1443        // but we verify the limit configuration is properly stored
1444        assert_eq!(auth.limits.max_subscriptions, Some(2));
1445    }
1446
1447    // Test snapshot limit enforcement
1448    #[tokio::test]
1449    async fn snapshot_limit_enforcement() {
1450        let manager = ClientManager::new();
1451        let addr = create_test_socket_addr("127.0.0.1");
1452
1453        let auth = create_test_auth_context(
1454            "user-1",
1455            Limits {
1456                max_snapshot_rows: Some(1000),
1457                ..Default::default()
1458            },
1459        );
1460
1461        assert!(manager
1462            .check_connection_allowed(addr, &Some(auth.clone()))
1463            .await
1464            .is_ok());
1465
1466        // Note: Actual snapshot limit checking happens in check_snapshot_allowed
1467        // which requires a connected client
1468    }
1469
1470    // Test WebSocketRateLimiter integration
1471    #[tokio::test]
1472    async fn test_rate_limiter_integration() {
1473        use crate::websocket::rate_limiter::{RateLimiterConfig, WebSocketRateLimiter};
1474
1475        let rate_limiter = Arc::new(WebSocketRateLimiter::new(RateLimiterConfig::default()));
1476        let manager = ClientManager::new().with_rate_limiter(rate_limiter);
1477        let addr = create_test_socket_addr("127.0.0.1");
1478
1479        // Should allow connections when rate limiter is configured
1480        let auth = create_test_auth_context("user-1", Limits::default());
1481        assert!(manager
1482            .check_connection_allowed(addr, &Some(auth))
1483            .await
1484            .is_ok());
1485    }
1486}