1use std::collections::HashMap;
2use std::net::SocketAddr;
3use tokio::io::{AsyncRead, AsyncWrite};
4use tokio::sync::mpsc;
5use tokio_tungstenite::WebSocketStream;
6
7#[derive(Debug)]
8pub enum DisconnectReason {
9 Clean,
10 ForceClosed,
11 Error(String),
12}
13
14pub enum ServerMessage {
15 Text(String),
16 Disconnect,
17}
18
19pub enum ManagerCommand {
20 Shutdown,
21}
22
23pub enum ClientEvent<S = tokio::net::TcpStream>
24where
25 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
26{
27 Connected {
28 address: SocketAddr,
29 tx: mpsc::Sender<ServerMessage>,
30 ws: WebSocketStream<S>,
31 },
32 Message {
33 address: SocketAddr,
34 text: String,
35 },
36 Disconnected {
37 address: SocketAddr,
38 reason: DisconnectReason,
39 },
40}
41
42pub struct ClientHandle {
43 pub tx: mpsc::Sender<ServerMessage>,
44}
45
46#[derive(Debug)]
49pub enum LobbyEvent {
50 ClientConnected {
52 address: SocketAddr,
53 message_tx: mpsc::Sender<go_fish_web::ServerMessage>,
56 },
57 ClientMessage { address: SocketAddr, message: go_fish_web::ClientMessage },
58 ClientDisconnected { address: SocketAddr, reason: DisconnectReason },
59 Hook {
62 lobby_id: String,
63 player_name: String,
64 request: go_fish_web::ClientHookRequest,
65 },
66}
67
68use bytes::Bytes;
71use futures_util::{SinkExt, StreamExt};
72use std::time::Duration;
73use tokio_tungstenite::tungstenite::Message;
74use tracing::{debug, error, info, instrument, warn};
75
76#[instrument(skip(ws, event_tx, msg_rx), fields(%address))]
77pub async fn run_connection_handler<S, T>(
78 address: SocketAddr,
79 mut ws: WebSocketStream<S>,
80 event_tx: mpsc::Sender<ClientEvent<T>>,
81 mut msg_rx: mpsc::Receiver<ServerMessage>,
82) where
83 S: AsyncRead + AsyncWrite + Unpin,
84 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
85{
86 let mut ping_interval = tokio::time::interval(Duration::from_secs(45));
87 ping_interval.tick().await; let mut pending_pings: u32 = 0;
89 const MAX_PENDING_PINGS: u32 = 3;
90
91 loop {
92 tokio::select! {
93 frame = ws.next() => {
94 match frame {
95 Some(Ok(Message::Text(text))) => {
96 debug!(event = "client_message_received", %text);
97 if event_tx.send(ClientEvent::Message { address, text: text.to_string() }).await.is_err() {
98 break;
100 }
101 }
102 Some(Ok(Message::Close(_))) => {
103 let _ = event_tx.send(ClientEvent::Disconnected {
104 address,
105 reason: DisconnectReason::Clean,
106 }).await;
107 break;
108 }
109 Some(Ok(Message::Pong(_))) => {
110 pending_pings = 0;
111 }
112 Some(Ok(_)) => {
113 continue;
115 }
116 Some(Err(e)) => {
117 error!(event = "websocket_stream_error", error = %e);
118 let _ = event_tx.send(ClientEvent::Disconnected {
119 address,
120 reason: DisconnectReason::Error(e.to_string()),
121 }).await;
122 break;
123 }
124 None => {
125 warn!(event = "websocket_force_closed", %address);
127 let _ = event_tx.send(ClientEvent::Disconnected {
128 address,
129 reason: DisconnectReason::ForceClosed,
130 }).await;
131 break;
132 }
133 }
134 }
135 msg = msg_rx.recv() => {
136 match msg {
137 None => {
138 break;
140 }
141 Some(ServerMessage::Disconnect) => {
142 let _ = ws.send(Message::Close(None)).await;
143 break;
144 }
145 Some(ServerMessage::Text(t)) => {
146 if ws.send(Message::Text(t.into())).await.is_err() {
147 break;
148 }
149 }
150 }
151 }
152 _ = ping_interval.tick() => {
153 if pending_pings >= MAX_PENDING_PINGS {
154 info!(event = "client_ping_timeout", %address);
155 let _ = event_tx.send(ClientEvent::Disconnected {
156 address,
157 reason: DisconnectReason::Error("ping timeout".to_string()),
158 }).await;
159 break;
160 }
161 let _ = ws.send(Message::Ping(Bytes::new())).await;
162 pending_pings += 1;
163 }
164 }
165 }
166}
167
168pub struct ConnectionManager<S = tokio::net::TcpStream>
171where
172 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
173{
174 clients: HashMap<SocketAddr, ClientHandle>,
175 max_client_connections: usize,
176 event_rx: mpsc::Receiver<ClientEvent<S>>,
177 event_tx: mpsc::Sender<ClientEvent<S>>,
178 command_rx: mpsc::Receiver<ManagerCommand>,
179 command_tx: mpsc::Sender<ManagerCommand>,
180 lobby_tx: mpsc::Sender<LobbyEvent>,
181}
182
183impl<S> ConnectionManager<S>
184where
185 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
186{
187 pub fn new(
188 lobby_tx: mpsc::Sender<LobbyEvent>,
189 max_client_connections: usize,
190 ) -> Self {
191 let (event_tx, event_rx) = mpsc::channel::<ClientEvent<S>>(64);
192 let (command_tx, command_rx) = mpsc::channel::<ManagerCommand>(8);
193 ConnectionManager {
194 clients: HashMap::new(),
195 max_client_connections,
196 event_rx,
197 event_tx,
198 command_rx,
199 command_tx,
200 lobby_tx,
201 }
202 }
203
204 pub fn event_tx(&self) -> mpsc::Sender<ClientEvent<S>> {
205 self.event_tx.clone()
206 }
207
208 pub fn command_tx(&self) -> mpsc::Sender<ManagerCommand> {
209 self.command_tx.clone()
210 }
211
212 #[instrument(skip(self))]
213 pub async fn run(mut self) {
214 loop {
215 tokio::select! {
216 cmd = self.command_rx.recv() => {
217 match cmd {
218 Some(ManagerCommand::Shutdown) | None => {
219 for (address, handle) in &self.clients {
221 if handle.tx.send(ServerMessage::Disconnect).await.is_err() {
222 warn!(event = "shutdown_disconnect_failed", %address);
223 }
224 }
225 break;
226 }
227 }
228 }
229 event = self.event_rx.recv() => {
230 match event {
231 None => break,
232 Some(ClientEvent::Connected { address, tx, ws }) => {
233 if self.clients.len() >= self.max_client_connections {
234 let mut ws = ws;
235 ws.close(None).await.ok();
236 warn!(event = "client_rejected_max_connections", %address,
237 connections = self.clients.len(),
238 max_connections = self.max_client_connections);
239 continue;
240 }
241 let (handler_tx, handler_rx) = mpsc::channel::<ServerMessage>(32);
243 let (web_tx, mut web_rx) = mpsc::channel::<go_fish_web::ServerMessage>(64);
244 let serializer_tx = handler_tx.clone();
245 tokio::spawn(async move {
246 while let Some(msg) = web_rx.recv().await {
247 match serde_json::to_string(&msg) {
248 Ok(json) => {
249 if serializer_tx.send(ServerMessage::Text(json)).await.is_err() {
250 break;
251 }
252 }
253 Err(e) => {
254 warn!(event = "serialize_outbound_failed", error = %e);
255 }
256 }
257 }
258 });
259 self.clients.insert(address, ClientHandle { tx: handler_tx });
260 let event_tx = self.event_tx.clone();
261 tokio::spawn(run_connection_handler(address, ws, event_tx, handler_rx));
262 drop(tx);
263 info!(event = "client_connected", %address,
264 connections = self.clients.len(),
265 max_connections = self.max_client_connections);
266 if self.lobby_tx.send(LobbyEvent::ClientConnected { address, message_tx: web_tx }).await.is_err() {
267 warn!(event = "lobby_forward_failed", %address, message = "ClientConnected");
268 }
269 }
270 Some(ClientEvent::Message { address, text }) => {
271 match serde_json::from_str::<go_fish_web::ClientMessage>(&text) {
272 Ok(message) => {
273 if self.lobby_tx.send(LobbyEvent::ClientMessage { address, message }).await.is_err() {
274 warn!(event = "lobby_forward_failed", %address, message = "ClientMessage");
275 }
276 }
277 Err(e) => {
278 warn!(event = "client_message_parse_failed", %address, error = %e, raw = %text);
279 if let Some(handle) = self.clients.get(&address) {
280 let error_json = serde_json::to_string(
281 &go_fish_web::ServerMessage::Error("invalid message".to_string())
282 ).unwrap_or_else(|_| r#"{"Error":"invalid message"}"#.to_string());
283 if handle.tx.send(ServerMessage::Text(error_json)).await.is_err() {
284 warn!(event = "send_failed", %address);
285 }
286 }
287 }
288 }
289 }
290 Some(ClientEvent::Disconnected { address, reason }) => {
291 self.clients.remove(&address);
292 info!(event = "client_disconnected", %address, reason = ?reason);
293 if self.lobby_tx.send(LobbyEvent::ClientDisconnected { address, reason }).await.is_err() {
294 warn!(event = "lobby_forward_failed", %address, message = "ClientDisconnected");
295 }
296 }
297 }
298 }
299 }
300 }
301 }
302}
303
304pub async fn run_tcp_listener(
307 addr: SocketAddr,
308 event_tx: mpsc::Sender<ClientEvent>,
309 command_rx: mpsc::Receiver<ManagerCommand>,
310) {
311 let listener = match tokio::net::TcpListener::bind(addr).await {
312 Ok(l) => l,
313 Err(e) => {
314 error!(event = "tcp_bind_failed", error = %e);
315 return;
316 }
317 };
318 info!(event = "tcp_listener_bound", %addr);
319 run_tcp_listener_inner(listener, event_tx, command_rx).await
320}
321
322#[instrument(skip(event_tx, command_rx))]
323pub async fn run_tcp_listener_inner(
324 listener: tokio::net::TcpListener,
325 event_tx: mpsc::Sender<ClientEvent>,
326 mut command_rx: mpsc::Receiver<ManagerCommand>,
327) {
328 let addr = listener.local_addr().unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap());
329 loop {
330 tokio::select! {
331 cmd = command_rx.recv() => {
332 match cmd {
333 Some(ManagerCommand::Shutdown) | None => {
334 info!(event = "tcp_listener_shutdown", %addr);
335 break;
336 }
337 }
338 }
339 accept = listener.accept() => {
340 let (stream, address) = match accept {
341 Ok(pair) => pair,
342 Err(e) => {
343 error!(event = "tcp_accept_failed", error = %e);
344 continue;
345 }
346 };
347
348 match tokio_tungstenite::accept_async(stream).await {
349 Ok(ws) => {
350 let (tx, _rx) = mpsc::channel::<ServerMessage>(32);
351 if event_tx
352 .send(ClientEvent::Connected { address, tx, ws })
353 .await
354 .is_err()
355 {
356 break;
357 }
358 }
359 Err(e) => {
360 error!(event = "websocket_handshake_failed", %address, error = %e);
361 }
362 }
363 }
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use futures_util::SinkExt;
372 use proptest::prelude::*;
373 use std::net::SocketAddr;
374 use tokio::io::duplex;
375 use tokio::sync::mpsc;
376 use tokio::time::{timeout, Duration};
377 use tokio_tungstenite::WebSocketStream;
378 use tungstenite::protocol::Role;
379
380 async fn make_ws_pair() -> (
383 WebSocketStream<tokio::io::DuplexStream>,
384 WebSocketStream<tokio::io::DuplexStream>,
385 ) {
386 let (server_io, client_io) = duplex(65536);
387 let server_ws =
388 WebSocketStream::from_raw_socket(server_io, Role::Server, None).await;
389 let client_ws =
390 WebSocketStream::from_raw_socket(client_io, Role::Client, None).await;
391 (server_ws, client_ws)
392 }
393
394 fn test_addr() -> SocketAddr {
395 "127.0.0.1:12345".parse().unwrap()
396 }
397
398 #[tokio::test]
402 async fn binary_frame_produces_no_message_event() {
403 let (server_ws, mut client_ws) = make_ws_pair().await;
404 let (event_tx, mut event_rx) = mpsc::channel::<ClientEvent>(16);
405 let (_msg_tx, msg_rx) = mpsc::channel::<ServerMessage>(16);
406
407 let addr = test_addr();
408 let handler = tokio::spawn(run_connection_handler(
409 addr,
410 server_ws,
411 event_tx,
412 msg_rx,
413 ));
414
415 client_ws
417 .send(tungstenite::Message::Binary(vec![1, 2, 3].into()))
418 .await
419 .unwrap();
420 client_ws
421 .send(tungstenite::Message::Close(None))
422 .await
423 .unwrap();
424
425 timeout(Duration::from_secs(2), handler)
426 .await
427 .expect("handler timed out")
428 .expect("handler panicked");
429
430 let mut got_message = false;
432 while let Ok(event) = event_rx.try_recv() {
433 if matches!(event, ClientEvent::Message { .. }) {
434 got_message = true;
435 }
436 }
437 assert!(!got_message, "binary frame should not produce a ClientEvent::Message");
438 }
439
440 #[tokio::test]
444 async fn ping_frame_produces_no_message_event() {
445 let (server_ws, mut client_ws) = make_ws_pair().await;
446 let (event_tx, mut event_rx) = mpsc::channel::<ClientEvent>(16);
447 let (_msg_tx, msg_rx) = mpsc::channel::<ServerMessage>(16);
448
449 let addr = test_addr();
450 let handler = tokio::spawn(run_connection_handler(
451 addr,
452 server_ws,
453 event_tx,
454 msg_rx,
455 ));
456
457 client_ws
459 .send(tungstenite::Message::Ping(vec![].into()))
460 .await
461 .unwrap();
462 client_ws
463 .send(tungstenite::Message::Close(None))
464 .await
465 .unwrap();
466
467 timeout(Duration::from_secs(2), handler)
468 .await
469 .expect("handler timed out")
470 .expect("handler panicked");
471
472 let mut got_message = false;
474 while let Ok(event) = event_rx.try_recv() {
475 if matches!(event, ClientEvent::Message { .. }) {
476 got_message = true;
477 }
478 }
479 assert!(!got_message, "ping frame should not produce a ClientEvent::Message");
480 }
481
482 #[tokio::test]
486 async fn close_frame_sends_clean_disconnect() {
487 let (server_ws, mut client_ws) = make_ws_pair().await;
488 let (event_tx, mut event_rx) = mpsc::channel::<ClientEvent>(16);
489 let (_msg_tx, msg_rx) = mpsc::channel::<ServerMessage>(16);
490
491 let addr = test_addr();
492 let handler = tokio::spawn(run_connection_handler(
493 addr,
494 server_ws,
495 event_tx,
496 msg_rx,
497 ));
498
499 client_ws
500 .send(tungstenite::Message::Close(None))
501 .await
502 .unwrap();
503
504 timeout(Duration::from_secs(2), handler)
505 .await
506 .expect("handler timed out")
507 .expect("handler panicked");
508
509 let mut found_clean = false;
511 while let Ok(event) = event_rx.try_recv() {
512 if let ClientEvent::Disconnected { reason: DisconnectReason::Clean, .. } = event {
513 found_clean = true;
514 }
515 }
516 assert!(found_clean, "expected DisconnectReason::Clean after Close frame");
517 }
518
519 fn start_manager() -> (
524 mpsc::Sender<ClientEvent<tokio::io::DuplexStream>>,
525 mpsc::Sender<ManagerCommand>,
526 tokio::task::JoinHandle<()>,
527 mpsc::Receiver<LobbyEvent>,
528 ) {
529 start_manager_with_limit(2)
530 }
531
532 fn start_manager_with_limit(max_client_connections: usize) -> (
533 mpsc::Sender<ClientEvent<tokio::io::DuplexStream>>,
534 mpsc::Sender<ManagerCommand>,
535 tokio::task::JoinHandle<()>,
536 mpsc::Receiver<LobbyEvent>,
537 ) {
538 let (lobby_tx, lobby_rx) = mpsc::channel::<LobbyEvent>(64);
539 let manager: ConnectionManager<tokio::io::DuplexStream> =
540 ConnectionManager::new(lobby_tx, max_client_connections);
541 let event_tx = manager.event_tx();
542 let command_tx = manager.command_tx();
543 let handle = tokio::spawn(manager.run());
544 (event_tx, command_tx, handle, lobby_rx)
545 }
546
547 async fn connect_client(
549 event_tx: &mpsc::Sender<ClientEvent<tokio::io::DuplexStream>>,
550 addr: SocketAddr,
551 ) -> WebSocketStream<tokio::io::DuplexStream> {
552 let (server_io, client_io) = duplex(65536);
553 let server_ws = WebSocketStream::from_raw_socket(server_io, Role::Server, None).await;
554 let client_ws = WebSocketStream::from_raw_socket(client_io, Role::Client, None).await;
555 let (tx, _rx) = mpsc::channel::<ServerMessage>(1);
556 event_tx
557 .send(ClientEvent::Connected { address: addr, tx, ws: server_ws })
558 .await
559 .unwrap();
560 client_ws
561 }
562
563 #[tokio::test]
567 async fn invalid_json_sends_error() {
568 let (event_tx, command_tx, manager_handle, _lobby_rx) = start_manager();
569 let addr: SocketAddr = "127.0.0.1:10001".parse().unwrap();
570
571 let mut client_ws = connect_client(&event_tx, addr).await;
572 tokio::time::sleep(Duration::from_millis(10)).await;
573
574 client_ws.send(tungstenite::Message::Text("not valid json".into())).await.unwrap();
575
576 let reply = timeout(Duration::from_secs(2), client_ws.next()).await
577 .expect("timed out waiting for error reply")
578 .expect("stream ended")
579 .expect("ws error");
580
581 if let tungstenite::Message::Text(t) = reply {
582 let parsed: serde_json::Value = serde_json::from_str(&t).unwrap();
583 assert!(parsed.get("Error").is_some(), "expected Error variant in response");
584 } else {
585 panic!("expected Text message, got {:?}", reply);
586 }
587
588 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
589 let _ = timeout(Duration::from_secs(2), manager_handle).await;
590 }
591
592 #[tokio::test]
596 async fn disconnection_removes_client() {
597 let (event_tx, command_tx, manager_handle, _lobby_rx) = start_manager();
598 let addr: SocketAddr = "127.0.0.1:10004".parse().unwrap();
599
600 let mut client_ws = connect_client(&event_tx, addr).await;
601 tokio::time::sleep(Duration::from_millis(10)).await;
602
603 client_ws.send(tungstenite::Message::Close(None)).await.unwrap();
605 tokio::time::sleep(Duration::from_millis(10)).await;
606
607 event_tx.send(ClientEvent::Message { address: addr, text: "ghost".into() }).await.unwrap();
609 tokio::time::sleep(Duration::from_millis(10)).await;
610
611 let next = timeout(Duration::from_millis(200), client_ws.next()).await;
613 match next {
614 Ok(Some(Ok(tungstenite::Message::Close(_)))) | Ok(None) | Err(_) => {}
615 Ok(Some(Err(_))) => {} Ok(Some(Ok(tungstenite::Message::Text(t)))) => {
617 panic!("disconnected client received unexpected message: {t}");
618 }
619 other => panic!("unexpected: {other:?}"),
620 }
621
622 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
623 let _ = timeout(Duration::from_secs(2), manager_handle).await;
624 }
625
626 #[tokio::test]
630 async fn disconnect_does_not_affect_remaining_clients() {
631 let (event_tx, command_tx, manager_handle, _lobby_rx) = start_manager();
632 let addr_a: SocketAddr = "127.0.0.1:10005".parse().unwrap();
633 let addr_b: SocketAddr = "127.0.0.1:10006".parse().unwrap();
634
635 let mut client_a = connect_client(&event_tx, addr_a).await;
636 let mut client_b = connect_client(&event_tx, addr_b).await;
637 tokio::time::sleep(Duration::from_millis(10)).await;
638
639 client_a.send(tungstenite::Message::Close(None)).await.unwrap();
641 tokio::time::sleep(Duration::from_millis(10)).await;
642
643 let valid_json = serde_json::to_string(&go_fish_web::ClientMessage::Identity).unwrap();
646 client_b.send(tungstenite::Message::Text(valid_json.into())).await.unwrap();
647 tokio::time::sleep(Duration::from_millis(50)).await;
648
649 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
650 let _ = timeout(Duration::from_secs(2), manager_handle).await;
651 }
652
653 #[tokio::test]
657 async fn handshake_failure_does_not_stop_server() {
658 use tokio::io::AsyncWriteExt;
659 use tokio::net::TcpListener;
660
661 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
662 let addr = listener.local_addr().unwrap();
663 let (lobby_tx, _lobby_rx) = mpsc::channel::<LobbyEvent>(64);
664 let max_client_connections = 2;
665 let manager: ConnectionManager<tokio::net::TcpStream> =
666 ConnectionManager::new(lobby_tx, max_client_connections);
667 let event_tx = manager.event_tx();
668 let command_tx = manager.command_tx();
669 let manager_handle = tokio::spawn(manager.run());
670 let (listener_cmd_tx, listener_cmd_rx) = mpsc::channel::<ManagerCommand>(1);
671 tokio::spawn(run_tcp_listener_inner(listener, event_tx.clone(), listener_cmd_rx));
672
673 let mut plain = tokio::net::TcpStream::connect(addr).await.unwrap();
675 plain.write_all(b"NOT A WEBSOCKET\r\n\r\n").await.unwrap();
676 drop(plain);
677 tokio::time::sleep(Duration::from_millis(50)).await;
678
679 let url = format!("ws://{}", addr);
681 let (_ws, _) = tokio_tungstenite::connect_async(&url).await
682 .expect("valid WS connection should succeed after handshake failure");
683 tokio::time::sleep(Duration::from_millis(50)).await;
684
685 listener_cmd_tx.send(ManagerCommand::Shutdown).await.unwrap();
687 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
688 let result = timeout(Duration::from_secs(2), manager_handle).await;
689 assert!(result.is_ok(), "manager should shut down within 2 seconds");
690 }
691
692 macro_rules! prop_async {
698 ($body:expr) => {{
699 let rt = tokio::runtime::Builder::new_current_thread()
700 .enable_all()
701 .build()
702 .unwrap();
703 rt.block_on(async move { $body })
704 }};
705 }
706
707 proptest! {
712 #![proptest_config(ProptestConfig::with_cases(20))]
713 #[test]
714 fn prop_connection_registration(
715 a in 1u8..=254u8,
716 b in 0u8..=255u8,
717 c in 0u8..=255u8,
718 d in 1u8..=254u8,
719 port in 1024u16..=49151u16,
720 ) {
721 prop_async!({
722 let addr: SocketAddr = format!("{a}.{b}.{c}.{d}:{port}").parse().unwrap();
723 let (event_tx, command_tx, manager_handle, _lobby_rx) = start_manager();
724
725 let mut client_ws = connect_client(&event_tx, addr).await;
726 tokio::time::sleep(Duration::from_millis(10)).await;
727
728 let valid_json = serde_json::to_string(&go_fish_web::ClientMessage::Identity).unwrap();
730 client_ws.send(tungstenite::Message::Text(valid_json.into())).await.unwrap();
731 tokio::time::sleep(Duration::from_millis(20)).await;
732
733 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
734 let _ = timeout(Duration::from_secs(2), manager_handle).await;
735 Ok::<(), TestCaseError>(())
736 }).unwrap();
737 }
738 }
739
740 proptest! {
745 #![proptest_config(ProptestConfig::with_cases(20))]
746 #[test]
747 fn prop_invalid_json_returns_error(msg in "[a-zA-Z0-9 ]{1,64}") {
748 prop_assume!(serde_json::from_str::<go_fish_web::ClientMessage>(&msg).is_err());
750 prop_async!({
751 let addr: SocketAddr = "127.0.0.1:20001".parse().unwrap();
752 let (event_tx, command_tx, manager_handle, _lobby_rx) = start_manager();
753
754 let mut client_ws = connect_client(&event_tx, addr).await;
755 tokio::time::sleep(Duration::from_millis(10)).await;
756
757 client_ws.send(tungstenite::Message::Text(msg.clone().into())).await.unwrap();
758
759 let reply = timeout(Duration::from_secs(2), client_ws.next()).await;
760 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
761 let _ = timeout(Duration::from_secs(2), manager_handle).await;
762
763 match reply {
764 Ok(Some(Ok(tungstenite::Message::Text(t)))) => {
765 let parsed: serde_json::Value = serde_json::from_str(&t)
766 .map_err(|_| TestCaseError::fail("response was not valid JSON"))?;
767 prop_assert!(parsed.get("Error").is_some(), "expected Error variant");
768 }
769 _ => return Err(TestCaseError::fail("did not receive error response")),
770 }
771 Ok(())
772 }).unwrap();
773 }
774 }
775
776 proptest! {
781 #![proptest_config(ProptestConfig::with_cases(20))]
782 #[test]
783 fn prop_disconnection_removes_client(_msg in "[a-zA-Z0-9]{1,32}") {
784 prop_async!({
785 let addr: SocketAddr = "127.0.0.1:22001".parse().unwrap();
786 let (event_tx, command_tx, manager_handle, _lobby_rx) = start_manager();
787
788 let mut client_ws = connect_client(&event_tx, addr).await;
789 tokio::time::sleep(Duration::from_millis(10)).await;
790
791 client_ws.send(tungstenite::Message::Close(None)).await.unwrap();
793 tokio::time::sleep(Duration::from_millis(20)).await;
794
795 event_tx.send(ClientEvent::Message { address: addr, text: "ghost".into() }).await.unwrap();
797 tokio::time::sleep(Duration::from_millis(20)).await;
798
799 let next = timeout(Duration::from_millis(100), client_ws.next()).await;
801 match next {
802 Ok(Some(Ok(tungstenite::Message::Text(t)))) => {
803 return Err(TestCaseError::fail(
804 format!("disconnected client received unexpected message: {t}")
805 ));
806 }
807 _ => {} }
809
810 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
811 let _ = timeout(Duration::from_secs(2), manager_handle).await;
812 Ok(())
813 }).unwrap();
814 }
815 }
816
817 #[tokio::test]
821 async fn max_connections_rejects_with_close_frame() {
822 let (event_tx, command_tx, manager_handle, _lobby_rx) =
823 start_manager_with_limit(2);
824 let addr_a: SocketAddr = "127.0.0.1:11001".parse().unwrap();
825 let addr_b: SocketAddr = "127.0.0.1:11002".parse().unwrap();
826 let addr_c: SocketAddr = "127.0.0.1:11003".parse().unwrap();
827
828 let _client_a = connect_client(&event_tx, addr_a).await;
831 let _client_b = connect_client(&event_tx, addr_b).await;
832 tokio::time::sleep(Duration::from_millis(10)).await;
833
834 let mut rejected = connect_client(&event_tx, addr_c).await;
835
836 let msg = timeout(Duration::from_secs(2), rejected.next()).await
837 .expect("timed out waiting for close frame from rejected client")
838 .expect("stream ended without a message")
839 .expect("ws error on rejected client");
840
841 assert!(
842 matches!(msg, tungstenite::Message::Close(_)),
843 "expected Close frame for rejected client, got {msg:?}"
844 );
845
846 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
847 let _ = timeout(Duration::from_secs(2), manager_handle).await;
848 }
849
850 #[tokio::test]
854 async fn max_connections_does_not_affect_existing_clients() {
855 let (event_tx, command_tx, manager_handle, mut lobby_rx) =
856 start_manager_with_limit(1);
857 let addr_a: SocketAddr = "127.0.0.1:11004".parse().unwrap();
858 let addr_b: SocketAddr = "127.0.0.1:11005".parse().unwrap();
859
860 let mut client_a = connect_client(&event_tx, addr_a).await;
861 tokio::time::sleep(Duration::from_millis(10)).await;
862
863 let mut rejected = connect_client(&event_tx, addr_b).await;
865 tokio::time::sleep(Duration::from_millis(10)).await;
866
867 let msg = timeout(Duration::from_millis(500), rejected.next()).await
869 .expect("timed out waiting for close frame")
870 .expect("stream ended without a message")
871 .expect("ws error");
872 assert!(matches!(msg, tungstenite::Message::Close(_)));
873
874 let valid_json = serde_json::to_string(&go_fish_web::ClientMessage::Identity).unwrap();
876 client_a.send(tungstenite::Message::Text(valid_json.into())).await.unwrap();
877
878 let lobby_event = timeout(Duration::from_secs(2), lobby_rx.recv()).await
879 .expect("timed out waiting for lobby event")
880 .expect("lobby channel closed");
881 assert!(
882 matches!(lobby_event, LobbyEvent::ClientConnected { address, .. } if address == addr_a),
883 "expected ClientConnected for addr_a, got {lobby_event:?}"
884 );
885
886 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
887 let _ = timeout(Duration::from_secs(2), manager_handle).await;
888 }
889
890 proptest! {
895 #![proptest_config(ProptestConfig::with_cases(20))]
896 #[test]
897 fn prop_max_connections_rejects_excess(limit in 1usize..=3usize) {
898 prop_async!({
899 let (event_tx, command_tx, manager_handle, _lobby_rx) =
900 start_manager_with_limit(limit);
901
902 let mut live_clients = vec![];
905 for i in 0..limit {
906 let addr: SocketAddr = format!("127.0.0.2:{}", 10000 + i as u16).parse().unwrap();
907 live_clients.push(connect_client(&event_tx, addr).await);
908 }
909 tokio::time::sleep(Duration::from_millis(10)).await;
910
911 let overflow_addr: SocketAddr = "127.0.0.2:19999".parse().unwrap();
913 let mut rejected = connect_client(&event_tx, overflow_addr).await;
914
915 let msg = timeout(Duration::from_secs(2), rejected.next()).await
916 .map_err(|_| TestCaseError::fail("timed out waiting for close frame"))?
917 .ok_or_else(|| TestCaseError::fail("stream ended without a message"))?
918 .map_err(|e| TestCaseError::fail(format!("ws error: {e}")))?;
919
920 prop_assert!(
921 matches!(msg, tungstenite::Message::Close(_)),
922 "expected Close frame, got {msg:?}"
923 );
924
925 command_tx.send(ManagerCommand::Shutdown).await.unwrap();
926 let _ = timeout(Duration::from_secs(2), manager_handle).await;
927 Ok(())
928 }).unwrap();
929 }
930 }
931}