1use std::{
5 fmt, io,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use futures_core::stream::Stream;
11
12use crate::{tungstenite::Bytes, Message, WsError};
13
14pub struct ByteWriter<S> {
19 sender: S,
20 state: State,
21}
22
23impl<S> ByteWriter<S> {
24 #[inline(always)]
26 pub fn new(sender: S) -> Self
27 where
28 S: Sender,
29 {
30 Self {
31 sender,
32 state: State::Open,
33 }
34 }
35
36 #[inline(always)]
38 pub fn into_inner(self) -> S {
39 self.sender
40 }
41}
42
43impl<S> fmt::Debug for ByteWriter<S>
44where
45 S: fmt::Debug,
46{
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 f.debug_struct("ByteWriter")
49 .field("sender", &self.sender)
50 .field("state", &"..")
51 .finish()
52 }
53}
54
55enum State {
56 Open,
57 Closing(Option<Message>),
58}
59
60impl State {
61 fn close(&mut self) -> &mut Option<Message> {
62 match self {
63 State::Open => {
64 *self = State::Closing(Some(Message::Close(None)));
65 if let State::Closing(msg) = self {
66 msg
67 } else {
68 unreachable!()
69 }
70 }
71 State::Closing(msg) => msg,
72 }
73 }
74}
75
76pub trait Sender: private::SealedSender {}
84
85pub(crate) mod private {
86 use super::*;
87
88 pub trait SealedSender {
89 fn poll_write(
90 self: Pin<&mut Self>,
91 cx: &mut Context<'_>,
92 buf: &[u8],
93 ) -> Poll<Result<usize, WsError>>;
94
95 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>>;
96
97 fn poll_close(
98 self: Pin<&mut Self>,
99 cx: &mut Context<'_>,
100 msg: &mut Option<Message>,
101 ) -> Poll<Result<(), WsError>>;
102 }
103
104 impl<S> Sender for S where S: SealedSender {}
105}
106
107#[cfg(feature = "futures-03-sink")]
108impl<S> private::SealedSender for S
109where
110 S: futures_util::Sink<Message, Error = WsError> + Unpin,
111{
112 fn poll_write(
113 mut self: Pin<&mut Self>,
114 cx: &mut Context<'_>,
115 buf: &[u8],
116 ) -> Poll<Result<usize, WsError>> {
117 use std::task::ready;
118
119 ready!(self.as_mut().poll_ready(cx))?;
120 let len = buf.len();
121 self.start_send(Message::binary(buf.to_owned()))?;
122 Poll::Ready(Ok(len))
123 }
124
125 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
126 <S as futures_util::Sink<_>>::poll_flush(self, cx)
127 }
128
129 fn poll_close(
130 self: Pin<&mut Self>,
131 cx: &mut Context<'_>,
132 _: &mut Option<Message>,
133 ) -> Poll<Result<(), WsError>> {
134 <S as futures_util::Sink<_>>::poll_close(self, cx)
135 }
136}
137
138impl<S> futures_io::AsyncWrite for ByteWriter<S>
139where
140 S: Sender + Unpin,
141{
142 fn poll_write(
143 mut self: Pin<&mut Self>,
144 cx: &mut Context<'_>,
145 buf: &[u8],
146 ) -> Poll<io::Result<usize>> {
147 <S as private::SealedSender>::poll_write(Pin::new(&mut self.sender), cx, buf)
148 .map_err(convert_err)
149 }
150
151 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
152 <S as private::SealedSender>::poll_flush(Pin::new(&mut self.sender), cx)
153 .map_err(convert_err)
154 }
155
156 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
157 let me = self.get_mut();
158 let msg = me.state.close();
159 <S as private::SealedSender>::poll_close(Pin::new(&mut me.sender), cx, msg)
160 .map_err(convert_err)
161 }
162}
163
164#[cfg(feature = "tokio-runtime")]
165impl<S> tokio::io::AsyncWrite for ByteWriter<S>
166where
167 S: Sender + Unpin,
168{
169 fn poll_write(
170 mut self: Pin<&mut Self>,
171 cx: &mut Context<'_>,
172 buf: &[u8],
173 ) -> Poll<io::Result<usize>> {
174 <S as private::SealedSender>::poll_write(Pin::new(&mut self.sender), cx, buf)
175 .map_err(convert_err)
176 }
177
178 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
179 <S as private::SealedSender>::poll_flush(Pin::new(&mut self.sender), cx)
180 .map_err(convert_err)
181 }
182
183 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184 let me = self.get_mut();
185 let msg = me.state.close();
186 <S as private::SealedSender>::poll_close(Pin::new(&mut me.sender), cx, msg)
187 .map_err(convert_err)
188 }
189}
190
191#[derive(Debug)]
198pub struct ByteReader<S> {
199 stream: S,
200 bytes: Option<Bytes>,
201}
202
203impl<S> ByteReader<S> {
204 #[inline(always)]
206 pub fn new(stream: S) -> Self {
207 Self {
208 stream,
209 bytes: None,
210 }
211 }
212}
213
214fn poll_read_helper<S>(
215 mut s: Pin<&mut ByteReader<S>>,
216 cx: &mut Context<'_>,
217 buf_len: usize,
218) -> Poll<io::Result<Option<Bytes>>>
219where
220 S: Stream<Item = Result<Message, WsError>> + Unpin,
221{
222 Poll::Ready(Ok(Some(match s.bytes {
223 None => match Pin::new(&mut s.stream).poll_next(cx) {
224 Poll::Pending => return Poll::Pending,
225 Poll::Ready(None) => return Poll::Ready(Ok(None)),
226 Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))),
227 Poll::Ready(Some(Ok(msg))) => {
228 let bytes = msg.into_data();
229 if bytes.len() > buf_len {
230 s.bytes.insert(bytes).split_to(buf_len)
231 } else {
232 bytes
233 }
234 }
235 },
236 Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len),
237 Some(ref mut bytes) => {
238 let bytes = bytes.clone();
239 s.bytes = None;
240 bytes
241 }
242 })))
243}
244
245impl<S> futures_io::AsyncRead for ByteReader<S>
246where
247 S: Stream<Item = Result<Message, WsError>> + Unpin,
248{
249 fn poll_read(
250 self: Pin<&mut Self>,
251 cx: &mut Context<'_>,
252 buf: &mut [u8],
253 ) -> Poll<io::Result<usize>> {
254 poll_read_helper(self, cx, buf.len()).map_ok(|bytes| {
255 bytes.map_or(0, |bytes| {
256 buf[..bytes.len()].copy_from_slice(&bytes);
257 bytes.len()
258 })
259 })
260 }
261}
262
263#[cfg(feature = "tokio-runtime")]
264impl<S> tokio::io::AsyncRead for ByteReader<S>
265where
266 S: Stream<Item = Result<Message, WsError>> + Unpin,
267{
268 fn poll_read(
269 self: Pin<&mut Self>,
270 cx: &mut Context<'_>,
271 buf: &mut tokio::io::ReadBuf,
272 ) -> Poll<io::Result<()>> {
273 poll_read_helper(self, cx, buf.remaining()).map_ok(|bytes| {
274 if let Some(ref bytes) = bytes {
275 buf.put_slice(bytes);
276 }
277 })
278 }
279}
280
281fn convert_err(e: WsError) -> io::Error {
282 match e {
283 WsError::Io(io) => io,
284 _ => io::Error::new(io::ErrorKind::Other, e),
285 }
286}