use super::frame::{Frame, StreamFrame};
use super::locked_sink::LockedMessageSink;
use super::tungstenite_error_to_io_error;
use bytes::Bytes;
use futures_util::Sink as FutureSink;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tracing::{debug, trace, warn};
use tungstenite::Message;
#[allow(clippy::module_name_repetitions)]
pub struct MuxStream<Sink> {
pub(super) frame_rx: mpsc::Receiver<Bytes>,
pub our_port: u16,
pub their_port: u16,
pub dest_host: Bytes,
pub dest_port: u16,
pub(super) fin_sent: AtomicBool,
pub(super) stream_removed: Arc<AtomicBool>,
pub(super) buf: Bytes,
pub(super) sink: LockedMessageSink<Sink>,
pub(super) dropped_ports_tx: mpsc::UnboundedSender<(u16, u16, bool)>,
}
impl<Sink> std::fmt::Debug for MuxStream<Sink> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MuxStream")
.field("our_port", &self.our_port)
.field("their_port", &self.their_port)
.field("dest_host", &self.dest_host)
.field("dest_port", &self.dest_port)
.field("fin_sent", &self.fin_sent)
.finish()
}
}
impl<Sink> Drop for MuxStream<Sink> {
fn drop(&mut self) {
self.dropped_ports_tx
.send((
self.our_port,
self.their_port,
self.fin_sent.load(Ordering::Relaxed),
))
.unwrap_or_else(|_| warn!("Failed to notify task of dropped port"));
}
}
impl<Sink> AsyncRead for MuxStream<Sink>
where
Sink: FutureSink<Message, Error = tungstenite::Error> + Send + Sync + Unpin + 'static,
{
#[tracing::instrument(skip(cx, buf), level = "trace")]
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let remaining = buf.remaining();
if self.buf.is_empty() {
trace!("polling the stream");
let next = ready!(self.frame_rx.poll_recv(cx));
if next.is_none() || next.as_ref().unwrap().is_empty() {
self.frame_rx.close();
return Poll::Ready(Ok(()));
}
self.buf = next.unwrap();
} else {
trace!("using the remaining buffer");
}
if remaining < self.buf.len() {
let to_write = self.buf.split_to(remaining);
buf.put_slice(&to_write);
} else {
buf.put_slice(&self.buf);
self.buf.clear();
}
Poll::Ready(Ok(()))
}
}
impl<Sink> AsyncWrite for MuxStream<Sink>
where
Sink: FutureSink<Message, Error = tungstenite::Error> + Send + Sync + Unpin + 'static,
{
#[tracing::instrument(skip(cx, buf), level = "trace")]
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
if self.fin_sent.load(Ordering::Relaxed) || self.stream_removed.load(Ordering::Relaxed) {
debug!("stream has been closed, returning `BrokenPipe`");
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
ready!(self
.sink
.poll_send_stream_buf(cx, buf, self.our_port, self.their_port))?;
Poll::Ready(Ok(buf.len()))
}
#[tracing::instrument(skip(cx), level = "trace")]
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
ready!(self.sink.poll_flush(cx)).map_err(tungstenite_error_to_io_error)?;
Poll::Ready(Ok(()))
}
#[tracing::instrument(skip(cx), level = "trace")]
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
if !self.fin_sent.load(Ordering::Relaxed) {
let message = Frame::Stream(StreamFrame::new_fin(self.our_port, self.their_port))
.try_into()
.expect("Frame should be representable as a message (this is a bug)");
ready!(self.sink.poll_send_message(cx, &message))
.map_err(tungstenite_error_to_io_error)?;
self.fin_sent.store(true, Ordering::Relaxed);
}
ready!(self.sink.poll_flush(cx)).ok();
Poll::Ready(Ok(()))
}
}