async_ws/connection/
reader.rs

1use crate::connection::waker::{new_waker, Wakers};
2use crate::connection::WsConnectionInner;
3use crate::message::WsMessageKind;
4use futures::{AsyncRead, AsyncWrite};
5use std::io;
6use std::ops::DerefMut;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10
11pub struct WsMessageReader<T: AsyncRead + AsyncWrite + Unpin> {
12    kind: WsMessageKind,
13    parent: Option<Arc<Mutex<(WsConnectionInner<T>, Wakers)>>>,
14}
15
16impl<T: AsyncRead + AsyncWrite + Unpin> WsMessageReader<T> {
17    pub(crate) fn new(
18        kind: WsMessageKind,
19        parent: &Arc<Mutex<(WsConnectionInner<T>, Wakers)>>,
20    ) -> Self {
21        Self {
22            kind,
23            parent: Some(parent.clone()),
24        }
25    }
26    pub fn kind(&self) -> WsMessageKind {
27        self.kind
28    }
29}
30
31impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for WsMessageReader<T> {
32    fn poll_read(
33        mut self: Pin<&mut Self>,
34        cx: &mut Context<'_>,
35        buf: &mut [u8],
36    ) -> Poll<io::Result<usize>> {
37        if buf.len() == 0 {
38            return Poll::Ready(Ok(0));
39        }
40        if let Some(parent) = &self.parent {
41            let waker = new_waker(Arc::downgrade(parent));
42            let mut guard = parent.lock().unwrap();
43            let (inner, wakers) = guard.deref_mut();
44            wakers.reader_waker = Some(cx.waker().clone());
45            let n = match inner.poll_read(&mut Context::from_waker(&waker), buf) {
46                Poll::Ready(r) => match r {
47                    Ok(r) => r,
48                    Err(err) => {
49                        wakers.wake();
50                        return Poll::Ready(Err(err));
51                    }
52                },
53                Poll::Pending => return Poll::Pending,
54            };
55            if n == 0 {
56                inner.detach_reader();
57                drop(guard);
58                self.parent.take();
59            }
60            return Poll::Ready(Ok(n));
61        }
62        Poll::Ready(Ok(0))
63    }
64}
65
66impl<T: AsyncRead + AsyncWrite + Unpin> Drop for WsMessageReader<T> {
67    fn drop(&mut self) {
68        if let Some(parent) = self.parent.take() {
69            let mut guard = parent.lock().unwrap();
70            let (inner, wakers) = guard.deref_mut();
71            inner.detach_reader();
72            wakers.reader_waker.take();
73        }
74    }
75}