endpoint_libs/libs/ws/
session.rs1use eyre::Result;
2use futures::StreamExt;
3use futures::{Sink, SinkExt, Stream};
4use serde_json::Value;
5use std::collections::HashSet;
6use std::sync::Arc;
7use tokio::sync::mpsc;
8use tokio_tungstenite::tungstenite::Message;
9use tracing::*;
10
11use crate::libs::error_code::ErrorCode;
12use crate::libs::toolbox::{RequestContext, TOOLBOX};
13
14use super::{request_error_to_resp, WebsocketServer, WsConnection, WsRequestValue};
15pub struct WsClientSession<WS> {
16 conn_info: Arc<WsConnection>,
17 conn: WS,
18 rx: mpsc::Receiver<Message>,
19 server: Arc<WebsocketServer>,
20}
21impl<
22 WS: Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
23 + Stream<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
24 + Unpin,
25 > WsClientSession<WS>
26{
27 pub fn new(
28 conn_info: Arc<WsConnection>,
29 conn: WS,
30 rx: mpsc::Receiver<Message>,
31 server: Arc<WebsocketServer>,
32 ) -> Self {
33 Self {
34 conn_info,
35 conn,
36 rx,
37 server,
38 }
39 }
40
41 pub fn conn(&self) -> &WS {
42 &self.conn
43 }
44 pub async fn run(mut self) {
45 let addr = self.conn_info.address;
46 let conn_id = self.conn_info.connection_id;
47 if let Err(err) = self.run_loop().await {
48 error!(?err, ?addr, ?conn_id, "Failed to run websocket session");
49 }
50
51 }
60 fn handle_message(&mut self, msg: Message) -> Result<bool> {
62 let addr = &self.conn_info.address;
63 let mut context = RequestContext::from_conn(&self.conn_info);
64
65 let obj: Result<WsRequestValue, _> = match msg {
66 Message::Text(t) => {
67 debug!(?addr, "Handling request {}", t);
68
69 serde_json::from_str(&t)
70 }
71 Message::Binary(b) => {
72 debug!(?addr, "Handling request <BIN>");
73 serde_json::from_slice(&b)
74 }
75 Message::Ping(_) => {
76 return Ok(true);
77 }
78 Message::Pong(_) => {
79 return Ok(true);
80 }
81 Message::Close(_) => {
82 info!(?addr, "Receive side terminated");
83 return Ok(false);
84 }
85 _ => {
86 warn!(?addr, "Strange pattern {:?}", msg);
87 return Ok(true);
88 }
89 };
90 let req = match obj {
91 Ok(req) => req,
92 Err(err) => {
93 self.server.toolbox.send(
94 context.connection_id,
95 request_error_to_resp(
96 &context,
97 ErrorCode::new(100400), err.to_string(),
99 ),
100 );
101 return Ok(true);
102 }
103 };
104 context.seq = req.seq;
105 context.method = req.method;
106 context.user_id = self.conn_info.get_user_id();
107 context.role = self.conn_info.get_role();
108
109 let Some(allowed_roles) = self.server.allowed_roles.get(&req.method) else {
111 return Ok(true);
112 };
113
114 let allowed = check_roles(context.role, allowed_roles);
115 if !allowed {
116 self.server.toolbox.send(
117 context.connection_id,
118 request_error_to_resp(
119 &context,
120 ErrorCode::new(100403), "Forbidden",
122 ),
123 );
124 return Ok(true);
125 }
126
127 let handler = self.server.handlers.get(&req.method);
128 let handler = match handler {
129 Some(handler) => handler,
130 None => {
131 self.server.toolbox.send(
132 context.connection_id,
133 request_error_to_resp(
134 &context,
135 ErrorCode::new(100501), Value::Null,
137 ),
138 );
139 return Ok(true);
140 }
141 };
142 let handler = handler.handler.clone();
143 let toolbox = self.server.toolbox.clone();
144 tokio::task::spawn_local(async move {
145 TOOLBOX
146 .scope(
147 toolbox.clone(),
148 handler.handle(&toolbox, context, req.params),
149 )
150 .await;
151 });
152
153 Ok(true)
154 }
155 async fn run_loop(&mut self) -> Result<()> {
156 let conn_id = self.conn_info.connection_id;
157 loop {
158 tokio::select! {
159 msg = self.rx.recv() => {
160 if let Some(msg) = msg {
162 self.send_message(msg).await?;
163 if self.server.config.header_only {
164 break;
165 }
166 } else {
167 info!(?conn_id, "Receive side terminated");
168 break;
169 }
170 }
171 msg = self.conn.next() => {
172 if let Some(msg) = msg {
173 let msg = msg?;
174 if !self.handle_message(msg)? {
176 break;
177 }
178 } else {
179 info!(?conn_id, "Send side terminated");
180 break;
181 }
182 }
183 }
184 }
185
186 Ok(())
187 }
188 async fn send_message(&mut self, msg: Message) -> Result<()> {
189 self.conn.send(msg).await?;
191 Ok(())
192 }
193}
194
195fn check_roles(role: u32, allowed_roles: &Option<HashSet<u32>>) -> bool {
196 if let Some(allowed_roles) = allowed_roles {
197 return allowed_roles.contains(&role);
198 }
199 true }
201
202#[cfg(test)]
203mod tests {
204 #[test]
205 fn check_roles_allowed() {
206 use super::check_roles;
207 use std::collections::HashSet;
208
209 let allowed_roles: HashSet<u32> = [1, 2, 3].iter().cloned().collect();
210 assert!(check_roles(1, &Some(allowed_roles.clone())));
211 assert!(check_roles(2, &Some(allowed_roles.clone())));
212 assert!(check_roles(3, &Some(allowed_roles.clone())));
213 assert!(!check_roles(4, &Some(allowed_roles)));
214 }
215
216 #[test]
217 fn check_roles_none() {
218 use super::check_roles;
219
220 assert!(check_roles(1, &None)); }
222
223 #[test]
224 fn check_roles_empty() {
225 use super::check_roles;
226 use std::collections::HashSet;
227
228 let allowed_roles: HashSet<u32> = HashSet::new();
229 assert!(!check_roles(1, &Some(allowed_roles))); }
231}