flnet 0.7.2

Network setup and communication
Documentation
use async_trait::async_trait;
use futures::{
    lock::Mutex,
    stream::{SplitSink, SplitStream},
    SinkExt, StreamExt,
};
use std::sync::Arc;
use tokio::{
    net::{TcpListener, TcpStream},
    select,
    sync::oneshot,
    task::JoinHandle,
};
use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream};

use crate::websocket::{WSError, WSSError, WSServerInput, WSServerMessage, WSServerOutput};
use flmodules::broker::{Broker, Subsystem, SubsystemHandler};

pub struct WebSocketServer {
    connections: Arc<Mutex<Vec<WSConnection>>>,
    conn_thread: JoinHandle<()>,
}

impl WebSocketServer {
    pub async fn new(port: u16) -> Result<Broker<WSServerMessage>, WSSError> {
        let server = TcpListener::bind(format!("0.0.0.0:{}", port)).await?;
        let connections = Arc::new(Mutex::new(Vec::new()));
        let connections_cl = Arc::clone(&connections);
        let mut broker = Broker::new();
        let mut broker_cl = broker.clone();
        let conn_thread = tokio::spawn(async move {
            let mut connection_id = 0;
            loop {
                if let Ok((stream, _)) = server.accept().await {
                    let broker_cl2 = broker_cl.clone();
                    match WSConnection::new(stream, broker_cl2, connection_id).await {
                        Ok(conn) => {
                            log::trace!("Got new connection");
                            connections_cl.lock().await.push(conn);
                            broker_cl
                                .emit_msg(WSServerMessage::Output(WSServerOutput::NewConnection(
                                    connection_id,
                                )))
                                .expect("Error sending connect message");
                        }
                        Err(e) => log::error!("Error while getting connection: {:?}", e),
                    }
                    connection_id += 1;
                }
            }
        });

        broker
            .add_subsystem(Subsystem::Handler(Box::new(WebSocketServer {
                connections,
                conn_thread,
            })))
            .await?;

        Ok(broker)
    }
}

#[async_trait]
impl SubsystemHandler<WSServerMessage> for WebSocketServer {
    async fn messages(&mut self, from_broker: Vec<WSServerMessage>) -> Vec<WSServerMessage> {
        for msg in from_broker {
            if let WSServerMessage::Input(msg_in) = msg {
                match msg_in {
                    WSServerInput::Message(id, msg) => {
                        let mut connections = self.connections.lock().await;
                        if let Some(conn) = connections.get_mut(id) {
                            if let Err(e) = conn.send(msg).await {
                                log::error!("Error while sending: {e}");
                                conn.close();
                            }
                        }
                    }
                    WSServerInput::Close(id) => {
                        let mut connections = self.connections.lock().await;
                        if let Some(conn) = connections.get_mut(id) {
                            conn.close();
                        }
                    }
                    WSServerInput::Stop => {
                        log::warn!("Stopping thread");
                        self.conn_thread.abort();
                        return vec![WSServerMessage::Output(WSServerOutput::Stopped)];
                    }
                }
            }
        }
        vec![]
    }
}

pub struct WSConnection {
    websocket: SplitSink<WebSocketStream<TcpStream>, Message>,
    tx: Option<oneshot::Sender<bool>>,
}

impl WSConnection {
    async fn new(
        stream: TcpStream,
        broker: Broker<WSServerMessage>,
        id: usize,
    ) -> Result<WSConnection, WSError> {
        let websocket = accept_async(stream)
            .await
            .map_err(|e| WSError::Underlying(e.to_string()))?;
        let (wr, rd) = websocket.split();
        let (tx, rx) = oneshot::channel();
        let uwsc = WSConnection {
            websocket: wr,
            tx: Some(tx),
        };

        WSConnection::loop_read(broker, rd, rx, id).await;
        Ok(uwsc)
    }

    async fn loop_read(
        mut broker: Broker<WSServerMessage>,
        mut ws: SplitStream<WebSocketStream<TcpStream>>,
        mut rx: oneshot::Receiver<bool>,
        id: usize,
    ) {
        tokio::spawn(async move {
            loop {
                select! {
                    _ = (&mut rx) => {
                        broker
                        .emit_msg(WSServerMessage::Output(WSServerOutput::Disconnection(id)))
                        .expect("While sending message to broker.");
                        return;
                    },
                    ws_out = ws.next() =>
                        if let Some(msg_ws) = ws_out {
                            if let Some(out) = match msg_ws {
                                Ok(msg) => match msg {
                                    Message::Text(s) => {
                                        Some(WSServerMessage::Output(WSServerOutput::Message(id, s)))
                                    }
                                    Message::Close(_) => {
                                        Some(WSServerMessage::Output(WSServerOutput::Disconnection(id)))
                                    }
                                    _ => None,
                                },
                                Err(e) => {
                                    log::warn!("Closing connection because of error: {e:?}");
                                    Some(WSServerMessage::Output(WSServerOutput::Disconnection(id)))
                                }
                            } {
                                broker
                                    .emit_msg(out.clone())
                                    .expect("While sending message to broker.");
                                if matches!(
                                    out,
                                    WSServerMessage::Output(WSServerOutput::Disconnection(_))
                                ) {
                                    return;
                                }
                            }
                        }
                }
            }
        });
    }

    async fn send(&mut self, msg: String) -> Result<(), WSError> {
        self.websocket
            .send(Message::Text(msg))
            .await
            .map_err(|e| WSError::Underlying(e.to_string()))
    }

    fn close(&mut self) {
        if let Some(tx) = self.tx.take() {
            tx.send(true)
                .err()
                .map(|_| log::error!("Closing Websocket failed"));
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::web_socket_client::WebSocketClient;
    use crate::websocket::{WSClientInput, WSClientMessage, WSClientOutput};
    use flarch::start_logging;
    use std::sync::mpsc::Receiver;

    async fn send_client_server(
        client: &mut Broker<WSClientMessage>,
        server_tap: &Receiver<WSServerMessage>,
        ch_index: usize,
        txt: String,
    ) {
        client
            .emit_msg_dest(
                Destination::NoTap,
                WSClientInput::Message(txt.clone()).into(),
            )
            .unwrap();
        assert_eq!(
            server_tap.recv().unwrap(),
            WSServerMessage::Output(WSServerOutput::Message(ch_index, txt))
        );
    }

    async fn send_server_client(
        server: &mut Broker<WSServerMessage>,
        client_tap: &Receiver<WSClientMessage>,
        ch_index: usize,
        txt: String,
    ) {
        server
            .emit_msg_dest(
                Destination::NoTap,
                WSServerInput::Message(ch_index, txt.clone()).into(),
            )
            .unwrap();
        assert_eq!(
            client_tap.recv().unwrap(),
            WSClientOutput::Message(txt).into()
        );
    }

    #[tokio::test(flavor = "multi_thread")]
    async fn test_server() {
        start_logging();
        let mut server = WebSocketServer::new(8080).await.unwrap();
        let (server_tap, _) = server.get_tap_sync().await.unwrap();

        let mut client1 = WebSocketClient::connect("ws://localhost:8080")
            .await
            .unwrap();
        let (client1_tap, _) = client1.get_tap_sync().await.unwrap();
        log::debug!("Server reply from client 1: {:?}", server_tap.recv());

        let mut client2 = WebSocketClient::connect("ws://localhost:8080")
            .await
            .unwrap();
        let (client2_tap, _) = client2.get_tap_sync().await.unwrap();
        log::debug!("Server reply from client 2: {:?}", server_tap.recv());

        for _ in 1..=2 {
            send_client_server(&mut client1, &server_tap, 0, "Hello 1".to_string()).await;
            send_server_client(&mut server, &client1_tap, 0, "there 1".to_string()).await;

            send_client_server(&mut client2, &server_tap, 1, "Hello 2".to_string()).await;
            send_server_client(&mut server, &client2_tap, 1, "there 2".to_string()).await;
        }

        client1
            .emit_msg(WSClientInput::Disconnect.into())
            .await
            .unwrap();
        client2
            .emit_msg(WSClientInput::Disconnect.into())
            .await
            .unwrap();
        server
            .emit_msg(WSServerMessage::Input(WSServerInput::Stop))
            .await
            .unwrap();
    }
}