cyfs_lib/ws/
session.rs

1use super::packet::*;
2use super::request::*;
3use async_std::future::TimeoutError;
4use cyfs_base::{BuckyError, BuckyErrorCode, BuckyResult};
5use cyfs_debug::Mutex;
6
7use async_std::channel::{Receiver, Sender};
8use async_std::io::{Read, Write};
9use async_tungstenite::{tungstenite::Message, WebSocketStream};
10use futures::future::Either;
11use futures::future::{AbortHandle, Aborted};
12use futures_util::sink::*;
13use futures_util::StreamExt;
14use http_types::Url;
15use std::marker::Send;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20// ping间隔
21pub const WS_PING_INTERVAL_IN_SECS: Duration = Duration::from_secs(30);
22
23// 连接上收不到任何消息的最大时长
24#[cfg(debug_assertions)]
25pub const WS_ALIVE_TIMEOUT_IN_SECS: Duration = Duration::from_secs(60 * 10);
26
27#[cfg(not(debug_assertions))]
28pub const WS_ALIVE_TIMEOUT_IN_SECS: Duration = Duration::from_secs(60 * 10);
29
30struct WebSocketCancelerState {
31    canceler: Option<AbortHandle>,
32    stopped: bool,
33}
34
35pub struct WebSocketSession {
36    sid: u32,
37
38    // 连接信息
39    conn_info: (SocketAddr, SocketAddr),
40    source: String,
41
42    // 消息发送端
43    tx: Mutex<Option<Sender<Message>>>,
44
45    handler: Box<dyn WebSocketRequestHandler>,
46    requestor: Arc<WebSocketRequestManager>,
47
48    canceler: Mutex<WebSocketCancelerState>,
49}
50
51impl Drop for WebSocketSession {
52    fn drop(&mut self) {
53        warn!("ws session dropped! sid={}", self.sid);
54    }
55}
56
57impl WebSocketSession {
58    pub fn new(
59        sid: u32,
60        source: String,
61        conn_info: (SocketAddr, SocketAddr),
62        handler: Box<dyn WebSocketRequestHandler>,
63    ) -> Self {
64        info!("new ws session: sid={}, source={}", sid, source);
65
66        Self {
67            sid,
68            conn_info,
69            source,
70            tx: Mutex::new(None),
71            handler: handler.clone_handler(),
72            requestor: Arc::new(WebSocketRequestManager::new(handler)),
73            canceler: Mutex::new(WebSocketCancelerState {
74                canceler: None,
75                stopped: false,
76            }),
77        }
78    }
79
80    pub fn is_valid(&self) -> bool {
81        self.requestor.is_session_valid()
82    }
83
84    pub fn requestor(&self) -> &Arc<WebSocketRequestManager> {
85        &self.requestor
86    }
87
88    pub fn sid(&self) -> u32 {
89        self.sid
90    }
91
92    pub fn conn_info(&self) -> &(SocketAddr, SocketAddr) {
93        &self.conn_info
94    }
95
96    pub fn stop(&self) {
97        let canceler = {
98            let mut state = self.canceler.lock().unwrap();
99            state.stopped = true;
100            state.canceler.take()
101        };
102
103        if let Some(canceler) = canceler {
104            info!("will stop ws session: {}", self.sid);
105            canceler.abort();
106        }
107    }
108
109    pub async fn post_msg(&self, msg: Vec<u8>) -> BuckyResult<()> {
110        let tx = self.tx.lock().unwrap().clone();
111        if let Some(tx) = tx {
112            let msg = Message::binary(msg);
113            if let Err(e) = tx.send(msg).await {
114                warn!("session tx already closed! sid={}, {}", self.sid, e);
115                Err(BuckyError::from(BuckyErrorCode::NotConnected))
116            } else {
117                Ok(())
118            }
119        } else {
120            // session已经结束,直接忽略
121            warn!("session tx not exists! sid={}", self.sid);
122            Err(BuckyError::from(BuckyErrorCode::NotConnected))
123        }
124    }
125
126    pub async fn run_client<S>(session: Arc<Self>, service_url: &Url, stream: S) -> BuckyResult<()>
127    where
128        S: Read + Write + Unpin + Send + 'static,
129    {
130        let (stream, _) = async_tungstenite::client_async(service_url, stream)
131            .await
132            .map_err(|e| {
133                let msg = format!("ws connect error: service_url={}, err={}", service_url, e);
134                error!("{}", msg);
135
136                BuckyError::new(BuckyErrorCode::Unknown, msg)
137            })?;
138
139        Self::run(session, stream, false).await
140    }
141
142    pub async fn run_server<S>(session: Arc<Self>, stream: S) -> BuckyResult<()>
143    where
144        S: Read + Write + Unpin + Send + 'static,
145    {
146        let stream = async_tungstenite::accept_async(stream).await.map_err(|e| {
147            let msg = format!("ws accept error: err={}", e);
148            error!("{}", msg);
149
150            BuckyError::new(BuckyErrorCode::Unknown, msg)
151        })?;
152
153        Self::run(session, stream, true).await
154    }
155
156    async fn run<S>(
157        session: Arc<Self>,
158        stream: WebSocketStream<S>,
159        as_server: bool,
160    ) -> BuckyResult<()>
161    where
162        S: Read + Write + Unpin + Send + 'static,
163    {
164        let (tx, rx) = async_std::channel::bounded::<Message>(1024);
165
166        // 保存sender
167        {
168            let mut current = session.tx.lock().unwrap();
169            assert!(current.is_none());
170            *current = Some(tx.clone());
171        }
172
173        // 初始化请求管理器
174        session.requestor.bind_session(session.clone());
175
176        // 正式通知session启动了
177        session.handler.on_session_begin(&session).await;
178
179        let (fut, handle) =
180            futures::future::abortable(Self::run_loop(session.clone(), stream, rx, as_server));
181
182        let stopped = {
183            let mut state = session.canceler.lock().unwrap();
184            assert!(state.canceler.is_none());
185            if !state.stopped {
186                state.canceler = Some(handle);
187            } else {
188                warn!(
189                    "ws session start but already been stopped! sid={}",
190                    session.sid
191                );
192            }
193            state.stopped
194        };
195
196        let ret = if !stopped {
197            match fut.await {
198                Ok(ret) => ret,
199                Err(Aborted) => Err(BuckyError::from(BuckyErrorCode::Aborted)),
200            }
201        } else {
202            Err(BuckyError::from(BuckyErrorCode::Aborted))
203        };
204
205        session.handler.on_session_end(&session).await;
206
207        // 通知session结束
208        session.requestor.unbind_session();
209
210        // 终止发送
211        {
212            let tx = session.tx.lock().unwrap().take();
213            assert!(tx.is_some());
214        }
215
216        ret
217    }
218
219    async fn run_loop<S>(
220        session: Arc<Self>,
221        stream: WebSocketStream<S>,
222        rx: Receiver<Message>,
223        with_ping: bool,
224    ) -> BuckyResult<()>
225    where
226        S: Read + Write + Unpin + Send + 'static,
227    {
228        let (mut outgoing, mut incoming) = stream.split();
229
230        // 记录最后一次活动时间
231        let mut last_alive = Instant::now();
232
233        let ret = loop {
234            // trace!("try recv from ws session: {}", session.sid());
235
236            let send_recv = futures::future::select(incoming.next(), rx.recv());
237            let ret = async_std::future::timeout(WS_PING_INTERVAL_IN_SECS, send_recv).await;
238
239            // trace!("recv sth. from ws session: {}, ret={:?}", session.sid(), ret);
240
241            match ret {
242                Err(TimeoutError { .. }) => {
243                    if with_ping {
244                        let msg = Message::Ping(Vec::new());
245                        if let Err(e) = outgoing.send(msg).await {
246                            let msg =
247                                format!("ws send msg error: sid={}, err={}", session.sid(), e);
248                            warn!("{}", msg);
249
250                            break Err(BuckyError::new(BuckyErrorCode::ConnectionAborted, msg));
251                        }
252                    }
253
254                    // 检查连接是否还在活跃
255                    let now = Instant::now();
256                    if now - last_alive >= WS_ALIVE_TIMEOUT_IN_SECS {
257                        let msg = format!("ws session alive timeout: sid={}", session.sid());
258                        error!("{}", msg);
259
260                        break Err(BuckyError::new(BuckyErrorCode::Timeout, msg));
261                    }
262
263                    continue;
264                }
265                Ok(ret) => {
266                    match ret {
267                        Either::Left((ret, _fut)) => {
268                            if ret.is_none() {
269                                info!(
270                                    "ws recv complete, sid={}, source={}",
271                                    session.sid(),
272                                    session.source
273                                );
274                                break Ok(());
275                            }
276
277                            match ret.unwrap() {
278                                Ok(msg) => {
279                                    if msg.is_close() {
280                                        info!(
281                                            "ws rx closed msg: sid={}, source={}",
282                                            session.sid(),
283                                            session.source
284                                        );
285                                        break Ok(());
286                                    }
287
288                                    // 收到了有效消息,需要更新最后活跃时刻
289                                    last_alive = Instant::now();
290
291                                    // 如果收到ping后,那么需要答复pong
292                                    if msg.is_ping() {
293                                        // 会自动发送pong
294                                        /*
295                                        trace!(
296                                            "ws recv ping: sid={}, is_server={}",
297                                            session.sid(),
298                                            as_server
299                                        );
300                                        */
301                                        continue;
302                                    } else if msg.is_pong() {
303                                        /*
304                                        trace!(
305                                            "ws recv pong: sid={}, is_server={}",
306                                            session.sid(),
307                                            as_server
308                                        );
309                                        */
310                                        continue;
311                                    }
312
313                                    async_std::task::spawn(Self::process_msg(
314                                        session.requestor.clone(),
315                                        msg,
316                                    ));
317                                }
318
319                                Err(e) => {
320                                    let msg =
321                                        format!("ws recv error: sid={}, err={}", session.sid(), e);
322                                    warn!("{}", msg);
323
324                                    break Err(BuckyError::new(
325                                        BuckyErrorCode::ConnectionAborted,
326                                        msg,
327                                    ));
328                                }
329                            }
330                        }
331                        Either::Right((ret, _fut)) => match ret {
332                            Ok(msg) => {
333                                if let Err(e) = outgoing.send(msg).await {
334                                    let msg = format!(
335                                        "ws send msg error: sid={}, err={}",
336                                        session.sid(),
337                                        e
338                                    );
339                                    warn!("{}", msg);
340
341                                    break Err(BuckyError::new(
342                                        BuckyErrorCode::ConnectionAborted,
343                                        msg,
344                                    ));
345                                }
346                            }
347                            Err(e) => {
348                                info!("ws send msg stopped: {}", e);
349                                break Ok(());
350                            }
351                        },
352                    }
353                }
354            }
355        };
356
357        ret
358    }
359
360    async fn process_msg(requestor: Arc<WebSocketRequestManager>, msg: Message) -> BuckyResult<()> {
361        let data = msg.into_data();
362        let packet = WSPacket::decode(data)?;
363
364        match WebSocketRequestManager::on_msg(requestor, packet).await {
365            Ok(_) => {
366                // 处理消息成功了
367            }
368            Err(e) => {
369                error!("process ws request error: {}", e);
370
371                /*
372                // 处理消息失败,不需要终止当前session
373                // 只要包格式正确,session就可以继续使用
374                *has_err.lock().unwrap() = Some(e);
375
376                let mut abort_state = abort_state.lock().unwrap();
377                abort_state.is_abort = true;
378                if let Some(abort_handle) = abort_state.handle.take() {
379                    warn!("will abort ws session: {}", requestor.sid());
380                    abort_handle.abort();
381                }
382                */
383            }
384        }
385
386        Ok(())
387    }
388}