1use core::{
2 fmt,
3 future::Future,
4 pin::Pin,
5 task::{ready, Context, Poll},
6};
7
8use alloc::sync::{Arc, Weak};
9
10use std::{error, io, sync::Mutex};
11
12use bytes::{Bytes, BytesMut};
13use futures_core::stream::Stream;
14use pin_project_lite::pin_project;
15use tokio::sync::mpsc::{channel, Receiver, Sender};
16
17use super::{
18 codec::{Codec, Message},
19 error::ProtocolError,
20 proto::CloseReason,
21};
22
23pin_project! {
24 pub struct RequestStream<S> {
44 #[pin]
45 stream: S,
46 buf: BytesMut,
47 codec: Codec,
48 }
49}
50
51impl<S, T, E> RequestStream<S>
52where
53 S: Stream<Item = Result<T, E>>,
54 T: AsRef<[u8]>,
55{
56 pub fn new(stream: S) -> Self {
57 Self::with_codec(stream, Codec::new())
58 }
59
60 pub fn with_codec(stream: S, codec: Codec) -> Self {
61 Self {
62 stream,
63 buf: BytesMut::new(),
64 codec,
65 }
66 }
67
68 #[inline]
69 pub fn inner_mut(&mut self) -> &mut S {
70 &mut self.stream
71 }
72
73 #[inline]
74 pub fn codec_mut(&mut self) -> &mut Codec {
75 &mut self.codec
76 }
77
78 pub fn response_stream(&self) -> (ResponseStream, ResponseSender) {
82 let codec = self.codec.duplicate();
83 let cap = codec.capacity();
84 let (tx, rx) = channel(cap);
85 (ResponseStream(rx), ResponseSender::new(tx, codec))
86 }
87}
88
89pub enum WsError<E> {
90 Protocol(ProtocolError),
91 Stream(E),
92}
93
94impl<E> fmt::Debug for WsError<E> {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 match *self {
97 Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
98 Self::Stream(..) => f.write_str("Input Stream error"),
99 }
100 }
101}
102
103impl<E> fmt::Display for WsError<E> {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 match *self {
106 Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
107 Self::Stream(..) => f.write_str("Input Stream error"),
108 }
109 }
110}
111
112impl<E> error::Error for WsError<E> {}
113
114impl<E> From<ProtocolError> for WsError<E> {
115 fn from(e: ProtocolError) -> Self {
116 Self::Protocol(e)
117 }
118}
119
120impl<S, T, E> Stream for RequestStream<S>
121where
122 S: Stream<Item = Result<T, E>>,
123 T: AsRef<[u8]>,
124{
125 type Item = Result<Message, WsError<E>>;
126
127 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
128 let mut this = self.project();
129
130 loop {
131 match this.codec.decode(this.buf)? {
132 Some(msg) => return Poll::Ready(Some(Ok(msg))),
133 None => match ready!(this.stream.as_mut().poll_next(cx)) {
134 Some(res) => {
135 let item = res.map_err(WsError::Stream)?;
136 this.buf.extend_from_slice(item.as_ref())
137 }
138 None => return Poll::Ready(Some(Err(WsError::Protocol(ProtocolError::UnexpectedEof)))),
139 },
140 }
141 }
142 }
143}
144
145pub struct ResponseStream(Receiver<Item>);
146
147type Item = io::Result<Bytes>;
148
149impl Stream for ResponseStream {
150 type Item = Item;
151
152 #[inline]
153 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
154 self.get_mut().0.poll_recv(cx)
155 }
156}
157
158#[derive(Debug)]
160pub struct ResponseSender {
161 inner: Arc<_ResponseSender>,
162}
163
164impl ResponseSender {
165 fn new(tx: Sender<Item>, codec: Codec) -> Self {
166 let buf = BytesMut::with_capacity(codec.max_size());
167 Self {
168 inner: Arc::new(_ResponseSender {
169 encoder: Mutex::new(Encoder { codec, buf }),
170 tx,
171 }),
172 }
173 }
174
175 pub fn downgrade(&self) -> ResponseWeakSender {
177 ResponseWeakSender {
178 inner: Arc::downgrade(&self.inner),
179 }
180 }
181
182 #[inline]
184 pub async fn text(&self, txt: impl Into<Bytes>) -> Result<(), ProtocolError> {
185 let bytes = txt.into();
186 core::str::from_utf8(&bytes).map_err(|_| ProtocolError::BadOpCode)?;
187 self.send(Message::Text(bytes)).await
188 }
189
190 #[inline]
192 pub fn binary(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
193 self.send(Message::Binary(bin.into()))
194 }
195
196 #[inline]
198 pub fn continuation(&self, item: super::codec::Item) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
199 self.send(Message::Continuation(item))
200 }
201
202 #[inline]
204 pub fn ping(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
205 self.send(Message::Ping(bin.into()))
206 }
207
208 #[inline]
210 pub fn pong(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
211 self.send(Message::Pong(bin.into()))
212 }
213
214 pub async fn close(&mut self, reason: Option<impl Into<CloseReason>>) -> Result<(), ProtocolError> {
220 self.send(Message::Close(reason.map(Into::into))).await
221 }
222
223 #[inline]
224 fn send(&self, msg: Message) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
225 self.inner.send(msg)
226 }
227}
228
229#[derive(Debug)]
233pub struct ResponseWeakSender {
234 inner: Weak<_ResponseSender>,
235}
236
237impl ResponseWeakSender {
238 pub fn upgrade(&self) -> Option<ResponseSender> {
241 self.inner.upgrade().and_then(|inner| {
242 let closed = inner.encoder.lock().unwrap().codec.send_closed();
243 (!closed).then(|| ResponseSender { inner })
244 })
245 }
246}
247
248#[derive(Debug)]
249struct _ResponseSender {
250 encoder: Mutex<Encoder>,
251 tx: Sender<Item>,
252}
253
254#[derive(Debug)]
255struct Encoder {
256 codec: Codec,
257 buf: BytesMut,
258}
259
260impl _ResponseSender {
261 async fn send(&self, msg: Message) -> Result<(), ProtocolError> {
264 let permit = self.tx.reserve().await.map_err(|_| ProtocolError::UnexpectedEof)?;
265 let buf = {
266 let mut encoder = self.encoder.lock().unwrap();
267 let Encoder { codec, buf } = &mut *encoder;
268 codec.encode(msg, buf)?;
269 buf.split().freeze()
270 };
271 permit.send(Ok(buf));
272 Ok(())
273 }
274}