endpoint_libs/libs/ws/
session.rs

1use eyre::Result;
2use futures::StreamExt;
3use futures::{Sink, SinkExt, Stream};
4use serde_json::Value;
5use std::collections::HashSet;
6use std::sync::Arc;
7use tokio::sync::mpsc;
8use tokio_tungstenite::tungstenite::Message;
9use tracing::*;
10
11use crate::libs::error_code::ErrorCode;
12use crate::libs::toolbox::{RequestContext, TOOLBOX};
13
14use super::{request_error_to_resp, WebsocketServer, WsConnection, WsRequestValue};
15pub struct WsClientSession<WS> {
16    conn_info: Arc<WsConnection>,
17    conn: WS,
18    rx: mpsc::Receiver<Message>,
19    server: Arc<WebsocketServer>,
20}
21impl<
22        WS: Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
23            + Stream<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
24            + Unpin,
25    > WsClientSession<WS>
26{
27    pub fn new(
28        conn_info: Arc<WsConnection>,
29        conn: WS,
30        rx: mpsc::Receiver<Message>,
31        server: Arc<WebsocketServer>,
32    ) -> Self {
33        Self {
34            conn_info,
35            conn,
36            rx,
37            server,
38        }
39    }
40
41    pub fn conn(&self) -> &WS {
42        &self.conn
43    }
44    pub async fn run(mut self) {
45        let addr = self.conn_info.address;
46        let conn_id = self.conn_info.connection_id;
47        if let Err(err) = self.run_loop().await {
48            error!(?err, ?addr, ?conn_id, "Failed to run websocket session");
49        }
50
51        // if let Err(err) = self.handler.handle_drop().await {
52        //     error!(
53        //         ?err,
54        //         ?addr,
55        //         ?conn_id,
56        //         "Failed to handle websocket session drop"
57        //     );
58        // }
59    }
60    // if continue, returns true
61    fn handle_message(&mut self, msg: Message) -> Result<bool> {
62        let addr = &self.conn_info.address;
63        let mut context = RequestContext::from_conn(&self.conn_info);
64
65        let obj: Result<WsRequestValue, _> = match msg {
66            Message::Text(t) => {
67                debug!(?addr, "Handling request {}", t);
68
69                serde_json::from_str(&t)
70            }
71            Message::Binary(b) => {
72                debug!(?addr, "Handling request <BIN>");
73                serde_json::from_slice(&b)
74            }
75            Message::Ping(_) => {
76                return Ok(true);
77            }
78            Message::Pong(_) => {
79                return Ok(true);
80            }
81            Message::Close(_) => {
82                info!(?addr, "Receive side terminated");
83                return Ok(false);
84            }
85            _ => {
86                warn!(?addr, "Strange pattern {:?}", msg);
87                return Ok(true);
88            }
89        };
90        let req = match obj {
91            Ok(req) => req,
92            Err(err) => {
93                self.server.toolbox.send(
94                    context.connection_id,
95                    request_error_to_resp(&context, ErrorCode::BAD_REQUEST, err.to_string()),
96                );
97                return Ok(true);
98            }
99        };
100        context.seq = req.seq;
101        context.method = req.method;
102        context.user_id = self.conn_info.get_user_id();
103        context.roles = Arc::new(self.conn_info.get_roles());
104
105        // Check roles
106        let Some(allowed_roles) = self.server.allowed_roles.get(&req.method) else {
107            return Ok(true);
108        };
109
110        let allowed = check_roles(&context.roles, allowed_roles);
111        if !allowed {
112            self.server.toolbox.send(
113                context.connection_id,
114                request_error_to_resp(&context, ErrorCode::FORBIDDEN, "Forbidden"),
115            );
116            return Ok(true);
117        }
118
119        let handler = self.server.handlers.get(&req.method);
120        let handler = match handler {
121            Some(handler) => handler,
122            None => {
123                self.server.toolbox.send(
124                    context.connection_id,
125                    request_error_to_resp(&context, ErrorCode::NOT_IMPLEMENTED, Value::Null),
126                );
127                return Ok(true);
128            }
129        };
130        let handler = handler.handler.clone();
131        let toolbox = self.server.toolbox.clone();
132        tokio::task::spawn_local(async move {
133            TOOLBOX
134                .scope(
135                    toolbox.clone(),
136                    handler.handle(&toolbox, context, req.params),
137                )
138                .await;
139        });
140
141        Ok(true)
142    }
143    async fn run_loop(&mut self) -> Result<()> {
144        let conn_id = self.conn_info.connection_id;
145        loop {
146            tokio::select! {
147                msg = self.rx.recv() => {
148                    // info!(?conn_id, ?msg, "Received message to send");
149                    if let Some(msg) = msg {
150                        self.send_message(msg).await?;
151                        if self.server.config.header_only {
152                            break;
153                        }
154                    } else {
155                        info!(?conn_id, "Receive side terminated");
156                        break;
157                    }
158                }
159                msg = self.conn.next() => {
160                    if let Some(msg) = msg {
161                        let msg = msg?;
162                        // info!(?conn_id, ?msg, "Received message");
163                        if !self.handle_message(msg)? {
164                            break;
165                        }
166                    } else {
167                        info!(?conn_id, "Send side terminated");
168                        break;
169                    }
170                }
171            }
172        }
173
174        Ok(())
175    }
176    async fn send_message(&mut self, msg: Message) -> Result<()> {
177        // info!(?msg, "Sending message");
178        self.conn.send(msg).await?;
179        Ok(())
180    }
181}
182
183fn check_roles(actual_roles: &[u32], allowed_roles: &HashSet<u32>) -> bool {
184    if allowed_roles.is_empty() {
185        return false; // No roles are allowed
186    }
187    for role in actual_roles.iter() {
188        if allowed_roles.contains(role) {
189            return true; // At least one role is allowed
190        }
191    }
192    false // No roles matched
193}
194
195#[cfg(test)]
196mod tests {
197    #[test]
198    fn check_roles_allowed() {
199        use super::check_roles;
200        use std::collections::HashSet;
201
202        let allowed_roles: HashSet<u32> = [1, 2, 3].iter().cloned().collect();
203        assert!(check_roles(&[1], &allowed_roles.clone()));
204        assert!(check_roles(&[2], &allowed_roles.clone()));
205        assert!(check_roles(&[1, 2], &allowed_roles.clone()));
206        assert!(check_roles(&[4, 2], &allowed_roles.clone()));
207
208        assert!(!check_roles(&[4], &allowed_roles.clone()));
209    }
210
211    #[test]
212    fn check_roles_empty() {
213        use super::check_roles;
214        use std::collections::HashSet;
215
216        let allowed_roles: HashSet<u32> = HashSet::new();
217        assert!(!check_roles(&[1], &allowed_roles));
218    }
219}