1use core::{
2 fmt,
3 future::Future,
4 pin::Pin,
5 task::{ready, Context, Poll},
6};
7
8use alloc::sync::{Arc, Weak};
9
10use std::{error, io, sync::Mutex};
11
12use bytes::{Bytes, BytesMut};
13use futures_core::stream::Stream;
14use pin_project_lite::pin_project;
15use tokio::sync::mpsc::{channel, Receiver, Sender};
16
17use super::{
18 codec::{Codec, Message},
19 error::ProtocolError,
20 proto::CloseReason,
21};
22
23pin_project! {
24 pub struct RequestStream<S> {
28 #[pin]
29 stream: S,
30 buf: BytesMut,
31 codec: Codec,
32 }
33}
34
35impl<S, T, E> RequestStream<S>
36where
37 S: Stream<Item = Result<T, E>>,
38 T: AsRef<[u8]>,
39{
40 pub fn new(stream: S) -> Self {
41 Self::with_codec(stream, Codec::new())
42 }
43
44 pub fn with_codec(stream: S, codec: Codec) -> Self {
45 Self {
46 stream,
47 buf: BytesMut::new(),
48 codec,
49 }
50 }
51
52 #[inline]
53 pub fn inner_mut(&mut self) -> &mut S {
54 &mut self.stream
55 }
56
57 #[inline]
58 pub fn codec_mut(&mut self) -> &mut Codec {
59 &mut self.codec
60 }
61
62 pub fn response_stream(&self) -> (ResponseStream, ResponseSender) {
66 let codec = self.codec.duplicate();
67 let cap = codec.capacity();
68 let (tx, rx) = channel(cap);
69 (ResponseStream(rx), ResponseSender::new(tx, codec))
70 }
71}
72
73pub enum WsError<E> {
74 Protocol(ProtocolError),
75 Stream(E),
76}
77
78impl<E> fmt::Debug for WsError<E> {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match *self {
81 Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
82 Self::Stream(..) => f.write_str("Input Stream error"),
83 }
84 }
85}
86
87impl<E> fmt::Display for WsError<E> {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match *self {
90 Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
91 Self::Stream(..) => f.write_str("Input Stream error"),
92 }
93 }
94}
95
96impl<E> error::Error for WsError<E> {}
97
98impl<E> From<ProtocolError> for WsError<E> {
99 fn from(e: ProtocolError) -> Self {
100 Self::Protocol(e)
101 }
102}
103
104impl<S, T, E> Stream for RequestStream<S>
105where
106 S: Stream<Item = Result<T, E>>,
107 T: AsRef<[u8]>,
108{
109 type Item = Result<Message, WsError<E>>;
110
111 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112 let mut this = self.project();
113
114 loop {
115 if let Some(msg) = this.codec.decode(this.buf)? {
116 return Poll::Ready(Some(Ok(msg)));
117 }
118 match ready!(this.stream.as_mut().poll_next(cx)) {
119 Some(res) => {
120 let item = res.map_err(WsError::Stream)?;
121 this.buf.extend_from_slice(item.as_ref())
122 }
123 None => return Poll::Ready(None),
124 }
125 }
126 }
127}
128
129pub struct ResponseStream(Receiver<Item>);
130
131type Item = io::Result<Bytes>;
132
133impl Stream for ResponseStream {
134 type Item = Item;
135
136 #[inline]
137 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138 self.get_mut().0.poll_recv(cx)
139 }
140}
141
142#[derive(Debug)]
144pub struct ResponseSender {
145 inner: Arc<_ResponseSender>,
146}
147
148impl ResponseSender {
149 fn new(tx: Sender<Item>, codec: Codec) -> Self {
150 Self {
151 inner: Arc::new(_ResponseSender {
152 encoder: Mutex::new(Encoder {
153 codec,
154 buf: BytesMut::with_capacity(codec.max_size()),
155 }),
156 tx,
157 }),
158 }
159 }
160
161 pub fn downgrade(&self) -> ResponseWeakSender {
163 ResponseWeakSender {
164 inner: Arc::downgrade(&self.inner),
165 }
166 }
167
168 #[inline]
170 pub fn send(&self, msg: Message) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
171 self.inner.send(msg)
172 }
173
174 #[inline]
228 pub fn send_error(&self, err: io::Error) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
229 self.inner.send_error(err)
230 }
231
232 #[inline]
234 pub fn text(&self, txt: impl Into<String>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
235 self.send(Message::Text(Bytes::from(txt.into())))
236 }
237
238 #[inline]
240 pub fn binary(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
241 self.send(Message::Binary(bin.into()))
242 }
243
244 pub async fn close(self, reason: Option<impl Into<CloseReason>>) -> Result<(), ProtocolError> {
247 self.send(Message::Close(reason.map(Into::into))).await
248 }
249}
250
251#[derive(Debug)]
253pub struct ResponseWeakSender {
254 inner: Weak<_ResponseSender>,
255}
256
257impl ResponseWeakSender {
258 pub fn upgrade(&self) -> Option<ResponseSender> {
261 self.inner.upgrade().map(|inner| ResponseSender { inner })
262 }
263}
264
265#[derive(Debug)]
266struct _ResponseSender {
267 encoder: Mutex<Encoder>,
268 tx: Sender<Item>,
269}
270
271#[derive(Debug)]
272struct Encoder {
273 codec: Codec,
274 buf: BytesMut,
275}
276
277impl _ResponseSender {
278 async fn send(&self, msg: Message) -> Result<(), ProtocolError> {
281 let permit = self.tx.reserve().await.map_err(|_| ProtocolError::Closed)?;
282 let buf = {
283 let mut encoder = self.encoder.lock().unwrap();
284 let Encoder { codec, buf } = &mut *encoder;
285 codec.encode(msg, buf)?;
286 buf.split().freeze()
287 };
288 permit.send(Ok(buf));
289 Ok(())
290 }
291
292 async fn send_error(&self, err: io::Error) -> Result<(), ProtocolError> {
297 self.tx.send(Err(err)).await.map_err(|_| ProtocolError::Closed)
298 }
299}