cyfs_lib/ws/
request.rs

1use super::packet::*;
2use super::session::*;
3use cyfs_base::{bucky_time_now, BuckyError, BuckyErrorCode, BuckyResult};
4use cyfs_debug::Mutex;
5
6use async_trait::async_trait;
7use futures::future::{AbortHandle, Abortable};
8use futures::prelude::*;
9use lru_time_cache::LruCache;
10use std::sync::{
11    atomic::{AtomicU32, Ordering},
12    Arc,
13};
14use std::time::Duration;
15
16// ws request的默认超时时间
17const WS_REQUEST_DEFAULT_TIMEOUT: Duration = Duration::from_secs(60 * 10 * 10);
18
19#[async_trait]
20pub trait WebSocketRequestHandler: Send + Sync + 'static {
21    async fn on_request(
22        &self,
23        requestor: Arc<WebSocketRequestManager>,
24        cmd: u16,
25        content: Vec<u8>,
26    ) -> BuckyResult<Option<Vec<u8>>> {
27        self.process_string_request(requestor, cmd, content).await
28    }
29
30    async fn process_string_request(
31        &self,
32        requestor: Arc<WebSocketRequestManager>,
33        cmd: u16,
34        content: Vec<u8>,
35    ) -> BuckyResult<Option<Vec<u8>>> {
36        let content = String::from_utf8(content).map_err(|e| {
37            let msg = format!(
38                "decode ws packet as string failed! sid={}, cmd={}, {}",
39                requestor.sid(),
40                cmd,
41                e
42            );
43            error!("{}", msg);
44
45            BuckyError::new(BuckyErrorCode::InvalidFormat, msg)
46        })?;
47
48        self.on_string_request(requestor, cmd, content)
49            .await
50            .map(|v| v.map(|v| v.into_bytes()))
51    }
52
53    async fn on_string_request(
54        &self,
55        _requestor: Arc<WebSocketRequestManager>,
56        _cmd: u16,
57        _content: String,
58    ) -> BuckyResult<Option<String>> {
59        unimplemented!();
60    }
61
62    async fn on_session_begin(&self, session: &Arc<WebSocketSession>);
63    async fn on_session_end(&self, session: &Arc<WebSocketSession>);
64
65    fn clone_handler(&self) -> Box<dyn WebSocketRequestHandler>;
66}
67
68struct RequestItem {
69    seq: u16,
70    send_tick: u64,
71    resp: Option<BuckyResult<Vec<u8>>>,
72    waker: Option<AbortHandle>,
73}
74
75impl RequestItem {
76    fn new(seq: u16) -> Self {
77        Self {
78            seq,
79            send_tick: bucky_time_now(),
80            resp: None,
81            waker: None,
82        }
83    }
84
85    fn resp(&mut self, code: BuckyErrorCode) {
86        if let Some(waker) = self.waker.take() {
87            if self.resp.is_none() {
88                self.resp = Some(Err(BuckyError::from(code)));
89            } else {
90                warn!(
91                    "end ws request with {:?} but already has resp! send_tick={}, seq={}",
92                    code, self.send_tick, self.seq
93                );
94                unreachable!();
95            }
96
97            waker.abort();
98        }
99    }
100
101    fn timeout(&mut self) {
102        self.resp(BuckyErrorCode::Timeout);
103    }
104
105    fn abort(&mut self) {
106        self.resp(BuckyErrorCode::Aborted);
107    }
108}
109
110impl Drop for RequestItem {
111    fn drop(&mut self) {
112        // info!("will drop ws request! seq={}", self.seq);
113        self.abort();
114    }
115}
116
117struct WebSocketRequestContainer {
118    list: LruCache<u16, Arc<Mutex<RequestItem>>>,
119    next_seq: u16,
120}
121
122impl WebSocketRequestContainer {
123    fn new() -> Self {
124        let list = LruCache::with_expiry_duration(WS_REQUEST_DEFAULT_TIMEOUT);
125
126        Self { list, next_seq: 1 }
127    }
128
129    fn new_request(
130        &mut self,
131        sid: u32,
132    ) -> (
133        u16,
134        Arc<Mutex<RequestItem>>,
135        Vec<(u16, Arc<Mutex<RequestItem>>)>,
136    ) {
137        let seq = self.next_seq;
138        self.next_seq += 1;
139        if self.next_seq == u16::MAX {
140            warn!("ws request seq roll back! sid={}", sid);
141            self.next_seq = 1;
142        }
143
144        let req_item = RequestItem::new(seq);
145
146        let req_item = Arc::new(Mutex::new(req_item));
147        let (old, mut list) = self.list.notify_insert(seq, req_item.clone());
148
149        if let Some(old) = old {
150            // 正常情况下不应该到这里,除非短时间内发了超大量的request,导致seq回环
151            let seq;
152            {
153                let old_item = old.lock().unwrap();
154                error!(
155                    "replace old with same seq! sid={}, seq={}, send_tick={}",
156                    sid, old_item.seq, old_item.send_tick
157                );
158                seq = old_item.seq;
159            }
160
161            // FIXME 先用超时对待
162            list.push((seq, old));
163        }
164
165        (seq, req_item, list)
166    }
167
168    /*
169    fn bind_waker(&mut self, seq: u16, waker: AbortHandle) {
170        let (item, list) = self.list.notify_get_mut(&seq);
171        if let Some(item) = item {
172            let mut item = item.lock().unwrap();
173            assert!(item.waker.is_none());
174            item.waker = Some(waker);
175        } else {
176            unreachable!();
177        }
178        if !list.is_empty() {
179            self.on_timeout(list);
180        }
181    }
182    */
183
184    fn remove_request(&mut self, seq: u16) -> Option<Arc<Mutex<RequestItem>>> {
185        assert!(seq > 0);
186
187        self.list.remove(&seq)
188    }
189
190    fn check_timeout(&mut self) -> Vec<(u16, Arc<Mutex<RequestItem>>)> {
191        // 直接清除过期的元素,不能迭代这些元素,否则会导致这些元素被更新时间戳
192        let (_, list) = self.list.notify_get(&0);
193
194        list
195    }
196
197    // 清空所有元素
198    fn clear(&mut self) {
199        for (seq, item) in self.list.iter() {
200            info!("will abort ws request: seq={}", seq);
201            item.lock().unwrap().abort();
202        }
203
204        self.list.clear();
205    }
206
207    fn on_timeout(sid: u32, list: Vec<(u16, Arc<Mutex<RequestItem>>)>) {
208        for (seq, item) in list {
209            warn!("ws request droped on timeout! sid={}, seq={}", sid, seq);
210
211            let mut item = item.lock().unwrap();
212            if item.waker.is_some() {
213                item.timeout();
214            } else {
215                // timeout的同时收到了应答,发生了竞争
216                warn!(
217                    "ws request timeout but already waked! sid={}, seq={}",
218                    sid, seq
219                );
220            }
221        }
222    }
223}
224
225pub struct WebSocketRequestManager {
226    reqs: Arc<Mutex<WebSocketRequestContainer>>,
227    session: Arc<Mutex<Option<Arc<WebSocketSession>>>>,
228    sid: AtomicU32,
229    monitor_canceler: Arc<Mutex<Option<AbortHandle>>>,
230    handler: Box<dyn WebSocketRequestHandler>,
231}
232
233impl Drop for WebSocketRequestManager {
234    fn drop(&mut self) {
235        let mut monitor_canceler = self.monitor_canceler.lock().unwrap();
236        if let Some(canceler) = monitor_canceler.take() {
237            info!("will stop ws request monitor: sid={}", self.sid());
238            canceler.abort();
239        }
240
241        self.reqs.lock().unwrap().clear();
242    }
243}
244
245impl WebSocketRequestManager {
246    pub fn new(handler: Box<dyn WebSocketRequestHandler>) -> Self {
247        let reqs = WebSocketRequestContainer::new();
248
249        Self {
250            reqs: Arc::new(Mutex::new(reqs)),
251            session: Arc::new(Mutex::new(None)),
252            sid: AtomicU32::new(0),
253            monitor_canceler: Arc::new(Mutex::new(None)),
254            handler,
255        }
256    }
257
258    pub fn sid(&self) -> u32 {
259        self.sid.load(Ordering::Relaxed)
260    }
261
262    pub fn session(&self) -> Option<Arc<WebSocketSession>> {
263        self.session.lock().unwrap().clone()
264    }
265
266    pub fn is_session_valid(&self) -> bool {
267        self.session.lock().unwrap().is_some()
268    }
269
270    pub fn bind_session(&self, session: Arc<WebSocketSession>) {
271        {
272            let mut local = self.session.lock().unwrap();
273            assert!(local.is_none());
274
275            self.sid.store(session.sid(), Ordering::SeqCst);
276            *local = Some(session);
277        }
278
279        self.monitor();
280    }
281
282    pub fn unbind_session(&self) {
283        self.stop_monitor();
284
285        // 强制所有pending的请求为超时
286        self.reqs.lock().unwrap().clear();
287
288        let _ = {
289            let mut local = self.session.lock().unwrap();
290            assert!(local.is_some());
291
292            debug!(
293                "ws request manager unbind session! sid={}",
294                local.as_ref().unwrap().sid()
295            );
296            local.take()
297        };
298    }
299
300    // 收到了msg
301    pub async fn on_msg(
302        requestor: Arc<WebSocketRequestManager>,
303        packet: WSPacket,
304    ) -> BuckyResult<()> {
305        let cmd = packet.header.cmd;
306        if cmd > 0 {
307            let seq = packet.header.seq;
308
309            let resp = requestor
310                .handler
311                .on_request(requestor.clone(), cmd, packet.content)
312                .await?;
313
314            // 如果seq==0,表示不需要应答,那么应该返回None
315            if resp.is_none() {
316                assert!(seq == 0);
317            } else {
318                assert!(seq > 0);
319
320                // 发起应答,cmd需要设置为0
321                let resp_packet = WSPacket::new_from_bytes(seq, 0, resp.unwrap());
322                let buf = resp_packet.encode();
323                requestor.post_to_session(buf).await?;
324            }
325        } else {
326            requestor.on_resp(packet).await?;
327        }
328        Ok(())
329    }
330
331    // 发送一个字符串请求
332    pub async fn post_req(&self, cmd: u16, msg: String) -> BuckyResult<String> {
333        let content = self.post_bytes_req(cmd, msg.into_bytes()).await?;
334
335        match String::from_utf8(content) {
336            Ok(v) => Ok(v),
337            Err(e) => {
338                let msg = format!(
339                    "decode ws resp as string failed! sid={}, cmd={}, {}",
340                    self.sid(),
341                    cmd,
342                    e
343                );
344                error!("{}", msg);
345
346                Err(BuckyError::new(BuckyErrorCode::InvalidFormat, msg))
347            }
348        }
349    }
350
351    // 发送一个请求并等待应答
352    pub async fn post_bytes_req(&self, cmd: u16, msg: Vec<u8>) -> BuckyResult<Vec<u8>> {
353        let (seq, item, timeout_list) = self.reqs.lock().unwrap().new_request(self.sid());
354        assert!(seq > 0);
355
356        // 首先处理超时的
357        if !timeout_list.is_empty() {
358            WebSocketRequestContainer::on_timeout(self.sid(), timeout_list);
359        }
360
361        // Init waker before send the packet
362        let (abort_handle, abort_registration) = AbortHandle::new_pair();
363        {
364            let mut item = item.lock().unwrap();
365            assert!(item.waker.is_none());
366            item.waker = Some(abort_handle);
367        }
368
369        let packet = WSPacket::new_from_bytes(seq, cmd, msg);
370        let buf = packet.encode();
371        if let Err(e) = self.post_to_session(buf).await {
372            self.reqs.lock().unwrap().remove_request(seq);
373
374            return Err(e);
375        }
376
377        // info!("request send complete, now will wait for resp! cmd={}", cmd);
378
379        // 等待唤醒
380        let future = Abortable::new(async_std::future::pending::<()>(), abort_registration);
381        future.await.unwrap_err();
382
383        // 应答
384        let mut item = item.lock().unwrap();
385        if let Some(resp) = item.resp.take() {
386            resp
387        } else {
388            unreachable!(
389                "ws request item waked up without resp: sid={}, seq={}",
390                self.sid(),
391                item.seq
392            );
393        }
394    }
395
396    // 不带应答的请求
397    async fn post_req_without_resp(&self, cmd: u16, msg: String) -> BuckyResult<()> {
398        self.post_bytes_req_without_resp(cmd, msg.into_bytes())
399            .await
400    }
401
402    async fn post_bytes_req_without_resp(&self, cmd: u16, msg: Vec<u8>) -> BuckyResult<()> {
403        let packet = WSPacket::new_from_bytes(0, cmd, msg);
404        let buf = packet.encode();
405
406        self.post_to_session(buf).await
407    }
408
409    // 收到了应答
410    async fn on_resp(&self, packet: WSPacket) -> BuckyResult<()> {
411        assert!(packet.header.cmd == 0);
412        assert!(packet.header.seq > 0);
413
414        let seq = packet.header.seq;
415        let ret = self.reqs.lock().unwrap().remove_request(seq);
416        if ret.is_none() {
417            let msg = format!(
418                "ws request recv resp but already been removed! sid={}, seq={}",
419                self.sid(),
420                seq
421            );
422
423            warn!("{}", msg);
424            return Err(BuckyError::new(BuckyErrorCode::NotFound, msg));
425        }
426
427        let item = ret.unwrap();
428
429        // 保存应答并唤醒
430        let mut item = item.lock().unwrap();
431        if let Some(waker) = item.waker.take() {
432            if item.resp.is_none() {
433                item.resp = Some(Ok(packet.content));
434            } else {
435                warn!(
436                    "ws request recv resp but already has local resp! sid={}, seq={}",
437                    self.sid(),
438                    seq
439                );
440                unreachable!();
441            }
442
443            drop(item);
444
445            waker.abort();
446        } else {
447            warn!(
448                "ws request recv resp but already timeout! sid={}, seq={}",
449                self.sid(),
450                seq
451            );
452        }
453
454        Ok(())
455    }
456
457    async fn post_to_session(&self, msg: Vec<u8>) -> BuckyResult<()> {
458        let ret = self.session.lock().unwrap().clone();
459        if ret.is_none() {
460            warn!("ws session not exists: {}", self.sid());
461            return Err(BuckyError::from(BuckyErrorCode::NotConnected));
462        }
463
464        let session = ret.unwrap();
465        session.post_msg(msg).await.map_err(|e| e)?;
466        Ok(())
467    }
468
469    fn monitor(&self) {
470        let reqs = self.reqs.clone();
471        let sid = self.sid();
472
473        let (fut, handle) = future::abortable(async move {
474            let mut interval = async_std::stream::interval(Duration::from_secs(15));
475            while let Some(_) = interval.next().await {
476                let list = reqs.lock().unwrap().check_timeout();
477
478                if !list.is_empty() {
479                    WebSocketRequestContainer::on_timeout(sid, list);
480                }
481            }
482        });
483
484        // 保存canceler,用以session结束时候取消
485        let mut monitor_canceler = self.monitor_canceler.lock().unwrap();
486        assert!(monitor_canceler.is_none());
487        *monitor_canceler = Some(handle);
488
489        async_std::task::spawn(async move {
490            match fut.await {
491                Ok(_) => {
492                    info!("ws request monitor complete, sid={}", sid);
493                    // 不应该到这里,只有被abort一种可能
494                    unreachable!();
495                }
496                Err(_aborted) => {
497                    info!("ws request monitor breaked, sid={}", sid);
498                }
499            };
500        });
501    }
502
503    fn stop_monitor(&self) {
504        let mut monitor_canceler = self.monitor_canceler.lock().unwrap();
505        if let Some(canceler) = monitor_canceler.take() {
506            debug!("will stop ws request monitor: sid={}", self.sid());
507            canceler.abort();
508        }
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use futures::future::{AbortHandle, Abortable};
515
516    async fn test_wakeup() {
517        let (abort_handle, abort_registration) = AbortHandle::new_pair();
518
519        abort_handle.abort();
520
521        async_std::task::spawn(async move {
522            async_std::task::sleep(std::time::Duration::from_secs(2)).await;
523            abort_handle.abort();
524        });
525
526        // 等待唤醒
527        let future = Abortable::new(async_std::future::pending::<()>(), abort_registration);
528        future.await.unwrap_err();
529
530        println!("future wait complete!");
531
532        async_std::task::sleep(std::time::Duration::from_secs(3)).await;
533    }
534
535    #[test]
536    fn test() {
537        async_std::task::block_on(async move {
538            test_wakeup().await;
539        })
540    }
541}