endpoint_libs/libs/ws/
session.rs1use eyre::Result;
2use futures::StreamExt;
3use futures::{Sink, SinkExt, Stream};
4use serde_json::Value;
5use std::collections::HashSet;
6use std::sync::Arc;
7use tokio::sync::mpsc;
8use tokio_tungstenite::tungstenite::Message;
9use tracing::*;
10
11use crate::libs::error_code::ErrorCode;
12use crate::libs::toolbox::{RequestContext, TOOLBOX};
13
14use super::{request_error_to_resp, WebsocketServer, WsConnection, WsRequestValue};
15pub struct WsClientSession<WS> {
16 conn_info: Arc<WsConnection>,
17 conn: WS,
18 rx: mpsc::Receiver<Message>,
19 server: Arc<WebsocketServer>,
20}
21impl<
22 WS: Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
23 + Stream<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
24 + Unpin,
25 > WsClientSession<WS>
26{
27 pub fn new(
28 conn_info: Arc<WsConnection>,
29 conn: WS,
30 rx: mpsc::Receiver<Message>,
31 server: Arc<WebsocketServer>,
32 ) -> Self {
33 Self {
34 conn_info,
35 conn,
36 rx,
37 server,
38 }
39 }
40
41 pub fn conn(&self) -> &WS {
42 &self.conn
43 }
44 pub async fn run(mut self) {
45 let addr = self.conn_info.address;
46 let conn_id = self.conn_info.connection_id;
47 if let Err(err) = self.run_loop().await {
48 error!(?err, ?addr, ?conn_id, "Failed to run websocket session");
49 }
50
51 }
60 fn handle_message(&mut self, msg: Message) -> Result<bool> {
62 let addr = &self.conn_info.address;
63 let mut context = RequestContext::from_conn(&self.conn_info);
64
65 let obj: Result<WsRequestValue, _> = match msg {
66 Message::Text(t) => {
67 debug!(?addr, "Handling request {}", t);
68
69 serde_json::from_str(&t)
70 }
71 Message::Binary(b) => {
72 debug!(?addr, "Handling request <BIN>");
73 serde_json::from_slice(&b)
74 }
75 Message::Ping(_) => {
76 return Ok(true);
77 }
78 Message::Pong(_) => {
79 return Ok(true);
80 }
81 Message::Close(_) => {
82 info!(?addr, "Receive side terminated");
83 return Ok(false);
84 }
85 _ => {
86 warn!(?addr, "Strange pattern {:?}", msg);
87 return Ok(true);
88 }
89 };
90 let req = match obj {
91 Ok(req) => req,
92 Err(err) => {
93 self.server.toolbox.send(
94 context.connection_id,
95 request_error_to_resp(&context, ErrorCode::BAD_REQUEST, err.to_string()),
96 );
97 return Ok(true);
98 }
99 };
100 context.seq = req.seq;
101 context.method = req.method;
102 context.user_id = self.conn_info.get_user_id();
103 context.roles = Arc::new(self.conn_info.get_roles());
104
105 let Some(allowed_roles) = self.server.allowed_roles.get(&req.method) else {
107 return Ok(true);
108 };
109
110 let allowed = check_roles(&context.roles, allowed_roles);
111 if !allowed {
112 self.server.toolbox.send(
113 context.connection_id,
114 request_error_to_resp(&context, ErrorCode::FORBIDDEN, "Forbidden"),
115 );
116 return Ok(true);
117 }
118
119 let handler = self.server.handlers.get(&req.method);
120 let handler = match handler {
121 Some(handler) => handler,
122 None => {
123 self.server.toolbox.send(
124 context.connection_id,
125 request_error_to_resp(&context, ErrorCode::NOT_IMPLEMENTED, Value::Null),
126 );
127 return Ok(true);
128 }
129 };
130 let handler = handler.handler.clone();
131 let toolbox = self.server.toolbox.clone();
132 tokio::task::spawn_local(async move {
133 TOOLBOX
134 .scope(
135 toolbox.clone(),
136 handler.handle(&toolbox, context, req.params),
137 )
138 .await;
139 });
140
141 Ok(true)
142 }
143 async fn run_loop(&mut self) -> Result<()> {
144 let conn_id = self.conn_info.connection_id;
145 loop {
146 tokio::select! {
147 msg = self.rx.recv() => {
148 if let Some(msg) = msg {
150 self.send_message(msg).await?;
151 if self.server.config.header_only {
152 break;
153 }
154 } else {
155 info!(?conn_id, "Receive side terminated");
156 break;
157 }
158 }
159 msg = self.conn.next() => {
160 if let Some(msg) = msg {
161 let msg = msg?;
162 if !self.handle_message(msg)? {
164 break;
165 }
166 } else {
167 info!(?conn_id, "Send side terminated");
168 break;
169 }
170 }
171 }
172 }
173
174 Ok(())
175 }
176 async fn send_message(&mut self, msg: Message) -> Result<()> {
177 self.conn.send(msg).await?;
179 Ok(())
180 }
181}
182
183fn check_roles(actual_roles: &[u32], allowed_roles: &HashSet<u32>) -> bool {
184 if allowed_roles.is_empty() {
185 return false; }
187 for role in actual_roles.iter() {
188 if allowed_roles.contains(role) {
189 return true; }
191 }
192 false }
194
195#[cfg(test)]
196mod tests {
197 #[test]
198 fn check_roles_allowed() {
199 use super::check_roles;
200 use std::collections::HashSet;
201
202 let allowed_roles: HashSet<u32> = [1, 2, 3].iter().cloned().collect();
203 assert!(check_roles(&[1], &allowed_roles.clone()));
204 assert!(check_roles(&[2], &allowed_roles.clone()));
205 assert!(check_roles(&[1, 2], &allowed_roles.clone()));
206 assert!(check_roles(&[4, 2], &allowed_roles.clone()));
207
208 assert!(!check_roles(&[4], &allowed_roles.clone()));
209 }
210
211 #[test]
212 fn check_roles_empty() {
213 use super::check_roles;
214 use std::collections::HashSet;
215
216 let allowed_roles: HashSet<u32> = HashSet::new();
217 assert!(!check_roles(&[1], &allowed_roles));
218 }
219}