async_ws/connection/
writer.rs

1use crate::connection::waker::{new_waker, Wakers};
2use crate::connection::WsConnectionInner;
3use crate::message::WsMessageKind;
4use futures::{AsyncRead, AsyncWrite};
5use std::io;
6use std::ops::DerefMut;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10
11pub struct WsMessageWriter<T: AsyncRead + AsyncWrite + Unpin> {
12    kind: WsMessageKind,
13    parent: Option<Arc<Mutex<(WsConnectionInner<T>, Wakers)>>>,
14}
15
16impl<T: AsyncRead + AsyncWrite + Unpin> WsMessageWriter<T> {
17    pub(crate) fn new(
18        kind: WsMessageKind,
19        parent: &Arc<Mutex<(WsConnectionInner<T>, Wakers)>>,
20    ) -> Self {
21        Self {
22            kind,
23            parent: Some(parent.clone()),
24        }
25    }
26    pub fn kind(&self) -> WsMessageKind {
27        self.kind
28    }
29}
30
31impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WsMessageWriter<T> {
32    fn poll_write(
33        self: Pin<&mut Self>,
34        cx: &mut Context<'_>,
35        buf: &[u8],
36    ) -> Poll<io::Result<usize>> {
37        match &self.parent {
38            Some(parent) => {
39                let waker = new_waker(Arc::downgrade(parent));
40                let mut guard = parent.lock().unwrap();
41                let (inner, wakers) = guard.deref_mut();
42                wakers.writer_waker = Some(cx.waker().clone());
43                let p = inner.poll_write(&mut Context::from_waker(&waker), buf);
44                wakers.wake_on_err(&p);
45                p
46            }
47            None => Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe))),
48        }
49    }
50
51    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
52        match &self.parent {
53            Some(parent) => {
54                let waker = new_waker(Arc::downgrade(parent));
55                let mut guard = parent.lock().unwrap();
56                let (inner, wakers) = guard.deref_mut();
57                wakers.writer_waker = Some(cx.waker().clone());
58                let p = inner.poll_flush(&mut Context::from_waker(&waker));
59                wakers.wake_on_err(&p);
60                p
61            }
62            None => Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe))),
63        }
64    }
65
66    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67        match &self.parent {
68            Some(parent) => {
69                let waker = new_waker(Arc::downgrade(parent));
70                let mut guard = parent.lock().unwrap();
71                let (inner, wakers) = guard.deref_mut();
72                wakers.writer_waker = Some(cx.waker().clone());
73                let p = inner.poll_close_writer(&mut Context::from_waker(&waker));
74                wakers.wake_on_err(&p);
75                if let Poll::Ready(Ok(())) = &p {
76                    inner.detach_writer();
77                    drop(guard);
78                    self.parent.take();
79                }
80                p
81            }
82            None => Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe))),
83        }
84    }
85}
86
87impl<T: AsyncRead + AsyncWrite + Unpin> Drop for WsMessageWriter<T> {
88    fn drop(&mut self) {
89        if let Some(parent) = self.parent.take() {
90            let mut guard = parent.lock().unwrap();
91            let (inner, wakers) = guard.deref_mut();
92            inner.detach_writer();
93            wakers.writer_waker.take();
94        }
95    }
96}