use std::io::{self, Read, Write};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use compcol::deflate::Deflate;
use compcol::limit::LimitedDecoder;
use compcol::vec::compress_to_vec;
use compcol::{Algorithm, Decoder, Status};
use purecrypto::hash::{Digest, Sha1};
use crate::error::{Error, Result};
use crate::tls::TlsStream;
use crate::url::Url;
const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const DEFLATE_TAIL: [u8; 4] = [0x00, 0x00, 0xFF, 0xFF];
const OPCODE_CONT: u8 = 0x0;
const OPCODE_TEXT: u8 = 0x1;
const OPCODE_BINARY: u8 = 0x2;
const OPCODE_CLOSE: u8 = 0x8;
const OPCODE_PING: u8 = 0x9;
const OPCODE_PONG: u8 = 0xA;
const MAX_PAYLOAD_BYTES: u64 = 64 * 1024 * 1024;
const HANDSHAKE_DEADLINE: Duration = Duration::from_secs(60);
const PMD_OFFER: &str =
"permessage-deflate; client_no_context_takeover; server_no_context_takeover; client_max_window_bits";
struct Pmd {
#[allow(dead_code)]
client_no_context_takeover: bool,
server_no_context_takeover: bool,
decoder: <Deflate as Algorithm>::Decoder,
}
impl Pmd {
fn inflate_message(&mut self, compressed: &[u8]) -> Result<Vec<u8>> {
if self.server_no_context_takeover {
self.decoder.reset();
}
let mut input = Vec::with_capacity(compressed.len() + DEFLATE_TAIL.len());
input.extend_from_slice(compressed);
input.extend_from_slice(&DEFLATE_TAIL);
let taken = std::mem::replace(&mut self.decoder, Deflate::decoder());
let mut limited = LimitedDecoder::new(taken, MAX_PAYLOAD_BYTES);
let result = Self::run_inflate(&mut limited, &input);
self.decoder = limited.into_inner();
result
}
fn run_inflate(
limited: &mut LimitedDecoder<<Deflate as Algorithm>::Decoder>,
input: &[u8],
) -> Result<Vec<u8>> {
let mut out: Vec<u8> = Vec::new();
let mut scratch = vec![0u8; 64 * 1024];
let mut consumed = 0usize;
loop {
let before_consumed = consumed;
let before_written = out.len();
let (p, status) = limited
.decode(&input[consumed..], &mut scratch)
.map_err(|e| {
Error::BadResponse(format!("permessage-deflate inflate failed: {e}"))
})?;
out.extend_from_slice(&scratch[..p.written]);
consumed += p.consumed;
match status {
Status::StreamEnd => break,
Status::OutputFull => continue,
Status::InputEmpty => {
if consumed >= input.len()
|| (consumed == before_consumed && out.len() == before_written)
{
break;
}
}
}
}
Ok(out)
}
}
fn deflate_message(payload: &[u8]) -> Result<Vec<u8>> {
let mut out = compress_to_vec::<Deflate>(payload)
.map_err(|e| Error::BadResponse(format!("permessage-deflate deflate failed: {e}")))?;
if out.ends_with(&DEFLATE_TAIL) {
out.truncate(out.len() - DEFLATE_TAIL.len());
}
Ok(out)
}
pub fn fetch(url: &Url) -> Result<Vec<u8>> {
fetch_with(url, &crate::net::NetConfig::default())
}
pub(crate) fn fetch_with(url: &Url, cfg: &crate::net::NetConfig) -> Result<Vec<u8>> {
match url.scheme.as_str() {
"ws" => {
let mut sock = tcp_connect(url, cfg)?;
let (mut pmd, _proto) = handshake(&mut sock, url, &[])?;
read_data_and_close(&mut sock, pmd.as_mut())
}
"wss" => {
let tcp = tcp_connect(url, cfg)?;
let mut tls = crate::tls::connect_over(tcp, &url.host)?;
let (mut pmd, _proto) = handshake(&mut tls, url, &[])?;
read_data_and_close(&mut tls, pmd.as_mut())
}
other => Err(Error::UnsupportedScheme(other.to_string())),
}
}
fn tcp_connect(url: &Url, cfg: &crate::net::NetConfig) -> Result<Box<dyn crate::net::NetStream>> {
let stream = cfg.connect(&url.host, url.port)?;
stream.set_read_timeout(Some(Duration::from_secs(60)))?;
stream.set_write_timeout(Some(Duration::from_secs(60)))?;
Ok(stream)
}
const SEND_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WsMessage {
Text(String),
Binary(Vec<u8>),
}
impl WsMessage {
pub fn as_bytes(&self) -> &[u8] {
match self {
WsMessage::Text(s) => s.as_bytes(),
WsMessage::Binary(b) => b,
}
}
pub fn as_text(&self) -> Option<&str> {
match self {
WsMessage::Text(s) => Some(s),
WsMessage::Binary(_) => None,
}
}
pub fn into_bytes(self) -> Vec<u8> {
match self {
WsMessage::Text(s) => s.into_bytes(),
WsMessage::Binary(b) => b,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WsClose {
pub code: u16,
pub reason: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WsOpcode {
Continuation,
Text,
Binary,
Close,
Ping,
Pong,
}
impl WsOpcode {
fn to_u8(self) -> u8 {
match self {
WsOpcode::Continuation => OPCODE_CONT,
WsOpcode::Text => OPCODE_TEXT,
WsOpcode::Binary => OPCODE_BINARY,
WsOpcode::Close => OPCODE_CLOSE,
WsOpcode::Ping => OPCODE_PING,
WsOpcode::Pong => OPCODE_PONG,
}
}
fn from_u8(op: u8) -> Result<WsOpcode> {
Ok(match op {
OPCODE_CONT => WsOpcode::Continuation,
OPCODE_TEXT => WsOpcode::Text,
OPCODE_BINARY => WsOpcode::Binary,
OPCODE_CLOSE => WsOpcode::Close,
OPCODE_PING => WsOpcode::Ping,
OPCODE_PONG => WsOpcode::Pong,
other => return Err(Error::BadResponse(format!("unknown WS opcode 0x{other:x}"))),
})
}
fn is_control(self) -> bool {
matches!(self, WsOpcode::Close | WsOpcode::Ping | WsOpcode::Pong)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WsEvent {
Text(String),
Binary(Vec<u8>),
Ping(Vec<u8>),
Pong(Vec<u8>),
Close(Option<WsClose>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WsFrame {
pub fin: bool,
pub opcode: WsOpcode,
pub payload: Vec<u8>,
}
trait ReadWrite: Read + Write + Send {}
impl<T: Read + Write + Send + ?Sized> ReadWrite for T {}
enum WsTransportKind {
Plain {
read: Mutex<Box<dyn crate::net::NetStream>>,
write: Mutex<Box<dyn crate::net::NetStream>>,
},
Tls(Box<crate::tls::TlsConn>),
Shared(Mutex<Box<dyn ReadWrite>>),
}
const SHUTDOWN_POLL: Duration = Duration::from_millis(250);
struct WsTransport {
kind: WsTransportKind,
shutdown: Arc<AtomicBool>,
read_timeout: Mutex<Option<Duration>>,
}
fn is_read_timeout(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
)
}
impl WsTransport {
fn new(kind: WsTransportKind) -> WsTransport {
WsTransport {
kind,
shutdown: Arc::new(AtomicBool::new(false)),
read_timeout: Mutex::new(Some(SEND_TIMEOUT)),
}
}
fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
match &self.kind {
WsTransportKind::Shared(s) => s.lock().unwrap().read(buf),
_ => self.read_polled(buf),
}
}
fn read_polled(&self, buf: &mut [u8]) -> io::Result<usize> {
let deadline = (*self.read_timeout.lock().unwrap()).map(|d| Instant::now() + d);
let mut programmed: Option<Duration> = None;
loop {
if self.shutdown.load(Ordering::SeqCst) {
return Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"websocket shut down by WsShutdown handle",
));
}
let tick = match deadline {
Some(dl) => dl
.checked_duration_since(Instant::now())
.unwrap_or(Duration::from_millis(1))
.min(SHUTDOWN_POLL)
.max(Duration::from_millis(1)),
None => SHUTDOWN_POLL,
};
if programmed != Some(tick) {
self.kind.set_read_timeout(Some(tick))?;
programmed = Some(tick);
}
match self.kind.read(buf) {
Ok(n) => return Ok(n),
Err(e) if is_read_timeout(&e) => match deadline {
Some(dl) if Instant::now() >= dl => return Err(e),
_ => continue,
},
Err(e) => return Err(e),
}
}
}
fn write_all(&self, data: &[u8]) -> io::Result<()> {
self.kind.write_all(data)
}
fn flush(&self) -> io::Result<()> {
self.kind.flush()
}
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
*self.read_timeout.lock().unwrap() = dur;
Ok(())
}
}
impl WsTransportKind {
fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
match self {
WsTransportKind::Plain { read, .. } => read.lock().unwrap().read(buf),
WsTransportKind::Tls(c) => c.read(buf),
WsTransportKind::Shared(s) => s.lock().unwrap().read(buf),
}
}
fn write_all(&self, data: &[u8]) -> io::Result<()> {
match self {
WsTransportKind::Plain { write, .. } => write.lock().unwrap().write_all(data),
WsTransportKind::Tls(c) => c.write(data),
WsTransportKind::Shared(s) => s.lock().unwrap().write_all(data),
}
}
fn flush(&self) -> io::Result<()> {
match self {
WsTransportKind::Plain { write, .. } => write.lock().unwrap().flush(),
WsTransportKind::Tls(c) => c.flush(),
WsTransportKind::Shared(s) => s.lock().unwrap().flush(),
}
}
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
match self {
WsTransportKind::Plain { read, .. } => read.lock().unwrap().set_read_timeout(dur),
WsTransportKind::Tls(c) => c.set_read_timeout(dur),
WsTransportKind::Shared(_) => Ok(()), }
}
}
struct TransportIo<'a>(&'a WsTransport);
impl Read for TransportIo<'_> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for TransportIo<'_> {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
self.0.write_all(data)?;
Ok(data.len())
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
pub struct WebSocket {
reader: WsReader,
writer: WsWriter,
subprotocol: Option<String>,
}
pub struct WsReader {
transport: Arc<WsTransport>,
rxbuf: Vec<u8>,
pmd: Option<Pmd>,
compression: bool,
auto_pong: bool,
send_closed: Arc<AtomicBool>,
recv_closed: Arc<AtomicBool>,
shutdown: Option<ShutdownSock>,
}
pub struct WsWriter {
transport: Arc<WsTransport>,
compress: bool,
send_closed: Arc<AtomicBool>,
recv_closed: Arc<AtomicBool>,
shutdown: Option<ShutdownSock>,
}
type ShutdownSock = Arc<Mutex<Box<dyn crate::net::NetStream>>>;
#[derive(Clone)]
pub struct WsShutdown {
sock: ShutdownSock,
flag: Arc<AtomicBool>,
}
impl WsShutdown {
pub fn shutdown(&self) -> Result<()> {
self.flag.store(true, Ordering::SeqCst);
self.sock
.lock()
.unwrap()
.shutdown(std::net::Shutdown::Both)
.map_err(Error::Io)
}
}
impl WebSocket {
pub fn connect(url: &str) -> Result<WebSocket> {
crate::net::Client::new().websocket(url)
}
pub fn connect_with_subprotocols(url: &str, subprotocols: &[&str]) -> Result<WebSocket> {
crate::net::Client::new().websocket_with_subprotocols(url, subprotocols)
}
pub(crate) fn open(
url: &Url,
cfg: &crate::net::NetConfig,
read_timeout: Option<Duration>,
subprotocols: &[String],
) -> Result<WebSocket> {
let (transport, shutdown_sock): (Arc<WsTransport>, Option<Box<dyn crate::net::NetStream>>) =
match url.scheme.as_str() {
"ws" => {
let data = cfg.connect(&url.host, url.port)?;
data.set_write_timeout(Some(SEND_TIMEOUT))
.map_err(Error::Io)?;
data.set_read_timeout(Some(SEND_TIMEOUT))
.map_err(Error::Io)?;
match data.try_clone_box() {
Ok(read) => {
read.set_read_timeout(Some(SEND_TIMEOUT))
.map_err(Error::Io)?;
let shutdown = data.try_clone_box().ok();
(
Arc::new(WsTransport::new(WsTransportKind::Plain {
read: Mutex::new(read),
write: Mutex::new(data),
})),
shutdown,
)
}
Err(_) => (
Arc::new(WsTransport::new(WsTransportKind::Shared(Mutex::new(
Box::new(data),
)))),
None,
),
}
}
"wss" => {
let data = cfg.connect(&url.host, url.port)?;
data.set_write_timeout(Some(SEND_TIMEOUT))
.map_err(Error::Io)?;
data.set_read_timeout(Some(SEND_TIMEOUT))
.map_err(Error::Io)?;
let read_clone = data.try_clone_box().ok();
let shutdown = data.try_clone_box().ok();
let mut opts = crate::tls::TlsOpts::verifying();
opts.verify = cfg.verify;
let tls = crate::tls::connect_over_tls(data, &url.host, opts)?;
match read_clone {
Some(read) => {
read.set_read_timeout(Some(SEND_TIMEOUT))
.map_err(Error::Io)?;
(
Arc::new(WsTransport::new(WsTransportKind::Tls(Box::new(
tls.into_concurrent(read),
)))),
shutdown,
)
}
None => (
Arc::new(WsTransport::new(WsTransportKind::Shared(Mutex::new(
Box::new(tls),
)))),
None,
),
}
}
other => return Err(Error::UnsupportedScheme(other.to_string())),
};
let (pmd, subprotocol) = handshake(&mut TransportIo(&transport), url, subprotocols)?;
let _ = transport.set_read_timeout(read_timeout);
let compression = pmd.is_some();
let send_closed = Arc::new(AtomicBool::new(false));
let recv_closed = Arc::new(AtomicBool::new(false));
let shutdown: Option<ShutdownSock> = shutdown_sock.map(|s| Arc::new(Mutex::new(s)));
Ok(WebSocket {
reader: WsReader {
transport: Arc::clone(&transport),
rxbuf: Vec::new(),
pmd,
compression,
auto_pong: true,
send_closed: Arc::clone(&send_closed),
recv_closed: Arc::clone(&recv_closed),
shutdown: shutdown.clone(),
},
writer: WsWriter {
transport,
compress: compression,
send_closed,
recv_closed,
shutdown,
},
subprotocol,
})
}
pub fn split(self) -> (WsReader, WsWriter) {
(self.reader, self.writer)
}
pub fn shutdown_handle(&self) -> Option<WsShutdown> {
self.reader.shutdown_handle()
}
pub fn subprotocol(&self) -> Option<&str> {
self.subprotocol.as_deref()
}
pub fn compression_enabled(&self) -> bool {
self.reader.compression
}
pub fn is_closed(&self) -> bool {
self.writer.is_closed()
}
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
self.reader.set_read_timeout(dur)
}
pub fn send_text(&mut self, text: &str) -> Result<()> {
self.writer.send_text(text)
}
pub fn send_binary(&mut self, data: &[u8]) -> Result<()> {
self.writer.send_binary(data)
}
pub fn send(&mut self, msg: &WsMessage) -> Result<()> {
self.writer.send(msg)
}
pub fn ping(&mut self, payload: &[u8]) -> Result<()> {
self.writer.ping(payload)
}
pub fn recv(&mut self) -> Result<Option<WsMessage>> {
self.reader.recv()
}
pub fn set_auto_pong(&mut self, on: bool) {
self.reader.set_auto_pong(on);
}
pub fn recv_event(&mut self) -> Result<WsEvent> {
self.reader.recv_event()
}
pub fn send_pong(&mut self, payload: &[u8]) -> Result<()> {
self.writer.send_pong(payload)
}
pub fn send_frame(&mut self, fin: bool, opcode: WsOpcode, payload: &[u8]) -> Result<()> {
self.writer.send_frame(fin, opcode, payload)
}
pub fn recv_frame(&mut self) -> Result<WsFrame> {
self.reader.recv_frame()
}
pub fn close(&mut self) -> Result<()> {
self.writer.close()
}
pub fn close_with(&mut self, code: u16, reason: &str) -> Result<()> {
self.writer.close_with(code, reason)
}
}
impl WsReader {
pub fn set_auto_pong(&mut self, on: bool) {
self.auto_pong = on;
}
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
self.transport.set_read_timeout(dur).map_err(Error::Io)
}
pub fn shutdown_handle(&self) -> Option<WsShutdown> {
self.shutdown.as_ref().map(|s| WsShutdown {
sock: Arc::clone(s),
flag: Arc::clone(&self.transport.shutdown),
})
}
pub fn is_closed(&self) -> bool {
self.send_closed.load(Ordering::SeqCst) || self.recv_closed.load(Ordering::SeqCst)
}
pub fn recv(&mut self) -> Result<Option<WsMessage>> {
loop {
match self.recv_event()? {
WsEvent::Text(s) => return Ok(Some(WsMessage::Text(s))),
WsEvent::Binary(b) => return Ok(Some(WsMessage::Binary(b))),
WsEvent::Ping(_) | WsEvent::Pong(_) => continue,
WsEvent::Close(_) => return Ok(None),
}
}
}
pub fn recv_event(&mut self) -> Result<WsEvent> {
if self.recv_closed.load(Ordering::SeqCst) {
return Ok(WsEvent::Close(None));
}
let mut frag_opcode: Option<u8> = None;
let mut compressed = false;
let mut buf: Vec<u8> = Vec::new();
loop {
let frame = read_frame(&mut TransportIo(self.transport.as_ref()), &mut self.rxbuf)?;
if frame.opcode >= 0x8 {
validate_control_frame(&frame)?;
let mid_message = frag_opcode.is_some();
match frame.opcode {
OPCODE_PING => {
if self.auto_pong {
let pong = build_client_frame(OPCODE_PONG, &frame.payload)?;
self.transport.write_all(&pong).map_err(Error::Io)?;
self.transport.flush().map_err(Error::Io)?;
}
if mid_message {
continue;
}
return Ok(WsEvent::Ping(frame.payload));
}
OPCODE_PONG => {
if mid_message {
continue;
}
return Ok(WsEvent::Pong(frame.payload));
}
OPCODE_CLOSE => {
if !self.send_closed.swap(true, Ordering::SeqCst) {
if let Ok(close) = build_client_frame(OPCODE_CLOSE, &[]) {
let _ = self.transport.write_all(&close);
let _ = self.transport.flush();
}
}
self.recv_closed.store(true, Ordering::SeqCst);
return Ok(WsEvent::Close(parse_close_payload(&frame.payload)));
}
other => {
return Err(Error::BadResponse(format!(
"unknown WS control opcode 0x{other:x}"
)));
}
}
}
match frame.opcode {
OPCODE_TEXT | OPCODE_BINARY => {
if frag_opcode.is_some() {
return Err(Error::BadResponse(
"new data frame began while a fragmented message was in progress"
.into(),
));
}
if frame.rsv1 {
if self.pmd.is_none() {
return Err(Error::BadResponse(
"RSV1 set on a WS frame but permessage-deflate was not negotiated"
.into(),
));
}
compressed = true;
}
accumulate(&mut buf, &frame.payload)?;
if frame.fin {
return self.finish_event(frame.opcode, buf, compressed);
}
frag_opcode = Some(frame.opcode);
}
OPCODE_CONT => {
let opcode = frag_opcode.ok_or_else(|| {
Error::BadResponse("continuation frame with no message in progress".into())
})?;
if frame.rsv1 {
return Err(Error::BadResponse(
"RSV1 set on a WS continuation frame".into(),
));
}
accumulate(&mut buf, &frame.payload)?;
if frame.fin {
return self.finish_event(opcode, buf, compressed);
}
}
other => {
return Err(Error::BadResponse(format!("unknown WS opcode 0x{other:x}")));
}
}
}
}
fn finish_event(&mut self, opcode: u8, payload: Vec<u8>, compressed: bool) -> Result<WsEvent> {
match finish_data_message(opcode, payload, compressed, self.pmd.as_mut())? {
Message::Data { opcode, payload } if opcode == OPCODE_TEXT => {
let s = String::from_utf8(payload).map_err(|_| {
Error::BadResponse("WS TEXT message payload is not valid UTF-8".into())
})?;
Ok(WsEvent::Text(s))
}
Message::Data { payload, .. } => Ok(WsEvent::Binary(payload)),
Message::Closed => Ok(WsEvent::Close(None)),
}
}
pub fn recv_frame(&mut self) -> Result<WsFrame> {
let frame = read_frame(&mut TransportIo(self.transport.as_ref()), &mut self.rxbuf)?;
Ok(WsFrame {
fin: frame.fin,
opcode: WsOpcode::from_u8(frame.opcode)?,
payload: frame.payload,
})
}
}
impl WsWriter {
pub fn is_closed(&self) -> bool {
self.send_closed.load(Ordering::SeqCst) || self.recv_closed.load(Ordering::SeqCst)
}
pub fn shutdown_handle(&self) -> Option<WsShutdown> {
self.shutdown.as_ref().map(|s| WsShutdown {
sock: Arc::clone(s),
flag: Arc::clone(&self.transport.shutdown),
})
}
fn ensure_open(&self) -> Result<()> {
if self.send_closed.load(Ordering::SeqCst) {
return Err(Error::BadResponse(
"websocket: the connection is closed".into(),
));
}
Ok(())
}
pub fn send_text(&mut self, text: &str) -> Result<()> {
self.ensure_open()?;
send_message(
&mut TransportIo(self.transport.as_ref()),
OPCODE_TEXT,
text.as_bytes(),
self.compress,
)
}
pub fn send_binary(&mut self, data: &[u8]) -> Result<()> {
self.ensure_open()?;
send_message(
&mut TransportIo(self.transport.as_ref()),
OPCODE_BINARY,
data,
self.compress,
)
}
pub fn send(&mut self, msg: &WsMessage) -> Result<()> {
match msg {
WsMessage::Text(s) => self.send_text(s),
WsMessage::Binary(b) => self.send_binary(b),
}
}
pub fn ping(&mut self, payload: &[u8]) -> Result<()> {
self.ensure_open()?;
self.send_control(OPCODE_PING, payload)
}
pub fn send_pong(&mut self, payload: &[u8]) -> Result<()> {
self.ensure_open()?;
self.send_control(OPCODE_PONG, payload)
}
pub fn send_frame(&mut self, fin: bool, opcode: WsOpcode, payload: &[u8]) -> Result<()> {
self.ensure_open()?;
if opcode.is_control() {
if !fin {
return Err(Error::BadResponse(
"websocket: control frames cannot be fragmented (fin must be true)".into(),
));
}
if payload.len() > MAX_CONTROL_PAYLOAD {
return Err(Error::BadResponse(format!(
"websocket: control frame payload too large: {} bytes (max {MAX_CONTROL_PAYLOAD})",
payload.len()
)));
}
}
let frame = build_client_frame_inner(fin, opcode.to_u8(), payload, false)?;
self.transport.write_all(&frame).map_err(Error::Io)?;
self.transport.flush().map_err(Error::Io)?;
Ok(())
}
fn send_control(&mut self, opcode: u8, payload: &[u8]) -> Result<()> {
if payload.len() > MAX_CONTROL_PAYLOAD {
return Err(Error::BadResponse(format!(
"websocket: control frame payload too large: {} bytes (max {MAX_CONTROL_PAYLOAD})",
payload.len()
)));
}
let frame = build_client_frame(opcode, payload)?;
self.transport.write_all(&frame).map_err(Error::Io)?;
self.transport.flush().map_err(Error::Io)?;
Ok(())
}
pub fn close(&mut self) -> Result<()> {
if self.send_closed.swap(true, Ordering::SeqCst) {
return Ok(());
}
let frame = build_client_frame(OPCODE_CLOSE, &[])?;
self.transport.write_all(&frame).map_err(Error::Io)?;
self.transport.flush().map_err(Error::Io)?;
Ok(())
}
pub fn close_with(&mut self, code: u16, reason: &str) -> Result<()> {
let mut payload = Vec::with_capacity(2 + reason.len());
payload.extend_from_slice(&code.to_be_bytes());
payload.extend_from_slice(reason.as_bytes());
if payload.len() > MAX_CONTROL_PAYLOAD {
return Err(Error::BadResponse(format!(
"websocket: close reason too long: {} bytes (max {})",
payload.len(),
MAX_CONTROL_PAYLOAD - 2
)));
}
if self.send_closed.swap(true, Ordering::SeqCst) {
return Ok(());
}
let frame = build_client_frame(OPCODE_CLOSE, &payload)?;
self.transport.write_all(&frame).map_err(Error::Io)?;
self.transport.flush().map_err(Error::Io)?;
Ok(())
}
}
fn handshake<S: Read + Write>(
stream: &mut S,
url: &Url,
subprotocols: &[String],
) -> Result<(Option<Pmd>, Option<String>)> {
let key_bytes: [u8; 16] = random_16()?;
let key_b64 = base64_encode(&key_bytes);
let host_header =
if (url.scheme == "ws" && url.port == 80) || (url.scheme == "wss" && url.port == 443) {
url.host.clone()
} else {
format!("{}:{}", url.host, url.port)
};
let path = if url.path.is_empty() {
"/"
} else {
url.path.as_str()
};
let proto_header = if subprotocols.is_empty() {
String::new()
} else {
format!("Sec-WebSocket-Protocol: {}\r\n", subprotocols.join(", "))
};
let req = format!(
"GET {path} HTTP/1.1\r\n\
Host: {host_header}\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {key_b64}\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Extensions: {PMD_OFFER}\r\n\
{proto_header}\
\r\n"
);
stream.write_all(req.as_bytes())?;
stream.flush()?;
let buf = read_handshake_head(stream, HANDSHAKE_DEADLINE)?;
let head = std::str::from_utf8(&buf)
.map_err(|_| Error::BadResponse("non-utf8 handshake response".into()))?;
let mut lines = head.split("\r\n");
let status_line = lines
.next()
.ok_or_else(|| Error::BadResponse("empty handshake response".into()))?;
if !(status_line.starts_with("HTTP/1.1 101") || status_line.starts_with("HTTP/1.0 101")) {
return Err(Error::BadResponse(format!(
"expected 101 Switching Protocols, got: {status_line:?}"
)));
}
let mut upgrade_ok = false;
let mut connection_ok = false;
let mut accept_value: Option<String> = None;
let mut extensions_value: Option<String> = None;
let mut subprotocol_value: Option<String> = None;
for line in lines {
if line.is_empty() {
break;
}
let (k, v) = match line.split_once(':') {
Some((k, v)) => (k.trim(), v.trim()),
None => continue,
};
if k.eq_ignore_ascii_case("upgrade") {
if v.eq_ignore_ascii_case("websocket") {
upgrade_ok = true;
}
} else if k.eq_ignore_ascii_case("connection") {
if v.split(',')
.any(|t| t.trim().eq_ignore_ascii_case("upgrade"))
{
connection_ok = true;
}
} else if k.eq_ignore_ascii_case("sec-websocket-accept") {
accept_value = Some(v.to_string());
} else if k.eq_ignore_ascii_case("sec-websocket-protocol") {
subprotocol_value = Some(v.to_string());
} else if k.eq_ignore_ascii_case("sec-websocket-extensions") {
match &mut extensions_value {
Some(existing) => {
existing.push_str(", ");
existing.push_str(v);
}
None => extensions_value = Some(v.to_string()),
}
}
}
if !upgrade_ok {
return Err(Error::BadResponse(
"missing or wrong Upgrade header in handshake response".into(),
));
}
if !connection_ok {
return Err(Error::BadResponse(
"missing or wrong Connection header in handshake response".into(),
));
}
let accept = accept_value
.ok_or_else(|| Error::BadResponse("missing Sec-WebSocket-Accept header".into()))?;
let expected = derive_accept(&key_b64);
if accept != expected {
return Err(Error::BadResponse(format!(
"Sec-WebSocket-Accept mismatch: got {accept:?}, expected {expected:?}"
)));
}
let pmd = extensions_value.as_deref().and_then(parse_pmd_response);
Ok((pmd, subprotocol_value))
}
fn read_handshake_head<S: Read>(stream: &mut S, deadline: Duration) -> Result<Vec<u8>> {
let mut buf: Vec<u8> = Vec::with_capacity(512);
let start = Instant::now();
loop {
let mut b = [0u8; 1];
let n = stream.read(&mut b)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
buf.push(b[0]);
if buf.len() >= 4 && &buf[buf.len() - 4..] == b"\r\n\r\n" {
break;
}
if buf.len() > 64 * 1024 {
return Err(Error::BadResponse("handshake response too large".into()));
}
if start.elapsed() > deadline {
return Err(Error::BadResponse(
"handshake response timed out (header read exceeded deadline)".into(),
));
}
}
Ok(buf)
}
fn parse_pmd_response(value: &str) -> Option<Pmd> {
for ext in value.split(',') {
let mut params = ext.split(';').map(str::trim);
let name = params.next()?;
if !name.eq_ignore_ascii_case("permessage-deflate") {
continue;
}
let mut client_no_context_takeover = false;
let mut server_no_context_takeover = false;
for param in params {
if param.is_empty() {
continue;
}
let token = param.split('=').next().unwrap_or(param).trim();
if token.eq_ignore_ascii_case("client_no_context_takeover") {
client_no_context_takeover = true;
} else if token.eq_ignore_ascii_case("server_no_context_takeover") {
server_no_context_takeover = true;
}
}
return Some(Pmd {
client_no_context_takeover,
server_no_context_takeover,
decoder: Deflate::decoder(),
});
}
None
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Message {
Data { opcode: u8, payload: Vec<u8> },
Closed,
}
const MAX_CONTROL_PAYLOAD: usize = 125;
fn read_message<S: Read + Write>(stream: &mut S, mut pmd: Option<&mut Pmd>) -> Result<Message> {
let mut frag_opcode: Option<u8> = None;
let mut compressed = false;
let mut buf: Vec<u8> = Vec::new();
let mut rx: Vec<u8> = Vec::new();
loop {
let frame = read_frame(stream, &mut rx)?;
if frame.opcode >= 0x8 {
if frame.rsv1 {
return Err(Error::BadResponse("RSV1 set on a WS control frame".into()));
}
if !frame.fin {
return Err(Error::BadResponse(
"fragmented control frame (FIN=0 on a control opcode)".into(),
));
}
if frame.payload.len() > MAX_CONTROL_PAYLOAD {
return Err(Error::BadResponse(format!(
"control frame payload too large: {} bytes (max {MAX_CONTROL_PAYLOAD})",
frame.payload.len()
)));
}
match frame.opcode {
OPCODE_PING => {
let pong = build_client_frame(OPCODE_PONG, &frame.payload)?;
stream.write_all(&pong)?;
stream.flush()?;
continue;
}
OPCODE_PONG => continue,
OPCODE_CLOSE => {
if let Ok(close) = build_client_frame(OPCODE_CLOSE, &[]) {
let _ = stream.write_all(&close);
let _ = stream.flush();
}
return Ok(Message::Closed);
}
other => {
return Err(Error::BadResponse(format!(
"unknown WS control opcode 0x{other:x}"
)));
}
}
}
match frame.opcode {
OPCODE_TEXT | OPCODE_BINARY => {
if frag_opcode.is_some() {
return Err(Error::BadResponse(
"new data frame began while a fragmented message was in progress".into(),
));
}
if frame.rsv1 {
if pmd.is_none() {
return Err(Error::BadResponse(
"RSV1 set on a WS frame but permessage-deflate was not negotiated"
.into(),
));
}
compressed = true;
}
accumulate(&mut buf, &frame.payload)?;
if frame.fin {
return finish_data_message(frame.opcode, buf, compressed, pmd.as_deref_mut());
}
frag_opcode = Some(frame.opcode);
}
OPCODE_CONT => {
let opcode = frag_opcode.ok_or_else(|| {
Error::BadResponse("continuation frame with no message in progress".into())
})?;
if frame.rsv1 {
return Err(Error::BadResponse(
"RSV1 set on a WS continuation frame".into(),
));
}
accumulate(&mut buf, &frame.payload)?;
if frame.fin {
return finish_data_message(opcode, buf, compressed, pmd.as_deref_mut());
}
}
other => {
return Err(Error::BadResponse(format!("unknown WS opcode 0x{other:x}")));
}
}
}
}
fn finish_data_message(
opcode: u8,
payload: Vec<u8>,
compressed: bool,
pmd: Option<&mut Pmd>,
) -> Result<Message> {
let payload = if compressed {
let pmd = pmd.ok_or_else(|| {
Error::BadResponse("compressed WS message without negotiated permessage-deflate".into())
})?;
pmd.inflate_message(&payload)?
} else {
payload
};
if opcode == OPCODE_TEXT && std::str::from_utf8(&payload).is_err() {
return Err(Error::BadResponse(
"WS TEXT message payload is not valid UTF-8".into(),
));
}
Ok(Message::Data { opcode, payload })
}
fn accumulate(buf: &mut Vec<u8>, chunk: &[u8]) -> Result<()> {
let total = buf.len() as u64 + chunk.len() as u64;
if total > MAX_PAYLOAD_BYTES {
return Err(Error::BadResponse(format!(
"reassembled WS message too large: {total} bytes (max {MAX_PAYLOAD_BYTES})"
)));
}
buf.extend_from_slice(chunk);
Ok(())
}
fn validate_control_frame(frame: &Frame) -> Result<()> {
if frame.rsv1 {
return Err(Error::BadResponse("RSV1 set on a WS control frame".into()));
}
if !frame.fin {
return Err(Error::BadResponse(
"fragmented control frame (FIN=0 on a control opcode)".into(),
));
}
if frame.payload.len() > MAX_CONTROL_PAYLOAD {
return Err(Error::BadResponse(format!(
"control frame payload too large: {} bytes (max {MAX_CONTROL_PAYLOAD})",
frame.payload.len()
)));
}
Ok(())
}
fn parse_close_payload(payload: &[u8]) -> Option<WsClose> {
if payload.len() < 2 {
return None;
}
let code = u16::from_be_bytes([payload[0], payload[1]]);
let reason = String::from_utf8_lossy(&payload[2..]).into_owned();
Some(WsClose { code, reason })
}
fn send_message<S: Write>(
stream: &mut S,
opcode: u8,
payload: &[u8],
compress: bool,
) -> Result<()> {
if opcode != OPCODE_TEXT && opcode != OPCODE_BINARY {
return Err(Error::BadResponse(format!(
"send_message expects a data opcode (text/binary), got 0x{opcode:x}"
)));
}
let frame = if compress {
let compressed = deflate_message(payload)?;
build_client_frame_rsv1(opcode, &compressed)?
} else {
build_client_frame(opcode, payload)?
};
stream.write_all(&frame)?;
stream.flush()?;
Ok(())
}
fn read_data_and_close<S: Read + Write>(stream: &mut S, pmd: Option<&mut Pmd>) -> Result<Vec<u8>> {
let payload = match read_message(stream, pmd)? {
Message::Data { payload, .. } => payload,
Message::Closed => return Ok(Vec::new()),
};
if let Ok(close) = build_client_frame(OPCODE_CLOSE, &[]) {
let _ = stream.write_all(&close);
let _ = stream.flush();
}
Ok(payload)
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Frame {
fin: bool,
rsv1: bool,
opcode: u8,
payload: Vec<u8>,
}
fn read_frame<S: Read>(stream: &mut S, rx: &mut Vec<u8>) -> Result<Frame> {
fill_to(stream, rx, 2)?;
let h0 = rx[0];
let h1 = rx[1];
let fin = (h0 & 0x80) != 0;
let rsv1 = (h0 & 0x40) != 0;
if (h0 & 0x30) != 0 {
return Err(Error::BadResponse(
"non-zero RSV2/RSV3 bits on incoming WS frame".into(),
));
}
let opcode = h0 & 0x0F;
if (h1 & 0x80) != 0 {
return Err(Error::BadResponse(
"server-to-client frame is masked".into(),
));
}
let len7 = h1 & 0x7F;
let mut pos = 2usize;
let payload_len: u64 = match len7 {
0..=125 => len7 as u64,
126 => {
fill_to(stream, rx, pos + 2)?;
let v = u16::from_be_bytes([rx[pos], rx[pos + 1]]) as u64;
pos += 2;
v
}
127 => {
fill_to(stream, rx, pos + 8)?;
let v = u64::from_be_bytes(rx[pos..pos + 8].try_into().unwrap());
pos += 8;
v
}
_ => unreachable!(),
};
if payload_len > MAX_PAYLOAD_BYTES {
return Err(Error::BadResponse(format!(
"WS payload too large: {payload_len} bytes"
)));
}
let plen = payload_len as usize;
fill_to(stream, rx, pos + plen)?;
let payload = rx[pos..pos + plen].to_vec();
pos += plen;
rx.drain(..pos); Ok(Frame {
fin,
rsv1,
opcode,
payload,
})
}
fn fill_to<S: Read>(stream: &mut S, rx: &mut Vec<u8>, need: usize) -> Result<()> {
let mut tmp = [0u8; 16 * 1024];
while rx.len() < need {
let n = stream.read(&mut tmp)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
rx.extend_from_slice(&tmp[..n]);
}
Ok(())
}
fn build_client_frame(opcode: u8, payload: &[u8]) -> Result<Vec<u8>> {
build_client_frame_inner(true, opcode, payload, false)
}
fn build_client_frame_rsv1(opcode: u8, payload: &[u8]) -> Result<Vec<u8>> {
build_client_frame_inner(true, opcode, payload, true)
}
fn build_client_frame_inner(fin: bool, opcode: u8, payload: &[u8], rsv1: bool) -> Result<Vec<u8>> {
let mask: [u8; 4] = {
let r = random_16()?;
[r[0], r[1], r[2], r[3]]
};
let mut out = Vec::with_capacity(2 + 8 + 4 + payload.len());
let fin_bit = if fin { 0x80 } else { 0x00 };
let rsv1_bit = if rsv1 { 0x40 } else { 0x00 };
out.push(fin_bit | rsv1_bit | (opcode & 0x0F)); let n = payload.len();
if n < 126 {
out.push(0x80 | (n as u8));
} else if n <= u16::MAX as usize {
out.push(0x80 | 126);
out.extend_from_slice(&(n as u16).to_be_bytes());
} else {
out.push(0x80 | 127);
out.extend_from_slice(&(n as u64).to_be_bytes());
}
out.extend_from_slice(&mask);
let start = out.len();
out.extend_from_slice(payload);
for (i, b) in out[start..].iter_mut().enumerate() {
*b ^= mask[i & 3];
}
Ok(out)
}
fn derive_accept(key_b64: &str) -> String {
let mut h = Sha1::new();
h.update(key_b64.as_bytes());
h.update(WS_GUID.as_bytes());
let digest = h.finalize();
base64_encode(digest.as_ref())
}
fn random_16() -> Result<[u8; 16]> {
use purecrypto::rng::{OsRng, RngCore};
let mut out = [0u8; 16];
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
OsRng.fill_bytes(&mut out);
}))
.map_err(|_| Error::BadResponse("websocket: no secure entropy source available".into()))?;
Ok(out)
}
pub(crate) fn base64_encode(input: &[u8]) -> String {
const ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
let mut i = 0;
while i + 3 <= input.len() {
let b0 = input[i];
let b1 = input[i + 1];
let b2 = input[i + 2];
out.push(ALPHA[(b0 >> 2) as usize] as char);
out.push(ALPHA[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
out.push(ALPHA[(((b1 & 0x0F) << 2) | (b2 >> 6)) as usize] as char);
out.push(ALPHA[(b2 & 0x3F) as usize] as char);
i += 3;
}
let rem = input.len() - i;
if rem == 1 {
let b0 = input[i];
out.push(ALPHA[(b0 >> 2) as usize] as char);
out.push(ALPHA[((b0 & 0x03) << 4) as usize] as char);
out.push('=');
out.push('=');
} else if rem == 2 {
let b0 = input[i];
let b1 = input[i + 1];
out.push(ALPHA[(b0 >> 2) as usize] as char);
out.push(ALPHA[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
out.push(ALPHA[((b1 & 0x0F) << 2) as usize] as char);
out.push('=');
}
out
}
#[allow(dead_code)]
fn _tlsstream_in_scope_for_docs<S: Read + Write>(_: TlsStream<S>) {}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
struct MockStream {
inbound: Cursor<Vec<u8>>,
sent: Vec<u8>,
}
impl MockStream {
fn new(inbound: Vec<u8>) -> Self {
MockStream {
inbound: Cursor::new(inbound),
sent: Vec::new(),
}
}
}
impl Read for MockStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.inbound.read(buf)
}
}
impl Write for MockStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.sent.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
fn server_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
let b0 = if fin { 0x80 } else { 0x00 } | (opcode & 0x0F);
out.push(b0);
let n = payload.len();
if n < 126 {
out.push(n as u8);
} else if n <= u16::MAX as usize {
out.push(126);
out.extend_from_slice(&(n as u16).to_be_bytes());
} else {
out.push(127);
out.extend_from_slice(&(n as u64).to_be_bytes());
}
out.extend_from_slice(payload);
out
}
fn server_frame_rsv1(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
let mut out = server_frame(fin, opcode, payload);
out[0] |= 0x40; out
}
fn pmd_compress(data: &[u8]) -> Vec<u8> {
let mut out = compress_to_vec::<Deflate>(data).expect("deflate encode");
if out.ends_with(&DEFLATE_TAIL) {
out.truncate(out.len() - DEFLATE_TAIL.len());
}
out
}
fn test_pmd() -> Pmd {
Pmd {
client_no_context_takeover: true,
server_no_context_takeover: true,
decoder: Deflate::decoder(),
}
}
fn decode_sent(sent: &[u8]) -> Vec<(u8, Vec<u8>)> {
let mut frames = Vec::new();
let mut i = 0;
while i < sent.len() {
let opcode = sent[i] & 0x0F;
let masked = (sent[i + 1] & 0x80) != 0;
assert!(masked, "client frame must be masked");
let len7 = sent[i + 1] & 0x7F;
i += 2;
let len = match len7 {
0..=125 => len7 as usize,
126 => {
let l = u16::from_be_bytes([sent[i], sent[i + 1]]) as usize;
i += 2;
l
}
127 => {
let mut b = [0u8; 8];
b.copy_from_slice(&sent[i..i + 8]);
i += 8;
u64::from_be_bytes(b) as usize
}
_ => unreachable!(),
};
let mask = [sent[i], sent[i + 1], sent[i + 2], sent[i + 3]];
i += 4;
let mut payload = sent[i..i + len].to_vec();
i += len;
for (j, b) in payload.iter_mut().enumerate() {
*b ^= mask[j & 3];
}
frames.push((opcode, payload));
}
frames
}
struct DripStream {
per_read: Duration,
}
impl Read for DripStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
std::thread::sleep(self.per_read);
if buf.is_empty() {
return Ok(0);
}
buf[0] = b'X'; Ok(1)
}
}
#[test]
fn handshake_head_reads_up_to_terminator_without_overreading() {
let head = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\n\r\n";
let frame = [0x81u8, 0x02, b'h', b'i'];
let mut inbound = head.to_vec();
inbound.extend_from_slice(&frame);
let mut s = MockStream::new(inbound);
let got = read_handshake_head(&mut s, Duration::from_secs(60)).expect("reads header");
assert_eq!(
&got, head,
"must capture exactly the header, no frame bytes"
);
let mut rest = Vec::new();
s.read_to_end(&mut rest).expect("drain remainder");
assert_eq!(rest, frame, "first-frame bytes must not be consumed");
}
#[test]
fn handshake_head_deadline_trips_on_slow_drip() {
let mut s = DripStream {
per_read: Duration::from_millis(5),
};
let err = read_handshake_head(&mut s, Duration::from_millis(20))
.expect_err("slow drip must hit the deadline");
match err {
Error::BadResponse(m) => assert!(m.contains("timed out"), "unexpected message: {m}"),
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn base64_encode_hello() {
assert_eq!(base64_encode(b"hello"), "aGVsbG8=");
}
#[test]
fn base64_encode_rfc4648_vectors() {
assert_eq!(base64_encode(b""), "");
assert_eq!(base64_encode(b"f"), "Zg==");
assert_eq!(base64_encode(b"fo"), "Zm8=");
assert_eq!(base64_encode(b"foo"), "Zm9v");
assert_eq!(base64_encode(b"foob"), "Zm9vYg==");
assert_eq!(base64_encode(b"fooba"), "Zm9vYmE=");
assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
}
#[test]
fn rfc6455_accept_derivation() {
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
assert_eq!(derive_accept(key), expected);
}
#[test]
fn parse_short_text_frame() {
let bytes = [0x81, 0x05, b'h', b'e', b'l', b'l', b'o'];
let mut cur = Cursor::new(&bytes[..]);
let f = read_frame(&mut cur, &mut Vec::new()).expect("frame parses");
assert!(f.fin);
assert_eq!(f.opcode, OPCODE_TEXT);
assert_eq!(f.payload, b"hello");
}
#[test]
fn parse_16bit_length_frame() {
let mut bytes: Vec<u8> = vec![0x82, 126, 0x00, 200];
bytes.extend(std::iter::repeat_n(b'A', 200));
let mut cur = Cursor::new(bytes);
let f = read_frame(&mut cur, &mut Vec::new()).expect("frame parses");
assert_eq!(f.opcode, OPCODE_BINARY);
assert_eq!(f.payload.len(), 200);
assert!(f.payload.iter().all(|&b| b == b'A'));
}
#[test]
fn read_frame_resumes_after_midframe_error() {
struct Flaky {
data: Vec<u8>,
pos: usize,
fail_at: usize,
failed: bool,
}
impl Read for Flaky {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if !self.failed && self.pos >= self.fail_at {
self.failed = true;
return Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"timeout",
));
}
if self.pos >= self.data.len() {
return Ok(0);
}
buf[0] = self.data[self.pos];
self.pos += 1;
Ok(1)
}
}
let mut r = Flaky {
data: vec![0x81, 0x05, b'h', b'e', b'l', b'l', b'o'],
pos: 0,
fail_at: 3, failed: false,
};
let mut rx = Vec::new();
assert!(read_frame(&mut r, &mut rx).is_err());
let f = read_frame(&mut r, &mut rx).expect("resumes after the error");
assert_eq!(f.opcode, OPCODE_TEXT);
assert_eq!(f.payload, b"hello");
assert!(rx.is_empty(), "buffer fully consumed");
}
#[test]
fn reject_masked_server_frame() {
let bytes = [0x81, 0x80, 0, 0, 0, 0];
let mut cur = Cursor::new(&bytes[..]);
let err = read_frame(&mut cur, &mut Vec::new())
.expect_err("masked server frame must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn build_close_frame_short_payload() {
let frame = build_client_frame(OPCODE_CLOSE, &[]).unwrap();
assert_eq!(frame.len(), 6);
assert_eq!(frame[0], 0x88);
assert_eq!(frame[1], 0x80); }
#[test]
fn build_text_frame_masks_payload() {
let payload = b"hi";
let frame = build_client_frame(OPCODE_TEXT, payload).unwrap();
assert_eq!(frame.len(), 8);
assert_eq!(frame[0], 0x81);
assert_eq!(frame[1], 0x82);
let mask = [frame[2], frame[3], frame[4], frame[5]];
let unmasked: Vec<u8> = frame[6..]
.iter()
.enumerate()
.map(|(i, &b)| b ^ mask[i & 3])
.collect();
assert_eq!(unmasked, payload);
}
#[test]
fn build_frame_uses_16bit_length_for_medium_payload() {
let payload = vec![0u8; 200];
let frame = build_client_frame(OPCODE_BINARY, &payload).unwrap();
assert_eq!(frame[0], 0x82);
assert_eq!(frame[1], 0x80 | 126);
let len = u16::from_be_bytes([frame[2], frame[3]]);
assert_eq!(len, 200);
assert_eq!(frame.len(), 208);
}
#[test]
fn build_frame_uses_64bit_length_for_large_payload() {
let payload = vec![0u8; 70_000];
let frame = build_client_frame(OPCODE_BINARY, &payload).unwrap();
assert_eq!(frame[1], 0x80 | 127);
let len = u64::from_be_bytes([
frame[2], frame[3], frame[4], frame[5], frame[6], frame[7], frame[8], frame[9],
]);
assert_eq!(len, 70_000);
}
#[test]
fn random_16_is_nonzero() {
let r = random_16().expect("OS entropy available in the test environment");
assert_ne!(r, [0u8; 16]);
}
#[test]
fn random_16_is_not_constant() {
let a = random_16().unwrap();
let b = random_16().unwrap();
assert_ne!(a, b);
}
#[test]
fn reassembles_fragmented_text_message() {
let mut inbound = server_frame(false, OPCODE_TEXT, b"Hel");
inbound.extend(server_frame(false, OPCODE_CONT, b"lo "));
inbound.extend(server_frame(true, OPCODE_CONT, b"world"));
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("reassembles");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: b"Hello world".to_vec(),
}
);
}
#[test]
fn invalid_utf8_text_message_is_rejected() {
let inbound = server_frame(true, OPCODE_TEXT, &[0xff, 0xfe]);
let mut s = MockStream::new(inbound);
let err = read_message(&mut s, None).expect_err("invalid utf-8 TEXT must be rejected");
match err {
Error::BadResponse(m) => assert!(m.contains("UTF-8"), "unexpected message: {m}"),
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn invalid_utf8_binary_message_is_accepted() {
let inbound = server_frame(true, OPCODE_BINARY, &[0xff, 0xfe]);
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("binary is not utf-8 validated");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_BINARY,
payload: vec![0xff, 0xfe],
}
);
}
#[test]
fn valid_utf8_text_message_passes() {
let mut inbound = server_frame(false, OPCODE_TEXT, &[0xc3]);
inbound.extend(server_frame(true, OPCODE_CONT, &[0xa9]));
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("valid utf-8 across fragments");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: "é".as_bytes().to_vec(),
}
);
}
#[test]
fn ping_between_fragments_gets_pong_and_message_completes() {
let mut inbound = server_frame(false, OPCODE_TEXT, b"foo");
inbound.extend(server_frame(true, OPCODE_PING, b"pingdata"));
inbound.extend(server_frame(true, OPCODE_CONT, b"bar"));
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("completes despite ping");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: b"foobar".to_vec(),
}
);
let sent = decode_sent(&s.sent);
assert_eq!(sent.len(), 1, "exactly one pong expected");
assert_eq!(sent[0].0, OPCODE_PONG);
assert_eq!(sent[0].1, b"pingdata");
}
#[test]
fn close_is_answered_and_returns_closed() {
let inbound = server_frame(true, OPCODE_CLOSE, &[]);
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("handles close");
assert_eq!(msg, Message::Closed);
let sent = decode_sent(&s.sent);
assert_eq!(sent.len(), 1, "exactly one close reply expected");
assert_eq!(sent[0].0, OPCODE_CLOSE);
}
#[test]
fn unsolicited_pong_is_ignored_then_data_returns() {
let mut inbound = server_frame(true, OPCODE_PONG, b"x");
inbound.extend(server_frame(true, OPCODE_TEXT, b"hi"));
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("ignores pong");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: b"hi".to_vec(),
}
);
assert!(s.sent.is_empty(), "unsolicited pong must not be answered");
}
#[test]
fn send_message_produces_masked_frame() {
let mut s = MockStream::new(Vec::new());
send_message(&mut s, OPCODE_TEXT, b"hello", false).expect("sends");
assert_eq!(s.sent[0], 0x81);
assert_eq!(s.sent[1], 0x80 | 5);
assert_eq!(s.sent.len(), 2 + 4 + 5);
let decoded = decode_sent(&s.sent);
assert_eq!(decoded, vec![(OPCODE_TEXT, b"hello".to_vec())]);
}
#[test]
fn send_message_rejects_control_opcode() {
let mut s = MockStream::new(Vec::new());
let err = send_message(&mut s, OPCODE_PING, b"x", false)
.expect_err("control opcode must be rejected for send_message");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
assert!(s.sent.is_empty());
}
#[test]
fn oversized_cumulative_fragmented_payload_is_rejected() {
let mut buf = vec![0u8; (MAX_PAYLOAD_BYTES - 1) as usize];
accumulate(&mut buf, &[0u8]).expect("exactly at the cap is allowed");
let err = accumulate(&mut buf, &[0u8]).expect_err("over the cap must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn fragmented_control_frame_is_rejected() {
let inbound = server_frame(false, OPCODE_PING, b"x");
let mut s = MockStream::new(inbound);
let err = read_message(&mut s, None).expect_err("fragmented control must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn oversized_control_frame_is_rejected() {
let inbound = server_frame(true, OPCODE_PING, &[0u8; 126]);
let mut s = MockStream::new(inbound);
let err = read_message(&mut s, None).expect_err("oversized control must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn new_data_frame_during_fragmentation_is_rejected() {
let mut inbound = server_frame(false, OPCODE_TEXT, b"a");
inbound.extend(server_frame(true, OPCODE_TEXT, b"b"));
let mut s = MockStream::new(inbound);
let err =
read_message(&mut s, None).expect_err("interleaved new data frame must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn lone_continuation_frame_is_rejected() {
let inbound = server_frame(true, OPCODE_CONT, b"x");
let mut s = MockStream::new(inbound);
let err = read_message(&mut s, None).expect_err("lone continuation must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn read_data_and_close_returns_reassembled_payload_and_sends_close() {
let mut inbound = server_frame(false, OPCODE_BINARY, &[1, 2, 3]);
inbound.extend(server_frame(true, OPCODE_CONT, &[4, 5]));
let mut s = MockStream::new(inbound);
let payload = read_data_and_close(&mut s, None).expect("reads message");
assert_eq!(payload, vec![1, 2, 3, 4, 5]);
let sent = decode_sent(&s.sent);
assert_eq!(sent.last().map(|f| f.0), Some(OPCODE_CLOSE));
}
fn decode_first_frame_with_rsv1(sent: &[u8]) -> (u8, bool, Vec<u8>) {
let opcode = sent[0] & 0x0F;
let rsv1 = (sent[0] & 0x40) != 0;
let masked = (sent[1] & 0x80) != 0;
assert!(masked, "client frame must be masked");
let len7 = sent[1] & 0x7F;
let mut i = 2;
let len = match len7 {
0..=125 => len7 as usize,
126 => {
let l = u16::from_be_bytes([sent[i], sent[i + 1]]) as usize;
i += 2;
l
}
127 => {
let mut b = [0u8; 8];
b.copy_from_slice(&sent[i..i + 8]);
i += 8;
u64::from_be_bytes(b) as usize
}
_ => unreachable!(),
};
let mask = [sent[i], sent[i + 1], sent[i + 2], sent[i + 3]];
i += 4;
let mut payload = sent[i..i + len].to_vec();
for (j, b) in payload.iter_mut().enumerate() {
*b ^= mask[j & 3];
}
(opcode, rsv1, payload)
}
fn pmd_inflate(compressed: &[u8]) -> Vec<u8> {
let mut input = compressed.to_vec();
input.extend_from_slice(&DEFLATE_TAIL);
let mut dec = Deflate::decoder();
let mut out = Vec::new();
let mut scratch = vec![0u8; 32 * 1024];
let mut consumed = 0usize;
loop {
let before_c = consumed;
let before_w = out.len();
let (p, status) = dec
.decode(&input[consumed..], &mut scratch)
.expect("inflate");
out.extend_from_slice(&scratch[..p.written]);
consumed += p.consumed;
match status {
Status::StreamEnd => break,
Status::OutputFull => continue,
Status::InputEmpty => {
if consumed >= input.len() || (consumed == before_c && out.len() == before_w) {
break;
}
}
}
}
out
}
#[test]
fn handshake_offers_permessage_deflate() {
struct Recorder {
request: Vec<u8>,
response: Cursor<Vec<u8>>,
}
impl Read for Recorder {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.response.read(buf)
}
}
impl Write for Recorder {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.request.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let resp = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: wrong\r\n\r\n".to_vec();
let mut rec = Recorder {
request: Vec::new(),
response: Cursor::new(resp),
};
let url = Url::parse("ws://example.com/chat").expect("url");
let _ = handshake(&mut rec, &url, &[]); let req = String::from_utf8(rec.request).expect("utf8 request");
assert!(
req.contains("Sec-WebSocket-Extensions: permessage-deflate"),
"request must offer permessage-deflate, got:\n{req}"
);
assert!(req.contains("client_no_context_takeover"));
assert!(req.contains("server_no_context_takeover"));
}
#[test]
fn parse_pmd_response_enables_compression() {
let pmd = parse_pmd_response("permessage-deflate; server_no_context_takeover")
.expect("permessage-deflate accepted");
assert!(pmd.server_no_context_takeover);
assert!(!pmd.client_no_context_takeover);
let pmd2 = parse_pmd_response(
"permessage-deflate; client_no_context_takeover; server_no_context_takeover",
)
.expect("accepted with both flags");
assert!(pmd2.client_no_context_takeover);
assert!(pmd2.server_no_context_takeover);
}
#[test]
fn parse_pmd_response_without_extension_is_none() {
assert!(parse_pmd_response("some-other-extension").is_none());
assert!(parse_pmd_response("").is_none());
assert!(parse_pmd_response("foo; bar=1").is_none());
}
#[test]
fn inflate_compressed_message_decodes_to_original() {
let original = b"the quick brown fox jumps over the lazy dog, the quick brown fox";
let compressed = pmd_compress(original);
assert!(
compressed.len() < original.len(),
"fixture should actually compress"
);
let inbound = server_frame_rsv1(true, OPCODE_TEXT, &compressed);
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let msg = read_message(&mut s, Some(&mut pmd)).expect("decodes compressed message");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: original.to_vec(),
}
);
}
#[test]
fn send_message_compressed_round_trips() {
let input = b"hello hello hello permessage-deflate round trip";
let mut s = MockStream::new(Vec::new());
send_message(&mut s, OPCODE_TEXT, input, true).expect("sends compressed");
let (opcode, rsv1, payload) = decode_first_frame_with_rsv1(&s.sent);
assert_eq!(opcode, OPCODE_TEXT);
assert!(rsv1, "compressed frame must have RSV1 set");
assert_ne!(payload, input, "payload should be compressed, not raw");
assert_eq!(pmd_inflate(&payload), input);
}
#[test]
fn rsv1_without_negotiation_is_rejected() {
let inbound = server_frame_rsv1(true, OPCODE_TEXT, b"whatever");
let mut s = MockStream::new(inbound);
let err = read_message(&mut s, None).expect_err("RSV1 without PMD must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn rsv2_or_rsv3_always_rejected() {
let mut inbound = server_frame(true, OPCODE_TEXT, b"x");
inbound[0] |= 0x20;
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err = read_message(&mut s, Some(&mut pmd)).expect_err("RSV2 must be rejected");
assert!(matches!(err, Error::BadResponse(_)));
let mut inbound = server_frame(true, OPCODE_TEXT, b"x");
inbound[0] |= 0x10;
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err = read_message(&mut s, Some(&mut pmd)).expect_err("RSV3 must be rejected");
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn rsv1_on_control_frame_is_rejected() {
let inbound = server_frame_rsv1(true, OPCODE_PING, b"x");
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err =
read_message(&mut s, Some(&mut pmd)).expect_err("RSV1 on control must be rejected");
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn compressed_bomb_exceeding_cap_is_rejected() {
let huge = vec![0u8; (MAX_PAYLOAD_BYTES + (1 << 20)) as usize];
let compressed = pmd_compress(&huge);
assert!(
(compressed.len() as u64) < MAX_PAYLOAD_BYTES,
"fixture must be much smaller than the cap"
);
let inbound = server_frame_rsv1(true, OPCODE_BINARY, &compressed);
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err =
read_message(&mut s, Some(&mut pmd)).expect_err("compression bomb must be rejected");
match err {
Error::BadResponse(msg) => {
assert!(msg.contains("permessage-deflate"), "got {msg:?}")
}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn fragmented_compressed_message_reassembles_and_inflates() {
let original = b"fragmented compressed payload, split across two frames on the wire";
let compressed = pmd_compress(original);
assert!(compressed.len() >= 4, "need enough bytes to split");
let mid = compressed.len() / 2;
let mut inbound = server_frame_rsv1(false, OPCODE_TEXT, &compressed[..mid]);
inbound.extend(server_frame(true, OPCODE_CONT, &compressed[mid..]));
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let msg = read_message(&mut s, Some(&mut pmd)).expect("reassembles + inflates");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: original.to_vec(),
}
);
}
#[test]
fn rsv1_on_continuation_frame_is_rejected() {
let original = b"continuation rsv1 should be rejected here";
let compressed = pmd_compress(original);
let mid = compressed.len() / 2;
let mut inbound = server_frame_rsv1(false, OPCODE_TEXT, &compressed[..mid]);
inbound.extend(server_frame_rsv1(true, OPCODE_CONT, &compressed[mid..]));
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err = read_message(&mut s, Some(&mut pmd))
.expect_err("RSV1 on continuation must be rejected");
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn uncompressed_message_passes_through_when_pmd_negotiated() {
let inbound = server_frame(true, OPCODE_TEXT, b"plain text, no rsv1");
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let msg = read_message(&mut s, Some(&mut pmd)).expect("plain message");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: b"plain text, no rsv1".to_vec(),
}
);
}
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct SharedMock {
inbound: Arc<Mutex<Cursor<Vec<u8>>>>,
sent: Arc<Mutex<Vec<u8>>>,
}
impl SharedMock {
fn new(inbound: Vec<u8>) -> Self {
SharedMock {
inbound: Arc::new(Mutex::new(Cursor::new(inbound))),
sent: Arc::new(Mutex::new(Vec::new())),
}
}
fn sent(&self) -> Vec<u8> {
self.sent.lock().unwrap().clone()
}
}
impl Read for SharedMock {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.inbound.lock().unwrap().read(buf)
}
}
impl Write for SharedMock {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.sent.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
fn ws_over(mock: SharedMock, pmd: Option<Pmd>) -> WebSocket {
let compression = pmd.is_some();
let transport = Arc::new(WsTransport::new(WsTransportKind::Shared(Mutex::new(
Box::new(mock),
))));
let send_closed = Arc::new(AtomicBool::new(false));
let recv_closed = Arc::new(AtomicBool::new(false));
WebSocket {
reader: WsReader {
transport: Arc::clone(&transport),
rxbuf: Vec::new(),
pmd,
compression,
auto_pong: true,
send_closed: Arc::clone(&send_closed),
recv_closed: Arc::clone(&recv_closed),
shutdown: None,
},
writer: WsWriter {
transport,
compress: compression,
send_closed,
recv_closed,
shutdown: None,
},
subprotocol: None,
}
}
#[test]
fn websocket_recv_maps_text_binary_and_close() {
let mut inbound = Vec::new();
inbound.extend(server_frame(true, OPCODE_TEXT, b"hello"));
inbound.extend(server_frame(true, OPCODE_BINARY, &[1, 2, 3]));
inbound.extend(server_frame(true, OPCODE_CLOSE, &[]));
let mut ws = ws_over(SharedMock::new(inbound), None);
assert_eq!(ws.recv().unwrap(), Some(WsMessage::Text("hello".into())));
assert_eq!(ws.recv().unwrap(), Some(WsMessage::Binary(vec![1, 2, 3])));
assert_eq!(ws.recv().unwrap(), None);
assert!(ws.is_closed());
assert_eq!(ws.recv().unwrap(), None);
}
#[test]
fn websocket_send_text_writes_one_masked_text_frame() {
let mock = SharedMock::new(Vec::new());
let mut ws = ws_over(mock.clone(), None);
ws.send_text("hi there").unwrap();
let frames = decode_sent(&mock.sent());
assert_eq!(frames.len(), 1);
assert_eq!(frames[0].0, OPCODE_TEXT);
assert_eq!(frames[0].1, b"hi there");
}
#[test]
fn websocket_close_sends_close_frame_and_blocks_further_sends() {
let mock = SharedMock::new(Vec::new());
let mut ws = ws_over(mock.clone(), None);
ws.close().unwrap();
let frames = decode_sent(&mock.sent());
assert_eq!(frames.len(), 1);
assert_eq!(frames[0].0, OPCODE_CLOSE);
assert!(ws.is_closed());
assert!(ws.send_text("nope").is_err());
ws.close().unwrap();
}
#[test]
fn websocket_recv_autoresponds_ping_with_pong() {
let mut inbound = Vec::new();
inbound.extend(server_frame(true, OPCODE_PING, b"pingdata"));
inbound.extend(server_frame(true, OPCODE_TEXT, b"after ping"));
let mock = SharedMock::new(inbound);
let mut ws = ws_over(mock.clone(), None);
assert_eq!(
ws.recv().unwrap(),
Some(WsMessage::Text("after ping".into()))
);
let frames = decode_sent(&mock.sent());
assert!(
frames
.iter()
.any(|(op, pl)| *op == OPCODE_PONG && pl.as_slice() == b"pingdata"),
"expected an automatic PONG echoing the ping data, got {frames:?}"
);
}
#[test]
fn websocket_recv_inflates_compressed_message() {
let payload = "compress me ".repeat(8);
let inbound = server_frame_rsv1(true, OPCODE_TEXT, &pmd_compress(payload.as_bytes()));
let mut ws = ws_over(SharedMock::new(inbound), Some(test_pmd()));
assert!(ws.compression_enabled());
assert_eq!(ws.recv().unwrap(), Some(WsMessage::Text(payload)));
}
#[test]
fn websocket_recv_event_surfaces_ping_pong_and_close_code() {
let mut inbound = Vec::new();
inbound.extend(server_frame(true, OPCODE_PING, b"pp"));
inbound.extend(server_frame(true, OPCODE_PONG, b"qq"));
let mut close_payload = 1001u16.to_be_bytes().to_vec();
close_payload.extend_from_slice(b"bye");
inbound.extend(server_frame(true, OPCODE_CLOSE, &close_payload));
let mock = SharedMock::new(inbound);
let mut ws = ws_over(mock.clone(), None);
assert_eq!(ws.recv_event().unwrap(), WsEvent::Ping(b"pp".to_vec()));
assert_eq!(ws.recv_event().unwrap(), WsEvent::Pong(b"qq".to_vec()));
assert_eq!(
ws.recv_event().unwrap(),
WsEvent::Close(Some(WsClose {
code: 1001,
reason: "bye".to_string(),
}))
);
assert!(ws.is_closed());
let frames = decode_sent(&mock.sent());
assert!(frames
.iter()
.any(|(op, pl)| *op == OPCODE_PONG && pl == b"pp"));
}
#[test]
fn websocket_no_autopong_when_disabled() {
let inbound = server_frame(true, OPCODE_PING, b"hi");
let mock = SharedMock::new(inbound);
let mut ws = ws_over(mock.clone(), None);
ws.set_auto_pong(false);
assert_eq!(ws.recv_event().unwrap(), WsEvent::Ping(b"hi".to_vec()));
let frames = decode_sent(&mock.sent());
assert!(
!frames.iter().any(|(op, _)| *op == OPCODE_PONG),
"auto-pong was disabled but a PONG was sent: {frames:?}"
);
}
#[test]
fn websocket_send_pong_and_close_with_write_expected_frames() {
let mock = SharedMock::new(Vec::new());
let mut ws = ws_over(mock.clone(), None);
ws.send_pong(b"keepalive").unwrap();
ws.close_with(1000, "done").unwrap();
let frames = decode_sent(&mock.sent());
assert_eq!(frames.len(), 2);
assert_eq!(frames[0].0, OPCODE_PONG);
assert_eq!(frames[0].1, b"keepalive");
assert_eq!(frames[1].0, OPCODE_CLOSE);
let mut expected = 1000u16.to_be_bytes().to_vec();
expected.extend_from_slice(b"done");
assert_eq!(frames[1].1, expected);
assert!(ws.is_closed());
}
#[test]
fn websocket_send_frame_fragments_a_message() {
let mock = SharedMock::new(Vec::new());
let mut ws = ws_over(mock.clone(), None);
ws.send_frame(false, WsOpcode::Text, b"ab").unwrap();
ws.send_frame(true, WsOpcode::Continuation, b"cd").unwrap();
let raw = mock.sent();
assert_eq!(raw[0] & 0x80, 0x00, "first fragment must have FIN=0");
let frames = decode_sent(&raw);
assert_eq!(frames.len(), 2);
assert_eq!(frames[0].0, OPCODE_TEXT);
assert_eq!(frames[0].1, b"ab");
assert_eq!(frames[1].0, OPCODE_CONT);
assert_eq!(frames[1].1, b"cd");
assert!(ws.send_frame(false, WsOpcode::Ping, b"x").is_err());
}
#[test]
fn websocket_recv_frame_returns_raw_frame() {
let inbound = server_frame(true, OPCODE_TEXT, b"raw");
let mut ws = ws_over(SharedMock::new(inbound), None);
let f = ws.recv_frame().unwrap();
assert_eq!(
f,
WsFrame {
fin: true,
opcode: WsOpcode::Text,
payload: b"raw".to_vec(),
}
);
}
}