1
2
3
4use degen_logger;
5use futures_util::StreamExt;
6use futures_util::stream::SplitSink;
7use futures_util::future::join_all;
8
9
10use tokio_tungstenite::WebSocketStream;
11
12use serde::Serialize;
13use serde_json;
14
15
16use tokio::net::{TcpListener, TcpStream};
17
18
19use futures::SinkExt;
20use std::collections::HashMap;
21
22
23use std::thread;
24
25use std::sync::Arc;
26use tokio::sync::{RwLock,Mutex};
27
28use std::collections::HashSet;
29
30use tokio::time::{interval,Duration};
31
32use tokio_tungstenite::tungstenite::Message;
33
34use crate::{util::{rand::generate_random_uuid, logtypes::CustomLogStyle} };
35
36use super::reliable_message_subsystem::ReliableMessageSubsystem;
37
38 use tokio::sync::mpsc::{channel, Sender, Receiver};
39
40use super::websocket_messages::{
41 SocketMessage,
42 SocketMessageDestination,
43 InboundMessage,
44 OutboundMessage,
45
46 SocketMessageError,
47
48 MessageReliability,
49 MessageReliabilityType,
50 OutboundMessageDestination
51
52 };
53
54
55type ClientsMap = Arc<RwLock<HashMap<String, ClientConnection>>>;
56
57
58type RoomsMap = Arc<RwLock<HashMap<String, HashSet<String>>>>;
59
60
61type TxSink = Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>;
62
63
64type RxSink = Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>;
65
66
67
68
69pub enum WebsocketSystemEvent {
70 ReceivedMessageAck{ reliable_msg_uuid:String }
71
72}
73
74#[derive(Debug)]
75pub enum WebsocketServerError {
76 SendMessageError,
77 SerdeJsonError(String),
78 TokioError(String),
79 SocketMessageErr(String)
80
81
82}
83
84
85impl std::fmt::Display for WebsocketServerError {
86 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
87 match self {
88 WebsocketServerError::SendMessageError => write!(f, "Could not send message"),
89 WebsocketServerError::SerdeJsonError(error) => write!(f, "Serde json error: {}", error),
90 WebsocketServerError::TokioError(error) => write!(f, "Tokio Error: {}", error),
91 WebsocketServerError::SocketMessageErr(error) => write!(f,"SocketMessage Error"),
92 }
93 }
94}
95
96
97 impl From<crossbeam_channel::TrySendError<OutboundMessage>> for WebsocketServerError {
98 fn from(error: crossbeam_channel::TrySendError<OutboundMessage>) -> Self {
99 WebsocketServerError::SendMessageError}
102
103 }
104
105
106
107impl From<serde_json::Error> for WebsocketServerError {
108 fn from(err: serde_json::Error) -> Self {
109 WebsocketServerError::SerdeJsonError(format!("Serialization error: {}", err))
111 }
112}
113
114impl From<tokio_tungstenite::tungstenite::Error> for WebsocketServerError {
115 fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
116 WebsocketServerError::TokioError(format!("Tokio error: {}", err))
117 }
118}
119
120impl From<SocketMessageError> for WebsocketServerError {
121 fn from(err: SocketMessageError) -> Self {
122 WebsocketServerError::SocketMessageErr(format!("SocketMessageError: {}", err))
124 }
125
126}
127
128
129impl std::error::Error for WebsocketServerError {}
130
131
132
133#[derive(Clone)]
134pub struct ClientConnection {
135 pub client_socket_uuid: String,
136 pub addr: String,
137 pub tx_sink: TxSink,
138}
139
140impl ClientConnection {
141
142 pub fn new( addr:String, client_tx: SplitSink<WebSocketStream<tokio::net::TcpStream>, Message> ) -> Self{
143
144 Self {
145 client_socket_uuid: generate_random_uuid(),
146 addr: addr.clone(),
147 tx_sink: Arc::new(Mutex::new( client_tx ))
148 }
149
150
151 }
152
153 pub async fn send_message(&self, msg: Message) -> Result<(), tokio_tungstenite::tungstenite::error::Error> {
154 self.tx_sink.lock().await.send(msg).await
155 }
156
157}
158
159
160
161
162
163pub struct WebsocketServer{
164
165 clients: ClientsMap,
166
167 rooms: RoomsMap, global_recv_tx: Sender<InboundMessage>, global_recv_rx: Option<Receiver<InboundMessage>>,
173
174 global_send_tx: Sender<OutboundMessage>,
175 global_send_rx: Option<Receiver<OutboundMessage>>,
176
177 pending_reliable_messages: Arc<RwLock<HashMap<String,OutboundMessage>>>,
178
179
180 ws_server_events_tx: Sender<WebsocketSystemEvent>,
181 ws_server_events_rx: Option<Receiver<WebsocketSystemEvent>>,
182
183}
184
185impl WebsocketServer {
186
187
188 pub fn new() -> Self {
189
190
191 let (global_recv_tx, global_recv_rx): (Sender<InboundMessage>, Receiver<InboundMessage>) = channel(2500);
192
193 let (global_send_tx, global_send_rx): (Sender<OutboundMessage>, Receiver<OutboundMessage>) = channel(2500);
194
195
196 let (ws_server_events_tx, ws_server_events_rx): (Sender<WebsocketSystemEvent>, Receiver<WebsocketSystemEvent>) = channel(2500);
197
198
199 Self {
200 clients: Arc::new(RwLock::new(HashMap::new())),
201 rooms: Arc::new(RwLock::new(HashMap::new())),
202 global_recv_tx,
203 global_recv_rx:Some(global_recv_rx),
204 global_send_tx,
205 global_send_rx:Some(global_send_rx),
206 pending_reliable_messages: Arc::new(RwLock::new(HashMap::new())),
207 ws_server_events_tx,
208 ws_server_events_rx:Some(ws_server_events_rx)
209
210 }
211 }
212
213
214 pub async fn start(&mut self, url:Option<String>) -> std::io::Result<()> {
215
216 let clients = Arc::clone(&self.clients);
217 let rooms = Arc::clone(&self.rooms);
218 let pending_reliable_messages = Arc::clone(&self.pending_reliable_messages);
219
220
221 let global_recv_tx = self.global_recv_tx.clone();
222
223 let global_send_rx = self.global_send_rx.take().unwrap();
224
225 let global_send_tx = self.global_send_tx.clone();
226
227 let ws_server_events_tx = self.ws_server_events_tx.clone();
228 let ws_server_events_rx = self.ws_server_events_rx.take().unwrap();
229
230
231 let addr: String = url.unwrap_or_else(|| "127.0.0.1:8080".to_string());
232 let try_socket = TcpListener::bind(&addr).await;
234 let listener = try_socket.expect("Failed to bind");
235
236
237 degen_logger::log(format!("Listening on: {}", addr) , CustomLogStyle::Info ) ;
238
239 let accept_connections =
240 Self::try_accept_new_connections(
241 Arc::clone(&clients),
242 listener,
243 global_recv_tx,
244 global_send_tx.clone(),
245 ws_server_events_tx
246 );
247
248
249 let send_outbound_messages = Self::try_send_outbound_messages(
250 Arc::clone(&clients) ,
251 Arc::clone(&rooms),
252 global_send_rx
253 );
254
255 let resend_reliable_messages = ReliableMessageSubsystem::resend_reliable_messages(
256 Arc::clone(&pending_reliable_messages),
257 global_send_tx
258 );
259
260 let handle_server_events = Self::handle_server_events(
261 ws_server_events_rx ,
262 Arc::clone(&pending_reliable_messages )
263 );
264
265 let accept_conn_future =accept_connections;
266
267 let send_outbound_messages_future = send_outbound_messages;
268
269 let resend_reliable_messages_future = resend_reliable_messages;
270
271 let handle_server_events_future = handle_server_events;
272
273 tokio::select! {
274 _ = accept_conn_future => eprintln!("accept_conn_handle finished"),
275 _ = send_outbound_messages_future => eprintln!("send_outbound_messages_handle finished"),
276 _ = resend_reliable_messages_future => eprintln!("resend_reliable_messages_handle finished"),
277 _ = handle_server_events_future => eprintln!("handle_server_events_handle finished"),
278 }
279
280
281
282
283 degen_logger::log( format!("WS WARN: TOKIO SELECT DROPPED") , CustomLogStyle::Error ) ;
284
285 Ok(())
286 }
287
288
289 pub fn take_recv_channel(&mut self) -> Option<Receiver<InboundMessage>> {
291 self.global_recv_rx.take()
292 }
293
294 pub fn get_send_channel(&self) -> Sender<OutboundMessage> {
295 self.global_send_tx.clone()
296 }
297
298
299
300
301 pub async fn send_socket_message(&self, socket_message: SocketMessage, destination: OutboundMessageDestination)
302 -> Result<(), WebsocketServerError> {
303
304 let reliability_type = socket_message.clone().reliability_type;
305
306
307 let outbound_message = OutboundMessage {
308 destination,
309 message:socket_message
310 };
311
312
313 if let MessageReliabilityType::Reliable(msg_uuid) = reliability_type {
314 self.pending_reliable_messages.write().await.insert(msg_uuid, outbound_message.clone());
315 }
316
317 self.send_outbound_message(outbound_message)?;
318
319
320 Ok(())
321 }
322
323
324
325
326 pub fn send_outbound_message(&self, msg:OutboundMessage)
327 -> Result<(), WebsocketServerError> {
328
329
330
331
332 self.get_send_channel().try_send( msg ) .map_err(|_| WebsocketServerError::SendMessageError) }
336
337
338
339
340
341async fn get_cloned_clients(clients: &ClientsMap) -> Vec<ClientConnection> {
342 let clients_map = clients.read().await;
343 clients_map.values().cloned().collect()
344}
345
346
347async fn get_cloned_clients_in_room(clients: &ClientsMap, rooms: &RoomsMap, room_name: String ) -> Vec<ClientConnection> {
348
349 let client_connection_uuids = Vec::new();
350
351
352 let rooms = rooms.read().await;
353
354 match rooms.get(&room_name) {
355 Some(uuid_set) => {}
356 None => {}
357 }
358
359 return Self::get_cloned_clients_filtered(clients, client_connection_uuids).await;
360}
361
362
363 async fn get_cloned_clients_filtered(clients: &ClientsMap, client_connection_uuids: Vec<String> ) -> Vec<ClientConnection> {
364 let clients_map = clients.read().await;
365
366 let mut filtered_clients: Vec<ClientConnection> = Vec::new();
367
368 for uuid in client_connection_uuids {
369 if let Some(client_conn) = clients_map.get(&uuid) {
370 filtered_clients.push(client_conn.clone());
371 }
372 }
373
374 filtered_clients
375}
376
377
378 async fn get_cloned_client_specific(clients: &ClientsMap, socket_connection_uuid: String ) -> Vec<ClientConnection> {
379 let clients_map = clients.read().await;
380
381 let mut filtered_clients: Vec<ClientConnection> = Vec::new();
382
383 if let Some(client_conn) = clients_map.get(&socket_connection_uuid) {
384 filtered_clients.push(client_conn.clone());
385 }
386
387
388 filtered_clients
389}
390
391
392
393
394
395async fn handle_server_events(
397 mut ws_event_rx: Receiver<WebsocketSystemEvent>,
398 pending_reliable_messages: Arc<RwLock<HashMap<String, OutboundMessage>>>,
399) -> std::io::Result<()> {
400
401 loop {
402
403 while let Some(evt) = ws_event_rx.recv().await { degen_logger::log( format!("ws server handling server event") , CustomLogStyle::Info ) ;
408
409 match evt {
410
411
412
413 WebsocketSystemEvent::ReceivedMessageAck { reliable_msg_uuid } => {
414 Self::clear_pending_reliable_message(
416 Arc::clone(&pending_reliable_messages),
417 reliable_msg_uuid
418 ).await;
419 }
420
421 };
422
423 }
424 }
425
426
427 }
429
430
431pub async fn clear_pending_reliable_message(
433 pending_reliable_messages: Arc<RwLock<HashMap<String, OutboundMessage>>>,
434 message_uuid: String,
435
436) {
437 let mut messages = pending_reliable_messages.write().await ;
438 messages.remove(&message_uuid) ;
439}
440
441
442
443pub async fn try_send_outbound_messages(
444 clients_map: ClientsMap,
445 rooms_map: RoomsMap,
446 mut global_send_rx: Receiver<OutboundMessage>
447) -> std::io::Result<()> {
448
449
450 loop {
451
452 match global_send_rx.recv().await {
454 Some(msg) => {
455
456 let clients_map = Arc::clone(&clients_map);
458 let rooms_map = Arc::clone(&rooms_map);
459
460
461
462 Self::broadcast(clients_map, rooms_map, msg).await;
463
464
465 }
466
467 None => {} ,
468 }
469 }
470
471 }
474
475
476
477
478 pub async fn add_client_to_room(&self, client_connection_uuid:String, room_name: String ) {
479 let mut rooms = self.rooms.write().await;
480
481 let room_clients = rooms.entry(room_name).or_insert_with(HashSet::new);
482 room_clients.insert(client_connection_uuid);
483 }
484
485 pub async fn remove_client_from_room(&self, client_connection_uuid:String, room_name: String ) {
486 let mut rooms = self.rooms.write().await;
487
488 if let Some(room_clients) = rooms.get_mut(&room_name) {
489 room_clients.remove(&client_connection_uuid);
490
491 if room_clients.is_empty() {
493 rooms.remove(&room_name);
494 }
495 }
496 }
497
498
499
500
501
502
503
504
505pub async fn broadcast(
506 clients_map: ClientsMap,
507 rooms_map:RoomsMap,
508 outbound_message: OutboundMessage
509) -> Result<(), WebsocketServerError> {
510
511
512
513 degen_logger::log( format!("ws_server broadcasting msg: {} ", outbound_message.message ) , CustomLogStyle::Info ) ;
514
515 let socket_message = outbound_message.message;
516
517
518 let client_connections = match outbound_message.destination {
519
520 OutboundMessageDestination::All => Self::get_cloned_clients(&clients_map).await,
521 OutboundMessageDestination::Room(room_name) => Self::get_cloned_clients_in_room(&clients_map,&rooms_map,room_name).await,
522 OutboundMessageDestination::SocketConn(socket_connection_uuid) => Self::get_cloned_client_specific(&clients_map,socket_connection_uuid).await,
523
524 };
525
526
527 Self::broadcast_to_connections(client_connections, socket_message).await
528}
529
530
531pub async fn broadcast_to_connections
532( connections: Vec<ClientConnection>, socket_message: SocketMessage)
533 -> Result<(), WebsocketServerError>
534 {
535
536
537 let message = socket_message.to_message()?;
538
539 let send_futures: Vec<_> = {
541
542 connections
543 .iter()
544 .map(|client| {
545 let message = message.clone();
546 client.send_message(message)
547 })
548 .collect()
549
550 };
551
552 let results = join_all(send_futures).await;
553
554 for result in results {
555 if let Err(err) = result {
556
557
558 degen_logger::log( format!("Failed to send a message: {}", err) , CustomLogStyle::Error ) ;
559
560
561 return Err(WebsocketServerError::SendMessageError);
562 }
563 }
564 Ok(())
565}
566
567
568
569
570pub async fn try_accept_new_connections(
571 clients_map: ClientsMap,
572 listener: TcpListener,
573 global_recv_tx: Sender<InboundMessage>, global_send_tx: Sender<OutboundMessage>, ws_server_events_tx: Sender<WebsocketSystemEvent>
578) -> std::io::Result<()> {
579
580
581 while let Ok((stream, _)) = listener.accept().await { let clients_map = Arc::clone(&clients_map);
583 let new_client_thread = tokio::spawn(
584 Self::accept_connection(
585 clients_map,
586 stream,
587 global_recv_tx.clone(),
588
589 global_send_tx.clone(),
590 ws_server_events_tx.clone()
591
592 )
593 );
594
595
596 }
597 Ok(())
598}
599
600
601
602async fn accept_connection(
603 clients: ClientsMap,
604 raw_stream: TcpStream,
605 global_socket_tx: Sender<InboundMessage>,
606
607 outbound_messages_tx: Sender<OutboundMessage> ,ws_server_events_tx: Sender<WebsocketSystemEvent>
609 ) {
610
611 let addr = raw_stream
612 .peer_addr()
613 .expect("connected streams should have a peer address")
614 .to_string();
615
616 let ws_stream = tokio_tungstenite::accept_async(raw_stream)
617 .await
618 .expect("Error during the websocket handshake occurred");
619
620
621
622 degen_logger::log( format!("New WebSocket connection: {}", addr) , CustomLogStyle::Info ) ;
623
624
625
626 let ( client_tx, mut client_rx) = ws_stream.split(); let new_client_connection = ClientConnection::new( addr.clone(), client_tx );
630
631 let client_uuid = new_client_connection.client_socket_uuid.clone();
632
633 let client_socket_uuid = new_client_connection.client_socket_uuid.clone();
634
635 clients.write().await.insert(
636 client_socket_uuid.clone(),
637 new_client_connection
638 );
639
640
641
642 while let Some(msg) = client_rx.next().await { match msg {
645 Ok(msg) => {
646 if msg.is_text() || msg.is_binary() {
647 let inbound_msg_result = InboundMessage::from_message(
652 client_uuid.clone(),
653 msg
654 );
655
656 if let Ok(inbound_msg) = inbound_msg_result {
657 global_socket_tx.try_send( inbound_msg.clone() );
658
659 let socket_message = inbound_msg.message;
660 if let MessageReliabilityType::Reliable( msg_uuid ) = socket_message.reliability_type {
661 let ack_message = SocketMessage::create_reliability_ack( msg_uuid ) ;
662
663
664
665 degen_logger::log( format!("server learned of reliable msg") , CustomLogStyle::Info ) ;
666
667
668 let send_ack_result = outbound_messages_tx.try_send (
669 OutboundMessage {
670 destination: OutboundMessageDestination::SocketConn( client_socket_uuid.clone( )),
671 message:ack_message.clone()
672 }
673 );
674
675 match send_ack_result {
676 Ok( .. ) => { degen_logger::log( format!("Server sent ack {}", ack_message) , CustomLogStyle::Info ) },
677 Err(e) => println!("{}",e)
678 }
679 }
680
681 if let SocketMessageDestination::AckToReliableMsg( reliable_msg_uuid ) = socket_message.destination {
682 let handle_ack_result = ws_server_events_tx.try_send(
683 WebsocketSystemEvent::ReceivedMessageAck { reliable_msg_uuid }
684 );
685
686 match handle_ack_result {
687 Ok( .. ) => { degen_logger::log( format!("Server handling ack") , CustomLogStyle::Info ) },
688 Err(e) => println!("{}",e)
689 }
690
691 }
692
693
694 }
695 }
696 }
697 Err(e) => {
698 eprintln!(
699 "an error occurred while processing incoming messages for {}: {:?}",
700 client_socket_uuid.clone(),
701 e
702 );
703 break;
704 }
705 }
706 }
707
708 clients.write().await.remove(&addr);
710}
711
712
713
714
715
716
717
718}
719
720
721
722
723