1use super::packet::*;
2use super::session::*;
3use cyfs_base::{bucky_time_now, BuckyError, BuckyErrorCode, BuckyResult};
4use cyfs_debug::Mutex;
5
6use async_trait::async_trait;
7use futures::future::{AbortHandle, Abortable};
8use futures::prelude::*;
9use lru_time_cache::LruCache;
10use std::sync::{
11 atomic::{AtomicU32, Ordering},
12 Arc,
13};
14use std::time::Duration;
15
16const WS_REQUEST_DEFAULT_TIMEOUT: Duration = Duration::from_secs(60 * 10 * 10);
18
19#[async_trait]
20pub trait WebSocketRequestHandler: Send + Sync + 'static {
21 async fn on_request(
22 &self,
23 requestor: Arc<WebSocketRequestManager>,
24 cmd: u16,
25 content: Vec<u8>,
26 ) -> BuckyResult<Option<Vec<u8>>> {
27 self.process_string_request(requestor, cmd, content).await
28 }
29
30 async fn process_string_request(
31 &self,
32 requestor: Arc<WebSocketRequestManager>,
33 cmd: u16,
34 content: Vec<u8>,
35 ) -> BuckyResult<Option<Vec<u8>>> {
36 let content = String::from_utf8(content).map_err(|e| {
37 let msg = format!(
38 "decode ws packet as string failed! sid={}, cmd={}, {}",
39 requestor.sid(),
40 cmd,
41 e
42 );
43 error!("{}", msg);
44
45 BuckyError::new(BuckyErrorCode::InvalidFormat, msg)
46 })?;
47
48 self.on_string_request(requestor, cmd, content)
49 .await
50 .map(|v| v.map(|v| v.into_bytes()))
51 }
52
53 async fn on_string_request(
54 &self,
55 _requestor: Arc<WebSocketRequestManager>,
56 _cmd: u16,
57 _content: String,
58 ) -> BuckyResult<Option<String>> {
59 unimplemented!();
60 }
61
62 async fn on_session_begin(&self, session: &Arc<WebSocketSession>);
63 async fn on_session_end(&self, session: &Arc<WebSocketSession>);
64
65 fn clone_handler(&self) -> Box<dyn WebSocketRequestHandler>;
66}
67
68struct RequestItem {
69 seq: u16,
70 send_tick: u64,
71 resp: Option<BuckyResult<Vec<u8>>>,
72 waker: Option<AbortHandle>,
73}
74
75impl RequestItem {
76 fn new(seq: u16) -> Self {
77 Self {
78 seq,
79 send_tick: bucky_time_now(),
80 resp: None,
81 waker: None,
82 }
83 }
84
85 fn resp(&mut self, code: BuckyErrorCode) {
86 if let Some(waker) = self.waker.take() {
87 if self.resp.is_none() {
88 self.resp = Some(Err(BuckyError::from(code)));
89 } else {
90 warn!(
91 "end ws request with {:?} but already has resp! send_tick={}, seq={}",
92 code, self.send_tick, self.seq
93 );
94 unreachable!();
95 }
96
97 waker.abort();
98 }
99 }
100
101 fn timeout(&mut self) {
102 self.resp(BuckyErrorCode::Timeout);
103 }
104
105 fn abort(&mut self) {
106 self.resp(BuckyErrorCode::Aborted);
107 }
108}
109
110impl Drop for RequestItem {
111 fn drop(&mut self) {
112 self.abort();
114 }
115}
116
117struct WebSocketRequestContainer {
118 list: LruCache<u16, Arc<Mutex<RequestItem>>>,
119 next_seq: u16,
120}
121
122impl WebSocketRequestContainer {
123 fn new() -> Self {
124 let list = LruCache::with_expiry_duration(WS_REQUEST_DEFAULT_TIMEOUT);
125
126 Self { list, next_seq: 1 }
127 }
128
129 fn new_request(
130 &mut self,
131 sid: u32,
132 ) -> (
133 u16,
134 Arc<Mutex<RequestItem>>,
135 Vec<(u16, Arc<Mutex<RequestItem>>)>,
136 ) {
137 let seq = self.next_seq;
138 self.next_seq += 1;
139 if self.next_seq == u16::MAX {
140 warn!("ws request seq roll back! sid={}", sid);
141 self.next_seq = 1;
142 }
143
144 let req_item = RequestItem::new(seq);
145
146 let req_item = Arc::new(Mutex::new(req_item));
147 let (old, mut list) = self.list.notify_insert(seq, req_item.clone());
148
149 if let Some(old) = old {
150 let seq;
152 {
153 let old_item = old.lock().unwrap();
154 error!(
155 "replace old with same seq! sid={}, seq={}, send_tick={}",
156 sid, old_item.seq, old_item.send_tick
157 );
158 seq = old_item.seq;
159 }
160
161 list.push((seq, old));
163 }
164
165 (seq, req_item, list)
166 }
167
168 fn remove_request(&mut self, seq: u16) -> Option<Arc<Mutex<RequestItem>>> {
185 assert!(seq > 0);
186
187 self.list.remove(&seq)
188 }
189
190 fn check_timeout(&mut self) -> Vec<(u16, Arc<Mutex<RequestItem>>)> {
191 let (_, list) = self.list.notify_get(&0);
193
194 list
195 }
196
197 fn clear(&mut self) {
199 for (seq, item) in self.list.iter() {
200 info!("will abort ws request: seq={}", seq);
201 item.lock().unwrap().abort();
202 }
203
204 self.list.clear();
205 }
206
207 fn on_timeout(sid: u32, list: Vec<(u16, Arc<Mutex<RequestItem>>)>) {
208 for (seq, item) in list {
209 warn!("ws request droped on timeout! sid={}, seq={}", sid, seq);
210
211 let mut item = item.lock().unwrap();
212 if item.waker.is_some() {
213 item.timeout();
214 } else {
215 warn!(
217 "ws request timeout but already waked! sid={}, seq={}",
218 sid, seq
219 );
220 }
221 }
222 }
223}
224
225pub struct WebSocketRequestManager {
226 reqs: Arc<Mutex<WebSocketRequestContainer>>,
227 session: Arc<Mutex<Option<Arc<WebSocketSession>>>>,
228 sid: AtomicU32,
229 monitor_canceler: Arc<Mutex<Option<AbortHandle>>>,
230 handler: Box<dyn WebSocketRequestHandler>,
231}
232
233impl Drop for WebSocketRequestManager {
234 fn drop(&mut self) {
235 let mut monitor_canceler = self.monitor_canceler.lock().unwrap();
236 if let Some(canceler) = monitor_canceler.take() {
237 info!("will stop ws request monitor: sid={}", self.sid());
238 canceler.abort();
239 }
240
241 self.reqs.lock().unwrap().clear();
242 }
243}
244
245impl WebSocketRequestManager {
246 pub fn new(handler: Box<dyn WebSocketRequestHandler>) -> Self {
247 let reqs = WebSocketRequestContainer::new();
248
249 Self {
250 reqs: Arc::new(Mutex::new(reqs)),
251 session: Arc::new(Mutex::new(None)),
252 sid: AtomicU32::new(0),
253 monitor_canceler: Arc::new(Mutex::new(None)),
254 handler,
255 }
256 }
257
258 pub fn sid(&self) -> u32 {
259 self.sid.load(Ordering::Relaxed)
260 }
261
262 pub fn session(&self) -> Option<Arc<WebSocketSession>> {
263 self.session.lock().unwrap().clone()
264 }
265
266 pub fn is_session_valid(&self) -> bool {
267 self.session.lock().unwrap().is_some()
268 }
269
270 pub fn bind_session(&self, session: Arc<WebSocketSession>) {
271 {
272 let mut local = self.session.lock().unwrap();
273 assert!(local.is_none());
274
275 self.sid.store(session.sid(), Ordering::SeqCst);
276 *local = Some(session);
277 }
278
279 self.monitor();
280 }
281
282 pub fn unbind_session(&self) {
283 self.stop_monitor();
284
285 self.reqs.lock().unwrap().clear();
287
288 let _ = {
289 let mut local = self.session.lock().unwrap();
290 assert!(local.is_some());
291
292 debug!(
293 "ws request manager unbind session! sid={}",
294 local.as_ref().unwrap().sid()
295 );
296 local.take()
297 };
298 }
299
300 pub async fn on_msg(
302 requestor: Arc<WebSocketRequestManager>,
303 packet: WSPacket,
304 ) -> BuckyResult<()> {
305 let cmd = packet.header.cmd;
306 if cmd > 0 {
307 let seq = packet.header.seq;
308
309 let resp = requestor
310 .handler
311 .on_request(requestor.clone(), cmd, packet.content)
312 .await?;
313
314 if resp.is_none() {
316 assert!(seq == 0);
317 } else {
318 assert!(seq > 0);
319
320 let resp_packet = WSPacket::new_from_bytes(seq, 0, resp.unwrap());
322 let buf = resp_packet.encode();
323 requestor.post_to_session(buf).await?;
324 }
325 } else {
326 requestor.on_resp(packet).await?;
327 }
328 Ok(())
329 }
330
331 pub async fn post_req(&self, cmd: u16, msg: String) -> BuckyResult<String> {
333 let content = self.post_bytes_req(cmd, msg.into_bytes()).await?;
334
335 match String::from_utf8(content) {
336 Ok(v) => Ok(v),
337 Err(e) => {
338 let msg = format!(
339 "decode ws resp as string failed! sid={}, cmd={}, {}",
340 self.sid(),
341 cmd,
342 e
343 );
344 error!("{}", msg);
345
346 Err(BuckyError::new(BuckyErrorCode::InvalidFormat, msg))
347 }
348 }
349 }
350
351 pub async fn post_bytes_req(&self, cmd: u16, msg: Vec<u8>) -> BuckyResult<Vec<u8>> {
353 let (seq, item, timeout_list) = self.reqs.lock().unwrap().new_request(self.sid());
354 assert!(seq > 0);
355
356 if !timeout_list.is_empty() {
358 WebSocketRequestContainer::on_timeout(self.sid(), timeout_list);
359 }
360
361 let (abort_handle, abort_registration) = AbortHandle::new_pair();
363 {
364 let mut item = item.lock().unwrap();
365 assert!(item.waker.is_none());
366 item.waker = Some(abort_handle);
367 }
368
369 let packet = WSPacket::new_from_bytes(seq, cmd, msg);
370 let buf = packet.encode();
371 if let Err(e) = self.post_to_session(buf).await {
372 self.reqs.lock().unwrap().remove_request(seq);
373
374 return Err(e);
375 }
376
377 let future = Abortable::new(async_std::future::pending::<()>(), abort_registration);
381 future.await.unwrap_err();
382
383 let mut item = item.lock().unwrap();
385 if let Some(resp) = item.resp.take() {
386 resp
387 } else {
388 unreachable!(
389 "ws request item waked up without resp: sid={}, seq={}",
390 self.sid(),
391 item.seq
392 );
393 }
394 }
395
396 async fn post_req_without_resp(&self, cmd: u16, msg: String) -> BuckyResult<()> {
398 self.post_bytes_req_without_resp(cmd, msg.into_bytes())
399 .await
400 }
401
402 async fn post_bytes_req_without_resp(&self, cmd: u16, msg: Vec<u8>) -> BuckyResult<()> {
403 let packet = WSPacket::new_from_bytes(0, cmd, msg);
404 let buf = packet.encode();
405
406 self.post_to_session(buf).await
407 }
408
409 async fn on_resp(&self, packet: WSPacket) -> BuckyResult<()> {
411 assert!(packet.header.cmd == 0);
412 assert!(packet.header.seq > 0);
413
414 let seq = packet.header.seq;
415 let ret = self.reqs.lock().unwrap().remove_request(seq);
416 if ret.is_none() {
417 let msg = format!(
418 "ws request recv resp but already been removed! sid={}, seq={}",
419 self.sid(),
420 seq
421 );
422
423 warn!("{}", msg);
424 return Err(BuckyError::new(BuckyErrorCode::NotFound, msg));
425 }
426
427 let item = ret.unwrap();
428
429 let mut item = item.lock().unwrap();
431 if let Some(waker) = item.waker.take() {
432 if item.resp.is_none() {
433 item.resp = Some(Ok(packet.content));
434 } else {
435 warn!(
436 "ws request recv resp but already has local resp! sid={}, seq={}",
437 self.sid(),
438 seq
439 );
440 unreachable!();
441 }
442
443 drop(item);
444
445 waker.abort();
446 } else {
447 warn!(
448 "ws request recv resp but already timeout! sid={}, seq={}",
449 self.sid(),
450 seq
451 );
452 }
453
454 Ok(())
455 }
456
457 async fn post_to_session(&self, msg: Vec<u8>) -> BuckyResult<()> {
458 let ret = self.session.lock().unwrap().clone();
459 if ret.is_none() {
460 warn!("ws session not exists: {}", self.sid());
461 return Err(BuckyError::from(BuckyErrorCode::NotConnected));
462 }
463
464 let session = ret.unwrap();
465 session.post_msg(msg).await.map_err(|e| e)?;
466 Ok(())
467 }
468
469 fn monitor(&self) {
470 let reqs = self.reqs.clone();
471 let sid = self.sid();
472
473 let (fut, handle) = future::abortable(async move {
474 let mut interval = async_std::stream::interval(Duration::from_secs(15));
475 while let Some(_) = interval.next().await {
476 let list = reqs.lock().unwrap().check_timeout();
477
478 if !list.is_empty() {
479 WebSocketRequestContainer::on_timeout(sid, list);
480 }
481 }
482 });
483
484 let mut monitor_canceler = self.monitor_canceler.lock().unwrap();
486 assert!(monitor_canceler.is_none());
487 *monitor_canceler = Some(handle);
488
489 async_std::task::spawn(async move {
490 match fut.await {
491 Ok(_) => {
492 info!("ws request monitor complete, sid={}", sid);
493 unreachable!();
495 }
496 Err(_aborted) => {
497 info!("ws request monitor breaked, sid={}", sid);
498 }
499 };
500 });
501 }
502
503 fn stop_monitor(&self) {
504 let mut monitor_canceler = self.monitor_canceler.lock().unwrap();
505 if let Some(canceler) = monitor_canceler.take() {
506 debug!("will stop ws request monitor: sid={}", self.sid());
507 canceler.abort();
508 }
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use futures::future::{AbortHandle, Abortable};
515
516 async fn test_wakeup() {
517 let (abort_handle, abort_registration) = AbortHandle::new_pair();
518
519 abort_handle.abort();
520
521 async_std::task::spawn(async move {
522 async_std::task::sleep(std::time::Duration::from_secs(2)).await;
523 abort_handle.abort();
524 });
525
526 let future = Abortable::new(async_std::future::pending::<()>(), abort_registration);
528 future.await.unwrap_err();
529
530 println!("future wait complete!");
531
532 async_std::task::sleep(std::time::Duration::from_secs(3)).await;
533 }
534
535 #[test]
536 fn test() {
537 async_std::task::block_on(async move {
538 test_wakeup().await;
539 })
540 }
541}