degen_websockets/
websocket_server.rs

1 
2
3 
4use degen_logger;
5use futures_util::StreamExt;
6use futures_util::stream::SplitSink;
7use futures_util::future::join_all;
8 
9 
10use tokio_tungstenite::WebSocketStream;
11 
12use serde::Serialize;
13use serde_json;
14
15 
16use tokio::net::{TcpListener, TcpStream};
17
18
19use futures::SinkExt; 
20use std::collections::HashMap; 
21
22
23use std::thread; 
24 
25use std::sync::Arc; 
26use tokio::sync::{RwLock,Mutex};
27 
28use std::collections::HashSet;
29 
30use tokio::time::{interval,Duration};
31
32use tokio_tungstenite::tungstenite::Message;
33
34use crate::{util::{rand::generate_random_uuid, logtypes::CustomLogStyle} };
35 
36use super::reliable_message_subsystem::ReliableMessageSubsystem;
37
38 use tokio::sync::mpsc::{channel, Sender, Receiver};
39 
40use super::websocket_messages::{
41    SocketMessage,
42    SocketMessageDestination,
43    InboundMessage,
44    OutboundMessage,
45    
46    SocketMessageError,
47    
48    MessageReliability,
49    MessageReliabilityType,
50    OutboundMessageDestination
51    
52    };
53 
54 
55type ClientsMap = Arc<RwLock<HashMap<String, ClientConnection>>>;
56 
57
58type RoomsMap = Arc<RwLock<HashMap<String, HashSet<String>>>>;
59 
60
61type TxSink = Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>;
62
63
64type RxSink = Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>;
65
66 
67
68
69pub enum WebsocketSystemEvent {
70    ReceivedMessageAck{ reliable_msg_uuid:String }
71    
72}
73
74#[derive(Debug)]
75pub enum WebsocketServerError {
76    SendMessageError,
77    SerdeJsonError(String),
78    TokioError(String),
79    SocketMessageErr(String)
80    
81    
82}
83 
84 
85impl std::fmt::Display for WebsocketServerError {
86    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
87        match self {
88            WebsocketServerError::SendMessageError => write!(f, "Could not send message"),
89            WebsocketServerError::SerdeJsonError(error) => write!(f, "Serde json error: {}", error),
90            WebsocketServerError::TokioError(error) => write!(f, "Tokio Error: {}", error),
91            WebsocketServerError::SocketMessageErr(error)  => write!(f,"SocketMessage Error"),
92        }
93    }
94}
95 
96 
97  impl From<crossbeam_channel::TrySendError<OutboundMessage>> for WebsocketServerError {
98        fn from(error: crossbeam_channel::TrySendError<OutboundMessage>) -> Self {
99        // Here you need to decide how you want to construct your GalaxyServerError
100        WebsocketServerError::SendMessageError//( error )
101    }
102      
103  }
104  
105  
106 
107impl From<serde_json::Error> for WebsocketServerError {
108    fn from(err: serde_json::Error) -> Self {
109        // You may want to customize this to better suit your needs
110        WebsocketServerError::SerdeJsonError(format!("Serialization error: {}", err))
111    }
112}
113 
114impl From<tokio_tungstenite::tungstenite::Error> for WebsocketServerError {
115      fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
116        WebsocketServerError::TokioError(format!("Tokio error: {}", err))
117    }
118}
119
120impl From<SocketMessageError> for WebsocketServerError {
121     fn from(err: SocketMessageError) -> Self {
122        // You may want to customize this to better suit your needs
123        WebsocketServerError::SocketMessageErr(format!("SocketMessageError: {}", err))
124    }
125    
126}
127
128
129impl std::error::Error for WebsocketServerError {}
130
131
132
133#[derive(Clone)]
134pub struct ClientConnection {
135    pub client_socket_uuid: String,
136    pub addr: String,
137    pub tx_sink: TxSink,
138}
139
140impl ClientConnection {
141
142    pub fn new( addr:String, client_tx: SplitSink<WebSocketStream<tokio::net::TcpStream>, Message> ) -> Self{
143
144        Self {  
145            client_socket_uuid: generate_random_uuid(),
146            addr: addr.clone(), 
147            tx_sink: Arc::new(Mutex::new( client_tx )) 
148        }
149
150
151    }
152
153    pub async fn send_message(&self, msg: Message) -> Result<(), tokio_tungstenite::tungstenite::error::Error> {
154        self.tx_sink.lock().await.send(msg).await
155    } 
156
157}
158
159
160
161 
162
163pub struct WebsocketServer{
164
165    clients: ClientsMap,
166
167    rooms: RoomsMap, // room name -> Set[client_uuid]
168
169
170    
171    global_recv_tx: Sender<InboundMessage>, //passed to each client connection 
172    global_recv_rx: Option<Receiver<InboundMessage>>,
173
174    global_send_tx: Sender<OutboundMessage>,
175    global_send_rx: Option<Receiver<OutboundMessage>>, 
176    
177    pending_reliable_messages: Arc<RwLock<HashMap<String,OutboundMessage>>>,
178    
179    
180    ws_server_events_tx: Sender<WebsocketSystemEvent>,
181    ws_server_events_rx: Option<Receiver<WebsocketSystemEvent>>, 
182 
183}
184
185impl WebsocketServer {
186
187
188    pub fn new() -> Self {
189
190
191        let (global_recv_tx, global_recv_rx): (Sender<InboundMessage>, Receiver<InboundMessage>) = channel(2500);
192     
193        let (global_send_tx, global_send_rx): (Sender<OutboundMessage>, Receiver<OutboundMessage>) = channel(2500);
194       
195    
196        let (ws_server_events_tx, ws_server_events_rx): (Sender<WebsocketSystemEvent>, Receiver<WebsocketSystemEvent>) = channel(2500);
197        
198        
199        Self {
200            clients: Arc::new(RwLock::new(HashMap::new())),
201            rooms: Arc::new(RwLock::new(HashMap::new())),
202            global_recv_tx,
203            global_recv_rx:Some(global_recv_rx),
204            global_send_tx,
205            global_send_rx:Some(global_send_rx),
206            pending_reliable_messages: Arc::new(RwLock::new(HashMap::new())),
207            ws_server_events_tx,
208            ws_server_events_rx:Some(ws_server_events_rx)
209       
210        }
211    }
212    
213  
214    pub async fn start(&mut self, url:Option<String>) -> std::io::Result<()> {
215            
216        let clients = Arc::clone(&self.clients); 
217        let rooms = Arc::clone(&self.rooms);
218        let pending_reliable_messages = Arc::clone(&self.pending_reliable_messages);
219     
220     
221        let global_recv_tx = self.global_recv_tx.clone();
222
223        let global_send_rx = self.global_send_rx.take().unwrap();
224        
225        let global_send_tx = self.global_send_tx.clone();
226        
227        let ws_server_events_tx = self.ws_server_events_tx.clone();
228        let ws_server_events_rx = self.ws_server_events_rx.take().unwrap();
229        
230        
231        let addr: String = url.unwrap_or_else(|| "127.0.0.1:8080".to_string());
232        // Create the event loop and TCP listener we'll accept connections on.
233        let try_socket = TcpListener::bind(&addr).await;
234        let listener = try_socket.expect("Failed to bind");
235      
236        
237        degen_logger::log(format!("Listening on: {}", addr)  , CustomLogStyle::Info  ) ; 
238    
239        let accept_connections = 
240          Self::try_accept_new_connections(  
241            Arc::clone(&clients),
242             listener,
243             global_recv_tx,
244             global_send_tx.clone(),
245             ws_server_events_tx
246               );
247               
248               
249        let send_outbound_messages = Self::try_send_outbound_messages( 
250            Arc::clone(&clients) , 
251            Arc::clone(&rooms),
252            global_send_rx
253         );
254         
255        let resend_reliable_messages = ReliableMessageSubsystem::resend_reliable_messages(
256            Arc::clone(&pending_reliable_messages),  
257            global_send_tx
258        );
259         
260        let handle_server_events = Self::handle_server_events(
261                ws_server_events_rx ,
262                Arc::clone(&pending_reliable_messages   )                     
263            );
264
265        let accept_conn_future =accept_connections;
266        
267        let send_outbound_messages_future =  send_outbound_messages;
268        
269        let resend_reliable_messages_future =  resend_reliable_messages;
270        
271        let handle_server_events_future =  handle_server_events;
272        
273        tokio::select! {
274            _ = accept_conn_future => eprintln!("accept_conn_handle finished"),
275            _ = send_outbound_messages_future => eprintln!("send_outbound_messages_handle finished"),
276            _ = resend_reliable_messages_future => eprintln!("resend_reliable_messages_handle finished"),
277            _ = handle_server_events_future => eprintln!("handle_server_events_handle finished"),
278        }
279            
280            
281       
282        
283          degen_logger::log( format!("WS WARN: TOKIO SELECT DROPPED")  , CustomLogStyle::Error  ) ; 
284    
285        Ok(())
286    }
287 
288
289    //recv'd client messages are fed into here 
290    pub fn take_recv_channel(&mut self) -> Option<Receiver<InboundMessage>> {
291        self.global_recv_rx.take() 
292    }
293
294    pub fn get_send_channel(&self) -> Sender<OutboundMessage> {
295        self.global_send_tx.clone()
296    }
297    
298    
299    
300   
301    pub async fn send_socket_message(&self,  socket_message: SocketMessage, destination: OutboundMessageDestination)
302     -> Result<(), WebsocketServerError> {
303         
304        let reliability_type = socket_message.clone().reliability_type;
305       
306         
307         let outbound_message = OutboundMessage { 
308             destination, 
309             message:socket_message
310             };
311         
312       
313        if let MessageReliabilityType::Reliable(msg_uuid) = reliability_type {
314             self.pending_reliable_messages.write().await.insert(msg_uuid, outbound_message.clone());
315         }
316         
317          self.send_outbound_message(outbound_message)?;
318        
319         
320        Ok(())
321    }
322    
323     
324    
325
326    pub fn send_outbound_message(&self, msg:OutboundMessage) 
327     -> Result<(), WebsocketServerError> {
328 
329       
330            
331            
332        self.get_send_channel().try_send( msg )  .map_err(|_| WebsocketServerError::SendMessageError) //.map_err(|_|   Err(WebsocketServerError::SendMessageError) )
333
334       
335    }
336    
337
338 
339
340
341async fn get_cloned_clients(clients: &ClientsMap) -> Vec<ClientConnection> {
342    let clients_map = clients.read().await;
343    clients_map.values().cloned().collect()
344}
345
346
347async fn get_cloned_clients_in_room(clients: &ClientsMap, rooms: &RoomsMap, room_name: String ) -> Vec<ClientConnection>  {
348
349    let client_connection_uuids = Vec::new();
350
351
352    let rooms = rooms.read().await;
353
354    match rooms.get(&room_name) {
355        Some(uuid_set) => {}
356        None => {}
357    }
358
359    return Self::get_cloned_clients_filtered(clients, client_connection_uuids).await;
360}
361
362
363 async fn get_cloned_clients_filtered(clients: &ClientsMap, client_connection_uuids: Vec<String>  ) -> Vec<ClientConnection> {
364    let clients_map = clients.read().await;
365
366    let mut filtered_clients: Vec<ClientConnection> = Vec::new();
367    
368    for uuid in client_connection_uuids {
369        if let Some(client_conn) = clients_map.get(&uuid) {
370            filtered_clients.push(client_conn.clone());
371        }
372    }
373
374    filtered_clients
375}
376
377
378 async fn get_cloned_client_specific(clients: &ClientsMap, socket_connection_uuid: String  ) -> Vec<ClientConnection> {
379    let clients_map = clients.read().await;
380
381    let mut filtered_clients: Vec<ClientConnection> = Vec::new();
382     
383    if let Some(client_conn) = clients_map.get(&socket_connection_uuid) {
384        filtered_clients.push(client_conn.clone());
385    }
386
387
388    filtered_clients
389}
390
391
392
393 
394
395//need to put a loop AROUND the while let 
396async fn handle_server_events( 
397    mut ws_event_rx: Receiver<WebsocketSystemEvent>,
398    pending_reliable_messages: Arc<RwLock<HashMap<String, OutboundMessage>>>,
399) -> std::io::Result<()>  {
400    
401    loop {
402         
403    while let Some(evt) = ws_event_rx.recv().await { //pass control abck to executor 
404        
405         
406          
407        degen_logger::log( format!("ws server handling server event")  , CustomLogStyle::Info  ) ; 
408          
409        match evt {
410            
411          
412            
413            WebsocketSystemEvent::ReceivedMessageAck { reliable_msg_uuid } => {
414                //pending_reliable_messages.write().await.
415                Self::clear_pending_reliable_message(
416                    Arc::clone(&pending_reliable_messages),
417                    reliable_msg_uuid
418                ).await;
419            }
420            
421        };
422        
423    }
424    }
425    
426    
427   // Ok(())
428}
429
430
431//when we receive an ACK with this message uuid, we clear 
432pub async fn clear_pending_reliable_message( 
433    pending_reliable_messages: Arc<RwLock<HashMap<String, OutboundMessage>>>,
434    message_uuid: String,
435    
436) {
437    let mut messages = pending_reliable_messages.write().await ;
438    messages.remove(&message_uuid) ;
439}
440 
441
442
443pub async fn try_send_outbound_messages( 
444    clients_map: ClientsMap, 
445    rooms_map: RoomsMap,
446    mut global_send_rx: Receiver<OutboundMessage> 
447) -> std::io::Result<()> {
448 
449
450    loop {      
451        
452        //await gives control back to executor 
453        match global_send_rx.recv().await {
454            Some(msg) => {
455                     
456               // let message = msg;
457                let clients_map = Arc::clone(&clients_map);
458                let rooms_map = Arc::clone(&rooms_map);
459
460              
461
462                Self::broadcast(clients_map, rooms_map, msg).await;
463
464
465            }
466          
467            None   =>  {} ,
468        }
469    }
470
471   // Ok(())
472
473}
474
475
476
477
478    pub async fn add_client_to_room(&self, client_connection_uuid:String, room_name: String ) {
479        let mut rooms = self.rooms.write().await;
480    
481        let room_clients = rooms.entry(room_name).or_insert_with(HashSet::new);
482        room_clients.insert(client_connection_uuid);
483    }
484    
485    pub async fn remove_client_from_room(&self, client_connection_uuid:String, room_name: String ) {
486        let mut rooms = self.rooms.write().await;
487    
488        if let Some(room_clients) = rooms.get_mut(&room_name) {
489            room_clients.remove(&client_connection_uuid);
490            
491            // Optionally, you can remove the room if it's now empty
492            if room_clients.is_empty() {
493                rooms.remove(&room_name);
494            }
495        }
496    }
497    
498 
499
500
501
502
503
504
505pub async fn broadcast( 
506    clients_map: ClientsMap, 
507    rooms_map:RoomsMap,  
508    outbound_message: OutboundMessage
509) -> Result<(), WebsocketServerError> {
510       
511   
512    
513      degen_logger::log(  format!("ws_server broadcasting msg: {} ", outbound_message.message  ) , CustomLogStyle::Info  ) ; 
514
515    let socket_message = outbound_message.message;
516         
517
518    let client_connections = match outbound_message.destination {
519
520       OutboundMessageDestination::All =>  Self::get_cloned_clients(&clients_map).await,
521       OutboundMessageDestination::Room(room_name) => Self::get_cloned_clients_in_room(&clients_map,&rooms_map,room_name).await,
522       OutboundMessageDestination::SocketConn(socket_connection_uuid) => Self::get_cloned_client_specific(&clients_map,socket_connection_uuid).await,
523     
524    };
525
526 
527    Self::broadcast_to_connections(client_connections, socket_message).await
528}
529
530
531pub async fn broadcast_to_connections
532( connections: Vec<ClientConnection>, socket_message: SocketMessage)
533 -> Result<(), WebsocketServerError> 
534  {
535         
536
537   let message = socket_message.to_message()?;
538 
539    //Could cause thread lock issue !? 
540    let send_futures: Vec<_> = {
541
542        connections 
543            .iter()
544            .map(|client| {
545                let message = message.clone();
546                client.send_message(message)
547            })
548            .collect()
549
550    };
551
552    let results = join_all(send_futures).await;
553
554    for result in results {
555        if let Err(err) = result {
556           
557            
558                degen_logger::log(   format!("Failed to send a message: {}", err) , CustomLogStyle::Error  ) ; 
559            
560            
561            return Err(WebsocketServerError::SendMessageError);
562        }
563    }
564    Ok(())
565}
566 
567 
568 
569
570pub async fn try_accept_new_connections(
571    clients_map: ClientsMap, 
572    listener: TcpListener, 
573    global_recv_tx: Sender<InboundMessage>, //so we can tell outer process about what we recv'd
574    
575    global_send_tx: Sender<OutboundMessage>, //so we can send ACK packets if got reliable 
576
577    ws_server_events_tx: Sender<WebsocketSystemEvent>
578) -> std::io::Result<()> {
579 
580                                                  
581     while let Ok((stream, _)) = listener.accept().await {  //pass control back to executor
582        let clients_map =  Arc::clone(&clients_map);
583         let new_client_thread = tokio::spawn(
584            Self::accept_connection(
585                clients_map, 
586                stream, 
587                global_recv_tx.clone(),
588                
589                global_send_tx.clone(),
590                ws_server_events_tx.clone()
591                
592            ) 
593            );
594            
595          
596    }
597    Ok(())
598}
599 
600 
601
602async fn accept_connection(
603    clients: ClientsMap, 
604    raw_stream: TcpStream,
605     global_socket_tx: Sender<InboundMessage>,
606     
607     outbound_messages_tx: Sender<OutboundMessage> ,//for sending ack packets
608     ws_server_events_tx: Sender<WebsocketSystemEvent>
609    ) {
610
611    let addr = raw_stream
612        .peer_addr()
613        .expect("connected streams should have a peer address")
614        .to_string();
615
616    let ws_stream = tokio_tungstenite::accept_async(raw_stream)
617        .await
618        .expect("Error during the websocket handshake occurred");
619
620    
621    
622        degen_logger::log(  format!("New WebSocket connection: {}", addr)  , CustomLogStyle::Info  ) ; 
623            
624            
625
626    let (   client_tx, mut client_rx) = ws_stream.split();   //this is how i can read and write to this client 
627    
628
629    let new_client_connection = ClientConnection::new( addr.clone(),  client_tx  );
630    
631    let client_uuid = new_client_connection.client_socket_uuid.clone();
632    
633    let client_socket_uuid = new_client_connection.client_socket_uuid.clone();
634
635    clients.write().await.insert(
636         client_socket_uuid.clone(), 
637         new_client_connection 
638        );
639
640   
641
642         //in this new thread for the socket connection, recv'd messages are constantly collected
643    while let Some(msg) = client_rx.next().await {  //pass control back to executor
644        match msg {
645            Ok(msg) => {
646                if msg.is_text() || msg.is_binary() {
647                    //let data = msg.clone().into_data();
648                    //println!("Received a message from {}: {:?}", addr, data);
649                    // here you can consume your messages
650
651                    let inbound_msg_result = InboundMessage::from_message(
652                        client_uuid.clone(),
653                        msg
654                    );
655
656                    if let Ok(inbound_msg) = inbound_msg_result {
657                      global_socket_tx.try_send( inbound_msg.clone()  );
658                      
659                      let socket_message = inbound_msg.message;
660                      if let MessageReliabilityType::Reliable( msg_uuid ) =  socket_message.reliability_type {
661                        let ack_message =  SocketMessage::create_reliability_ack( msg_uuid ) ;
662 
663                                
664    
665                                 degen_logger::log(  format!("server learned of reliable msg")  , CustomLogStyle::Info  ) ; 
666 
667                        
668                              let send_ack_result =  outbound_messages_tx.try_send (
669                                        OutboundMessage {
670                                            destination: OutboundMessageDestination::SocketConn( client_socket_uuid.clone( )),
671                                            message:ack_message.clone()
672                                        }
673                                    );
674                                    
675                        match send_ack_result {
676                             Ok( .. ) => {    degen_logger::log(  format!("Server sent ack {}", ack_message)  , CustomLogStyle::Info  )   },
677                             Err(e) => println!("{}",e)
678                         }     
679                      }
680                      
681                      if let SocketMessageDestination::AckToReliableMsg( reliable_msg_uuid ) = socket_message.destination   {
682                         let handle_ack_result =  ws_server_events_tx.try_send( 
683                           WebsocketSystemEvent::ReceivedMessageAck { reliable_msg_uuid }   
684                          );    
685                          
686                         match handle_ack_result {
687                             Ok( .. ) => {  degen_logger::log(  format!("Server handling ack")  , CustomLogStyle::Info  )     },
688                             Err(e) => println!("{}",e)
689                         }                  
690                          
691                      }                    
692                      
693                      
694                    }
695                }
696            }
697            Err(e) => {
698                eprintln!(
699                    "an error occurred while processing incoming messages for {}: {:?}",
700                    client_socket_uuid.clone(),
701                    e
702                );
703                break;
704            }
705        }
706    }
707
708    // Remove the client from the map once it has disconnected.
709    clients.write().await.remove(&addr);
710}
711 
712
713
714 
715   
716
717
718} 
719
720
721
722
723