endpoint_libs/libs/ws/
session.rs

1use eyre::Result;
2use futures::StreamExt;
3use futures::{Sink, SinkExt, Stream};
4use serde_json::Value;
5use std::sync::Arc;
6use tokio::sync::mpsc;
7use tokio_tungstenite::tungstenite::Message;
8use tracing::*;
9
10use crate::libs::error_code::ErrorCode;
11use crate::libs::toolbox::{RequestContext, TOOLBOX};
12
13use super::{request_error_to_resp, WebsocketServer, WsConnection, WsRequestValue};
14pub struct WsClientSession<WS> {
15    conn_info: Arc<WsConnection>,
16    conn: WS,
17    rx: mpsc::Receiver<Message>,
18    server: Arc<WebsocketServer>,
19}
20impl<
21        WS: Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
22            + Stream<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
23            + Unpin,
24    > WsClientSession<WS>
25{
26    pub fn new(
27        conn_info: Arc<WsConnection>,
28        conn: WS,
29        rx: mpsc::Receiver<Message>,
30        server: Arc<WebsocketServer>,
31    ) -> Self {
32        Self {
33            conn_info,
34            conn,
35            rx,
36            server,
37        }
38    }
39
40    pub fn conn(&self) -> &WS {
41        &self.conn
42    }
43    pub async fn run(mut self) {
44        let addr = self.conn_info.address;
45        let conn_id = self.conn_info.connection_id;
46        if let Err(err) = self.run_loop().await {
47            error!(?err, ?addr, ?conn_id, "Failed to run websocket session");
48        }
49
50        // if let Err(err) = self.handler.handle_drop().await {
51        //     error!(
52        //         ?err,
53        //         ?addr,
54        //         ?conn_id,
55        //         "Failed to handle websocket session drop"
56        //     );
57        // }
58    }
59    // if continue, returns true
60    fn handle_message(&mut self, msg: Message) -> Result<bool> {
61        let addr = &self.conn_info.address;
62        let mut context = RequestContext::from_conn(&self.conn_info);
63
64        let obj: Result<WsRequestValue, _> = match msg {
65            Message::Text(t) => {
66                debug!(?addr, "Handling request {}", t);
67
68                serde_json::from_str(&t)
69            }
70            Message::Binary(b) => {
71                debug!(?addr, "Handling request <BIN>");
72                serde_json::from_slice(&b)
73            }
74            Message::Ping(_) => {
75                return Ok(true);
76            }
77            Message::Pong(_) => {
78                return Ok(true);
79            }
80            Message::Close(_) => {
81                info!(?addr, "Receive side terminated");
82                return Ok(false);
83            }
84            _ => {
85                warn!(?addr, "Strange pattern {:?}", msg);
86                return Ok(true);
87            }
88        };
89        let req = match obj {
90            Ok(req) => req,
91            Err(err) => {
92                self.server.toolbox.send(
93                    context.connection_id,
94                    request_error_to_resp(
95                        &context,
96                        ErrorCode::new(100400), // BadRequest
97                        err.to_string(),
98                    ),
99                );
100                return Ok(true);
101            }
102        };
103        context.seq = req.seq;
104        context.method = req.method;
105        context.user_id = self.conn_info.get_user_id();
106
107        let handler = self.server.handlers.get(&req.method);
108        let handler = match handler {
109            Some(handler) => handler,
110            None => {
111                self.server.toolbox.send(
112                    context.connection_id,
113                    request_error_to_resp(
114                        &context,
115                        ErrorCode::new(100501), // Not Implemented
116                        Value::Null,
117                    ),
118                );
119                return Ok(true);
120            }
121        };
122        let handler = handler.handler.clone();
123        let toolbox = self.server.toolbox.clone();
124        tokio::task::spawn_local(async move {
125            TOOLBOX
126                .scope(toolbox.clone(), handler.handle(&toolbox, context, req.params))
127                .await;
128        });
129
130        Ok(true)
131    }
132    async fn run_loop(&mut self) -> Result<()> {
133        let conn_id = self.conn_info.connection_id;
134        loop {
135            tokio::select! {
136                msg = self.rx.recv() => {
137                    // info!(?conn_id, ?msg, "Received message to send");
138                    if let Some(msg) = msg {
139                        self.send_message(msg).await?;
140                        if self.server.config.header_only {
141                            break;
142                        }
143                    } else {
144                        info!(?conn_id, "Receive side terminated");
145                        break;
146                    }
147                }
148                msg = self.conn.next() => {
149                    if let Some(msg) = msg {
150                        let msg = msg?;
151                        // info!(?conn_id, ?msg, "Received message");
152                        if !self.handle_message(msg)? {
153                            break;
154                        }
155                    } else {
156                        info!(?conn_id, "Send side terminated");
157                        break;
158                    }
159                }
160            }
161        }
162
163        Ok(())
164    }
165    async fn send_message(&mut self, msg: Message) -> Result<()> {
166        // info!(?msg, "Sending message");
167        self.conn.send(msg).await?;
168        Ok(())
169    }
170}