use std::{
collections::VecDeque,
hint::unreachable_unchecked,
io,
mem::{replace, take},
pin::Pin,
task::{ready, Context, Poll},
};
use bytes::{Buf, BytesMut};
use futures_core::Stream;
use futures_sink::Sink;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::{codec::FramedRead, io::poll_write_buf};
#[cfg(any(feature = "client", feature = "server"))]
use super::types::Limits;
use super::{
codec::WebSocketProtocol,
types::{Frame, Message, OpCode, Payload, Role, StreamState},
Config,
};
use crate::{CloseCode, Error};
#[derive(Debug)]
struct EncodedFrame {
header: [u8; 10],
header_len: u8,
mask: Option<[u8; 4]>,
payload: Payload,
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct WebSocketStream<T> {
inner: FramedRead<T, WebSocketProtocol>,
config: Config,
state: StreamState,
partial_payload: BytesMut,
partial_opcode: OpCode,
header_buf: [u8; 10],
frame_queue: VecDeque<EncodedFrame>,
bytes_written: usize,
}
unsafe impl<T> Sync for WebSocketStream<T> {}
impl<T> WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[cfg(any(feature = "client", feature = "server"))]
pub(crate) fn from_raw_stream(stream: T, role: Role, config: Config, limits: Limits) -> Self {
Self {
inner: FramedRead::new(stream, WebSocketProtocol::new(role, limits)),
config,
state: StreamState::Active,
partial_payload: BytesMut::new(),
partial_opcode: OpCode::Continuation,
header_buf: [0; 10],
frame_queue: VecDeque::with_capacity(1),
bytes_written: 0,
}
}
#[cfg(any(feature = "client", feature = "server"))]
pub(crate) fn from_framed<U>(
framed: FramedRead<T, U>,
role: Role,
config: Config,
limits: Limits,
) -> Self {
Self {
inner: framed.map_decoder(|_| WebSocketProtocol::new(role, limits)),
config,
state: StreamState::Active,
partial_payload: BytesMut::new(),
partial_opcode: OpCode::Continuation,
header_buf: [0; 10],
frame_queue: VecDeque::with_capacity(1),
bytes_written: 0,
}
}
fn poll_next_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame, Error>>> {
if self.state == StreamState::CloseAcknowledged {
return Poll::Ready(None);
} else if self.state == StreamState::ClosedByPeer {
ready!(self.as_mut().poll_flush(cx))?;
self.state = StreamState::CloseAcknowledged;
return Poll::Ready(None);
}
if !self.frame_queue.is_empty() {
_ = self.as_mut().poll_flush(cx)?;
}
let frame = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(frame)) => frame,
Some(Err(e)) => {
if self.state == StreamState::ClosedByUs {
self.state = StreamState::CloseAcknowledged;
} else {
self.state = StreamState::ClosedByPeer;
match &e {
Error::Protocol(e) => self.queue_frame(Frame::from(e)),
Error::PayloadTooLong { max_len, .. } => self.queue_frame(
Message::close(
Some(CloseCode::MESSAGE_TOO_BIG),
&format!("max length: {max_len}"),
)
.into(),
),
_ => {}
}
}
return Poll::Ready(Some(Err(e)));
}
None => return Poll::Ready(None),
};
match frame.opcode {
OpCode::Close => match self.state {
StreamState::Active => {
self.state = StreamState::ClosedByPeer;
let mut frame = frame.clone();
frame.payload.truncate(2);
self.queue_frame(frame);
}
StreamState::ClosedByPeer | StreamState::CloseAcknowledged => unsafe {
unreachable_unchecked()
},
StreamState::ClosedByUs => {
self.state = StreamState::CloseAcknowledged;
}
},
OpCode::Ping if self.state == StreamState::Active => {
let mut frame = frame.clone();
frame.opcode = OpCode::Pong;
self.queue_frame(frame);
}
_ => {}
}
Poll::Ready(Some(Ok(frame)))
}
fn queue_frame(&mut self, frame: Frame) {
if frame.opcode == OpCode::Close && self.state != StreamState::ClosedByPeer {
self.state = StreamState::ClosedByUs;
}
let (frame, mask): (Frame, Option<[u8; 4]>) = if self.inner.decoder().role == Role::Client {
#[cfg(feature = "client")]
{
let mut frame = frame;
let mut payload = BytesMut::from(frame.payload);
let mask = crate::rand::get_mask();
crate::mask::frame(&mask, &mut payload, 0);
frame.payload = Payload::from(payload);
(frame, Some(mask))
}
#[cfg(not(feature = "client"))]
{
unsafe { std::hint::unreachable_unchecked() }
}
} else {
(frame, None)
};
let header_len = frame.encode(&mut self.header_buf);
if mask.is_some() {
self.header_buf[1] |= 1 << 7;
}
self.frame_queue.push_back(EncodedFrame {
header: self.header_buf,
header_len,
mask,
payload: frame.payload,
});
}
}
impl<T> Stream for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<Message, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let max_len = self.inner.decoder().limits.max_payload_len;
loop {
let (opcode, payload, fin) = match ready!(self.as_mut().poll_next_frame(cx)?) {
Some(frame) => (frame.opcode, frame.payload, frame.is_final),
None => return Poll::Ready(None),
};
let len = self.partial_payload.len() + payload.len();
if opcode != OpCode::Continuation {
if fin {
return Poll::Ready(Some(Ok(Message { opcode, payload })));
}
self.partial_opcode = opcode;
self.partial_payload = BytesMut::from(payload);
} else if len > max_len {
return Poll::Ready(Some(Err(Error::PayloadTooLong { len, max_len })));
} else {
self.partial_payload.extend_from_slice(&payload);
}
if fin {
break;
}
}
let opcode = replace(&mut self.partial_opcode, OpCode::Continuation);
let mut payload = Payload::from(take(&mut self.partial_payload));
payload.set_utf8_validated(opcode == OpCode::Text);
Poll::Ready(Some(Ok(Message { opcode, payload })))
}
}
impl<T> Sink<Message> for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let pending_bytes = self
.frame_queue
.iter()
.map(|f| {
f.header_len as usize + (u8::from(f.mask.is_some()) * 4) as usize + f.payload.len()
})
.sum::<usize>();
if pending_bytes >= 8096 {
self.as_mut().poll_flush(cx)
} else {
Poll::Ready(Ok(()))
}
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
if self.state != StreamState::Active {
return Err(Error::AlreadyClosed);
}
if item.opcode.is_control() || item.payload.len() <= self.config.frame_size {
let frame: Frame = item.into();
self.queue_frame(frame);
} else {
for frame in item.into_frames(self.config.frame_size) {
self.queue_frame(frame);
}
}
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = self.get_mut();
let frame_queue = &mut this.frame_queue;
let io = this.inner.get_mut();
let bytes_written = &mut this.bytes_written;
while !frame_queue.is_empty() {
let frame = unsafe { frame_queue.front().unwrap_unchecked() };
let frame_header = unsafe { frame.header.get_unchecked(..frame.header_len as usize) };
let mut buf = frame_header
.chain(
frame
.mask
.as_ref()
.map(<[u8; 4]>::as_slice)
.unwrap_or_default(),
)
.chain(&*frame.payload);
buf.advance(*bytes_written);
while buf.has_remaining() {
let n = ready!(poll_write_buf(Pin::new(io), cx, &mut buf))?;
if n == 0 {
return Poll::Ready(Err(Error::Io(io::ErrorKind::WriteZero.into())));
}
*bytes_written += n;
}
frame_queue.pop_front();
*bytes_written = 0;
}
ready!(Pin::new(io).poll_flush(cx))?;
Poll::Ready(Ok(()))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.state == StreamState::Active {
self.queue_frame(Frame::DEFAULT_CLOSE);
}
while ready!(self.as_mut().poll_next(cx)).is_some() {}
ready!(self.as_mut().poll_flush(cx))?;
Pin::new(self.inner.get_mut())
.poll_shutdown(cx)
.map_err(Error::Io)
}
}