use std::io::{Read, Write};
use std::net::TcpStream;
use std::time::Duration;
use compcol::deflate::Deflate;
use compcol::limit::LimitedDecoder;
use compcol::vec::compress_to_vec;
use compcol::{Algorithm, Decoder, Status};
use purecrypto::hash::{Digest, Sha1};
use crate::error::{Error, Result};
use crate::tls::TlsStream;
use crate::url::Url;
const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const DEFLATE_TAIL: [u8; 4] = [0x00, 0x00, 0xFF, 0xFF];
const OPCODE_CONT: u8 = 0x0;
const OPCODE_TEXT: u8 = 0x1;
const OPCODE_BINARY: u8 = 0x2;
const OPCODE_CLOSE: u8 = 0x8;
const OPCODE_PING: u8 = 0x9;
const OPCODE_PONG: u8 = 0xA;
const MAX_PAYLOAD_BYTES: u64 = 64 * 1024 * 1024;
const PMD_OFFER: &str =
"permessage-deflate; client_no_context_takeover; server_no_context_takeover; client_max_window_bits";
struct Pmd {
#[allow(dead_code)]
client_no_context_takeover: bool,
server_no_context_takeover: bool,
decoder: <Deflate as Algorithm>::Decoder,
}
impl Pmd {
fn inflate_message(&mut self, compressed: &[u8]) -> Result<Vec<u8>> {
if self.server_no_context_takeover {
self.decoder.reset();
}
let mut input = Vec::with_capacity(compressed.len() + DEFLATE_TAIL.len());
input.extend_from_slice(compressed);
input.extend_from_slice(&DEFLATE_TAIL);
let taken = std::mem::replace(&mut self.decoder, Deflate::decoder());
let mut limited = LimitedDecoder::new(taken, MAX_PAYLOAD_BYTES);
let result = Self::run_inflate(&mut limited, &input);
self.decoder = limited.into_inner();
result
}
fn run_inflate(
limited: &mut LimitedDecoder<<Deflate as Algorithm>::Decoder>,
input: &[u8],
) -> Result<Vec<u8>> {
let mut out: Vec<u8> = Vec::new();
let mut scratch = vec![0u8; 64 * 1024];
let mut consumed = 0usize;
loop {
let before_consumed = consumed;
let before_written = out.len();
let (p, status) = limited
.decode(&input[consumed..], &mut scratch)
.map_err(|e| {
Error::BadResponse(format!("permessage-deflate inflate failed: {e}"))
})?;
out.extend_from_slice(&scratch[..p.written]);
consumed += p.consumed;
match status {
Status::StreamEnd => break,
Status::OutputFull => continue,
Status::InputEmpty => {
if consumed >= input.len()
|| (consumed == before_consumed && out.len() == before_written)
{
break;
}
}
}
}
Ok(out)
}
}
fn deflate_message(payload: &[u8]) -> Result<Vec<u8>> {
let mut out = compress_to_vec::<Deflate>(payload)
.map_err(|e| Error::BadResponse(format!("permessage-deflate deflate failed: {e}")))?;
if out.ends_with(&DEFLATE_TAIL) {
out.truncate(out.len() - DEFLATE_TAIL.len());
}
Ok(out)
}
pub fn fetch(url: &Url) -> Result<Vec<u8>> {
match url.scheme.as_str() {
"ws" => {
let mut sock = tcp_connect(url)?;
let mut pmd = handshake(&mut sock, url)?;
read_data_and_close(&mut sock, pmd.as_mut())
}
"wss" => {
let tcp = tcp_connect(url)?;
let mut tls = crate::tls::connect_over(tcp, &url.host)?;
let mut pmd = handshake(&mut tls, url)?;
read_data_and_close(&mut tls, pmd.as_mut())
}
other => Err(Error::UnsupportedScheme(other.to_string())),
}
}
fn tcp_connect(url: &Url) -> Result<TcpStream> {
let addr = format!("{}:{}", url.host, url.port);
let addrs: Vec<_> = std::net::ToSocketAddrs::to_socket_addrs(&addr)?.collect();
let first = addrs
.into_iter()
.next()
.ok_or_else(|| Error::InvalidUrl(url.host.clone()))?;
let stream = TcpStream::connect_timeout(&first, Duration::from_secs(30))?;
stream.set_read_timeout(Some(Duration::from_secs(60)))?;
stream.set_write_timeout(Some(Duration::from_secs(60)))?;
Ok(stream)
}
fn handshake<S: Read + Write>(stream: &mut S, url: &Url) -> Result<Option<Pmd>> {
let key_bytes: [u8; 16] = random_16()?;
let key_b64 = base64_encode(&key_bytes);
let host_header =
if (url.scheme == "ws" && url.port == 80) || (url.scheme == "wss" && url.port == 443) {
url.host.clone()
} else {
format!("{}:{}", url.host, url.port)
};
let path = if url.path.is_empty() {
"/"
} else {
url.path.as_str()
};
let req = format!(
"GET {path} HTTP/1.1\r\n\
Host: {host_header}\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {key_b64}\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Extensions: {PMD_OFFER}\r\n\
\r\n"
);
stream.write_all(req.as_bytes())?;
stream.flush()?;
let mut buf: Vec<u8> = Vec::with_capacity(512);
loop {
let mut b = [0u8; 1];
let n = stream.read(&mut b)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
buf.push(b[0]);
if buf.len() >= 4 && &buf[buf.len() - 4..] == b"\r\n\r\n" {
break;
}
if buf.len() > 64 * 1024 {
return Err(Error::BadResponse("handshake response too large".into()));
}
}
let head = std::str::from_utf8(&buf)
.map_err(|_| Error::BadResponse("non-utf8 handshake response".into()))?;
let mut lines = head.split("\r\n");
let status_line = lines
.next()
.ok_or_else(|| Error::BadResponse("empty handshake response".into()))?;
if !(status_line.starts_with("HTTP/1.1 101") || status_line.starts_with("HTTP/1.0 101")) {
return Err(Error::BadResponse(format!(
"expected 101 Switching Protocols, got: {status_line:?}"
)));
}
let mut upgrade_ok = false;
let mut connection_ok = false;
let mut accept_value: Option<String> = None;
let mut extensions_value: Option<String> = None;
for line in lines {
if line.is_empty() {
break;
}
let (k, v) = match line.split_once(':') {
Some((k, v)) => (k.trim(), v.trim()),
None => continue,
};
if k.eq_ignore_ascii_case("upgrade") {
if v.eq_ignore_ascii_case("websocket") {
upgrade_ok = true;
}
} else if k.eq_ignore_ascii_case("connection") {
if v.split(',')
.any(|t| t.trim().eq_ignore_ascii_case("upgrade"))
{
connection_ok = true;
}
} else if k.eq_ignore_ascii_case("sec-websocket-accept") {
accept_value = Some(v.to_string());
} else if k.eq_ignore_ascii_case("sec-websocket-extensions") {
match &mut extensions_value {
Some(existing) => {
existing.push_str(", ");
existing.push_str(v);
}
None => extensions_value = Some(v.to_string()),
}
}
}
if !upgrade_ok {
return Err(Error::BadResponse(
"missing or wrong Upgrade header in handshake response".into(),
));
}
if !connection_ok {
return Err(Error::BadResponse(
"missing or wrong Connection header in handshake response".into(),
));
}
let accept = accept_value
.ok_or_else(|| Error::BadResponse("missing Sec-WebSocket-Accept header".into()))?;
let expected = derive_accept(&key_b64);
if accept != expected {
return Err(Error::BadResponse(format!(
"Sec-WebSocket-Accept mismatch: got {accept:?}, expected {expected:?}"
)));
}
let pmd = extensions_value.as_deref().and_then(parse_pmd_response);
Ok(pmd)
}
fn parse_pmd_response(value: &str) -> Option<Pmd> {
for ext in value.split(',') {
let mut params = ext.split(';').map(str::trim);
let name = params.next()?;
if !name.eq_ignore_ascii_case("permessage-deflate") {
continue;
}
let mut client_no_context_takeover = false;
let mut server_no_context_takeover = false;
for param in params {
if param.is_empty() {
continue;
}
let token = param.split('=').next().unwrap_or(param).trim();
if token.eq_ignore_ascii_case("client_no_context_takeover") {
client_no_context_takeover = true;
} else if token.eq_ignore_ascii_case("server_no_context_takeover") {
server_no_context_takeover = true;
}
}
return Some(Pmd {
client_no_context_takeover,
server_no_context_takeover,
decoder: Deflate::decoder(),
});
}
None
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Message {
Data { opcode: u8, payload: Vec<u8> },
Closed,
}
const MAX_CONTROL_PAYLOAD: usize = 125;
fn read_message<S: Read + Write>(stream: &mut S, mut pmd: Option<&mut Pmd>) -> Result<Message> {
let mut frag_opcode: Option<u8> = None;
let mut compressed = false;
let mut buf: Vec<u8> = Vec::new();
loop {
let frame = read_frame(stream)?;
if frame.opcode >= 0x8 {
if frame.rsv1 {
return Err(Error::BadResponse("RSV1 set on a WS control frame".into()));
}
if !frame.fin {
return Err(Error::BadResponse(
"fragmented control frame (FIN=0 on a control opcode)".into(),
));
}
if frame.payload.len() > MAX_CONTROL_PAYLOAD {
return Err(Error::BadResponse(format!(
"control frame payload too large: {} bytes (max {MAX_CONTROL_PAYLOAD})",
frame.payload.len()
)));
}
match frame.opcode {
OPCODE_PING => {
let pong = build_client_frame(OPCODE_PONG, &frame.payload)?;
stream.write_all(&pong)?;
stream.flush()?;
continue;
}
OPCODE_PONG => continue,
OPCODE_CLOSE => {
if let Ok(close) = build_client_frame(OPCODE_CLOSE, &[]) {
let _ = stream.write_all(&close);
let _ = stream.flush();
}
return Ok(Message::Closed);
}
other => {
return Err(Error::BadResponse(format!(
"unknown WS control opcode 0x{other:x}"
)));
}
}
}
match frame.opcode {
OPCODE_TEXT | OPCODE_BINARY => {
if frag_opcode.is_some() {
return Err(Error::BadResponse(
"new data frame began while a fragmented message was in progress".into(),
));
}
if frame.rsv1 {
if pmd.is_none() {
return Err(Error::BadResponse(
"RSV1 set on a WS frame but permessage-deflate was not negotiated"
.into(),
));
}
compressed = true;
}
accumulate(&mut buf, &frame.payload)?;
if frame.fin {
return finish_data_message(frame.opcode, buf, compressed, pmd.as_deref_mut());
}
frag_opcode = Some(frame.opcode);
}
OPCODE_CONT => {
let opcode = frag_opcode.ok_or_else(|| {
Error::BadResponse("continuation frame with no message in progress".into())
})?;
if frame.rsv1 {
return Err(Error::BadResponse(
"RSV1 set on a WS continuation frame".into(),
));
}
accumulate(&mut buf, &frame.payload)?;
if frame.fin {
return finish_data_message(opcode, buf, compressed, pmd.as_deref_mut());
}
}
other => {
return Err(Error::BadResponse(format!("unknown WS opcode 0x{other:x}")));
}
}
}
}
fn finish_data_message(
opcode: u8,
payload: Vec<u8>,
compressed: bool,
pmd: Option<&mut Pmd>,
) -> Result<Message> {
if compressed {
let pmd = pmd.ok_or_else(|| {
Error::BadResponse("compressed WS message without negotiated permessage-deflate".into())
})?;
let inflated = pmd.inflate_message(&payload)?;
Ok(Message::Data {
opcode,
payload: inflated,
})
} else {
Ok(Message::Data { opcode, payload })
}
}
fn accumulate(buf: &mut Vec<u8>, chunk: &[u8]) -> Result<()> {
let total = buf.len() as u64 + chunk.len() as u64;
if total > MAX_PAYLOAD_BYTES {
return Err(Error::BadResponse(format!(
"reassembled WS message too large: {total} bytes (max {MAX_PAYLOAD_BYTES})"
)));
}
buf.extend_from_slice(chunk);
Ok(())
}
#[allow(dead_code)]
fn send_message<S: Write>(
stream: &mut S,
opcode: u8,
payload: &[u8],
pmd: Option<&mut Pmd>,
) -> Result<()> {
if opcode != OPCODE_TEXT && opcode != OPCODE_BINARY {
return Err(Error::BadResponse(format!(
"send_message expects a data opcode (text/binary), got 0x{opcode:x}"
)));
}
let frame = if pmd.is_some() {
let compressed = deflate_message(payload)?;
build_client_frame_rsv1(opcode, &compressed)?
} else {
build_client_frame(opcode, payload)?
};
stream.write_all(&frame)?;
stream.flush()?;
Ok(())
}
fn read_data_and_close<S: Read + Write>(stream: &mut S, pmd: Option<&mut Pmd>) -> Result<Vec<u8>> {
let payload = match read_message(stream, pmd)? {
Message::Data { payload, .. } => payload,
Message::Closed => return Ok(Vec::new()),
};
if let Ok(close) = build_client_frame(OPCODE_CLOSE, &[]) {
let _ = stream.write_all(&close);
let _ = stream.flush();
}
Ok(payload)
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Frame {
fin: bool,
rsv1: bool,
opcode: u8,
payload: Vec<u8>,
}
fn read_frame<S: Read>(stream: &mut S) -> Result<Frame> {
let mut header = [0u8; 2];
read_exact(stream, &mut header)?;
let fin = (header[0] & 0x80) != 0;
let rsv1 = (header[0] & 0x40) != 0;
if (header[0] & 0x30) != 0 {
return Err(Error::BadResponse(
"non-zero RSV2/RSV3 bits on incoming WS frame".into(),
));
}
let opcode = header[0] & 0x0F;
let masked = (header[1] & 0x80) != 0;
if masked {
return Err(Error::BadResponse(
"server-to-client frame is masked".into(),
));
}
let len7 = header[1] & 0x7F;
let payload_len: u64 = match len7 {
0..=125 => len7 as u64,
126 => {
let mut ext = [0u8; 2];
read_exact(stream, &mut ext)?;
u16::from_be_bytes(ext) as u64
}
127 => {
let mut ext = [0u8; 8];
read_exact(stream, &mut ext)?;
u64::from_be_bytes(ext)
}
_ => unreachable!(),
};
if payload_len > MAX_PAYLOAD_BYTES {
return Err(Error::BadResponse(format!(
"WS payload too large: {payload_len} bytes"
)));
}
let mut payload = vec![0u8; payload_len as usize];
if payload_len > 0 {
read_exact(stream, &mut payload)?;
}
Ok(Frame {
fin,
rsv1,
opcode,
payload,
})
}
fn build_client_frame(opcode: u8, payload: &[u8]) -> Result<Vec<u8>> {
build_client_frame_inner(opcode, payload, false)
}
fn build_client_frame_rsv1(opcode: u8, payload: &[u8]) -> Result<Vec<u8>> {
build_client_frame_inner(opcode, payload, true)
}
fn build_client_frame_inner(opcode: u8, payload: &[u8], rsv1: bool) -> Result<Vec<u8>> {
let mask: [u8; 4] = {
let r = random_16()?;
[r[0], r[1], r[2], r[3]]
};
let mut out = Vec::with_capacity(2 + 8 + 4 + payload.len());
let rsv1_bit = if rsv1 { 0x40 } else { 0x00 };
out.push(0x80 | rsv1_bit | (opcode & 0x0F)); let n = payload.len();
if n < 126 {
out.push(0x80 | (n as u8));
} else if n <= u16::MAX as usize {
out.push(0x80 | 126);
out.extend_from_slice(&(n as u16).to_be_bytes());
} else {
out.push(0x80 | 127);
out.extend_from_slice(&(n as u64).to_be_bytes());
}
out.extend_from_slice(&mask);
let start = out.len();
out.extend_from_slice(payload);
for (i, b) in out[start..].iter_mut().enumerate() {
*b ^= mask[i & 3];
}
Ok(out)
}
fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<()> {
let mut got = 0;
while got < buf.len() {
let n = r.read(&mut buf[got..])?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
got += n;
}
Ok(())
}
fn derive_accept(key_b64: &str) -> String {
let mut h = Sha1::new();
h.update(key_b64.as_bytes());
h.update(WS_GUID.as_bytes());
let digest = h.finalize();
base64_encode(digest.as_ref())
}
fn random_16() -> Result<[u8; 16]> {
use purecrypto::rng::{OsRng, RngCore};
let mut out = [0u8; 16];
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
OsRng.fill_bytes(&mut out);
}))
.map_err(|_| Error::BadResponse("websocket: no secure entropy source available".into()))?;
Ok(out)
}
pub(crate) fn base64_encode(input: &[u8]) -> String {
const ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
let mut i = 0;
while i + 3 <= input.len() {
let b0 = input[i];
let b1 = input[i + 1];
let b2 = input[i + 2];
out.push(ALPHA[(b0 >> 2) as usize] as char);
out.push(ALPHA[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
out.push(ALPHA[(((b1 & 0x0F) << 2) | (b2 >> 6)) as usize] as char);
out.push(ALPHA[(b2 & 0x3F) as usize] as char);
i += 3;
}
let rem = input.len() - i;
if rem == 1 {
let b0 = input[i];
out.push(ALPHA[(b0 >> 2) as usize] as char);
out.push(ALPHA[((b0 & 0x03) << 4) as usize] as char);
out.push('=');
out.push('=');
} else if rem == 2 {
let b0 = input[i];
let b1 = input[i + 1];
out.push(ALPHA[(b0 >> 2) as usize] as char);
out.push(ALPHA[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
out.push(ALPHA[((b1 & 0x0F) << 2) as usize] as char);
out.push('=');
}
out
}
#[allow(dead_code)]
fn _tlsstream_in_scope_for_docs<S: Read + Write>(_: TlsStream<S>) {}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
struct MockStream {
inbound: Cursor<Vec<u8>>,
sent: Vec<u8>,
}
impl MockStream {
fn new(inbound: Vec<u8>) -> Self {
MockStream {
inbound: Cursor::new(inbound),
sent: Vec::new(),
}
}
}
impl Read for MockStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.inbound.read(buf)
}
}
impl Write for MockStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.sent.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
fn server_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
let b0 = if fin { 0x80 } else { 0x00 } | (opcode & 0x0F);
out.push(b0);
let n = payload.len();
if n < 126 {
out.push(n as u8);
} else if n <= u16::MAX as usize {
out.push(126);
out.extend_from_slice(&(n as u16).to_be_bytes());
} else {
out.push(127);
out.extend_from_slice(&(n as u64).to_be_bytes());
}
out.extend_from_slice(payload);
out
}
fn server_frame_rsv1(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
let mut out = server_frame(fin, opcode, payload);
out[0] |= 0x40; out
}
fn pmd_compress(data: &[u8]) -> Vec<u8> {
let mut out = compress_to_vec::<Deflate>(data).expect("deflate encode");
if out.ends_with(&DEFLATE_TAIL) {
out.truncate(out.len() - DEFLATE_TAIL.len());
}
out
}
fn test_pmd() -> Pmd {
Pmd {
client_no_context_takeover: true,
server_no_context_takeover: true,
decoder: Deflate::decoder(),
}
}
fn decode_sent(sent: &[u8]) -> Vec<(u8, Vec<u8>)> {
let mut frames = Vec::new();
let mut i = 0;
while i < sent.len() {
let opcode = sent[i] & 0x0F;
let masked = (sent[i + 1] & 0x80) != 0;
assert!(masked, "client frame must be masked");
let len7 = sent[i + 1] & 0x7F;
i += 2;
let len = match len7 {
0..=125 => len7 as usize,
126 => {
let l = u16::from_be_bytes([sent[i], sent[i + 1]]) as usize;
i += 2;
l
}
127 => {
let mut b = [0u8; 8];
b.copy_from_slice(&sent[i..i + 8]);
i += 8;
u64::from_be_bytes(b) as usize
}
_ => unreachable!(),
};
let mask = [sent[i], sent[i + 1], sent[i + 2], sent[i + 3]];
i += 4;
let mut payload = sent[i..i + len].to_vec();
i += len;
for (j, b) in payload.iter_mut().enumerate() {
*b ^= mask[j & 3];
}
frames.push((opcode, payload));
}
frames
}
#[test]
fn base64_encode_hello() {
assert_eq!(base64_encode(b"hello"), "aGVsbG8=");
}
#[test]
fn base64_encode_rfc4648_vectors() {
assert_eq!(base64_encode(b""), "");
assert_eq!(base64_encode(b"f"), "Zg==");
assert_eq!(base64_encode(b"fo"), "Zm8=");
assert_eq!(base64_encode(b"foo"), "Zm9v");
assert_eq!(base64_encode(b"foob"), "Zm9vYg==");
assert_eq!(base64_encode(b"fooba"), "Zm9vYmE=");
assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
}
#[test]
fn rfc6455_accept_derivation() {
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
assert_eq!(derive_accept(key), expected);
}
#[test]
fn parse_short_text_frame() {
let bytes = [0x81, 0x05, b'h', b'e', b'l', b'l', b'o'];
let mut cur = Cursor::new(&bytes[..]);
let f = read_frame(&mut cur).expect("frame parses");
assert!(f.fin);
assert_eq!(f.opcode, OPCODE_TEXT);
assert_eq!(f.payload, b"hello");
}
#[test]
fn parse_16bit_length_frame() {
let mut bytes: Vec<u8> = vec![0x82, 126, 0x00, 200];
bytes.extend(std::iter::repeat_n(b'A', 200));
let mut cur = Cursor::new(bytes);
let f = read_frame(&mut cur).expect("frame parses");
assert_eq!(f.opcode, OPCODE_BINARY);
assert_eq!(f.payload.len(), 200);
assert!(f.payload.iter().all(|&b| b == b'A'));
}
#[test]
fn reject_masked_server_frame() {
let bytes = [0x81, 0x80, 0, 0, 0, 0];
let mut cur = Cursor::new(&bytes[..]);
let err = read_frame(&mut cur).expect_err("masked server frame must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn build_close_frame_short_payload() {
let frame = build_client_frame(OPCODE_CLOSE, &[]).unwrap();
assert_eq!(frame.len(), 6);
assert_eq!(frame[0], 0x88);
assert_eq!(frame[1], 0x80); }
#[test]
fn build_text_frame_masks_payload() {
let payload = b"hi";
let frame = build_client_frame(OPCODE_TEXT, payload).unwrap();
assert_eq!(frame.len(), 8);
assert_eq!(frame[0], 0x81);
assert_eq!(frame[1], 0x82);
let mask = [frame[2], frame[3], frame[4], frame[5]];
let unmasked: Vec<u8> = frame[6..]
.iter()
.enumerate()
.map(|(i, &b)| b ^ mask[i & 3])
.collect();
assert_eq!(unmasked, payload);
}
#[test]
fn build_frame_uses_16bit_length_for_medium_payload() {
let payload = vec![0u8; 200];
let frame = build_client_frame(OPCODE_BINARY, &payload).unwrap();
assert_eq!(frame[0], 0x82);
assert_eq!(frame[1], 0x80 | 126);
let len = u16::from_be_bytes([frame[2], frame[3]]);
assert_eq!(len, 200);
assert_eq!(frame.len(), 208);
}
#[test]
fn build_frame_uses_64bit_length_for_large_payload() {
let payload = vec![0u8; 70_000];
let frame = build_client_frame(OPCODE_BINARY, &payload).unwrap();
assert_eq!(frame[1], 0x80 | 127);
let len = u64::from_be_bytes([
frame[2], frame[3], frame[4], frame[5], frame[6], frame[7], frame[8], frame[9],
]);
assert_eq!(len, 70_000);
}
#[test]
fn random_16_is_nonzero() {
let r = random_16().expect("OS entropy available in the test environment");
assert_ne!(r, [0u8; 16]);
}
#[test]
fn random_16_is_not_constant() {
let a = random_16().unwrap();
let b = random_16().unwrap();
assert_ne!(a, b);
}
#[test]
fn reassembles_fragmented_text_message() {
let mut inbound = server_frame(false, OPCODE_TEXT, b"Hel");
inbound.extend(server_frame(false, OPCODE_CONT, b"lo "));
inbound.extend(server_frame(true, OPCODE_CONT, b"world"));
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("reassembles");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: b"Hello world".to_vec(),
}
);
}
#[test]
fn ping_between_fragments_gets_pong_and_message_completes() {
let mut inbound = server_frame(false, OPCODE_TEXT, b"foo");
inbound.extend(server_frame(true, OPCODE_PING, b"pingdata"));
inbound.extend(server_frame(true, OPCODE_CONT, b"bar"));
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("completes despite ping");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: b"foobar".to_vec(),
}
);
let sent = decode_sent(&s.sent);
assert_eq!(sent.len(), 1, "exactly one pong expected");
assert_eq!(sent[0].0, OPCODE_PONG);
assert_eq!(sent[0].1, b"pingdata");
}
#[test]
fn close_is_answered_and_returns_closed() {
let inbound = server_frame(true, OPCODE_CLOSE, &[]);
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("handles close");
assert_eq!(msg, Message::Closed);
let sent = decode_sent(&s.sent);
assert_eq!(sent.len(), 1, "exactly one close reply expected");
assert_eq!(sent[0].0, OPCODE_CLOSE);
}
#[test]
fn unsolicited_pong_is_ignored_then_data_returns() {
let mut inbound = server_frame(true, OPCODE_PONG, b"x");
inbound.extend(server_frame(true, OPCODE_TEXT, b"hi"));
let mut s = MockStream::new(inbound);
let msg = read_message(&mut s, None).expect("ignores pong");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: b"hi".to_vec(),
}
);
assert!(s.sent.is_empty(), "unsolicited pong must not be answered");
}
#[test]
fn send_message_produces_masked_frame() {
let mut s = MockStream::new(Vec::new());
send_message(&mut s, OPCODE_TEXT, b"hello", None).expect("sends");
assert_eq!(s.sent[0], 0x81);
assert_eq!(s.sent[1], 0x80 | 5);
assert_eq!(s.sent.len(), 2 + 4 + 5);
let decoded = decode_sent(&s.sent);
assert_eq!(decoded, vec![(OPCODE_TEXT, b"hello".to_vec())]);
}
#[test]
fn send_message_rejects_control_opcode() {
let mut s = MockStream::new(Vec::new());
let err = send_message(&mut s, OPCODE_PING, b"x", None)
.expect_err("control opcode must be rejected for send_message");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
assert!(s.sent.is_empty());
}
#[test]
fn oversized_cumulative_fragmented_payload_is_rejected() {
let mut buf = vec![0u8; (MAX_PAYLOAD_BYTES - 1) as usize];
accumulate(&mut buf, &[0u8]).expect("exactly at the cap is allowed");
let err = accumulate(&mut buf, &[0u8]).expect_err("over the cap must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn fragmented_control_frame_is_rejected() {
let inbound = server_frame(false, OPCODE_PING, b"x");
let mut s = MockStream::new(inbound);
let err = read_message(&mut s, None).expect_err("fragmented control must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn oversized_control_frame_is_rejected() {
let inbound = server_frame(true, OPCODE_PING, &[0u8; 126]);
let mut s = MockStream::new(inbound);
let err = read_message(&mut s, None).expect_err("oversized control must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn new_data_frame_during_fragmentation_is_rejected() {
let mut inbound = server_frame(false, OPCODE_TEXT, b"a");
inbound.extend(server_frame(true, OPCODE_TEXT, b"b"));
let mut s = MockStream::new(inbound);
let err =
read_message(&mut s, None).expect_err("interleaved new data frame must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn lone_continuation_frame_is_rejected() {
let inbound = server_frame(true, OPCODE_CONT, b"x");
let mut s = MockStream::new(inbound);
let err = read_message(&mut s, None).expect_err("lone continuation must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn read_data_and_close_returns_reassembled_payload_and_sends_close() {
let mut inbound = server_frame(false, OPCODE_BINARY, &[1, 2, 3]);
inbound.extend(server_frame(true, OPCODE_CONT, &[4, 5]));
let mut s = MockStream::new(inbound);
let payload = read_data_and_close(&mut s, None).expect("reads message");
assert_eq!(payload, vec![1, 2, 3, 4, 5]);
let sent = decode_sent(&s.sent);
assert_eq!(sent.last().map(|f| f.0), Some(OPCODE_CLOSE));
}
fn decode_first_frame_with_rsv1(sent: &[u8]) -> (u8, bool, Vec<u8>) {
let opcode = sent[0] & 0x0F;
let rsv1 = (sent[0] & 0x40) != 0;
let masked = (sent[1] & 0x80) != 0;
assert!(masked, "client frame must be masked");
let len7 = sent[1] & 0x7F;
let mut i = 2;
let len = match len7 {
0..=125 => len7 as usize,
126 => {
let l = u16::from_be_bytes([sent[i], sent[i + 1]]) as usize;
i += 2;
l
}
127 => {
let mut b = [0u8; 8];
b.copy_from_slice(&sent[i..i + 8]);
i += 8;
u64::from_be_bytes(b) as usize
}
_ => unreachable!(),
};
let mask = [sent[i], sent[i + 1], sent[i + 2], sent[i + 3]];
i += 4;
let mut payload = sent[i..i + len].to_vec();
for (j, b) in payload.iter_mut().enumerate() {
*b ^= mask[j & 3];
}
(opcode, rsv1, payload)
}
fn pmd_inflate(compressed: &[u8]) -> Vec<u8> {
let mut input = compressed.to_vec();
input.extend_from_slice(&DEFLATE_TAIL);
let mut dec = Deflate::decoder();
let mut out = Vec::new();
let mut scratch = vec![0u8; 32 * 1024];
let mut consumed = 0usize;
loop {
let before_c = consumed;
let before_w = out.len();
let (p, status) = dec
.decode(&input[consumed..], &mut scratch)
.expect("inflate");
out.extend_from_slice(&scratch[..p.written]);
consumed += p.consumed;
match status {
Status::StreamEnd => break,
Status::OutputFull => continue,
Status::InputEmpty => {
if consumed >= input.len() || (consumed == before_c && out.len() == before_w) {
break;
}
}
}
}
out
}
#[test]
fn handshake_offers_permessage_deflate() {
struct Recorder {
request: Vec<u8>,
response: Cursor<Vec<u8>>,
}
impl Read for Recorder {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.response.read(buf)
}
}
impl Write for Recorder {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.request.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let resp = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: wrong\r\n\r\n".to_vec();
let mut rec = Recorder {
request: Vec::new(),
response: Cursor::new(resp),
};
let url = Url::parse("ws://example.com/chat").expect("url");
let _ = handshake(&mut rec, &url); let req = String::from_utf8(rec.request).expect("utf8 request");
assert!(
req.contains("Sec-WebSocket-Extensions: permessage-deflate"),
"request must offer permessage-deflate, got:\n{req}"
);
assert!(req.contains("client_no_context_takeover"));
assert!(req.contains("server_no_context_takeover"));
}
#[test]
fn parse_pmd_response_enables_compression() {
let pmd = parse_pmd_response("permessage-deflate; server_no_context_takeover")
.expect("permessage-deflate accepted");
assert!(pmd.server_no_context_takeover);
assert!(!pmd.client_no_context_takeover);
let pmd2 = parse_pmd_response(
"permessage-deflate; client_no_context_takeover; server_no_context_takeover",
)
.expect("accepted with both flags");
assert!(pmd2.client_no_context_takeover);
assert!(pmd2.server_no_context_takeover);
}
#[test]
fn parse_pmd_response_without_extension_is_none() {
assert!(parse_pmd_response("some-other-extension").is_none());
assert!(parse_pmd_response("").is_none());
assert!(parse_pmd_response("foo; bar=1").is_none());
}
#[test]
fn inflate_compressed_message_decodes_to_original() {
let original = b"the quick brown fox jumps over the lazy dog, the quick brown fox";
let compressed = pmd_compress(original);
assert!(
compressed.len() < original.len(),
"fixture should actually compress"
);
let inbound = server_frame_rsv1(true, OPCODE_TEXT, &compressed);
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let msg = read_message(&mut s, Some(&mut pmd)).expect("decodes compressed message");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: original.to_vec(),
}
);
}
#[test]
fn send_message_compressed_round_trips() {
let input = b"hello hello hello permessage-deflate round trip";
let mut s = MockStream::new(Vec::new());
let mut pmd = test_pmd();
send_message(&mut s, OPCODE_TEXT, input, Some(&mut pmd)).expect("sends compressed");
let (opcode, rsv1, payload) = decode_first_frame_with_rsv1(&s.sent);
assert_eq!(opcode, OPCODE_TEXT);
assert!(rsv1, "compressed frame must have RSV1 set");
assert_ne!(payload, input, "payload should be compressed, not raw");
assert_eq!(pmd_inflate(&payload), input);
}
#[test]
fn rsv1_without_negotiation_is_rejected() {
let inbound = server_frame_rsv1(true, OPCODE_TEXT, b"whatever");
let mut s = MockStream::new(inbound);
let err = read_message(&mut s, None).expect_err("RSV1 without PMD must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn rsv2_or_rsv3_always_rejected() {
let mut inbound = server_frame(true, OPCODE_TEXT, b"x");
inbound[0] |= 0x20;
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err = read_message(&mut s, Some(&mut pmd)).expect_err("RSV2 must be rejected");
assert!(matches!(err, Error::BadResponse(_)));
let mut inbound = server_frame(true, OPCODE_TEXT, b"x");
inbound[0] |= 0x10;
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err = read_message(&mut s, Some(&mut pmd)).expect_err("RSV3 must be rejected");
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn rsv1_on_control_frame_is_rejected() {
let inbound = server_frame_rsv1(true, OPCODE_PING, b"x");
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err =
read_message(&mut s, Some(&mut pmd)).expect_err("RSV1 on control must be rejected");
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn compressed_bomb_exceeding_cap_is_rejected() {
let huge = vec![0u8; (MAX_PAYLOAD_BYTES + (1 << 20)) as usize];
let compressed = pmd_compress(&huge);
assert!(
(compressed.len() as u64) < MAX_PAYLOAD_BYTES,
"fixture must be much smaller than the cap"
);
let inbound = server_frame_rsv1(true, OPCODE_BINARY, &compressed);
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err =
read_message(&mut s, Some(&mut pmd)).expect_err("compression bomb must be rejected");
match err {
Error::BadResponse(msg) => {
assert!(msg.contains("permessage-deflate"), "got {msg:?}")
}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn fragmented_compressed_message_reassembles_and_inflates() {
let original = b"fragmented compressed payload, split across two frames on the wire";
let compressed = pmd_compress(original);
assert!(compressed.len() >= 4, "need enough bytes to split");
let mid = compressed.len() / 2;
let mut inbound = server_frame_rsv1(false, OPCODE_TEXT, &compressed[..mid]);
inbound.extend(server_frame(true, OPCODE_CONT, &compressed[mid..]));
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let msg = read_message(&mut s, Some(&mut pmd)).expect("reassembles + inflates");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: original.to_vec(),
}
);
}
#[test]
fn rsv1_on_continuation_frame_is_rejected() {
let original = b"continuation rsv1 should be rejected here";
let compressed = pmd_compress(original);
let mid = compressed.len() / 2;
let mut inbound = server_frame_rsv1(false, OPCODE_TEXT, &compressed[..mid]);
inbound.extend(server_frame_rsv1(true, OPCODE_CONT, &compressed[mid..]));
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let err = read_message(&mut s, Some(&mut pmd))
.expect_err("RSV1 on continuation must be rejected");
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn uncompressed_message_passes_through_when_pmd_negotiated() {
let inbound = server_frame(true, OPCODE_TEXT, b"plain text, no rsv1");
let mut s = MockStream::new(inbound);
let mut pmd = test_pmd();
let msg = read_message(&mut s, Some(&mut pmd)).expect("plain message");
assert_eq!(
msg,
Message::Data {
opcode: OPCODE_TEXT,
payload: b"plain text, no rsv1".to_vec(),
}
);
}
}