endpoint-libs 1.7.8

Common dependencies to be used with Pathscale projects, projects that use [endpoint-gen](https://github.com/pathscale/endpoint-gen), and projects that use honey_id-types.
Documentation
use eyre::Result;
use serde_json::Value;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::*;

use crate::libs::ws::WsMessage as Message;

use crate::libs::error_code::ErrorCode;
use crate::libs::toolbox::{RequestContext, TOOLBOX};

use super::{
    StreamError, WebsocketServer, WsConnection, WsRequestValue, WsStream, request_error_to_resp,
};

pub struct WsClientSession {
    conn_info: Arc<WsConnection>,
    conn: Box<dyn WsStream>,
    rx: mpsc::Receiver<Message>,
    server: Arc<WebsocketServer>,
}

impl WsClientSession {
    pub fn new(
        conn_info: Arc<WsConnection>,
        conn: Box<dyn WsStream>,
        rx: mpsc::Receiver<Message>,
        server: Arc<WebsocketServer>,
    ) -> Self {
        Self {
            conn_info,
            conn,
            rx,
            server,
        }
    }

    pub fn conn(&self) -> &dyn WsStream {
        self.conn.as_ref()
    }

    pub async fn run(mut self) {
        let addr = self.conn_info.address;
        let conn_id = self.conn_info.connection_id;
        if let Err(err) = self.run_loop().await {
            error!(
                ws_server = true,
                ?err,
                ?addr,
                ?conn_id,
                "Failed to run websocket session"
            );
        }
    }

    fn handle_message(&mut self, msg: Message) -> Result<bool> {
        let addr = &self.conn_info.address;
        let mut context = RequestContext::from_conn(&self.conn_info);

        #[allow(unreachable_patterns)]
        let obj: Result<WsRequestValue, _> = match msg {
            Message::Text(t) => {
                debug!(ws_server = true, ?addr, "Handling request {}", t);
                serde_json::from_str(&t)
            }
            Message::Binary(b) => {
                debug!(ws_server = true, ?addr, "Handling request <BIN>");
                serde_json::from_slice(&b)
            }
            Message::Ping(_) => {
                return Ok(true);
            }
            Message::Pong(_) => {
                return Ok(true);
            }
            Message::Close(_) => {
                debug!(ws_server = true, ?addr, "Receive side terminated");
                return Ok(false);
            }
            _ => {
                warn!(ws_server = true, ?addr, "Strange pattern {:?}", msg);
                return Ok(true);
            }
        };
        let req = match obj {
            Ok(req) => req,
            Err(err) => {
                self.server.toolbox.send(
                    context.connection_id,
                    request_error_to_resp(&context, ErrorCode::BAD_REQUEST, err.to_string()),
                );
                return Ok(true);
            }
        };
        context.seq = req.seq;
        context.method = req.method;
        context.user_id = self.conn_info.get_user_id();
        context.roles = self.conn_info.get_roles();

        let Some(endpoint) = self.server.handlers.get(&req.method) else {
            self.server.toolbox.send(
                context.connection_id,
                request_error_to_resp(&context, ErrorCode::NOT_IMPLEMENTED, Value::Null),
            );
            return Ok(true);
        };

        if !check_roles(&context.roles, &endpoint.allowed_roles) {
            self.server.toolbox.send(
                context.connection_id,
                request_error_to_resp(&context, ErrorCode::FORBIDDEN, "Forbidden"),
            );
            return Ok(true);
        }

        let handler = endpoint.handler.clone();
        let toolbox = self.server.toolbox.clone();
        tokio::task::spawn_local(async move {
            TOOLBOX
                .scope(
                    toolbox.clone(),
                    handler.handle(&toolbox, context, req.params),
                )
                .await;
        });

        Ok(true)
    }

    async fn run_loop(&mut self) -> Result<()> {
        let conn_id = self.conn_info.connection_id;
        loop {
            while let Ok(msg) = self.rx.try_recv() {
                if !self.send_message(msg).await {
                    return Ok(());
                }
                if self.server.config.header_only {
                    return Ok(());
                }
            }

            tokio::select! {
                msg = self.rx.recv() => {
                    if let Some(msg) = msg {
                        if !self.send_message(msg).await {
                            break;
                        }
                        if self.server.config.header_only {
                            break;
                        }
                    } else {
                        debug!(ws_server = true, ?conn_id, "Outbound channel closed");
                        break;
                    }
                }
                msg = self.conn.recv() => {
                    if let Some(msg_result) = msg {
                        let msg = match msg_result {
                            Ok(m) => m,
                            Err(StreamError::Closed) => {
                                debug!(ws_server = true, ?conn_id, "WS receive: connection closed");
                                break;
                            }
                            Err(StreamError::Protocol(e)) => {
                                warn!(ws_server = true, ?conn_id, err=%e, "WS protocol error on receive");
                                break;
                            }
                            Err(StreamError::WriteBufferFull) => {
                                warn!(ws_server = true, ?conn_id, "WS write buffer full on receive");
                                break;
                            }
                            Err(StreamError::Other(e)) => {
                                error!(ws_server = true, ?conn_id, err=%e, "WS receive error");
                                break;
                            }
                        };
                        if !self.handle_message(msg)? {
                            break;
                        }
                    } else {
                        debug!(ws_server = true, ?conn_id, "Inbound stream ended");
                        break;
                    }
                }
            }
        }

        Ok(())
    }

    async fn send_message(&mut self, msg: Message) -> bool {
        let conn_id = self.conn_info.connection_id;
        match self.conn.send(msg).await {
            Ok(()) => true,
            Err(StreamError::Closed) => {
                debug!(ws_server = true, ?conn_id, "WS send: connection closed");
                false
            }
            Err(StreamError::WriteBufferFull) => {
                warn!(ws_server = true, ?conn_id, "WS send: write buffer full");
                false
            }
            Err(StreamError::Protocol(e)) => {
                warn!(ws_server = true, ?conn_id, err=%e, "WS send: protocol error");
                false
            }
            Err(StreamError::Other(e)) => {
                error!(ws_server = true, ?conn_id, err=%e, "WS send error");
                false
            }
        }
    }
}

fn check_roles(actual_roles: &[u32], allowed_roles: &HashSet<u32>) -> bool {
    if allowed_roles.is_empty() || actual_roles.is_empty() {
        return false;
    }
    actual_roles.iter().any(|role| allowed_roles.contains(role))
}

#[cfg(test)]
mod tests {
    #[test]
    fn check_roles_allowed() {
        use super::check_roles;
        use std::collections::HashSet;

        let allowed_roles: HashSet<u32> = [1, 2, 3].iter().cloned().collect();
        assert!(check_roles(&[1], &allowed_roles.clone()));
        assert!(check_roles(&[2], &allowed_roles.clone()));
        assert!(check_roles(&[1, 2], &allowed_roles.clone()));
        assert!(check_roles(&[4, 2], &allowed_roles.clone()));

        assert!(!check_roles(&[4], &allowed_roles.clone()));
    }

    #[test]
    fn check_roles_empty() {
        use super::check_roles;
        use std::collections::HashSet;

        let allowed_roles: HashSet<u32> = HashSet::new();
        assert!(!check_roles(&[1], &allowed_roles));
    }
}