endpoint_libs/libs/ws/
session.rs

1use eyre::Result;
2use futures::StreamExt;
3use futures::{Sink, SinkExt, Stream};
4use serde_json::Value;
5use std::collections::HashSet;
6use std::sync::Arc;
7use tokio::sync::mpsc;
8use tokio_tungstenite::tungstenite::Message;
9use tracing::*;
10
11use crate::libs::error_code::ErrorCode;
12use crate::libs::toolbox::{RequestContext, TOOLBOX};
13
14use super::{request_error_to_resp, WebsocketServer, WsConnection, WsRequestValue};
15pub struct WsClientSession<WS> {
16    conn_info: Arc<WsConnection>,
17    conn: WS,
18    rx: mpsc::Receiver<Message>,
19    server: Arc<WebsocketServer>,
20}
21impl<
22        WS: Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
23            + Stream<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
24            + Unpin,
25    > WsClientSession<WS>
26{
27    pub fn new(
28        conn_info: Arc<WsConnection>,
29        conn: WS,
30        rx: mpsc::Receiver<Message>,
31        server: Arc<WebsocketServer>,
32    ) -> Self {
33        Self {
34            conn_info,
35            conn,
36            rx,
37            server,
38        }
39    }
40
41    pub fn conn(&self) -> &WS {
42        &self.conn
43    }
44    pub async fn run(mut self) {
45        let addr = self.conn_info.address;
46        let conn_id = self.conn_info.connection_id;
47        if let Err(err) = self.run_loop().await {
48            error!(?err, ?addr, ?conn_id, "Failed to run websocket session");
49        }
50
51        // if let Err(err) = self.handler.handle_drop().await {
52        //     error!(
53        //         ?err,
54        //         ?addr,
55        //         ?conn_id,
56        //         "Failed to handle websocket session drop"
57        //     );
58        // }
59    }
60    // if continue, returns true
61    fn handle_message(&mut self, msg: Message) -> Result<bool> {
62        let addr = &self.conn_info.address;
63        let mut context = RequestContext::from_conn(&self.conn_info);
64
65        let obj: Result<WsRequestValue, _> = match msg {
66            Message::Text(t) => {
67                debug!(?addr, "Handling request {}", t);
68
69                serde_json::from_str(&t)
70            }
71            Message::Binary(b) => {
72                debug!(?addr, "Handling request <BIN>");
73                serde_json::from_slice(&b)
74            }
75            Message::Ping(_) => {
76                return Ok(true);
77            }
78            Message::Pong(_) => {
79                return Ok(true);
80            }
81            Message::Close(_) => {
82                info!(?addr, "Receive side terminated");
83                return Ok(false);
84            }
85            _ => {
86                warn!(?addr, "Strange pattern {:?}", msg);
87                return Ok(true);
88            }
89        };
90        let req = match obj {
91            Ok(req) => req,
92            Err(err) => {
93                self.server.toolbox.send(
94                    context.connection_id,
95                    request_error_to_resp(
96                        &context,
97                        ErrorCode::new(100400), // BadRequest
98                        err.to_string(),
99                    ),
100                );
101                return Ok(true);
102            }
103        };
104        context.seq = req.seq;
105        context.method = req.method;
106        context.user_id = self.conn_info.get_user_id();
107        context.role = self.conn_info.get_role();
108
109        // Check roles
110        let Some(allowed_roles) = self.server.allowed_roles.get(&req.method) else {
111            return Ok(true);
112        };
113
114        let allowed = check_roles(context.role, allowed_roles);
115        if !allowed {
116            self.server.toolbox.send(
117                context.connection_id,
118                request_error_to_resp(
119                    &context,
120                    ErrorCode::new(100403), // Forbidden
121                    "Forbidden",
122                ),
123            );
124            return Ok(true);
125        }
126
127        let handler = self.server.handlers.get(&req.method);
128        let handler = match handler {
129            Some(handler) => handler,
130            None => {
131                self.server.toolbox.send(
132                    context.connection_id,
133                    request_error_to_resp(
134                        &context,
135                        ErrorCode::new(100501), // Not Implemented
136                        Value::Null,
137                    ),
138                );
139                return Ok(true);
140            }
141        };
142        let handler = handler.handler.clone();
143        let toolbox = self.server.toolbox.clone();
144        tokio::task::spawn_local(async move {
145            TOOLBOX
146                .scope(
147                    toolbox.clone(),
148                    handler.handle(&toolbox, context, req.params),
149                )
150                .await;
151        });
152
153        Ok(true)
154    }
155    async fn run_loop(&mut self) -> Result<()> {
156        let conn_id = self.conn_info.connection_id;
157        loop {
158            tokio::select! {
159                msg = self.rx.recv() => {
160                    // info!(?conn_id, ?msg, "Received message to send");
161                    if let Some(msg) = msg {
162                        self.send_message(msg).await?;
163                        if self.server.config.header_only {
164                            break;
165                        }
166                    } else {
167                        info!(?conn_id, "Receive side terminated");
168                        break;
169                    }
170                }
171                msg = self.conn.next() => {
172                    if let Some(msg) = msg {
173                        let msg = msg?;
174                        // info!(?conn_id, ?msg, "Received message");
175                        if !self.handle_message(msg)? {
176                            break;
177                        }
178                    } else {
179                        info!(?conn_id, "Send side terminated");
180                        break;
181                    }
182                }
183            }
184        }
185
186        Ok(())
187    }
188    async fn send_message(&mut self, msg: Message) -> Result<()> {
189        // info!(?msg, "Sending message");
190        self.conn.send(msg).await?;
191        Ok(())
192    }
193}
194
195fn check_roles(role: u32, allowed_roles: &Option<HashSet<u32>>) -> bool {
196    if let Some(allowed_roles) = allowed_roles {
197        return allowed_roles.contains(&role);
198    }
199    true // If roles are None, allow all
200}
201
202#[cfg(test)]
203mod tests {
204    #[test]
205    fn check_roles_allowed() {
206        use super::check_roles;
207        use std::collections::HashSet;
208
209        let allowed_roles: HashSet<u32> = [1, 2, 3].iter().cloned().collect();
210        assert!(check_roles(1, &Some(allowed_roles.clone())));
211        assert!(check_roles(2, &Some(allowed_roles.clone())));
212        assert!(check_roles(3, &Some(allowed_roles.clone())));
213        assert!(!check_roles(4, &Some(allowed_roles)));
214    }
215
216    #[test]
217    fn check_roles_none() {
218        use super::check_roles;
219
220        assert!(check_roles(1, &None)); // If roles are None, allow all
221    }
222
223    #[test]
224    fn check_roles_empty() {
225        use super::check_roles;
226        use std::collections::HashSet;
227
228        let allowed_roles: HashSet<u32> = HashSet::new();
229        assert!(!check_roles(1, &Some(allowed_roles))); // Empty roles means no roles are allowed
230    }
231}