bevy_websocket_adapter 0.1.5

Simple adapter to receive WebSocket messages in your bevy games as native Rust types.
Documentation
use crate::shared::{MessageType, SendEnveloppe};
use crossbeam_channel::{unbounded, Receiver, Sender, TryRecvError};
use futures::{join, pending};
use futures_util::{future as ufuture, stream::TryStreamExt, SinkExt, StreamExt};
use log::{debug, trace, warn};
use serde::Serialize;
use std::{
    collections::HashMap,
    sync::{Arc, Mutex},
};
use uuid::Uuid;
use thiserror::Error as TError;
use tokio::{
    net::{TcpListener, ToSocketAddrs},
    runtime::Runtime,
    task::JoinHandle,
};
use crate::shared::{
    NetworkEvent,
    ConnectionHandle
};

#[derive(TError, Debug)]
pub enum ServerConfigError {}


pub struct Server {
    rt: Arc<Runtime>,
    server_handle: Option<JoinHandle<()>>,
    sessions_events: Arc<Mutex<HashMap<Uuid, Arc<Receiver<NetworkEvent>>>>>,
    sessions_sinks: Arc<Mutex<HashMap<Uuid, Arc<Sender<tokio_tungstenite::tungstenite::Message>>>>>,
    sessions_handles: Arc<Mutex<HashMap<Uuid, JoinHandle<()>>>>,
}

impl Default for Server {
    fn default() -> Self {
        Self::new()
    }
}

impl Server {
    pub fn new() -> Server {
        Server {
            rt: Arc::new(
                tokio::runtime::Builder::new_multi_thread()
                    .enable_all()
                    .build()
                    .expect("Could not build tokio runtime"),
            ),
            server_handle: None,
            sessions_events: Arc::new(Mutex::new(
                HashMap::<Uuid, Arc<Receiver<NetworkEvent>>>::new(),
            )),
            sessions_sinks: Arc::new(Mutex::new(HashMap::<
                Uuid,
                Arc<Sender<tokio_tungstenite::tungstenite::Message>>,
            >::new())),
            sessions_handles: Arc::new(Mutex::new(HashMap::<Uuid, JoinHandle<()>>::new())),
        }
    }

    pub fn is_running(&self) -> bool {
        self.server_handle.is_some()
    }

    pub fn listen(
        &mut self,
        addr: impl ToSocketAddrs + Send + 'static,
    ) -> Result<(), ServerConfigError> {
        self.start_listen_loop(addr)?;
        Ok(())
    }

    pub fn stop(&mut self) {
        if let Some(conn) = self.server_handle.take() {
            debug!("stopping WS accept loop");
            conn.abort();
        }
        for (k, conn) in self.sessions_handles.lock().unwrap().drain() {
            debug!("aborting session {}", k);
            conn.abort();
        }
    }

    pub fn recv(&self) -> Option<NetworkEvent> {
        let mut sel = crossbeam_channel::Select::new();
        let mut ids = Vec::<Uuid>::new();
        let mut receivers = Vec::new();
        {
            let chs = self.sessions_events.lock().unwrap();
            for (i, rx) in chs.iter() {
                ids.push(*i);
                receivers.push(rx.clone());
            }
        }
        if receivers.is_empty() {
            return None;
        }
        for rx in receivers.iter() {
            sel.recv(&**rx);
        }
        let mut r = None;
        while r.is_none() {
            let msg;
            let index = sel.try_ready();
            if index.is_err() {
                return None;
            }
            {
                let chs = self.sessions_events.lock().unwrap();
                let res = chs.get(&ids[index.unwrap()]);
                res?;
                msg = Some(res.unwrap().recv());
            }
            let sess_id = ids[index.unwrap()];
            r = match msg {
                Some(Err(_e)) => {
                    self.sessions_events.lock().unwrap().remove(&sess_id);
                    self.sessions_handles.lock().unwrap().remove(&sess_id);
                    debug!("connection closed for handle {}", sess_id);
                    None
                }
                Some(Ok(m)) => match m {
                    NetworkEvent::Error(_, _) => {
                        self.sessions_events.lock().unwrap().remove(&sess_id);
                        self.sessions_handles.lock().unwrap().remove(&sess_id);
                        Some(m)
                    }
                    _ => Some(m),
                },
                None => {
                    panic!("none message, this should never happens.")
                }
            };
        }
        r
    }

    fn start_listen_loop(
        &mut self,
        addr: impl ToSocketAddrs + Send + 'static,
    ) -> Result<(), ServerConfigError> {
        let rt = self.rt.clone();
        let sessions_events = self.sessions_events.clone();
        let sessions_handles = self.sessions_handles.clone();
        let sessions_sinks = self.sessions_sinks.clone();

        let listen_loop = async move {
            let try_socket = TcpListener::bind(addr).await;
            let listener = try_socket.expect("Failed to bind");
            while let Ok((socket, addr)) = listener.accept().await {
                debug!("new connection from {:?}", addr);
                let client_handle = ConnectionHandle::new();
                let handle_id = client_handle.id();
                let (ev_tx, ev_rx) = unbounded();
                let (from_handler_tx, from_handler_rx) = unbounded();

                let receiver = Arc::new(ev_rx);
                let sender = Arc::new(from_handler_tx);

                let handle = async move {
                    let ws_stream = tokio_tungstenite::accept_async(socket)
                        .await
                        .expect("Error during the websocket handshake occurred");
                    ev_tx
                        .send(NetworkEvent::Connected(client_handle.clone()))
                        .expect("failed to send network event");
                    let (mut outgoing, incoming) = ws_stream.split();
                    let handle_incoming = incoming.try_for_each(|msg| {
                        match msg {
                            tokio_tungstenite::tungstenite::Message::Binary(bts) => {
                                ev_tx
                                    .send(NetworkEvent::Message(client_handle.clone(), bts))
                                    .expect("failed to send network event");
                            }
                            tokio_tungstenite::tungstenite::Message::Close(_) => {
                                ev_tx
                                    .send(NetworkEvent::Disconnected(client_handle.clone()))
                                    .expect("failed to send network event");
                            }
                            _ => {
                                warn!("unsupported format for message: {:?}", msg);
                            }
                        }

                        ufuture::ok(())
                    });
                    let forward_handle = async move {
                        loop {
                            let req = from_handler_rx.try_recv();
                            match req {
                                Err(TryRecvError::Empty) => {
                                    pending!()
                                }
                                Err(e) => {
                                    warn!(
                                        "failed to forward message to client sink {:?} : {}",
                                        handle_id, e
                                    );
                                }
                                Ok(ev) => {
                                    if let Err(e) = outgoing.send(ev).await {
                                        warn!(
                                            "failed to send message to client {:?} : {}",
                                            handle_id, e
                                        );
                                    }
                                }
                            }
                        }
                    };
                    if let (_, Err(e)) = join!(forward_handle, handle_incoming) {
                        warn!("failure in connection handling: {:?}", e);
                    }
                };

                let session_handle = rt.spawn(handle);
                sessions_handles
                    .lock()
                    .unwrap()
                    .insert(handle_id, session_handle);
                sessions_events
                    .lock()
                    .unwrap()
                    .insert(handle_id, receiver.clone());
                sessions_sinks.lock().unwrap().insert(handle_id, sender);
            }
        };

        trace!("WS server started listening");

        self.server_handle = Some(self.rt.spawn(listen_loop));

        Ok(())
    }

    pub fn send_message<T: MessageType + Serialize + Clone>(
        &self,
        handle: &ConnectionHandle,
        msg: &T,
    ) {
        let sev = SendEnveloppe {
            message_type: T::message_type().to_string(),
            payload: msg.clone(),
        };
        let payload =
            tokio_tungstenite::tungstenite::Message::Binary(serde_json::to_vec(&sev).unwrap());
        self.send_raw_message(handle, payload)
    }

    pub fn send_raw_message(
        &self,
        handle: &ConnectionHandle,
        msg: tokio_tungstenite::tungstenite::Message,
    ) {
        let client;
        {
            let map = self.sessions_sinks.lock().unwrap();
            client = map.get(&handle.id()).cloned();
        }
        if let Some(channel) = client {
            if let Err(e) = channel.send(msg) {
                warn!(
                    "failed to forward message to client {:?} sink: {:?}",
                    handle, e
                );
            }
        } else {
            warn!(
                "trying to send to a non existing client handle {:?}",
                handle
            );
        }
    }

    pub fn broadcast<T: MessageType + Serialize + Clone>(&self, msg: T) {
        let sev = SendEnveloppe {
            message_type: T::message_type().to_string(),
            payload: msg,
        };
        let payload =
            tokio_tungstenite::tungstenite::Message::Binary(serde_json::to_vec(&sev).unwrap());
        let clients;
        {
            let map = self.sessions_sinks.lock().unwrap();
            clients = map.keys().cloned().collect::<Vec<Uuid>>();
        }
        for c in clients {
            self.send_raw_message(&ConnectionHandle { uuid: c }, payload.clone());
        }
    }
}