#![allow(clippy::significant_drop_tightening)]
use super::client::{Message, MessageAssembler, WebSocket, WebSocketConfig};
use super::close::{CloseHandshake, CloseReason, CloseState};
use super::frame::{Frame, FrameCodec, Opcode, WsError};
use crate::bytes::{Bytes, BytesMut};
use crate::codec::Decoder;
use crate::cx::Cx;
use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::util::EntropySource;
use parking_lot::Mutex;
use smallvec::SmallVec;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
const MAX_PENDING_PONGS: usize = 16;
fn enqueue_pending_pong(pending_pongs: &mut std::collections::VecDeque<Bytes>, payload: Bytes) {
if pending_pongs.len() >= MAX_PENDING_PONGS {
let _ = pending_pongs.pop_front();
}
pending_pongs.push_back(payload);
}
struct WriterWaiter {
id: u64,
waker: Waker,
}
struct WebSocketShared<IO> {
io: IO,
codec: FrameCodec,
read_buf: BytesMut,
write_buf: BytesMut,
close_handshake: CloseHandshake,
config: WebSocketConfig,
assembler: MessageAssembler,
protocol: Option<String>,
pending_pongs: std::collections::VecDeque<Bytes>,
pending_pong_flush: bool,
entropy: Arc<dyn EntropySource>,
writer_active: bool,
writer_waiters: SmallVec<[WriterWaiter; 2]>,
next_waiter_id: u64,
id: u64,
}
struct SplitWritePermit<IO> {
shared: Arc<Mutex<WebSocketShared<IO>>>,
}
impl<IO> Drop for SplitWritePermit<IO> {
fn drop(&mut self) {
let next_waker = {
let mut shared = self.shared.lock();
shared.writer_active = false;
shared.writer_waiters.first().map(|w| w.waker.clone())
};
if let Some(waker) = next_waker {
waker.wake();
}
}
}
struct AcquireWritePermitFuture<'a, IO> {
shared: &'a Arc<Mutex<WebSocketShared<IO>>>,
waiter_id: Option<u64>,
}
impl<IO> Future for AcquireWritePermitFuture<'_, IO> {
type Output = SplitWritePermit<IO>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.shared.lock();
let must_wait = if state.writer_active {
true
} else if let Some(id) = self.waiter_id {
state.writer_waiters.first().is_some_and(|w| w.id != id)
} else {
!state.writer_waiters.is_empty()
};
if must_wait {
if let Some(id) = self.waiter_id {
if let Some(existing) = state.writer_waiters.iter_mut().find(|w| w.id == id) {
if !existing.waker.will_wake(cx.waker()) {
existing.waker.clone_from(cx.waker());
}
}
} else {
let id = state.next_waiter_id;
state.next_waiter_id = state.next_waiter_id.wrapping_add(1);
state.writer_waiters.push(WriterWaiter {
id,
waker: cx.waker().clone(),
});
self.waiter_id = Some(id);
}
drop(state);
Poll::Pending
} else {
state.writer_active = true;
if let Some(id) = self.waiter_id {
if let Some(pos) = state.writer_waiters.iter().position(|w| w.id == id) {
state.writer_waiters.remove(pos);
}
self.waiter_id = None;
}
drop(state);
Poll::Ready(SplitWritePermit {
shared: Arc::clone(self.shared),
})
}
}
}
impl<IO> Drop for AcquireWritePermitFuture<'_, IO> {
fn drop(&mut self) {
if let Some(id) = self.waiter_id {
let next_waker = {
let mut state = self.shared.lock();
let is_head = state.writer_waiters.first().is_some_and(|w| w.id == id);
if let Some(pos) = state.writer_waiters.iter().position(|w| w.id == id) {
state.writer_waiters.remove(pos);
}
if is_head && !state.writer_active {
state.writer_waiters.first().map(|w| w.waker.clone())
} else {
None
}
};
if let Some(w) = next_waker {
w.wake();
}
}
}
}
async fn acquire_write_permit<IO>(
shared: &Arc<Mutex<WebSocketShared<IO>>>,
) -> SplitWritePermit<IO> {
AcquireWritePermitFuture {
shared,
waiter_id: None,
}
.await
}
async fn flush_write_buf<IO: AsyncWrite + Unpin>(
shared: &Arc<Mutex<WebSocketShared<IO>>>,
) -> Result<(), WsError> {
let _permit = acquire_write_permit(shared).await;
flush_shared_write_buf_with_permit(shared).await
}
async fn flush_shared_write_buf_with_permit<IO: AsyncWrite + Unpin>(
shared: &Arc<Mutex<WebSocketShared<IO>>>,
) -> Result<(), WsError> {
use std::future::poll_fn;
while {
let guard = shared.lock();
!guard.write_buf.is_empty()
} {
let is_open = shared.lock().close_handshake.is_open();
let n = poll_fn(|poll_cx| {
if is_open && crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
)));
}
let mut guard = shared.lock();
if guard.write_buf.is_empty() {
return Poll::Ready(Ok(0));
}
let WebSocketShared { io, write_buf, .. } = &mut *guard;
Pin::new(io).poll_write(poll_cx, &write_buf[..])
})
.await?;
if n == 0 {
let guard = shared.lock();
if !guard.write_buf.is_empty() {
return Err(WsError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"write returned 0",
)));
}
break;
}
let mut guard = shared.lock();
let _ = guard.write_buf.split_to(n);
}
let is_open = shared.lock().close_handshake.is_open();
poll_fn(|poll_cx| {
if is_open && crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
)));
}
let mut guard = shared.lock();
Pin::new(&mut guard.io).poll_flush(poll_cx)
})
.await?;
Ok(())
}
async fn write_owned_buf_with_permit<IO: AsyncWrite + Unpin>(
shared: &Arc<Mutex<WebSocketShared<IO>>>,
buf: &mut BytesMut,
) -> Result<(), WsError> {
use std::future::poll_fn;
let _permit = acquire_write_permit(shared).await;
flush_shared_write_buf_with_permit(shared).await?;
if buf.is_empty() {
return Ok(());
}
let is_open = shared.lock().close_handshake.is_open();
let n = poll_fn(|poll_cx| {
if is_open && crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
)));
}
let mut guard = shared.lock();
Pin::new(&mut guard.io).poll_write(poll_cx, &buf[..])
})
.await?;
if n == 0 {
return Err(WsError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"write returned 0",
)));
}
let _ = buf.split_to(n);
if !buf.is_empty() {
{
let mut guard = shared.lock();
guard.write_buf.extend_from_slice(&buf[..]);
buf.clear();
}
return flush_shared_write_buf_with_permit(shared).await;
}
let is_open = shared.lock().close_handshake.is_open();
poll_fn(|poll_cx| {
if is_open && crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
)));
}
let mut guard = shared.lock();
Pin::new(&mut guard.io).poll_flush(poll_cx)
})
.await?;
Ok(())
}
pub struct WebSocketRead<IO> {
shared: Arc<Mutex<WebSocketShared<IO>>>,
}
pub struct WebSocketWrite<IO> {
shared: Arc<Mutex<WebSocketShared<IO>>>,
}
pub struct ReuniteError<IO> {
pub read: WebSocketRead<IO>,
pub write: WebSocketWrite<IO>,
}
impl<IO> std::fmt::Debug for ReuniteError<IO> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReuniteError")
.field("read", &"WebSocketRead { .. }")
.field("write", &"WebSocketWrite { .. }")
.finish()
}
}
impl<IO> std::fmt::Display for ReuniteError<IO> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "attempted to reunite mismatched WebSocket halves")
}
}
impl<IO> std::error::Error for ReuniteError<IO> {}
impl<IO> WebSocket<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub fn split(self) -> (WebSocketRead<IO>, WebSocketWrite<IO>) {
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let shared = Arc::new(Mutex::new(WebSocketShared {
io: self.io,
codec: self.codec,
read_buf: self.read_buf,
write_buf: self.write_buf,
close_handshake: self.close_handshake,
config: self.config,
assembler: self.assembler,
protocol: self.protocol,
pending_pongs: self.pending_pongs,
pending_pong_flush: false,
entropy: self.entropy,
writer_active: false,
writer_waiters: SmallVec::new(),
next_waiter_id: 0,
id,
}));
let read = WebSocketRead {
shared: Arc::clone(&shared),
};
let write = WebSocketWrite { shared };
(read, write)
}
}
impl<IO> WebSocketRead<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub async fn recv(&mut self, cx: &Cx) -> Result<Option<Message>, WsError> {
loop {
if cx.checkpoint().is_err() {
return Err(WsError::Io(io::Error::new(
io::ErrorKind::Interrupted,
"cancelled",
)));
}
let flush_pending_pongs = {
let shared = &mut *self.shared.lock();
let mut flush_pending_pongs = shared.pending_pong_flush;
while let Some(payload) = shared.pending_pongs.pop_front() {
flush_pending_pongs = true;
shared.pending_pong_flush = true;
let pong = Frame::pong(payload);
let shared = &mut *shared;
shared
.codec
.encode_with_entropy(&pong, &mut shared.write_buf, cx.entropy())?;
}
flush_pending_pongs
};
if flush_pending_pongs {
flush_write_buf(&self.shared).await?;
self.shared.lock().pending_pong_flush = false;
}
let maybe_frame = {
let shared = &mut *self.shared.lock();
let (codec, read_buf) = (&mut shared.codec, &mut shared.read_buf);
codec.decode(read_buf)?
};
if let Some(frame) = maybe_frame {
match frame.opcode {
Opcode::Ping => {
let mut shared = self.shared.lock();
enqueue_pending_pong(&mut shared.pending_pongs, frame.payload);
}
Opcode::Pong => {
}
Opcode::Close => {
let response =
{ self.shared.lock().close_handshake.receive_close(&frame)? };
if let Some(response_frame) = response {
let send_result = self
.send_frame_internal_with_entropy(&response_frame, cx.entropy())
.await;
send_result?;
self.shared.lock().close_handshake.mark_response_sent();
}
let reason = CloseReason::parse(&frame.payload).ok();
return Ok(Some(Message::Close(reason)));
}
_ => {
let result = { self.shared.lock().assembler.push_frame(frame) };
match result {
Ok(Some(msg)) => return Ok(Some(msg)),
Ok(None) => {}
Err(err) => {
self.shared
.lock()
.close_handshake
.force_close(CloseReason::new(err.as_close_code(), None));
return Err(err);
}
}
}
}
} else {
if self.shared.lock().close_handshake.is_closed() {
return Ok(None);
}
let n = self.read_more().await?;
if n == 0 {
self.shared
.lock()
.close_handshake
.force_close(CloseReason::new(super::CloseCode::Abnormal, None));
return Ok(None);
}
}
}
}
#[must_use]
pub fn is_open(&self) -> bool {
self.shared.lock().close_handshake.is_open()
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.shared.lock().close_handshake.is_closed()
}
pub fn reunite(self, write: WebSocketWrite<IO>) -> Result<WebSocket<IO>, ReuniteError<IO>> {
let self_id = self.shared.lock().id;
let write_id = write.shared.lock().id;
if self_id != write_id {
return Err(ReuniteError::<IO> { read: self, write });
}
drop(write);
let shared = match Arc::try_unwrap(self.shared) {
Ok(mutex) => mutex.into_inner(),
Err(arc) => {
let write = WebSocketWrite {
shared: Arc::clone(&arc),
};
let read = Self { shared: arc };
return Err(ReuniteError::<IO> { read, write });
}
};
Ok(WebSocket {
io: shared.io,
codec: shared.codec,
read_buf: shared.read_buf,
write_buf: shared.write_buf,
close_handshake: shared.close_handshake,
config: shared.config,
assembler: shared.assembler,
protocol: shared.protocol,
pending_pongs: shared.pending_pongs,
entropy: shared.entropy,
})
}
fn encode_frame_with_entropy(
&self,
frame: &Frame,
entropy: &dyn EntropySource,
) -> Result<(), WsError> {
let mut shared = self.shared.lock();
let shared = &mut *shared;
shared
.codec
.encode_with_entropy(frame, &mut shared.write_buf, entropy)
}
async fn send_frame_internal_with_entropy(
&self,
frame: &Frame,
entropy: &dyn EntropySource,
) -> Result<(), WsError> {
self.encode_frame_with_entropy(frame, entropy)?;
flush_write_buf(&self.shared).await
}
#[allow(dead_code)] async fn send_frame_internal(&self, frame: &Frame) -> Result<(), WsError> {
let entropy = { Arc::clone(&self.shared.lock().entropy) };
self.send_frame_internal_with_entropy(frame, entropy.as_ref())
.await
}
async fn read_more(&self) -> Result<usize, WsError> {
use std::future::poll_fn;
let is_open = self.shared.lock().close_handshake.is_open();
poll_fn(|poll_cx| {
if is_open && crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
return Poll::Ready(Err(WsError::Io(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
))));
}
let mut temp = [0u8; 4096];
let mut shared = self.shared.lock();
let mut read_buf = ReadBuf::new(&mut temp);
match Pin::new(&mut shared.io).poll_read(poll_cx, &mut read_buf) {
Poll::Ready(Ok(())) => {
let n = read_buf.filled().len();
if n > 0 {
if shared.read_buf.capacity() - shared.read_buf.len() < n {
shared.read_buf.reserve(8192.max(n));
}
shared.read_buf.extend_from_slice(&temp[..n]);
}
Poll::Ready(Ok(n))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(WsError::Io(e))),
Poll::Pending => Poll::Pending,
}
})
.await
}
}
impl<IO> WebSocketWrite<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub async fn send(&mut self, cx: &Cx, msg: Message) -> Result<(), WsError> {
if cx.checkpoint().is_err() {
return Err(WsError::Io(io::Error::new(
io::ErrorKind::Interrupted,
"cancelled",
)));
}
{
let shared = self.shared.lock();
if !msg.is_control() && !shared.close_handshake.is_open() {
return Err(WsError::Io(io::Error::new(
io::ErrorKind::NotConnected,
"connection is closing",
)));
}
}
if let Message::Close(reason) = msg {
return self
.initiate_close(reason.unwrap_or_else(CloseReason::normal))
.await;
}
let frame = Frame::from(msg);
self.send_frame_with_entropy(&frame, cx.entropy()).await
}
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WsError> {
Self::initiate_close(self, reason).await
}
pub async fn ping(&mut self, payload: impl Into<Bytes>) -> Result<(), WsError> {
let frame = Frame::ping(payload);
Self::send_frame(self, &frame).await
}
#[must_use]
pub fn is_open(&self) -> bool {
self.shared.lock().close_handshake.is_open()
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.shared.lock().close_handshake.is_closed()
}
#[must_use]
pub fn close_state(&self) -> CloseState {
self.shared.lock().close_handshake.state()
}
async fn initiate_close(&self, reason: CloseReason) -> Result<(), WsError> {
let close_state = {
let shared = self.shared.lock();
shared.close_handshake.state()
};
if close_state == CloseState::CloseReceived {
flush_write_buf(&self.shared).await?;
self.shared.lock().close_handshake.mark_response_sent();
return Ok(());
}
if close_state == CloseState::CloseSent {
flush_write_buf(&self.shared).await?;
return Ok(());
}
let frame = {
let mut shared = self.shared.lock();
shared.close_handshake.initiate(reason)
};
if let Some(f) = frame {
self.send_frame(&f).await?;
}
Ok(())
}
async fn send_frame_with_entropy(
&self,
frame: &Frame,
entropy: &dyn EntropySource,
) -> Result<(), WsError> {
let mut encoded = BytesMut::new();
{
let shared = &mut *self.shared.lock();
shared
.codec
.encode_with_entropy(frame, &mut encoded, entropy)?;
}
write_owned_buf_with_permit(&self.shared, &mut encoded).await
}
async fn send_frame(&self, frame: &Frame) -> Result<(), WsError> {
let entropy = { Arc::clone(&self.shared.lock().entropy) };
self.send_frame_with_entropy(frame, entropy.as_ref()).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::Encoder;
use crate::types::{Budget, RegionId, TaskId};
use crate::util::EntropySource;
use futures_lite::future;
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll, Waker};
struct TestIo {
read_data: Vec<u8>,
read_pos: usize,
written: Vec<u8>,
fail_writes: bool,
pending_first_write: bool,
partial_first_write_len: Option<usize>,
pending_after_partial_write: bool,
}
impl TestIo {
fn new(read_data: Vec<u8>) -> Self {
Self {
read_data,
read_pos: 0,
written: Vec::new(),
fail_writes: false,
pending_first_write: false,
partial_first_write_len: None,
pending_after_partial_write: false,
}
}
fn with_write_failure(mut self) -> Self {
self.fail_writes = true;
self
}
fn with_pending_first_write(mut self) -> Self {
self.pending_first_write = true;
self
}
fn with_partial_first_write(mut self, len: usize) -> Self {
self.partial_first_write_len = Some(len);
self.pending_after_partial_write = true;
self
}
}
struct InterleavingIo {
written: Vec<u8>,
pending_next: bool,
}
impl InterleavingIo {
fn new() -> Self {
Self {
written: Vec::new(),
pending_next: false,
}
}
}
impl AsyncRead for TestIo {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let remaining = &self.read_data[self.read_pos..];
let to_read = remaining.len().min(buf.remaining());
buf.put_slice(&remaining[..to_read]);
self.read_pos += to_read;
Poll::Ready(Ok(()))
}
}
impl AsyncRead for InterleavingIo {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for TestIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.fail_writes {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"synthetic write failure",
)));
}
if self.pending_first_write {
self.pending_first_write = false;
cx.waker().wake_by_ref();
return Poll::Pending;
}
if let Some(len) = self.partial_first_write_len.take() {
let to_write = len.min(buf.len());
self.written.extend_from_slice(&buf[..to_write]);
return Poll::Ready(Ok(to_write));
}
if self.pending_after_partial_write {
self.pending_after_partial_write = false;
cx.waker().wake_by_ref();
return Poll::Pending;
}
self.written.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for InterleavingIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
if self.pending_next {
self.pending_next = false;
cx.waker().wake_by_ref();
return Poll::Pending;
}
self.pending_next = true;
self.written.push(buf[0]);
Poll::Ready(Ok(1))
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn encode_server_frame(frame: Frame) -> Vec<u8> {
let mut codec = FrameCodec::server();
let mut out = BytesMut::new();
codec
.encode(frame, &mut out)
.expect("frame encoding should succeed");
out.to_vec()
}
fn encode_client_frame_with_entropy(frame: &Frame, entropy: &dyn EntropySource) -> Vec<u8> {
let codec = FrameCodec::client();
let mut out = BytesMut::new();
codec
.encode_with_entropy(frame, &mut out, entropy)
.expect("frame encoding should succeed");
out.to_vec()
}
#[test]
fn split_writes_do_not_interleave_frame_bytes() {
future::block_on(async {
let ws = WebSocket::from_upgraded(InterleavingIo::new(), WebSocketConfig::default());
let (read, write) = ws.split();
let read_frame = Frame::binary(Bytes::from_static(b"read-half"));
let write_frame = Frame::binary(Bytes::from_static(b"write-half"));
let expected_read = encode_server_frame(read_frame.clone());
let expected_write = encode_server_frame(write_frame.clone());
{
let mut shared = read.shared.lock();
shared.codec = FrameCodec::server();
}
let (read_result, write_result): (Result<(), _>, Result<(), _>) = future::zip(
read.send_frame_internal(&read_frame),
write.send_frame(&write_frame),
)
.await;
assert!(read_result.is_ok(), "read half frame send must succeed");
assert!(write_result.is_ok(), "write half frame send must succeed");
let ws = read.reunite(write).expect("split halves must reunite");
let written = ws.io.written;
let mut read_then_write = expected_read.clone();
read_then_write.extend_from_slice(&expected_write);
let mut write_then_read = expected_write;
write_then_read.extend_from_slice(&expected_read);
assert!(
written == read_then_write || written == write_then_read,
"concurrent writes must preserve full-frame atomicity"
);
});
}
#[test]
fn test_reunite_error_display() {
let err_msg = "attempted to reunite mismatched WebSocket halves";
assert!(err_msg.contains("reunite"));
assert!(err_msg.contains("mismatched"));
}
#[test]
fn flush_write_buf_clears_eagerly_for_cancel_safety() {
future::block_on(async {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read, _write) = ws.split();
{
let mut shared = read.shared.lock();
shared.write_buf.extend_from_slice(b"stale-pong-data");
}
let result = flush_write_buf(&read.shared).await;
assert!(result.is_ok());
let is_empty = read.shared.lock().write_buf.is_empty();
assert!(
is_empty,
"write_buf must be cleared eagerly, not after write completes"
);
});
}
#[test]
fn multiple_pong_payloads_all_encoded() {
future::block_on(async {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read, _write) = ws.split();
{
let mut shared = read.shared.lock();
shared
.pending_pongs
.push_back(Bytes::from_static(b"pong-a"));
shared
.pending_pongs
.push_back(Bytes::from_static(b"pong-b"));
shared
.pending_pongs
.push_back(Bytes::from_static(b"pong-c"));
shared.pending_pong_flush = false;
}
{
let shared = &mut *read.shared.lock();
shared.write_buf.clear();
let pongs: Vec<_> = shared.pending_pongs.drain(..).collect();
for payload in pongs {
let pong = Frame::pong(payload);
shared.codec.encode(pong, &mut shared.write_buf).unwrap();
}
shared.pending_pong_flush = true;
}
let encoded = {
let shared = read.shared.lock();
BytesMut::from(shared.write_buf.as_ref())
};
let mut decode_buf = encoded;
let mut decoder = FrameCodec::server();
let mut payloads = Vec::new();
while let Some(frame) = decoder.decode(&mut decode_buf).unwrap() {
assert_eq!(frame.opcode, Opcode::Pong);
payloads.push(frame.payload);
}
assert_eq!(
payloads,
vec![
Bytes::from_static(b"pong-a"),
Bytes::from_static(b"pong-b"),
Bytes::from_static(b"pong-c"),
],
"pending pong payloads must be emitted in receive order"
);
});
}
#[test]
fn recv_flushes_pongs_left_encoded_by_cancelled_attempt() {
future::block_on(async {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (mut read, _write) = ws.split();
let cx = test_cx_with_entropy(Arc::new(FixedEntropy([0xAB, 0xCD, 0xEF, 0x01])));
{
let shared = &mut *read.shared.lock();
let pong = Frame::pong(Bytes::from_static(b"pong-after-cancel"));
shared
.codec
.encode(pong, &mut shared.write_buf)
.expect("must encode synthetic pong");
shared.pending_pong_flush = true;
}
let result = read.recv(&cx).await.expect("recv must succeed");
assert!(
result.is_none(),
"EOF should surface once buffered pongs flush"
);
let shared = read.shared.lock();
assert!(
!shared.pending_pong_flush,
"recv must clear the deferred pong flush marker after flushing"
);
assert!(
shared.write_buf.is_empty(),
"recv must flush the encoded pong bytes before returning"
);
assert!(
!shared.io.written.is_empty(),
"recv must actually flush deferred pong bytes to the transport"
);
});
}
#[test]
fn pending_pong_queue_keeps_most_recent_payloads() {
let mut pending = std::collections::VecDeque::new();
for n in 0u8..20 {
enqueue_pending_pong(&mut pending, Bytes::from(vec![n]));
}
assert_eq!(pending.len(), MAX_PENDING_PONGS);
let kept: Vec<u8> = pending
.into_iter()
.map(|payload| *payload.first().expect("single-byte payload"))
.collect();
assert_eq!(kept, (4u8..20).collect::<Vec<_>>());
}
#[test]
fn reunite_mismatched_halves_returns_error() {
let ws1 = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let ws2 = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read1, _write1) = ws1.split();
let (_read2, write2) = ws2.split();
let result = read1.reunite(write2);
assert!(result.is_err(), "mismatched halves must fail reunite");
}
#[test]
fn reunite_matching_halves_succeeds() {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read, write) = ws.split();
let result = read.reunite(write);
assert!(result.is_ok(), "matching halves must reunite successfully");
}
#[test]
fn writer_permit_serializes_access() {
future::block_on(async {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read, _write) = ws.split();
let permit = acquire_write_permit(&read.shared).await;
assert!(
read.shared.lock().writer_active,
"writer_active must be true while permit is held"
);
drop(permit);
assert!(
!read.shared.lock().writer_active,
"writer_active must be false after permit is dropped"
);
});
}
struct CountingWake {
wake_count: AtomicUsize,
}
impl CountingWake {
fn new() -> Arc<Self> {
Arc::new(Self {
wake_count: AtomicUsize::new(0),
})
}
fn count(&self) -> usize {
self.wake_count.load(Ordering::SeqCst)
}
}
use std::task::Wake;
impl Wake for CountingWake {
fn wake(self: Arc<Self>) {
self.wake_count.fetch_add(1, Ordering::SeqCst);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wake_count.fetch_add(1, Ordering::SeqCst);
}
}
#[test]
fn writer_permit_release_wakes_first_waiter() {
future::block_on(async {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read, _write) = ws.split();
let permit = acquire_write_permit(&read.shared).await;
let mut first_waiter = Box::pin(acquire_write_permit(&read.shared));
let mut second_waiter = Box::pin(acquire_write_permit(&read.shared));
let counter_a = CountingWake::new();
let counter_b = CountingWake::new();
let first_task_waker: Waker = Waker::from(Arc::clone(&counter_a));
let second_task_waker: Waker = Waker::from(Arc::clone(&counter_b));
let mut first_context = Context::from_waker(&first_task_waker);
let mut second_context = Context::from_waker(&second_task_waker);
assert!(matches!(
first_waiter.as_mut().poll(&mut first_context),
Poll::Pending
));
assert!(matches!(
second_waiter.as_mut().poll(&mut second_context),
Poll::Pending
));
drop(permit);
assert!(
counter_a.count() > 0,
"first waiter must be woken when permit is released"
);
assert_eq!(
counter_b.count(),
0,
"second waiter must NOT be woken when permit is released (no thundering herd)"
);
});
}
#[test]
fn writer_permit_queue_preserves_fifo_when_second_waiter_polls_first() {
future::block_on(async {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read, _write) = ws.split();
let permit = acquire_write_permit(&read.shared).await;
let mut first_waiter = Box::pin(acquire_write_permit(&read.shared));
let mut second_waiter = Box::pin(acquire_write_permit(&read.shared));
let first_waker: Waker = std::task::Waker::noop().clone();
let second_waker: Waker = std::task::Waker::noop().clone();
let mut first_context = Context::from_waker(&first_waker);
let mut second_context = Context::from_waker(&second_waker);
assert!(matches!(
first_waiter.as_mut().poll(&mut first_context),
Poll::Pending
));
assert!(matches!(
second_waiter.as_mut().poll(&mut second_context),
Poll::Pending
));
drop(permit);
assert!(
matches!(
second_waiter.as_mut().poll(&mut second_context),
Poll::Pending
),
"later waiters must not bypass the queued head when the permit becomes free"
);
assert!(
matches!(
first_waiter.as_mut().poll(&mut first_context),
Poll::Ready(_)
),
"the head waiter must acquire the permit first"
);
});
}
#[test]
fn split_send_close_message_initiates_close_handshake() {
future::block_on(async {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (_read, mut write) = ws.split();
let cx = Cx::for_testing();
assert!(write.is_open(), "connection should start open");
write
.send(&cx, Message::Close(None))
.await
.expect("sending close should succeed");
assert!(
!write.is_open(),
"sending Message::Close must transition handshake out of open state"
);
let err = write
.send(&cx, Message::text("late payload"))
.await
.expect_err("data frames must be rejected after close initiation");
assert!(
matches!(err, WsError::Io(ref e) if e.kind() == io::ErrorKind::NotConnected),
"expected NotConnected after close initiation, got {err:?}"
);
});
}
#[test]
fn split_recv_keeps_close_received_state_if_response_send_fails() {
future::block_on(async {
let read_data = encode_server_frame(Frame::close(Some(1000), None));
let ws = WebSocket::from_upgraded(
TestIo::new(read_data).with_write_failure(),
WebSocketConfig::default(),
);
let (mut read, _write) = ws.split();
let cx = Cx::for_testing();
let err = read
.recv(&cx)
.await
.expect_err("close response write should fail");
assert!(
matches!(err, WsError::Io(ref e) if e.kind() == io::ErrorKind::BrokenPipe),
"expected synthetic broken-pipe write failure, got {err:?}"
);
assert!(
!read.is_closed(),
"failed close response writes must not incorrectly finish the handshake"
);
assert_eq!(
read.shared.lock().close_handshake.state(),
CloseState::CloseReceived,
"failed close response writes must leave the handshake waiting for a retry"
);
});
}
#[test]
fn cancelled_write_half_send_does_not_flush_frame_later() {
future::block_on(async {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read, mut write) = ws.split();
let cx = test_cx_with_entropy(Arc::new(FixedEntropy([0xAA, 0xBB, 0xCC, 0xDD])));
{
let mut shared = read.shared.lock();
shared.codec = FrameCodec::server();
}
let permit = acquire_write_permit(&read.shared).await;
let cancelled = Message::text("cancelled");
let delivered = Message::text("delivered");
let mut cancelled_send = Box::pin(write.send(&cx, cancelled));
let wake_counter = CountingWake::new();
let task_waker: Waker = Waker::from(Arc::clone(&wake_counter));
let mut task_cx = Context::from_waker(&task_waker);
assert!(
matches!(cancelled_send.as_mut().poll(&mut task_cx), Poll::Pending),
"first send should park waiting for the write permit"
);
drop(cancelled_send);
assert!(
read.shared.lock().write_buf.is_empty(),
"dropping a parked split send must not leave bytes in the shared write buffer"
);
drop(permit);
write
.send(&cx, delivered.clone())
.await
.expect("second send should succeed");
let ws = read.reunite(write).expect("split halves must reunite");
assert_eq!(
ws.io.written,
encode_server_frame(Frame::from(delivered)),
"later flushes must not emit bytes from a cancelled split send"
);
});
}
#[test]
fn write_half_send_ignores_cancel_while_masked() {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read, mut write) = ws.split();
let entropy: Arc<dyn EntropySource> = Arc::new(FixedEntropy([0xDE, 0xAD, 0xBE, 0xEF]));
let cx = test_cx_with_entropy(Arc::clone(&entropy));
cx.set_cancel_requested(true);
let _guard = Cx::set_current(Some(cx.clone()));
let masked = Message::text("masked");
cx.masked(|| future::block_on(write.send(&cx, masked.clone())))
.expect("masked split send should defer cancellation");
let ws = read.reunite(write).expect("split halves must reunite");
assert_eq!(
ws.io.written,
encode_client_frame_with_entropy(&Frame::from(masked), entropy.as_ref()),
"masked split send should still flush the original frame"
);
assert!(
cx.is_cancel_requested(),
"masked send must not clear the pending cancellation"
);
assert!(
cx.checkpoint().is_err(),
"cancellation must still surface after the mask is released"
);
}
#[test]
fn cancelled_write_half_send_after_partial_write_preserves_tail_for_later_flush() {
future::block_on(async {
let ws = WebSocket::from_upgraded(
TestIo::new(vec![]).with_partial_first_write(1),
WebSocketConfig::default(),
);
let (read, mut write) = ws.split();
let cx = test_cx_with_entropy(Arc::new(FixedEntropy([0xAA, 0xBB, 0xCC, 0xDD])));
{
let mut shared = read.shared.lock();
shared.codec = FrameCodec::server();
}
let cancelled = Message::text("cancelled");
let delivered = Message::text("delivered");
let expected_cancelled = encode_server_frame(Frame::from(cancelled.clone()));
let expected_delivered = encode_server_frame(Frame::from(delivered.clone()));
let mut cancelled_send = Box::pin(write.send(&cx, cancelled));
let waker = std::task::Waker::noop().clone();
let mut poll_cx = Context::from_waker(&waker);
assert!(
matches!(cancelled_send.as_mut().poll(&mut poll_cx), Poll::Pending),
"send should park after partially writing the frame and buffering the tail"
);
drop(cancelled_send);
assert!(
!read.shared.lock().write_buf.is_empty(),
"after any byte hits the wire, the unwritten split-send tail must stay durable"
);
write
.send(&cx, delivered)
.await
.expect("later sends should flush the durable tail first");
let ws = read.reunite(write).expect("split halves must reunite");
let mut expected = expected_cancelled;
expected.extend_from_slice(&expected_delivered);
assert_eq!(
ws.io.written, expected,
"later flushes must preserve the partially written split frame before the next send"
);
});
}
#[test]
fn close_after_cancelled_recv_flushes_pending_echo_without_second_close() {
future::block_on(async {
let peer_close = encode_server_frame(Frame::close(Some(1000), None));
let ws = WebSocket::from_upgraded(
TestIo::new(peer_close).with_pending_first_write(),
WebSocketConfig::default(),
);
let (mut read, mut write) = ws.split();
let entropy: Arc<dyn EntropySource> = Arc::new(FixedEntropy([0x46, 0xD0, 0x1B, 0x0A]));
let cx = test_cx_with_entropy(Arc::clone(&entropy));
let mut cancelled_recv = Box::pin(read.recv(&cx));
let waker = std::task::Waker::noop().clone();
let mut poll_cx = Context::from_waker(&waker);
assert!(
matches!(cancelled_recv.as_mut().poll(&mut poll_cx), Poll::Pending),
"recv should park while flushing the echoed close response"
);
drop(cancelled_recv);
assert_eq!(
write.close_state(),
CloseState::CloseReceived,
"cancelling recv mid-flush must leave the echoed response pending"
);
assert!(
!read.shared.lock().write_buf.is_empty(),
"the echoed close response should stay buffered for a later retry"
);
write
.close(CloseReason::going_away())
.await
.expect("close should finish the pending echoed response");
assert_eq!(
write.close_state(),
CloseState::Closed,
"finishing the pending echoed response must close the handshake"
);
let ws = read.reunite(write).expect("split halves must reunite");
assert_eq!(
ws.io.written,
encode_client_frame_with_entropy(&Frame::close(Some(1000), None), entropy.as_ref()),
"retrying close after a cancelled recv must not append a second close frame"
);
});
}
#[test]
fn close_after_partially_flushed_echo_preserves_tail_without_second_close() {
future::block_on(async {
let peer_close = encode_server_frame(Frame::close(Some(1000), None));
let ws = WebSocket::from_upgraded(
TestIo::new(peer_close).with_partial_first_write(1),
WebSocketConfig::default(),
);
let (mut read, mut write) = ws.split();
let entropy: Arc<dyn EntropySource> = Arc::new(FixedEntropy([0x46, 0xD0, 0x1B, 0x0A]));
let cx = test_cx_with_entropy(Arc::clone(&entropy));
let expected =
encode_client_frame_with_entropy(&Frame::close(Some(1000), None), entropy.as_ref());
let mut cancelled_recv = Box::pin(read.recv(&cx));
let waker = std::task::Waker::noop().clone();
let mut poll_cx = Context::from_waker(&waker);
assert!(
matches!(cancelled_recv.as_mut().poll(&mut poll_cx), Poll::Pending),
"recv should park after partially flushing the echoed close response"
);
drop(cancelled_recv);
assert_eq!(
write.close_state(),
CloseState::CloseReceived,
"partial close-response flush must leave the split handshake awaiting completion"
);
assert!(
!read.shared.lock().write_buf.is_empty(),
"the echoed split close tail must remain buffered after partial I/O"
);
{
let guard = read.shared.lock();
assert_eq!(
guard.io.written,
expected[..1].to_vec(),
"only the committed close-frame prefix should hit the transport before retry"
);
}
write
.close(CloseReason::going_away())
.await
.expect("close should flush the durable close tail");
assert_eq!(
write.close_state(),
CloseState::Closed,
"completing the echoed close tail must close the split handshake"
);
let ws = read.reunite(write).expect("split halves must reunite");
assert_eq!(
ws.io.written, expected,
"retrying close must finish the original split close frame without appending a second one"
);
});
}
#[test]
fn close_retry_flushes_partially_sent_close_without_second_close() {
future::block_on(async {
let ws = WebSocket::from_upgraded(
TestIo::new(vec![]).with_partial_first_write(1),
WebSocketConfig::default(),
);
let (_read, mut write) = ws.split();
let entropy: Arc<dyn EntropySource> = Arc::new(FixedEntropy([0xD2, 0x10, 0x44, 0x9A]));
write.shared.lock().entropy = Arc::clone(&entropy);
let expected =
encode_client_frame_with_entropy(&Frame::close(Some(1001), None), entropy.as_ref());
let mut cancelled_close = Box::pin(write.close(CloseReason::going_away()));
let waker = std::task::Waker::noop().clone();
let mut poll_cx = Context::from_waker(&waker);
assert!(
matches!(cancelled_close.as_mut().poll(&mut poll_cx), Poll::Pending),
"close should park after partially writing the initiated split close frame"
);
drop(cancelled_close);
assert_eq!(
write.close_state(),
CloseState::CloseSent,
"cancelling split close after a partial write must keep the handshake in CloseSent"
);
assert!(
!write.shared.lock().write_buf.is_empty(),
"the initiated split close tail must remain buffered after partial I/O"
);
{
let guard = write.shared.lock();
assert_eq!(
guard.io.written,
expected[..1].to_vec(),
"only the committed split close prefix should hit the transport before retry"
);
}
write
.close(CloseReason::going_away())
.await
.expect("retrying close should flush the durable split close tail");
assert_eq!(
write.close_state(),
CloseState::CloseSent,
"split close retries should flush bytes without inventing a peer response"
);
{
let guard = write.shared.lock();
assert_eq!(
guard.io.written, expected,
"retrying split close must finish the original close frame without appending another"
);
}
});
}
#[derive(Debug, Clone, Copy)]
struct FixedEntropy([u8; 4]);
impl EntropySource for FixedEntropy {
fn fill_bytes(&self, dest: &mut [u8]) {
for (idx, byte) in dest.iter_mut().enumerate() {
*byte = self.0[idx % self.0.len()];
}
}
fn next_u64(&self) -> u64 {
u64::from_le_bytes([
self.0[0], self.0[1], self.0[2], self.0[3], self.0[0], self.0[1], self.0[2],
self.0[3],
])
}
fn fork(&self, _task_id: TaskId) -> Arc<dyn EntropySource> {
Arc::new(*self)
}
fn source_id(&self) -> &'static str {
"fixed"
}
}
fn test_cx_with_entropy(entropy: Arc<dyn EntropySource>) -> Cx {
Cx::new_with_observability(
RegionId::new_for_test(0, 0),
TaskId::new_for_test(0, 0),
Budget::INFINITE,
None,
None,
Some(entropy),
)
}
#[test]
fn split_send_uses_cx_entropy_for_client_masking() {
future::block_on(async {
let ws = WebSocket::from_upgraded(TestIo::new(vec![]), WebSocketConfig::default());
let (read, mut write) = ws.split();
let cx = test_cx_with_entropy(Arc::new(FixedEntropy([0xDE, 0xAD, 0xBE, 0xEF])));
write
.send(&cx, Message::text("hi"))
.await
.expect("split send should succeed");
let ws = read.reunite(write).expect("split halves must reunite");
assert_eq!(&ws.io.written[2..6], &[0xDE, 0xAD, 0xBE, 0xEF]);
});
}
}