1use crate::{
2 sink_unfold,
3 ws::{Frame, LockedWebSocketWrite, Payload},
4 AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError,
5};
6
7use bytes::{BufMut, Bytes, BytesMut};
8use event_listener::Event;
9use flume as mpsc;
10use futures::{
11 channel::oneshot,
12 ready, select,
13 stream::{self, IntoAsyncRead},
14 task::{noop_waker_ref, Context, Poll},
15 AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, TryStreamExt,
16};
17use pin_project_lite::pin_project;
18use std::{
19 pin::Pin,
20 sync::{
21 atomic::{AtomicBool, AtomicU32, Ordering},
22 Arc,
23 },
24};
25
26pub(crate) enum WsEvent {
27 Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
28 CreateStream(
29 StreamType,
30 String,
31 u16,
32 oneshot::Sender<Result<MuxStream, WispError>>,
33 ),
34 EndFut(Option<CloseReason>),
35}
36
37pub struct MuxStreamRead {
39 pub stream_id: u32,
41 pub stream_type: StreamType,
43
44 role: Role,
45
46 tx: LockedWebSocketWrite,
47 rx: mpsc::Receiver<Bytes>,
48
49 is_closed: Arc<AtomicBool>,
50 is_closed_event: Arc<Event>,
51 close_reason: Arc<AtomicCloseReason>,
52
53 flow_control: Arc<AtomicU32>,
54 flow_control_read: AtomicU32,
55 target_flow_control: u32,
56}
57
58impl MuxStreamRead {
59 pub async fn read(&self) -> Option<Bytes> {
61 if self.is_closed.load(Ordering::Acquire) {
62 return None;
63 }
64 let bytes = select! {
65 x = self.rx.recv_async() => x.ok()?,
66 _ = self.is_closed_event.listen().fuse() => return None
67 };
68 if self.role == Role::Server && self.stream_type == StreamType::Tcp {
69 let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
70 if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) {
71 self.tx
72 .write_frame(
73 Packet::new_continue(
74 self.stream_id,
75 self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
76 )
77 .into(),
78 )
79 .await
80 .ok()?;
81 self.flow_control_read.store(0, Ordering::Release);
82 }
83 }
84 Some(bytes)
85 }
86
87 pub(crate) fn into_inner_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
88 Box::pin(stream::unfold(self, |rx| async move {
89 Some((rx.read().await?, rx))
90 }))
91 }
92
93 pub fn into_stream(self) -> MuxStreamIoStream {
95 MuxStreamIoStream {
96 rx: self.into_inner_stream(),
97 }
98 }
99
100 pub fn get_close_reason(&self) -> Option<CloseReason> {
102 if self.is_closed.load(Ordering::Acquire) {
103 Some(self.close_reason.load(Ordering::Acquire))
104 } else {
105 None
106 }
107 }
108}
109
110pub struct MuxStreamWrite {
112 pub stream_id: u32,
114 pub stream_type: StreamType,
116
117 role: Role,
118 mux_tx: mpsc::Sender<WsEvent>,
119 tx: LockedWebSocketWrite,
120
121 is_closed: Arc<AtomicBool>,
122 close_reason: Arc<AtomicCloseReason>,
123
124 continue_recieved: Arc<Event>,
125 flow_control: Arc<AtomicU32>,
126}
127
128impl MuxStreamWrite {
129 pub(crate) async fn write_payload_internal<'a>(
130 &self,
131 header: Frame<'static>,
132 body: Frame<'a>,
133 ) -> Result<(), WispError> {
134 if self.role == Role::Client
135 && self.stream_type == StreamType::Tcp
136 && self.flow_control.load(Ordering::Acquire) == 0
137 {
138 self.continue_recieved.listen().await;
139 }
140 if self.is_closed.load(Ordering::Acquire) {
141 return Err(WispError::StreamAlreadyClosed);
142 }
143
144 self.tx.write_split(header, body).await?;
145
146 if self.role == Role::Client && self.stream_type == StreamType::Tcp {
147 self.flow_control.store(
148 self.flow_control.load(Ordering::Acquire).saturating_sub(1),
149 Ordering::Release,
150 );
151 }
152 Ok(())
153 }
154
155 pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> {
157 let frame: Frame<'static> = Frame::from(Packet::new_data(
158 self.stream_id,
159 Payload::Bytes(BytesMut::new()),
160 ));
161 self.write_payload_internal(frame, Frame::binary(data))
162 .await
163 }
164
165 pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
167 self.write_payload(Payload::Borrowed(data.as_ref())).await
168 }
169
170 pub fn get_close_handle(&self) -> MuxStreamCloser {
182 MuxStreamCloser {
183 stream_id: self.stream_id,
184 close_channel: self.mux_tx.clone(),
185 is_closed: self.is_closed.clone(),
186 close_reason: self.close_reason.clone(),
187 }
188 }
189
190 pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
192 MuxProtocolExtensionStream {
193 stream_id: self.stream_id,
194 tx: self.tx.clone(),
195 is_closed: self.is_closed.clone(),
196 }
197 }
198
199 pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
201 if self.is_closed.load(Ordering::Acquire) {
202 return Err(WispError::StreamAlreadyClosed);
203 }
204 self.is_closed.store(true, Ordering::Release);
205
206 let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
207 self.mux_tx
208 .send_async(WsEvent::Close(
209 Packet::new_close(self.stream_id, reason),
210 tx,
211 ))
212 .await
213 .map_err(|_| WispError::MuxMessageFailedToSend)?;
214 rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
215
216 Ok(())
217 }
218
219 pub fn get_close_reason(&self) -> Option<CloseReason> {
221 if self.is_closed.load(Ordering::Acquire) {
222 Some(self.close_reason.load(Ordering::Acquire))
223 } else {
224 None
225 }
226 }
227
228 pub(crate) fn into_inner_sink(
229 self,
230 ) -> Pin<Box<dyn Sink<Payload<'static>, Error = WispError> + Send>> {
231 let handle = self.get_close_handle();
232 Box::pin(sink_unfold::unfold(
233 self,
234 |tx, data| async move {
235 tx.write_payload(data).await?;
236 Ok(tx)
237 },
238 handle,
239 |handle| async move {
240 handle.close(CloseReason::Unknown).await?;
241 Ok(handle)
242 },
243 ))
244 }
245
246 pub fn into_sink(self) -> MuxStreamIoSink {
248 MuxStreamIoSink {
249 tx: self.into_inner_sink(),
250 }
251 }
252}
253
254impl Drop for MuxStreamWrite {
255 fn drop(&mut self) {
256 if !self.is_closed.load(Ordering::Acquire) {
257 self.is_closed.store(true, Ordering::Release);
258 let (tx, _) = oneshot::channel();
259 let _ = self.mux_tx.send(WsEvent::Close(
260 Packet::new_close(self.stream_id, CloseReason::Unknown),
261 tx,
262 ));
263 }
264 }
265}
266
267pub struct MuxStream {
269 pub stream_id: u32,
271 rx: MuxStreamRead,
272 tx: MuxStreamWrite,
273}
274
275impl MuxStream {
276 #[allow(clippy::too_many_arguments)]
277 pub(crate) fn new(
278 stream_id: u32,
279 role: Role,
280 stream_type: StreamType,
281 rx: mpsc::Receiver<Bytes>,
282 mux_tx: mpsc::Sender<WsEvent>,
283 tx: LockedWebSocketWrite,
284 is_closed: Arc<AtomicBool>,
285 is_closed_event: Arc<Event>,
286 close_reason: Arc<AtomicCloseReason>,
287 flow_control: Arc<AtomicU32>,
288 continue_recieved: Arc<Event>,
289 target_flow_control: u32,
290 ) -> Self {
291 Self {
292 stream_id,
293 rx: MuxStreamRead {
294 stream_id,
295 stream_type,
296 role,
297 tx: tx.clone(),
298 rx,
299 is_closed: is_closed.clone(),
300 is_closed_event: is_closed_event.clone(),
301 close_reason: close_reason.clone(),
302 flow_control: flow_control.clone(),
303 flow_control_read: AtomicU32::new(0),
304 target_flow_control,
305 },
306 tx: MuxStreamWrite {
307 stream_id,
308 stream_type,
309 role,
310 mux_tx,
311 tx,
312 is_closed: is_closed.clone(),
313 close_reason: close_reason.clone(),
314 flow_control: flow_control.clone(),
315 continue_recieved: continue_recieved.clone(),
316 },
317 }
318 }
319
320 pub async fn read(&self) -> Option<Bytes> {
322 self.rx.read().await
323 }
324
325 pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> {
327 self.tx.write_payload(data).await
328 }
329
330 pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
332 self.tx.write(data).await
333 }
334
335 pub fn get_close_handle(&self) -> MuxStreamCloser {
347 self.tx.get_close_handle()
348 }
349
350 pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
352 self.tx.get_protocol_extension_stream()
353 }
354
355 pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
357 self.tx.close(reason).await
358 }
359
360 pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) {
362 (self.rx, self.tx)
363 }
364
365 pub fn into_io(self) -> MuxStreamIo {
367 MuxStreamIo {
368 rx: self.rx.into_stream(),
369 tx: self.tx.into_sink(),
370 }
371 }
372}
373
374#[derive(Clone)]
376pub struct MuxStreamCloser {
377 pub stream_id: u32,
379 close_channel: mpsc::Sender<WsEvent>,
380 is_closed: Arc<AtomicBool>,
381 close_reason: Arc<AtomicCloseReason>,
382}
383
384impl MuxStreamCloser {
385 pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
387 if self.is_closed.load(Ordering::Acquire) {
388 return Err(WispError::StreamAlreadyClosed);
389 }
390 self.is_closed.store(true, Ordering::Release);
391
392 let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
393 self.close_channel
394 .send_async(WsEvent::Close(
395 Packet::new_close(self.stream_id, reason),
396 tx,
397 ))
398 .await
399 .map_err(|_| WispError::MuxMessageFailedToSend)?;
400 rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
401
402 Ok(())
403 }
404
405 pub fn get_close_reason(&self) -> Option<CloseReason> {
407 if self.is_closed.load(Ordering::Acquire) {
408 Some(self.close_reason.load(Ordering::Acquire))
409 } else {
410 None
411 }
412 }
413}
414
415pub struct MuxProtocolExtensionStream {
417 pub stream_id: u32,
419 pub(crate) tx: LockedWebSocketWrite,
420 pub(crate) is_closed: Arc<AtomicBool>,
421}
422
423impl MuxProtocolExtensionStream {
424 pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
426 if self.is_closed.load(Ordering::Acquire) {
427 return Err(WispError::StreamAlreadyClosed);
428 }
429 let mut encoded = BytesMut::with_capacity(1 + 4 + data.len());
430 encoded.put_u8(packet_type);
431 encoded.put_u32_le(self.stream_id);
432 encoded.extend(data);
433 self.tx
434 .write_frame(Frame::binary(Payload::Bytes(encoded)))
435 .await
436 }
437}
438
439pin_project! {
440 pub struct MuxStreamIo {
442 #[pin]
443 rx: MuxStreamIoStream,
444 #[pin]
445 tx: MuxStreamIoSink,
446 }
447}
448
449impl MuxStreamIo {
450 pub fn into_asyncrw(self) -> MuxStreamAsyncRW {
452 MuxStreamAsyncRW {
453 rx: self.rx.into_asyncread(),
454 tx: self.tx.into_asyncwrite(),
455 }
456 }
457
458 pub fn into_split(self) -> (MuxStreamIoStream, MuxStreamIoSink) {
460 (self.rx, self.tx)
461 }
462}
463
464impl Stream for MuxStreamIo {
465 type Item = Result<Bytes, std::io::Error>;
466 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
467 self.project().rx.poll_next(cx)
468 }
469}
470
471impl Sink<&[u8]> for MuxStreamIo {
472 type Error = std::io::Error;
473 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
474 self.project().tx.poll_ready(cx)
475 }
476 fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
477 self.project().tx.start_send(item)
478 }
479 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
480 self.project().tx.poll_flush(cx)
481 }
482 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
483 self.project().tx.poll_close(cx)
484 }
485}
486
487pin_project! {
488 pub struct MuxStreamIoStream {
490 #[pin]
491 rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>,
492 }
493}
494
495impl MuxStreamIoStream {
496 pub fn into_asyncread(self) -> MuxStreamAsyncRead {
498 MuxStreamAsyncRead::new(self)
499 }
500}
501
502impl Stream for MuxStreamIoStream {
503 type Item = Result<Bytes, std::io::Error>;
504 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
505 self.project().rx.poll_next(cx).map(|x| x.map(Ok))
506 }
507}
508
509pin_project! {
510 pub struct MuxStreamIoSink {
512 #[pin]
513 tx: Pin<Box<dyn Sink<Payload<'static>, Error = WispError> + Send>>,
514 }
515}
516
517impl MuxStreamIoSink {
518 pub fn into_asyncwrite(self) -> MuxStreamAsyncWrite {
520 MuxStreamAsyncWrite::new(self)
521 }
522}
523
524impl Sink<&[u8]> for MuxStreamIoSink {
525 type Error = std::io::Error;
526 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
527 self.project()
528 .tx
529 .poll_ready(cx)
530 .map_err(std::io::Error::other)
531 }
532 fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
533 self.project()
534 .tx
535 .start_send(Payload::Bytes(BytesMut::from(item)))
536 .map_err(std::io::Error::other)
537 }
538 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
539 self.project()
540 .tx
541 .poll_flush(cx)
542 .map_err(std::io::Error::other)
543 }
544 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
545 self.project()
546 .tx
547 .poll_close(cx)
548 .map_err(std::io::Error::other)
549 }
550}
551
552pin_project! {
553 pub struct MuxStreamAsyncRW {
555 #[pin]
556 rx: MuxStreamAsyncRead,
557 #[pin]
558 tx: MuxStreamAsyncWrite,
559 }
560}
561
562impl MuxStreamAsyncRW {
563 pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) {
565 (self.rx, self.tx)
566 }
567}
568
569impl AsyncRead for MuxStreamAsyncRW {
570 fn poll_read(
571 self: Pin<&mut Self>,
572 cx: &mut Context<'_>,
573 buf: &mut [u8],
574 ) -> Poll<std::io::Result<usize>> {
575 self.project().rx.poll_read(cx, buf)
576 }
577
578 fn poll_read_vectored(
579 self: Pin<&mut Self>,
580 cx: &mut Context<'_>,
581 bufs: &mut [std::io::IoSliceMut<'_>],
582 ) -> Poll<std::io::Result<usize>> {
583 self.project().rx.poll_read_vectored(cx, bufs)
584 }
585}
586
587impl AsyncBufRead for MuxStreamAsyncRW {
588 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
589 self.project().rx.poll_fill_buf(cx)
590 }
591
592 fn consume(self: Pin<&mut Self>, amt: usize) {
593 self.project().rx.consume(amt)
594 }
595}
596
597impl AsyncWrite for MuxStreamAsyncRW {
598 fn poll_write(
599 self: Pin<&mut Self>,
600 cx: &mut Context<'_>,
601 buf: &[u8],
602 ) -> Poll<std::io::Result<usize>> {
603 self.project().tx.poll_write(cx, buf)
604 }
605
606 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
607 self.project().tx.poll_flush(cx)
608 }
609
610 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
611 self.project().tx.poll_close(cx)
612 }
613}
614
615pin_project! {
616 pub struct MuxStreamAsyncRead {
618 #[pin]
619 rx: IntoAsyncRead<MuxStreamIoStream>,
620 }
622}
623
624impl MuxStreamAsyncRead {
625 pub(crate) fn new(stream: MuxStreamIoStream) -> Self {
626 Self {
627 rx: stream.into_async_read(),
628 }
630 }
631}
632
633impl AsyncRead for MuxStreamAsyncRead {
634 fn poll_read(
635 self: Pin<&mut Self>,
636 cx: &mut Context<'_>,
637 buf: &mut [u8],
638 ) -> Poll<std::io::Result<usize>> {
639 self.project().rx.poll_read(cx, buf)
640 }
641}
642impl AsyncBufRead for MuxStreamAsyncRead {
643 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
644 self.project().rx.poll_fill_buf(cx)
645 }
646 fn consume(self: Pin<&mut Self>, amt: usize) {
647 self.project().rx.consume(amt)
648 }
649}
650
651pin_project! {
652 pub struct MuxStreamAsyncWrite {
654 #[pin]
655 tx: MuxStreamIoSink,
656 error: Option<std::io::Error>
657 }
658}
659
660impl MuxStreamAsyncWrite {
661 pub(crate) fn new(sink: MuxStreamIoSink) -> Self {
662 Self {
663 tx: sink,
664 error: None,
665 }
666 }
667}
668
669impl AsyncWrite for MuxStreamAsyncWrite {
670 fn poll_write(
671 mut self: Pin<&mut Self>,
672 cx: &mut Context<'_>,
673 buf: &[u8],
674 ) -> Poll<std::io::Result<usize>> {
675 if let Some(err) = self.error.take() {
676 return Poll::Ready(Err(err));
677 }
678
679 let mut this = self.as_mut().project();
680
681 ready!(this.tx.as_mut().poll_ready(cx))?;
682 match this.tx.as_mut().start_send(buf) {
683 Ok(()) => {
684 let mut cx = Context::from_waker(noop_waker_ref());
685 let cx = &mut cx;
686
687 match this.tx.poll_flush(cx) {
688 Poll::Ready(Err(err)) => {
689 self.error = Some(err);
690 }
691 Poll::Ready(Ok(_)) | Poll::Pending => {}
692 }
693
694 Poll::Ready(Ok(buf.len()))
695 }
696 Err(e) => Poll::Ready(Err(e)),
697 }
698 }
699
700 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
701 self.project().tx.poll_flush(cx)
702 }
703
704 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
705 self.project().tx.poll_close(cx)
706 }
707}