use crate::bytes::{BufMut, Bytes, BytesMut};
use crate::codec::{Decoder, Encoder};
use crate::util::{EntropySource, OsEntropy};
use std::io;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum Opcode {
Continuation = 0x0,
Text = 0x1,
Binary = 0x2,
Close = 0x8,
Ping = 0x9,
Pong = 0xA,
}
impl Opcode {
#[must_use]
pub const fn is_control(self) -> bool {
matches!(self, Self::Close | Self::Ping | Self::Pong)
}
#[must_use]
pub const fn is_data(self) -> bool {
matches!(self, Self::Continuation | Self::Text | Self::Binary)
}
pub fn from_u8(value: u8) -> Result<Self, WsError> {
match value {
0x0 => Ok(Self::Continuation),
0x1 => Ok(Self::Text),
0x2 => Ok(Self::Binary),
0x8 => Ok(Self::Close),
0x9 => Ok(Self::Ping),
0xA => Ok(Self::Pong),
_ => Err(WsError::InvalidOpcode(value)),
}
}
}
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)] pub struct Frame {
pub fin: bool,
pub rsv1: bool,
pub rsv2: bool,
pub rsv3: bool,
pub opcode: Opcode,
pub masked: bool,
pub mask_key: Option<[u8; 4]>,
pub payload: Bytes,
}
impl Frame {
#[must_use]
pub fn text(payload: impl Into<Bytes>) -> Self {
Self {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Text,
masked: false,
mask_key: None,
payload: payload.into(),
}
}
#[must_use]
pub fn binary(payload: impl Into<Bytes>) -> Self {
Self {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Binary,
masked: false,
mask_key: None,
payload: payload.into(),
}
}
#[must_use]
pub fn ping(payload: impl Into<Bytes>) -> Self {
Self {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Ping,
masked: false,
mask_key: None,
payload: payload.into(),
}
}
#[must_use]
pub fn pong(payload: impl Into<Bytes>) -> Self {
Self {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Pong,
masked: false,
mask_key: None,
payload: payload.into(),
}
}
#[must_use]
pub fn close(code: Option<u16>, reason: Option<&str>) -> Self {
if let Some(c) = code {
assert!(
CloseCode::is_valid_code(c),
"close code {c} is not valid for use in a Close frame (RFC 6455 §7.4)"
);
}
if code.is_none() && reason.is_some() {
return Self {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Close,
masked: false,
mask_key: None,
payload: Bytes::new(),
};
}
if let Some(r) = reason {
let total = 2 + r.len();
assert!(
total <= 125,
"close frame payload ({total} bytes) exceeds 125-byte control frame limit (RFC 6455 §5.5)"
);
}
let payload = match (code, reason) {
(Some(c), Some(r)) => {
let mut buf = BytesMut::with_capacity(2 + r.len());
buf.put_u16(c);
buf.put_slice(r.as_bytes());
buf.freeze()
}
(Some(c), None) => {
let mut buf = BytesMut::with_capacity(2);
buf.put_u16(c);
buf.freeze()
}
(None, Some(_r)) => Bytes::new(),
(None, None) => Bytes::new(),
};
Self {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Close,
masked: false,
mask_key: None,
payload,
}
}
}
#[derive(Debug)]
pub enum WsError {
Io(io::Error),
InvalidOpcode(u8),
ProtocolViolation(&'static str),
ReservedBitsSet,
PayloadTooLarge {
size: u64,
max: usize,
},
ControlFrameTooLarge(usize),
FragmentedControlFrame,
UnmaskedClientFrame,
MaskedServerFrame,
InvalidUtf8,
InvalidClosePayload,
}
impl WsError {
#[must_use]
pub fn as_close_code(&self) -> CloseCode {
match self {
Self::PayloadTooLarge { .. } | Self::ControlFrameTooLarge(_) => {
CloseCode::MessageTooBig
}
Self::InvalidUtf8 => CloseCode::InvalidPayload,
Self::Io(_) => CloseCode::Abnormal,
_ => CloseCode::ProtocolError,
}
}
}
impl std::fmt::Display for WsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "I/O error: {e}"),
Self::InvalidOpcode(op) => write!(f, "invalid opcode: 0x{op:X}"),
Self::ProtocolViolation(msg) => write!(f, "protocol violation: {msg}"),
Self::ReservedBitsSet => write!(f, "reserved bits set without extension"),
Self::PayloadTooLarge { size, max } => {
write!(f, "payload too large: {size} bytes (max: {max})")
}
Self::ControlFrameTooLarge(size) => {
write!(
f,
"control frame payload too large: {size} bytes (max: 125)"
)
}
Self::FragmentedControlFrame => write!(f, "control frame cannot be fragmented"),
Self::UnmaskedClientFrame => write!(f, "client frame must be masked"),
Self::MaskedServerFrame => write!(f, "server frame should not be masked"),
Self::InvalidUtf8 => write!(f, "invalid UTF-8 in text frame"),
Self::InvalidClosePayload => write!(f, "invalid close frame payload"),
}
}
}
impl std::error::Error for WsError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
_ => None,
}
}
}
impl From<io::Error> for WsError {
fn from(err: io::Error) -> Self {
Self::Io(err)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
Client,
Server,
}
#[derive(Debug)]
enum DecodeState {
Header,
ExtendedLength {
fin: bool,
rsv1: bool,
rsv2: bool,
rsv3: bool,
opcode: Opcode,
masked: bool,
bytes_needed: usize,
},
MaskKey {
fin: bool,
rsv1: bool,
rsv2: bool,
rsv3: bool,
opcode: Opcode,
payload_len: u64,
},
Payload {
fin: bool,
rsv1: bool,
rsv2: bool,
rsv3: bool,
opcode: Opcode,
mask_key: Option<[u8; 4]>,
payload_len: u64,
},
Poisoned,
}
#[derive(Debug)]
pub struct FrameCodec {
max_payload_size: usize,
role: Role,
state: DecodeState,
validate_reserved_bits: bool,
}
impl FrameCodec {
pub const DEFAULT_MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
#[must_use]
pub fn new(role: Role) -> Self {
Self {
max_payload_size: Self::DEFAULT_MAX_PAYLOAD_SIZE,
role,
state: DecodeState::Header,
validate_reserved_bits: true,
}
}
#[must_use]
pub fn client() -> Self {
Self::new(Role::Client)
}
#[must_use]
pub fn server() -> Self {
Self::new(Role::Server)
}
#[must_use]
pub fn max_payload_size(mut self, size: usize) -> Self {
self.max_payload_size = size;
self
}
#[must_use]
pub fn validate_reserved_bits(mut self, validate: bool) -> Self {
self.validate_reserved_bits = validate;
self
}
pub(crate) fn encode_with_entropy(
&self,
frame: &Frame,
dst: &mut BytesMut,
entropy: &dyn EntropySource,
) -> Result<(), WsError> {
let payload_len = frame.payload.len();
if frame.opcode.is_control() {
if !frame.fin {
return Err(WsError::FragmentedControlFrame);
}
if payload_len > 125 {
return Err(WsError::ControlFrameTooLarge(payload_len));
}
}
if frame.opcode == Opcode::Close {
validate_close_payload(&frame.payload)?;
}
let should_mask = self.role == Role::Client;
let mut first_byte = frame.opcode as u8;
if frame.fin {
first_byte |= 0x80;
}
if frame.rsv1 {
first_byte |= 0x40;
}
if frame.rsv2 {
first_byte |= 0x20;
}
if frame.rsv3 {
first_byte |= 0x10;
}
let mask_bit = if should_mask { 0x80 } else { 0 };
let header_size =
2 + if payload_len > 65535 {
8
} else if payload_len > 125 {
2
} else {
0
} + if should_mask { 4 } else { 0 };
dst.reserve(header_size + payload_len);
dst.put_u8(first_byte);
if payload_len <= 125 {
dst.put_u8(mask_bit | (payload_len as u8));
} else if payload_len <= 65535 {
dst.put_u8(mask_bit | 0x7E);
dst.put_u16(payload_len as u16);
} else {
dst.put_u8(mask_bit | 0x7F);
dst.put_u64(payload_len as u64);
}
if should_mask {
let mask_key = generate_mask_key(entropy);
dst.put_slice(&mask_key);
let start = dst.len();
dst.put_slice(&frame.payload);
apply_mask(&mut dst[start..], mask_key);
} else {
dst.put_slice(&frame.payload);
}
Ok(())
}
}
impl Decoder for FrameCodec {
type Item = Frame;
type Error = WsError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.decode_inner(src) {
Err(e) => {
self.state = DecodeState::Poisoned;
Err(e)
}
Ok(v) => Ok(v),
}
}
}
impl FrameCodec {
#[allow(clippy::too_many_lines)] fn decode_inner(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, WsError> {
loop {
match &self.state {
DecodeState::Poisoned => {
return Err(WsError::ProtocolViolation(
"codec is poisoned after a fatal error",
));
}
DecodeState::Header => {
if src.len() < 2 {
return Ok(None);
}
let first_byte = src[0];
let second_byte = src[1];
let fin = (first_byte & 0x80) != 0;
let rsv1 = (first_byte & 0x40) != 0;
let rsv2 = (first_byte & 0x20) != 0;
let rsv3 = (first_byte & 0x10) != 0;
let opcode_raw = first_byte & 0x0F;
let masked = (second_byte & 0x80) != 0;
let payload_len_7 = second_byte & 0x7F;
if self.validate_reserved_bits && (rsv1 || rsv2 || rsv3) {
return Err(WsError::ReservedBitsSet);
}
let opcode = Opcode::from_u8(opcode_raw)?;
match self.role {
Role::Server if !masked => return Err(WsError::UnmaskedClientFrame),
Role::Client if masked => return Err(WsError::MaskedServerFrame),
_ => {}
}
if opcode.is_control() {
if !fin {
return Err(WsError::FragmentedControlFrame);
}
if payload_len_7 > 125 {
return Err(WsError::ControlFrameTooLarge(payload_len_7 as usize));
}
}
let _ = src.split_to(2);
match payload_len_7 {
0..=125 => {
let payload_len = u64::from(payload_len_7);
if payload_len > self.max_payload_size as u64 {
return Err(WsError::PayloadTooLarge {
size: payload_len,
max: self.max_payload_size,
});
}
if masked {
self.state = DecodeState::MaskKey {
fin,
rsv1,
rsv2,
rsv3,
opcode,
payload_len,
};
} else {
self.state = DecodeState::Payload {
fin,
rsv1,
rsv2,
rsv3,
opcode,
mask_key: None,
payload_len,
};
}
}
126 => {
self.state = DecodeState::ExtendedLength {
fin,
rsv1,
rsv2,
rsv3,
opcode,
masked,
bytes_needed: 2,
};
}
127 => {
self.state = DecodeState::ExtendedLength {
fin,
rsv1,
rsv2,
rsv3,
opcode,
masked,
bytes_needed: 8,
};
}
_ => unreachable!(),
}
}
DecodeState::ExtendedLength {
fin,
rsv1,
rsv2,
rsv3,
opcode,
masked,
bytes_needed,
} => {
if src.len() < *bytes_needed {
return Ok(None);
}
let payload_len = if *bytes_needed == 2 {
let bytes = src.split_to(2);
let len = u64::from(u16::from_be_bytes([bytes[0], bytes[1]]));
if len < 126 {
self.state = DecodeState::Header;
return Err(WsError::ProtocolViolation(
"non-minimal payload length encoding",
));
}
len
} else {
let bytes = src.split_to(8);
let raw = u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
bytes[7],
]);
if raw & (1u64 << 63) != 0 {
self.state = DecodeState::Header;
return Err(WsError::ProtocolViolation(
"most significant bit of 64-bit payload length must be 0",
));
}
if raw < 65536 {
self.state = DecodeState::Header;
return Err(WsError::ProtocolViolation(
"non-minimal payload length encoding",
));
}
raw
};
if payload_len > self.max_payload_size as u64 {
self.state = DecodeState::Header;
return Err(WsError::PayloadTooLarge {
size: payload_len,
max: self.max_payload_size,
});
}
let fin = *fin;
let rsv1 = *rsv1;
let rsv2 = *rsv2;
let rsv3 = *rsv3;
let opcode = *opcode;
let masked = *masked;
if masked {
self.state = DecodeState::MaskKey {
fin,
rsv1,
rsv2,
rsv3,
opcode,
payload_len,
};
} else {
self.state = DecodeState::Payload {
fin,
rsv1,
rsv2,
rsv3,
opcode,
mask_key: None,
payload_len,
};
}
}
DecodeState::MaskKey {
fin,
rsv1,
rsv2,
rsv3,
opcode,
payload_len,
} => {
if src.len() < 4 {
return Ok(None);
}
let mask_bytes = src.split_to(4);
let mut mask_key = [0u8; 4];
mask_key.copy_from_slice(&mask_bytes);
let fin = *fin;
let rsv1 = *rsv1;
let rsv2 = *rsv2;
let rsv3 = *rsv3;
let opcode = *opcode;
let payload_len = *payload_len;
self.state = DecodeState::Payload {
fin,
rsv1,
rsv2,
rsv3,
opcode,
mask_key: Some(mask_key),
payload_len,
};
}
DecodeState::Payload {
fin,
rsv1,
rsv2,
rsv3,
opcode,
mask_key,
payload_len,
} => {
let payload_len_usize =
usize::try_from(*payload_len).map_err(|_| WsError::PayloadTooLarge {
size: *payload_len,
max: usize::MAX,
})?;
if src.len() < payload_len_usize {
return Ok(None);
}
let mut payload = src.split_to(payload_len_usize);
if let Some(key) = mask_key {
apply_mask(&mut payload, *key);
}
if *opcode == Opcode::Close {
validate_close_payload(&payload)?;
}
let frame = Frame {
fin: *fin,
rsv1: *rsv1,
rsv2: *rsv2,
rsv3: *rsv3,
opcode: *opcode,
masked: mask_key.is_some(),
mask_key: *mask_key,
payload: payload.freeze(),
};
self.state = DecodeState::Header;
return Ok(Some(frame));
}
}
}
}
}
impl Encoder<Frame> for FrameCodec {
type Error = WsError;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
self.encode_with_entropy(&frame, dst, &OsEntropy)
}
}
pub fn apply_mask(payload: &mut [u8], mask_key: [u8; 4]) {
for (i, byte) in payload.iter_mut().enumerate() {
*byte ^= mask_key[i % 4];
}
}
fn generate_mask_key(entropy: &dyn EntropySource) -> [u8; 4] {
let mut key = [0u8; 4];
entropy.fill_bytes(&mut key);
key
}
fn validate_close_payload(payload: &[u8]) -> Result<(), WsError> {
match payload.len() {
0 => Ok(()),
1 => Err(WsError::InvalidClosePayload),
_ => {
let code = u16::from_be_bytes([payload[0], payload[1]]);
if !CloseCode::is_valid_received_code(code) {
return Err(WsError::InvalidClosePayload);
}
if payload.len() > 2 {
std::str::from_utf8(&payload[2..]).map_err(|_| WsError::InvalidClosePayload)?;
}
Ok(())
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum CloseCode {
Normal = 1000,
GoingAway = 1001,
ProtocolError = 1002,
Unsupported = 1003,
Reserved = 1004,
NoStatusReceived = 1005,
Abnormal = 1006,
InvalidPayload = 1007,
PolicyViolation = 1008,
MessageTooBig = 1009,
MandatoryExtension = 1010,
InternalError = 1011,
ServiceRestart = 1012,
TryAgainLater = 1013,
BadGateway = 1014,
TlsHandshake = 1015,
}
impl CloseCode {
#[must_use]
pub const fn is_sendable(self) -> bool {
!matches!(
self,
Self::Reserved | Self::NoStatusReceived | Self::Abnormal | Self::TlsHandshake
)
}
#[must_use]
pub fn from_u16(code: u16) -> Option<Self> {
match code {
1000 => Some(Self::Normal),
1001 => Some(Self::GoingAway),
1002 => Some(Self::ProtocolError),
1003 => Some(Self::Unsupported),
1004 => Some(Self::Reserved),
1005 => Some(Self::NoStatusReceived),
1006 => Some(Self::Abnormal),
1007 => Some(Self::InvalidPayload),
1008 => Some(Self::PolicyViolation),
1009 => Some(Self::MessageTooBig),
1010 => Some(Self::MandatoryExtension),
1011 => Some(Self::InternalError),
1012 => Some(Self::ServiceRestart),
1013 => Some(Self::TryAgainLater),
1014 => Some(Self::BadGateway),
1015 => Some(Self::TlsHandshake),
_ => None,
}
}
#[must_use]
pub fn is_valid_code(code: u16) -> bool {
matches!(code, 1000..=1003 | 1007..=1014 | 3000..=4999)
}
#[must_use]
pub fn is_valid_received_code(code: u16) -> bool {
matches!(code, 1000..=1003 | 1007..=1014 | 1016..=4999)
}
}
impl From<CloseCode> for u16 {
fn from(code: CloseCode) -> Self {
code as Self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opcode_is_control() {
assert!(!Opcode::Continuation.is_control());
assert!(!Opcode::Text.is_control());
assert!(!Opcode::Binary.is_control());
assert!(Opcode::Close.is_control());
assert!(Opcode::Ping.is_control());
assert!(Opcode::Pong.is_control());
}
#[test]
fn test_opcode_from_u8() {
assert_eq!(Opcode::from_u8(0x0).unwrap(), Opcode::Continuation);
assert_eq!(Opcode::from_u8(0x1).unwrap(), Opcode::Text);
assert_eq!(Opcode::from_u8(0x2).unwrap(), Opcode::Binary);
assert_eq!(Opcode::from_u8(0x8).unwrap(), Opcode::Close);
assert_eq!(Opcode::from_u8(0x9).unwrap(), Opcode::Ping);
assert_eq!(Opcode::from_u8(0xA).unwrap(), Opcode::Pong);
assert!(Opcode::from_u8(0x3).is_err());
assert!(Opcode::from_u8(0xF).is_err());
}
#[test]
fn test_apply_mask() {
let mask_key = [0x37, 0xfa, 0x21, 0x3d];
let mut payload = b"Hello".to_vec();
let original = payload.clone();
apply_mask(&mut payload, mask_key);
assert_ne!(payload, original);
apply_mask(&mut payload, mask_key);
assert_eq!(payload, original);
}
#[test]
fn test_encode_decode_text_frame() {
let mut encoder = FrameCodec::client();
let mut decoder = FrameCodec::server();
let frame = Frame::text("Hello, WebSocket!");
let mut buf = BytesMut::new();
encoder.encode(frame, &mut buf).unwrap();
let parsed = decoder.decode(&mut buf).unwrap().unwrap();
assert!(parsed.fin);
assert_eq!(parsed.opcode, Opcode::Text);
assert_eq!(parsed.payload.as_ref(), b"Hello, WebSocket!");
}
#[test]
fn test_encode_decode_binary_frame() {
let mut encoder = FrameCodec::client();
let mut decoder = FrameCodec::server();
let payload: Bytes = vec![0x00, 0x01, 0x02, 0xFF].into();
let frame = Frame::binary(payload.clone());
let mut buf = BytesMut::new();
encoder.encode(frame, &mut buf).unwrap();
let parsed = decoder.decode(&mut buf).unwrap().unwrap();
assert!(parsed.fin);
assert_eq!(parsed.opcode, Opcode::Binary);
assert_eq!(parsed.payload, payload);
}
#[test]
fn test_encode_decode_ping_pong() {
let mut encoder = FrameCodec::client();
let mut decoder = FrameCodec::server();
let ping_request = Frame::ping("ping-data");
let mut buf = BytesMut::new();
encoder.encode(ping_request, &mut buf).unwrap();
let ping_received = decoder.decode(&mut buf).unwrap().unwrap();
assert!(ping_received.fin);
assert_eq!(ping_received.opcode, Opcode::Ping);
assert_eq!(ping_received.payload.as_ref(), b"ping-data");
let pong_response = Frame::pong("pong-data");
let mut buf = BytesMut::new();
encoder.encode(pong_response, &mut buf).unwrap();
let pong_response = decoder.decode(&mut buf).unwrap().unwrap();
assert!(pong_response.fin);
assert_eq!(pong_response.opcode, Opcode::Pong);
assert_eq!(pong_response.payload.as_ref(), b"pong-data");
}
#[test]
fn test_encode_decode_close_frame() {
let mut encoder = FrameCodec::client();
let mut decoder = FrameCodec::server();
let close = Frame::close(Some(1000), Some("goodbye"));
let mut buf = BytesMut::new();
encoder.encode(close, &mut buf).unwrap();
let close_frame = decoder.decode(&mut buf).unwrap().unwrap();
assert!(close_frame.fin);
assert_eq!(close_frame.opcode, Opcode::Close);
let payload = close_frame.payload;
assert!(payload.len() >= 2);
let code = u16::from_be_bytes([payload[0], payload[1]]);
assert_eq!(code, 1000);
let reason = std::str::from_utf8(&payload[2..]).unwrap();
assert_eq!(reason, "goodbye");
}
#[test]
fn test_payload_length_126() {
let mut encoder = FrameCodec::client();
let mut decoder = FrameCodec::server();
let frame = Frame::binary(Bytes::from(vec![0u8; 200]));
let mut buf = BytesMut::new();
encoder.encode(frame, &mut buf).unwrap();
let parsed = decoder.decode(&mut buf).unwrap().unwrap();
assert_eq!(parsed.payload.len(), 200);
}
#[test]
fn test_payload_length_127() {
let mut encoder = FrameCodec::client();
let mut decoder = FrameCodec::server();
let frame = Frame::binary(Bytes::from(vec![0u8; 70_000]));
let mut buf = BytesMut::new();
encoder.encode(frame, &mut buf).unwrap();
let parsed = decoder.decode(&mut buf).unwrap().unwrap();
assert_eq!(parsed.payload.len(), 70_000);
}
#[test]
fn test_client_masking() {
let mut client_codec = FrameCodec::client();
let mut server_codec = FrameCodec::server();
let frame = Frame::text("masked message");
let mut buf = BytesMut::new();
client_codec.encode(frame, &mut buf).unwrap();
assert!(buf[1] & 0x80 != 0);
let parsed = server_codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(parsed.payload.as_ref(), b"masked message");
}
#[derive(Debug, Clone, Copy)]
struct FixedEntropy([u8; 4]);
impl EntropySource for FixedEntropy {
fn fill_bytes(&self, dest: &mut [u8]) {
for (idx, byte) in dest.iter_mut().enumerate() {
*byte = self.0[idx % self.0.len()];
}
}
fn next_u64(&self) -> u64 {
u64::from_le_bytes([
self.0[0], self.0[1], self.0[2], self.0[3], self.0[0], self.0[1], self.0[2],
self.0[3],
])
}
fn fork(&self, _task_id: crate::types::TaskId) -> std::sync::Arc<dyn EntropySource> {
std::sync::Arc::new(*self)
}
fn source_id(&self) -> &'static str {
"fixed"
}
}
#[test]
fn client_masking_uses_supplied_entropy_source() {
let client_codec = FrameCodec::client();
let mut server_codec = FrameCodec::server();
let mut buf = BytesMut::new();
let entropy = FixedEntropy([0x10, 0x20, 0x30, 0x40]);
client_codec
.encode_with_entropy(&Frame::text("mask-me"), &mut buf, &entropy)
.unwrap();
assert_eq!(&buf[2..6], &[0x10, 0x20, 0x30, 0x40]);
let parsed = server_codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(parsed.payload.as_ref(), b"mask-me");
}
#[test]
fn test_control_frame_too_large() {
let mut codec = FrameCodec::server();
let payload = Bytes::from(vec![0u8; 130]); let mut frame = Frame::ping(Bytes::new());
frame.payload = payload;
let mut buf = BytesMut::new();
let result = codec.encode(frame, &mut buf);
assert!(matches!(result, Err(WsError::ControlFrameTooLarge(_))));
}
#[test]
fn test_fragmented_control_frame_rejected() {
let mut codec = FrameCodec::server();
let mut frame = Frame::ping("data");
frame.fin = false;
let mut buf = BytesMut::new();
let result = codec.encode(frame, &mut buf);
assert!(matches!(result, Err(WsError::FragmentedControlFrame)));
}
#[test]
fn test_partial_frame_returns_none() {
let mut encoder = FrameCodec::client();
let mut decoder = FrameCodec::server();
let frame = Frame::text("Hello");
let mut buf = BytesMut::new();
encoder.encode(frame, &mut buf).unwrap();
let partial = buf.split_to(3);
let mut partial = BytesMut::from(partial.as_ref());
assert!(decoder.decode(&mut partial).unwrap().is_none());
}
#[test]
fn test_empty_payload() {
let mut encoder = FrameCodec::client();
let mut decoder = FrameCodec::server();
let frame = Frame::binary(Bytes::new());
let mut buf = BytesMut::new();
encoder.encode(frame, &mut buf).unwrap();
let parsed = decoder.decode(&mut buf).unwrap().unwrap();
assert!(parsed.payload.is_empty());
}
#[test]
fn test_close_code_is_sendable() {
assert!(CloseCode::Normal.is_sendable());
assert!(CloseCode::GoingAway.is_sendable());
assert!(CloseCode::ProtocolError.is_sendable());
assert!(CloseCode::ServiceRestart.is_sendable());
assert!(CloseCode::TryAgainLater.is_sendable());
assert!(CloseCode::BadGateway.is_sendable());
assert!(!CloseCode::Reserved.is_sendable());
assert!(!CloseCode::NoStatusReceived.is_sendable());
assert!(!CloseCode::Abnormal.is_sendable());
assert!(!CloseCode::TlsHandshake.is_sendable());
}
#[test]
fn test_close_code_from_u16_iana_registered() {
assert_eq!(CloseCode::from_u16(1012), Some(CloseCode::ServiceRestart));
assert_eq!(CloseCode::from_u16(1013), Some(CloseCode::TryAgainLater));
assert_eq!(CloseCode::from_u16(1014), Some(CloseCode::BadGateway));
}
#[test]
fn test_is_valid_code_iana_registered() {
assert!(CloseCode::is_valid_code(1012));
assert!(CloseCode::is_valid_code(1013));
assert!(CloseCode::is_valid_code(1014));
}
#[test]
fn test_invalid_opcode_from_u8() {
for &op in &[0x03, 0x04, 0x05, 0x06, 0x07] {
let result = Opcode::from_u8(op);
assert!(matches!(result, Err(WsError::InvalidOpcode(v)) if v == op));
}
for &op in &[0x0B, 0x0C, 0x0D, 0x0E, 0x0F] {
let result = Opcode::from_u8(op);
assert!(matches!(result, Err(WsError::InvalidOpcode(v)) if v == op));
}
}
#[test]
fn test_opcode_is_data() {
assert!(Opcode::Text.is_data());
assert!(Opcode::Binary.is_data());
assert!(Opcode::Continuation.is_data());
assert!(!Opcode::Close.is_data());
assert!(!Opcode::Ping.is_data());
assert!(!Opcode::Pong.is_data());
}
#[test]
fn test_close_frame_with_code_and_reason() {
let frame = Frame::close(Some(1000), Some("goodbye"));
assert_eq!(frame.opcode, Opcode::Close);
assert!(frame.fin);
assert_eq!(frame.payload.len(), 9);
assert_eq!(&frame.payload[..2], &1000u16.to_be_bytes());
assert_eq!(&frame.payload[2..], b"goodbye");
}
#[test]
fn test_close_frame_code_only() {
let frame = Frame::close(Some(1001), None);
assert_eq!(frame.payload.len(), 2);
assert_eq!(&frame.payload[..], &1001u16.to_be_bytes());
}
#[test]
fn test_close_frame_no_payload() {
let frame = Frame::close(None, None);
assert!(frame.payload.is_empty());
}
#[test]
fn test_ws_error_display_variants() {
let err = WsError::InvalidOpcode(0x0F);
assert!(err.to_string().contains("0xF"));
let err = WsError::ReservedBitsSet;
assert!(err.to_string().contains("reserved bits"));
let err = WsError::PayloadTooLarge {
size: 10_000,
max: 1024,
};
assert!(err.to_string().contains("10000"));
assert!(err.to_string().contains("1024"));
let err = WsError::ControlFrameTooLarge(200);
assert!(err.to_string().contains("200"));
let err = WsError::FragmentedControlFrame;
assert!(err.to_string().contains("fragmented"));
let err = WsError::UnmaskedClientFrame;
assert!(err.to_string().contains("masked"));
let err = WsError::InvalidUtf8;
assert!(err.to_string().contains("UTF-8"));
let err = WsError::InvalidClosePayload;
assert!(err.to_string().contains("close"));
}
#[test]
fn test_roundtrip_server_to_client() {
let mut encoder = FrameCodec::server();
let mut decoder = FrameCodec::client();
let frame = Frame::text("server says hi");
let mut buf = BytesMut::new();
encoder.encode(frame, &mut buf).unwrap();
let parsed = decoder.decode(&mut buf).unwrap().unwrap();
assert_eq!(parsed.opcode, Opcode::Text);
assert!(!parsed.masked);
assert_eq!(parsed.payload.as_ref(), b"server says hi");
}
#[test]
fn test_decode_reserved_bits_rejected() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0xC1);
buf.put_u8(0x05);
buf.put_slice(b"Hello");
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::ReservedBitsSet)));
}
#[test]
fn test_decode_unmasked_client_frame_rejected() {
let mut codec = FrameCodec::server();
let mut buf = BytesMut::new();
buf.put_u8(0x81);
buf.put_u8(0x05);
buf.put_slice(b"Hello");
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::UnmaskedClientFrame)));
}
#[test]
fn test_decode_fragmented_control_rejected() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x09);
buf.put_u8(0x04);
buf.put_slice(b"ping");
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::FragmentedControlFrame)));
}
#[test]
fn test_decode_control_frame_extended_length_rejected() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x89);
buf.put_u8(0x7E);
buf.put_u16(200);
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::ControlFrameTooLarge(_))));
}
#[test]
fn test_decode_multiple_frames_single_buffer() {
let mut encoder = FrameCodec::server();
let mut decoder = FrameCodec::client();
let mut buf = BytesMut::new();
encoder.encode(Frame::text("first"), &mut buf).unwrap();
encoder
.encode(Frame::binary(Bytes::from("second")), &mut buf)
.unwrap();
let frame1 = decoder.decode(&mut buf).unwrap().unwrap();
assert_eq!(frame1.opcode, Opcode::Text);
assert_eq!(frame1.payload.as_ref(), b"first");
let frame2 = decoder.decode(&mut buf).unwrap().unwrap();
assert_eq!(frame2.opcode, Opcode::Binary);
assert_eq!(frame2.payload.as_ref(), b"second");
assert!(decoder.decode(&mut buf).unwrap().is_none());
}
#[test]
fn test_close_frame_reason_without_code_fails_closed() {
let frame = Frame::close(None, Some("going away"));
assert_eq!(frame.opcode, Opcode::Close);
assert!(frame.payload.is_empty());
}
#[test]
fn test_decode_non_minimal_2byte_length_rejected() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x82);
buf.put_u8(0x7E);
buf.put_u16(100);
buf.put_slice(&[0u8; 100]);
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::ProtocolViolation(_))));
}
#[test]
fn test_decode_non_minimal_8byte_length_rejected() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x82);
buf.put_u8(0x7F);
buf.put_u64(200);
buf.put_slice(&[0u8; 200]);
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::ProtocolViolation(_))));
}
#[test]
fn test_decode_8byte_length_msb_set_rejected() {
let mut codec = FrameCodec::client().max_payload_size(usize::MAX); let mut buf = BytesMut::new();
buf.put_u8(0x82);
buf.put_u8(0x7F);
buf.put_u64(0x8000_0000_0000_0100);
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::ProtocolViolation(_))));
}
#[test]
fn test_decode_valid_2byte_length_accepted() {
let mut encoder = FrameCodec::server();
let mut decoder = FrameCodec::client();
let payload = Bytes::from(vec![0xABu8; 126]); let frame = Frame::binary(payload);
let mut buf = BytesMut::new();
encoder.encode(frame, &mut buf).unwrap();
let parsed = decoder.decode(&mut buf).unwrap().unwrap();
assert_eq!(parsed.payload.len(), 126);
}
#[test]
fn decode_close_frame_1byte_payload_rejected() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x88);
buf.put_u8(0x01);
buf.put_u8(0xFF);
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::InvalidClosePayload)));
}
#[test]
fn decode_close_frame_empty_payload_accepted() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x88);
buf.put_u8(0x00);
let frame = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(frame.opcode, Opcode::Close);
assert!(frame.payload.is_empty());
}
#[test]
fn decode_close_frame_2byte_payload_accepted() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x88);
buf.put_u8(0x02);
buf.put_u16(1000);
let frame = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(frame.opcode, Opcode::Close);
assert_eq!(frame.payload.len(), 2);
}
#[test]
fn decode_close_frame_invalid_code_rejected() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x88);
buf.put_u8(0x02);
buf.put_u16(1005);
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::InvalidClosePayload)));
}
#[test]
fn decode_close_frame_invalid_code_poisons_codec() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x88); buf.put_u8(0x02); buf.put_u16(1005);
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::InvalidClosePayload)));
buf.put_u8(0x81); buf.put_u8(0x05); buf.put_slice(b"hello");
let result2 = codec.decode(&mut buf);
assert!(
matches!(&result2, Err(WsError::ProtocolViolation(msg)) if msg.contains("poisoned")),
"codec must be poisoned after close validation error, got: {result2:?}"
);
}
#[test]
fn decode_close_frame_invalid_utf8_reason_rejected() {
let mut codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x88);
buf.put_u8(0x04);
buf.put_u16(1000);
buf.put_slice(&[0xF0, 0x28]);
let result = codec.decode(&mut buf);
assert!(matches!(result, Err(WsError::InvalidClosePayload)));
}
#[test]
fn decode_close_frame_custom_code_with_utf8_reason_accepted() {
let mut codec = FrameCodec::client();
let reason = "custom";
let mut buf = BytesMut::new();
buf.put_u8(0x88);
buf.put_u8((2 + reason.len()) as u8);
buf.put_u16(4000);
buf.put_slice(reason.as_bytes());
let frame = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(frame.opcode, Opcode::Close);
assert_eq!(frame.payload.len(), 2 + reason.len());
assert_eq!(
u16::from_be_bytes([frame.payload[0], frame.payload[1]]),
4000
);
assert_eq!(&frame.payload[2..], reason.as_bytes());
}
#[test]
fn encode_manual_close_frame_invalid_code_rejected() {
let mut codec = FrameCodec::server();
let frame = Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Close,
masked: false,
mask_key: None,
payload: {
let mut payload = BytesMut::with_capacity(2);
payload.put_u16(1005);
payload.freeze()
},
};
let result = codec.encode(frame, &mut BytesMut::new());
assert!(matches!(result, Err(WsError::InvalidClosePayload)));
}
#[test]
fn encode_manual_close_frame_invalid_utf8_reason_rejected() {
let mut codec = FrameCodec::server();
let frame = Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Close,
masked: false,
mask_key: None,
payload: Bytes::from_static(&[0x03, 0xE8, 0xF0, 0x28]),
};
let result = codec.encode(frame, &mut BytesMut::new());
assert!(matches!(result, Err(WsError::InvalidClosePayload)));
}
#[test]
fn encode_manual_close_frame_one_byte_payload_rejected() {
let mut codec = FrameCodec::server();
let frame = Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Close,
masked: false,
mask_key: None,
payload: Bytes::from_static(&[0xFF]),
};
let result = codec.encode(frame, &mut BytesMut::new());
assert!(matches!(result, Err(WsError::InvalidClosePayload)));
}
#[test]
#[should_panic(expected = "is not valid for use in a Close frame")]
fn close_frame_code_1005_panics() {
let _ = Frame::close(Some(1005), None);
}
#[test]
#[should_panic(expected = "is not valid for use in a Close frame")]
fn close_frame_code_1006_panics() {
let _ = Frame::close(Some(1006), None);
}
#[test]
#[should_panic(expected = "is not valid for use in a Close frame")]
fn close_frame_code_1015_panics() {
let _ = Frame::close(Some(1015), None);
}
#[test]
#[should_panic(expected = "is not valid for use in a Close frame")]
fn close_frame_code_1004_panics() {
let _ = Frame::close(Some(1004), None);
}
#[test]
#[should_panic(expected = "is not valid for use in a Close frame")]
fn close_frame_unassigned_received_only_code_panics() {
let _ = Frame::close(Some(1016), None);
}
#[test]
fn received_close_code_1015_rejected() {
assert!(!CloseCode::is_valid_received_code(1015));
}
#[test]
fn received_close_code_unassigned_accepted() {
assert!(CloseCode::is_valid_received_code(1016));
assert!(CloseCode::is_valid_received_code(2000));
assert!(CloseCode::is_valid_received_code(2999));
}
#[test]
fn close_frame_iana_registered_codes_accepted() {
let _ = Frame::close(Some(1012), None);
let _ = Frame::close(Some(1013), None);
let _ = Frame::close(Some(1014), None);
}
#[test]
fn close_frame_valid_codes_accepted() {
let _ = Frame::close(Some(1000), Some("normal"));
let _ = Frame::close(Some(1001), None);
let _ = Frame::close(Some(1002), None);
let _ = Frame::close(Some(1003), None);
let _ = Frame::close(Some(1007), None);
let _ = Frame::close(Some(1008), None);
let _ = Frame::close(Some(1009), None);
let _ = Frame::close(Some(1010), None);
let _ = Frame::close(Some(1011), None);
let _ = Frame::close(Some(4000), Some("app error"));
}
#[test]
fn payload_too_large_rejected_in_7bit_path() {
let mut codec = FrameCodec::client().max_payload_size(50);
let mut buf = BytesMut::new();
buf.put_u8(0x82);
buf.put_u8(100);
buf.put_slice(&[0u8; 100]);
let result = codec.decode(&mut buf);
assert!(matches!(
result,
Err(WsError::PayloadTooLarge { size: 100, max: 50 })
));
}
#[test]
fn mask_involution_empty_payload() {
let mut payload = Vec::new();
apply_mask(&mut payload, [0xAA, 0xBB, 0xCC, 0xDD]);
assert!(payload.is_empty());
}
#[test]
fn mask_involution_all_key_bytes_exercised() {
let mask_key = [0x11, 0x22, 0x33, 0x44];
let mut payload = vec![0x00; 5]; apply_mask(&mut payload, mask_key);
assert_eq!(payload, vec![0x11, 0x22, 0x33, 0x44, 0x11]);
apply_mask(&mut payload, mask_key);
assert_eq!(payload, vec![0x00; 5]);
}
#[test]
fn codec_is_poisoned_after_decode_error() {
let mut codec = FrameCodec::client();
let mut bad_buf = BytesMut::new();
bad_buf.put_u8(0xC1); bad_buf.put_u8(0x05);
bad_buf.put_slice(b"Hello");
let err = codec.decode(&mut bad_buf);
assert!(matches!(err, Err(WsError::ReservedBitsSet)));
let mut good_buf = BytesMut::new();
good_buf.put_u8(0x81);
good_buf.put_u8(0x02);
good_buf.put_slice(b"OK");
let err2 = codec.decode(&mut good_buf);
assert!(matches!(err2, Err(WsError::ProtocolViolation(msg)) if msg.contains("poisoned")));
}
#[test]
fn opcode_debug_clone_copy_hash_eq() {
use std::collections::HashSet;
let op = Opcode::Text;
let dbg = format!("{op:?}");
assert!(dbg.contains("Text"), "{dbg}");
let copied = op;
let cloned = op;
assert_eq!(copied, cloned);
let mut set = HashSet::new();
set.insert(Opcode::Text);
set.insert(Opcode::Binary);
assert_eq!(set.len(), 2);
assert!(set.contains(&Opcode::Text));
}
#[test]
fn frame_debug_clone() {
let f = Frame::text("hello");
let dbg = format!("{f:?}");
assert!(dbg.contains("Frame"), "{dbg}");
let cloned = f;
assert_eq!(cloned.opcode, Opcode::Text);
}
#[test]
fn role_debug_clone_copy_eq() {
let r = Role::Client;
let dbg = format!("{r:?}");
assert!(dbg.contains("Client"), "{dbg}");
let copied = r;
let cloned = r;
assert_eq!(copied, cloned);
assert_ne!(r, Role::Server);
}
#[test]
fn close_code_debug_clone_copy_eq() {
let c = CloseCode::Normal;
let dbg = format!("{c:?}");
assert!(dbg.contains("Normal"), "{dbg}");
let copied = c;
let cloned = c;
assert_eq!(copied, cloned);
assert_ne!(c, CloseCode::GoingAway);
}
}