endpoint_libs/libs/ws/
session.rs1use eyre::Result;
2use futures::StreamExt;
3use futures::{Sink, SinkExt, Stream};
4use serde_json::Value;
5use std::sync::Arc;
6use tokio::sync::mpsc;
7use tokio_tungstenite::tungstenite::Message;
8use tracing::*;
9
10use crate::libs::error_code::ErrorCode;
11use crate::libs::toolbox::{RequestContext, TOOLBOX};
12
13use super::{request_error_to_resp, WebsocketServer, WsConnection, WsRequestValue};
14pub struct WsClientSession<WS> {
15 conn_info: Arc<WsConnection>,
16 conn: WS,
17 rx: mpsc::Receiver<Message>,
18 server: Arc<WebsocketServer>,
19}
20impl<
21 WS: Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
22 + Stream<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
23 + Unpin,
24 > WsClientSession<WS>
25{
26 pub fn new(
27 conn_info: Arc<WsConnection>,
28 conn: WS,
29 rx: mpsc::Receiver<Message>,
30 server: Arc<WebsocketServer>,
31 ) -> Self {
32 Self {
33 conn_info,
34 conn,
35 rx,
36 server,
37 }
38 }
39
40 pub fn conn(&self) -> &WS {
41 &self.conn
42 }
43 pub async fn run(mut self) {
44 let addr = self.conn_info.address;
45 let conn_id = self.conn_info.connection_id;
46 if let Err(err) = self.run_loop().await {
47 error!(?err, ?addr, ?conn_id, "Failed to run websocket session");
48 }
49
50 }
59 fn handle_message(&mut self, msg: Message) -> Result<bool> {
61 let addr = &self.conn_info.address;
62 let mut context = RequestContext::from_conn(&self.conn_info);
63
64 let obj: Result<WsRequestValue, _> = match msg {
65 Message::Text(t) => {
66 debug!(?addr, "Handling request {}", t);
67
68 serde_json::from_str(&t)
69 }
70 Message::Binary(b) => {
71 debug!(?addr, "Handling request <BIN>");
72 serde_json::from_slice(&b)
73 }
74 Message::Ping(_) => {
75 return Ok(true);
76 }
77 Message::Pong(_) => {
78 return Ok(true);
79 }
80 Message::Close(_) => {
81 info!(?addr, "Receive side terminated");
82 return Ok(false);
83 }
84 _ => {
85 warn!(?addr, "Strange pattern {:?}", msg);
86 return Ok(true);
87 }
88 };
89 let req = match obj {
90 Ok(req) => req,
91 Err(err) => {
92 self.server.toolbox.send(
93 context.connection_id,
94 request_error_to_resp(
95 &context,
96 ErrorCode::new(100400), err.to_string(),
98 ),
99 );
100 return Ok(true);
101 }
102 };
103 context.seq = req.seq;
104 context.method = req.method;
105 context.user_id = self.conn_info.get_user_id();
106
107 let handler = self.server.handlers.get(&req.method);
108 let handler = match handler {
109 Some(handler) => handler,
110 None => {
111 self.server.toolbox.send(
112 context.connection_id,
113 request_error_to_resp(
114 &context,
115 ErrorCode::new(100501), Value::Null,
117 ),
118 );
119 return Ok(true);
120 }
121 };
122 let handler = handler.handler.clone();
123 let toolbox = self.server.toolbox.clone();
124 tokio::task::spawn_local(async move {
125 TOOLBOX
126 .scope(toolbox.clone(), handler.handle(&toolbox, context, req.params))
127 .await;
128 });
129
130 Ok(true)
131 }
132 async fn run_loop(&mut self) -> Result<()> {
133 let conn_id = self.conn_info.connection_id;
134 loop {
135 tokio::select! {
136 msg = self.rx.recv() => {
137 if let Some(msg) = msg {
139 self.send_message(msg).await?;
140 if self.server.config.header_only {
141 break;
142 }
143 } else {
144 info!(?conn_id, "Receive side terminated");
145 break;
146 }
147 }
148 msg = self.conn.next() => {
149 if let Some(msg) = msg {
150 let msg = msg?;
151 if !self.handle_message(msg)? {
153 break;
154 }
155 } else {
156 info!(?conn_id, "Send side terminated");
157 break;
158 }
159 }
160 }
161 }
162
163 Ok(())
164 }
165 async fn send_message(&mut self, msg: Message) -> Result<()> {
166 self.conn.send(msg).await?;
168 Ok(())
169 }
170}