async_tungstenite/
bytes.rs

1//! Provides abstractions to use `AsyncRead` and `AsyncWrite` with
2//! a [`WebSocketStream`](crate::WebSocketStream) or a [`WebSocketSender`](crate::WebSocketSender).
3
4use std::{
5    fmt, io,
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use futures_core::stream::Stream;
11
12use crate::{tungstenite::Bytes, Message, WsError};
13
14/// Treat a websocket [sender](Sender) as an `AsyncWrite` implementation.
15///
16/// Every write sends a binary message. If you want to group writes together, consider wrapping
17/// this with a `BufWriter`.
18pub struct ByteWriter<S> {
19    sender: S,
20    state: State,
21}
22
23impl<S> ByteWriter<S> {
24    /// Create a new `ByteWriter` from a [sender](Sender) that accepts a websocket [`Message`].
25    #[inline(always)]
26    pub fn new(sender: S) -> Self
27    where
28        S: Sender,
29    {
30        Self {
31            sender,
32            state: State::Open,
33        }
34    }
35
36    /// Get the underlying [sender](Sender) back.
37    #[inline(always)]
38    pub fn into_inner(self) -> S {
39        self.sender
40    }
41}
42
43impl<S> fmt::Debug for ByteWriter<S>
44where
45    S: fmt::Debug,
46{
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        f.debug_struct("ByteWriter")
49            .field("sender", &self.sender)
50            .field("state", &"..")
51            .finish()
52    }
53}
54
55enum State {
56    Open,
57    Closing(Option<Message>),
58}
59
60impl State {
61    fn close(&mut self) -> &mut Option<Message> {
62        match self {
63            State::Open => {
64                *self = State::Closing(Some(Message::Close(None)));
65                if let State::Closing(msg) = self {
66                    msg
67                } else {
68                    unreachable!()
69                }
70            }
71            State::Closing(msg) => msg,
72        }
73    }
74}
75
76/// Sends bytes as a websocket [`Message`].
77///
78/// It's implemented for [`WebSocketStream`](crate::WebSocketStream)
79/// and [`WebSocketSender`](crate::WebSocketSender).
80/// It's also implemeted for every `Sink` type that accepts
81/// a websocket [`Message`] and returns [`WsError`] type as
82/// an error when `futures-03-sink` feature is enabled.
83pub trait Sender: private::SealedSender {}
84
85pub(crate) mod private {
86    use super::*;
87
88    pub trait SealedSender {
89        fn poll_write(
90            self: Pin<&mut Self>,
91            cx: &mut Context<'_>,
92            buf: &[u8],
93        ) -> Poll<Result<usize, WsError>>;
94
95        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>>;
96
97        fn poll_close(
98            self: Pin<&mut Self>,
99            cx: &mut Context<'_>,
100            msg: &mut Option<Message>,
101        ) -> Poll<Result<(), WsError>>;
102    }
103
104    impl<S> Sender for S where S: SealedSender {}
105}
106
107#[cfg(feature = "futures-03-sink")]
108impl<S> private::SealedSender for S
109where
110    S: futures_util::Sink<Message, Error = WsError> + Unpin,
111{
112    fn poll_write(
113        mut self: Pin<&mut Self>,
114        cx: &mut Context<'_>,
115        buf: &[u8],
116    ) -> Poll<Result<usize, WsError>> {
117        use std::task::ready;
118
119        ready!(self.as_mut().poll_ready(cx))?;
120        let len = buf.len();
121        self.start_send(Message::binary(buf.to_owned()))?;
122        Poll::Ready(Ok(len))
123    }
124
125    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
126        <S as futures_util::Sink<_>>::poll_flush(self, cx)
127    }
128
129    fn poll_close(
130        self: Pin<&mut Self>,
131        cx: &mut Context<'_>,
132        _: &mut Option<Message>,
133    ) -> Poll<Result<(), WsError>> {
134        <S as futures_util::Sink<_>>::poll_close(self, cx)
135    }
136}
137
138impl<S> futures_io::AsyncWrite for ByteWriter<S>
139where
140    S: Sender + Unpin,
141{
142    fn poll_write(
143        mut self: Pin<&mut Self>,
144        cx: &mut Context<'_>,
145        buf: &[u8],
146    ) -> Poll<io::Result<usize>> {
147        <S as private::SealedSender>::poll_write(Pin::new(&mut self.sender), cx, buf)
148            .map_err(convert_err)
149    }
150
151    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
152        <S as private::SealedSender>::poll_flush(Pin::new(&mut self.sender), cx)
153            .map_err(convert_err)
154    }
155
156    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
157        let me = self.get_mut();
158        let msg = me.state.close();
159        <S as private::SealedSender>::poll_close(Pin::new(&mut me.sender), cx, msg)
160            .map_err(convert_err)
161    }
162}
163
164#[cfg(feature = "tokio-runtime")]
165impl<S> tokio::io::AsyncWrite for ByteWriter<S>
166where
167    S: Sender + Unpin,
168{
169    fn poll_write(
170        mut self: Pin<&mut Self>,
171        cx: &mut Context<'_>,
172        buf: &[u8],
173    ) -> Poll<io::Result<usize>> {
174        <S as private::SealedSender>::poll_write(Pin::new(&mut self.sender), cx, buf)
175            .map_err(convert_err)
176    }
177
178    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
179        <S as private::SealedSender>::poll_flush(Pin::new(&mut self.sender), cx)
180            .map_err(convert_err)
181    }
182
183    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184        let me = self.get_mut();
185        let msg = me.state.close();
186        <S as private::SealedSender>::poll_close(Pin::new(&mut me.sender), cx, msg)
187            .map_err(convert_err)
188    }
189}
190
191/// Treat a websocket [stream](Stream) as an `AsyncRead` implementation.
192///
193/// This also works with any other `Stream` of `Message`, such as a `SplitStream`.
194///
195/// Each read will only return data from one message. If you want to combine data from multiple
196/// messages into one read, consider wrapping this in a `BufReader`.
197#[derive(Debug)]
198pub struct ByteReader<S> {
199    stream: S,
200    bytes: Option<Bytes>,
201}
202
203impl<S> ByteReader<S> {
204    /// Create a new `ByteReader` from a [stream](Stream) that returns a WebSocket [`Message`].
205    #[inline(always)]
206    pub fn new(stream: S) -> Self {
207        Self {
208            stream,
209            bytes: None,
210        }
211    }
212}
213
214fn poll_read_helper<S>(
215    mut s: Pin<&mut ByteReader<S>>,
216    cx: &mut Context<'_>,
217    buf_len: usize,
218) -> Poll<io::Result<Option<Bytes>>>
219where
220    S: Stream<Item = Result<Message, WsError>> + Unpin,
221{
222    Poll::Ready(Ok(Some(match s.bytes {
223        None => match Pin::new(&mut s.stream).poll_next(cx) {
224            Poll::Pending => return Poll::Pending,
225            Poll::Ready(None) => return Poll::Ready(Ok(None)),
226            Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))),
227            Poll::Ready(Some(Ok(msg))) => {
228                let bytes = msg.into_data();
229                if bytes.len() > buf_len {
230                    s.bytes.insert(bytes).split_to(buf_len)
231                } else {
232                    bytes
233                }
234            }
235        },
236        Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len),
237        Some(ref mut bytes) => {
238            let bytes = bytes.clone();
239            s.bytes = None;
240            bytes
241        }
242    })))
243}
244
245impl<S> futures_io::AsyncRead for ByteReader<S>
246where
247    S: Stream<Item = Result<Message, WsError>> + Unpin,
248{
249    fn poll_read(
250        self: Pin<&mut Self>,
251        cx: &mut Context<'_>,
252        buf: &mut [u8],
253    ) -> Poll<io::Result<usize>> {
254        poll_read_helper(self, cx, buf.len()).map_ok(|bytes| {
255            bytes.map_or(0, |bytes| {
256                buf[..bytes.len()].copy_from_slice(&bytes);
257                bytes.len()
258            })
259        })
260    }
261}
262
263#[cfg(feature = "tokio-runtime")]
264impl<S> tokio::io::AsyncRead for ByteReader<S>
265where
266    S: Stream<Item = Result<Message, WsError>> + Unpin,
267{
268    fn poll_read(
269        self: Pin<&mut Self>,
270        cx: &mut Context<'_>,
271        buf: &mut tokio::io::ReadBuf,
272    ) -> Poll<io::Result<()>> {
273        poll_read_helper(self, cx, buf.remaining()).map_ok(|bytes| {
274            if let Some(ref bytes) = bytes {
275                buf.put_slice(bytes);
276            }
277        })
278    }
279}
280
281fn convert_err(e: WsError) -> io::Error {
282    match e {
283        WsError::Io(io) => io,
284        _ => io::Error::new(io::ErrorKind::Other, e),
285    }
286}