#![allow(clippy::all)]
#[cfg(test)]
mod tests {
use super::super::frame::{Frame, FrameCodec, Opcode, WsError};
use crate::bytes::{Bytes, BytesMut};
use crate::codec::{Decoder, Encoder};
use crate::util::EntropySource;
#[derive(Debug)]
struct DeterministicEntropy {
sequence: [u8; 16],
counter: std::sync::atomic::AtomicUsize,
}
impl Clone for DeterministicEntropy {
fn clone(&self) -> Self {
Self {
sequence: self.sequence,
counter: std::sync::atomic::AtomicUsize::new(
self.counter.load(std::sync::atomic::Ordering::Relaxed),
),
}
}
}
impl DeterministicEntropy {
fn new(seed: u64) -> Self {
let mut sequence = [0u8; 16];
for (i, byte) in sequence.iter_mut().enumerate() {
*byte = ((seed ^ (i as u64)) & 0xFF) as u8;
}
Self {
sequence,
counter: std::sync::atomic::AtomicUsize::new(0),
}
}
fn next_key(&self) -> [u8; 4] {
let idx = self
.counter
.fetch_add(4, std::sync::atomic::Ordering::Relaxed)
% 16;
[
self.sequence[idx],
self.sequence[(idx + 1) % 16],
self.sequence[(idx + 2) % 16],
self.sequence[(idx + 3) % 16],
]
}
}
impl EntropySource for DeterministicEntropy {
fn fill_bytes(&self, dest: &mut [u8]) {
for (i, byte) in dest.iter_mut().enumerate() {
let idx = (self.counter.load(std::sync::atomic::Ordering::Relaxed) + i) % 16;
*byte = self.sequence[idx];
}
self.counter
.fetch_add(dest.len(), std::sync::atomic::Ordering::Relaxed);
}
fn next_u64(&self) -> u64 {
let mut bytes = [0u8; 8];
self.fill_bytes(&mut bytes);
u64::from_le_bytes(bytes)
}
fn fork(&self, _task_id: crate::types::TaskId) -> std::sync::Arc<dyn EntropySource> {
std::sync::Arc::new(self.clone())
}
fn source_id(&self) -> &'static str {
"deterministic"
}
}
fn extract_mask_key(encoded: &[u8]) -> Option<[u8; 4]> {
if encoded.len() < 2 {
return None;
}
let second_byte = encoded[1];
let masked = (second_byte & 0x80) != 0;
if !masked {
return None;
}
let payload_len_7 = second_byte & 0x7F;
let mask_offset = match payload_len_7 {
0..=125 => 2,
126 => 4, 127 => 10, _ => return None,
};
if encoded.len() < mask_offset + 4 {
return None;
}
Some([
encoded[mask_offset],
encoded[mask_offset + 1],
encoded[mask_offset + 2],
encoded[mask_offset + 3],
])
}
fn has_mask_bit(encoded: &[u8]) -> bool {
encoded.len() >= 2 && (encoded[1] & 0x80) != 0
}
#[test]
fn client_text_frame_must_be_masked() {
let mut codec = FrameCodec::client();
let frame = Frame::text("Hello, WebSocket!");
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
assert!(has_mask_bit(&buf), "Client text frame missing mask bit");
let mask_key = extract_mask_key(&buf);
assert!(mask_key.is_some(), "Client text frame missing masking-key");
let payload_start = if buf[1] & 0x7F <= 125 { 6 } else { 8 }; let masked_payload = &buf[payload_start..];
assert_ne!(masked_payload, b"Hello, WebSocket!", "Payload not masked");
}
#[test]
fn client_binary_frame_must_be_masked() {
let mut codec = FrameCodec::client();
let original = vec![0x00, 0x01, 0x02, 0xFF, 0xAA, 0xBB];
let frame = Frame::binary(Bytes::copy_from_slice(&original));
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
assert!(has_mask_bit(&buf), "Client binary frame missing mask bit");
assert!(
extract_mask_key(&buf).is_some(),
"Client binary frame missing masking-key"
);
let payload_start = 6; let masked_payload = &buf[payload_start..];
assert_ne!(masked_payload, &original, "Binary payload not masked");
}
#[test]
fn client_ping_frame_must_be_masked() {
let mut codec = FrameCodec::client();
let frame = Frame::ping("ping-test");
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
assert!(has_mask_bit(&buf), "Client ping frame missing mask bit");
assert!(
extract_mask_key(&buf).is_some(),
"Client ping frame missing masking-key"
);
let mut server_codec = FrameCodec::server();
let mut decode_buf = BytesMut::from(buf.as_ref());
let decoded = server_codec.decode(&mut decode_buf).unwrap().unwrap();
assert_eq!(decoded.opcode, Opcode::Ping);
assert_eq!(decoded.payload.as_ref(), b"ping-test");
}
#[test]
fn client_pong_frame_must_be_masked() {
let mut codec = FrameCodec::client();
let frame = Frame::pong("pong-response");
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
assert!(has_mask_bit(&buf), "Client pong frame missing mask bit");
assert!(
extract_mask_key(&buf).is_some(),
"Client pong frame missing masking-key"
);
}
#[test]
fn client_close_frame_must_be_masked() {
let mut codec = FrameCodec::client();
let frame = Frame::close(Some(1000), Some("goodbye"));
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
assert!(has_mask_bit(&buf), "Client close frame missing mask bit");
assert!(
extract_mask_key(&buf).is_some(),
"Client close frame missing masking-key"
);
let mut server_codec = FrameCodec::server();
let mut decode_buf = BytesMut::from(buf.as_ref());
let decoded = server_codec.decode(&mut decode_buf).unwrap().unwrap();
assert_eq!(decoded.opcode, Opcode::Close);
let payload = decoded.payload;
assert!(payload.len() >= 2);
let code = u16::from_be_bytes([payload[0], payload[1]]);
assert_eq!(code, 1000);
let reason = std::str::from_utf8(&payload[2..]).unwrap();
assert_eq!(reason, "goodbye");
}
#[test]
fn server_frames_must_not_be_masked() {
let mut codec = FrameCodec::server();
let frames = [
Frame::text("server message"),
Frame::binary(vec![1, 2, 3, 4]),
Frame::ping("server ping"),
Frame::pong("server pong"),
Frame::close(Some(1000), Some("server close")),
];
for frame in &frames {
let mut buf = BytesMut::new();
codec.encode(frame.clone(), &mut buf).unwrap();
assert!(
!has_mask_bit(&buf),
"Server frame incorrectly masked: {frame:?}"
);
assert_eq!(
extract_mask_key(&buf),
None,
"Server frame has unexpected masking-key: {frame:?}"
);
}
}
#[test]
fn server_rejects_unmasked_client_frames() {
let mut server_codec = FrameCodec::server();
let mut buf = BytesMut::new();
buf.put_u8(0x81); buf.put_u8(0x05); buf.put_slice(b"hello");
let result = server_codec.decode(&mut buf);
assert!(
matches!(result, Err(WsError::UnmaskedClientFrame)),
"Server must reject unmasked client frame, got: {result:?}"
);
}
#[test]
fn client_rejects_masked_server_frames() {
let mut client_codec = FrameCodec::client();
let mut buf = BytesMut::new();
buf.put_u8(0x81); buf.put_u8(0x85); buf.put_slice(&[0x12, 0x34, 0x56, 0x78]); buf.put_slice(b"hello");
let result = client_codec.decode(&mut buf);
assert!(
matches!(result, Err(WsError::MaskedServerFrame)),
"Client must reject masked server frame, got: {result:?}"
);
}
#[test]
fn masking_algorithm_xor_correctness() {
let original = b"WebSocket masking test with longer payload to exercise all key positions!";
let mask_key = [0x37, 0xFA, 0x21, 0x3D];
let mut expected_masked = original.to_vec();
for (i, byte) in expected_masked.iter_mut().enumerate() {
*byte ^= mask_key[i % 4];
}
let entropy = DeterministicEntropy::new(0x123456789ABCDEF0);
let _generated_key = entropy.next_key();
let entropy_with_target_key = DeterministicEntropy {
sequence: [
0x37, 0xFA, 0x21, 0x3D, 0x37, 0xFA, 0x21, 0x3D, 0x37, 0xFA, 0x21, 0x3D, 0x37, 0xFA,
0x21, 0x3D,
],
counter: std::sync::atomic::AtomicUsize::new(0),
};
let codec = FrameCodec::client();
let frame = Frame::text(std::str::from_utf8(original).unwrap());
let mut buf = BytesMut::new();
codec
.encode_with_entropy(&frame, &mut buf, &entropy_with_target_key)
.unwrap();
let actual_key = extract_mask_key(&buf).unwrap();
assert_eq!(actual_key, mask_key, "Mask key mismatch");
let payload_start = 6; let actual_masked = &buf[payload_start..];
assert_eq!(
actual_masked, &expected_masked,
"Masked payload doesn't match RFC 6455 algorithm"
);
let mut server_codec = FrameCodec::server();
let mut decode_buf = BytesMut::from(buf.as_ref());
let decoded = server_codec.decode(&mut decode_buf).unwrap().unwrap();
assert_eq!(
decoded.payload.as_ref(),
original,
"Unmasking failed to restore original"
);
}
#[test]
fn masking_involution_property() {
use super::super::frame::apply_mask;
let test_cases = [
b"" as &[u8], b"A", b"AB", b"ABC", b"ABCD", b"ABCDE", b"Hello, WebSocket world!", &[0x00, 0xFF, 0x80, 0x7F, 0x55, 0xAA], ];
let mask_key = [0x12, 0x34, 0x56, 0x78];
for &original in &test_cases {
let mut payload = original.to_vec();
let backup = payload.clone();
apply_mask(&mut payload, mask_key);
if !original.is_empty() {
assert_ne!(payload, backup, "Masking should change non-empty payload");
}
apply_mask(&mut payload, mask_key);
assert_eq!(
payload, backup,
"Double masking should restore original for: {original:?}"
);
}
}
#[test]
fn client_uses_fresh_mask_keys() {
let mut codec = FrameCodec::client();
let mut used_keys = std::collections::HashSet::new();
for i in 0..20 {
let frame = Frame::text(format!("message {i}"));
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
let mask_key = extract_mask_key(&buf).expect("Frame should have mask key");
assert!(
used_keys.insert(mask_key),
"Mask key reused: {mask_key:?} (frame {i})"
);
}
assert_eq!(used_keys.len(), 20, "All mask keys should be unique");
}
#[test]
fn mask_keys_have_sufficient_entropy() {
let mut codec = FrameCodec::client();
let mut key_bytes = Vec::new();
for i in 0..100 {
let frame = Frame::binary(vec![i as u8; 10]);
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
let mask_key = extract_mask_key(&buf).expect("Frame should have mask key");
key_bytes.extend_from_slice(&mask_key);
}
let mut byte_counts = [0; 256];
for &byte in &key_bytes {
byte_counts[byte as usize] += 1;
}
let unique_bytes = byte_counts.iter().filter(|&&count| count > 0).count();
assert!(
unique_bytes >= 200, "Poor entropy in mask keys: only {unique_bytes}/256 byte values seen"
);
}
#[test]
fn client_server_roundtrip_preserves_payload() {
let mut client_codec = FrameCodec::client();
let mut server_codec = FrameCodec::server();
let large_text = "A".repeat(1000);
let test_payloads = vec![
(String::new(), Opcode::Text),
("Hello".to_string(), Opcode::Text),
(
"WebSocket test with special chars: üñíçødé".to_string(),
Opcode::Text,
),
(large_text, Opcode::Text), (String::new(), Opcode::Binary),
];
let binary_payloads = [
vec![], vec![0x00, 0x01, 0x02, 0xFF], vec![0x00; 500], ];
for (payload, opcode) in &test_payloads {
let frame = Frame::text(payload.clone());
let mut buf = BytesMut::new();
client_codec.encode(frame.clone(), &mut buf).unwrap();
assert!(has_mask_bit(&buf), "Client frame should be masked");
let decoded = server_codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(
decoded.opcode, *opcode,
"Opcode mismatch for payload: {payload:?}"
);
assert_eq!(
decoded.payload.as_ref(),
payload.as_bytes(),
"Payload mismatch for: {payload:?}"
);
assert!(
decoded.masked,
"Decoded frame should indicate it was masked"
);
assert!(
decoded.mask_key.is_some(),
"Decoded frame should have mask key"
);
}
for payload_data in &binary_payloads {
let frame = Frame::binary(payload_data.clone());
let mut buf = BytesMut::new();
client_codec.encode(frame.clone(), &mut buf).unwrap();
assert!(has_mask_bit(&buf), "Client frame should be masked");
let decoded = server_codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(
decoded.opcode,
Opcode::Binary,
"Opcode mismatch for binary payload"
);
assert_eq!(
decoded.payload.as_ref(),
payload_data.as_slice(),
"Binary payload mismatch"
);
assert!(
decoded.masked,
"Decoded frame should indicate it was masked"
);
assert!(
decoded.mask_key.is_some(),
"Decoded frame should have mask key"
);
}
}
#[test]
fn server_client_roundtrip_no_masking() {
let mut server_codec = FrameCodec::server();
let mut client_codec = FrameCodec::client();
let frames = [
Frame::text("server to client"),
Frame::binary(Vec::from(&b"binary data"[..])),
Frame::ping("ping from server"),
Frame::pong("pong from server"),
];
for frame in &frames {
let mut buf = BytesMut::new();
server_codec.encode(frame.clone(), &mut buf).unwrap();
assert!(!has_mask_bit(&buf), "Server frame should not be masked");
let decoded = client_codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.opcode, frame.opcode);
assert_eq!(decoded.payload, frame.payload);
assert!(
!decoded.masked,
"Server frame should not be marked as masked"
);
assert_eq!(
decoded.mask_key, None,
"Server frame should have no mask key"
);
}
}
#[test]
fn masking_empty_payload() {
let mut codec = FrameCodec::client();
let frame = Frame::text("");
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
assert!(has_mask_bit(&buf), "Empty frame should still have mask bit");
assert!(
extract_mask_key(&buf).is_some(),
"Empty frame should still have mask key"
);
let mut server_codec = FrameCodec::server();
let decoded = server_codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.payload.len(), 0);
assert!(decoded.masked);
}
#[test]
fn masking_large_payload() {
let payload_size = 70_000; let large_payload = "X".repeat(payload_size);
let mut client_codec = FrameCodec::client();
let frame = Frame::text(large_payload.clone());
let mut buf = BytesMut::new();
client_codec.encode(frame, &mut buf).unwrap();
assert!(has_mask_bit(&buf), "Large frame should be masked");
let mut server_codec = FrameCodec::server();
let decoded = server_codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.payload.len(), payload_size);
assert_eq!(decoded.payload.as_ref(), large_payload.as_bytes());
assert!(decoded.masked);
}
#[test]
fn control_frame_masking_with_max_payload() {
let max_control_payload = "A".repeat(125);
let mut client_codec = FrameCodec::client();
let frame = Frame::ping(max_control_payload.clone());
let mut buf = BytesMut::new();
client_codec.encode(frame, &mut buf).unwrap();
assert!(
has_mask_bit(&buf),
"Max-size control frame should be masked"
);
let mut server_codec = FrameCodec::server();
let decoded = server_codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.opcode, Opcode::Ping);
assert_eq!(decoded.payload.len(), 125);
assert_eq!(decoded.payload.as_ref(), max_control_payload.as_bytes());
assert!(decoded.masked);
}
#[test]
fn server_close_for_unmasked_frame_uses_protocol_error() {
let err = WsError::UnmaskedClientFrame;
let close_code = err.as_close_code();
use super::super::frame::CloseCode;
assert_eq!(close_code as u16, CloseCode::ProtocolError as u16);
}
#[test]
fn protocol_error_close_code_is_valid() {
use super::super::frame::CloseCode;
assert!(CloseCode::ProtocolError.is_sendable());
assert!(CloseCode::is_valid_code(CloseCode::ProtocolError as u16));
let close_frame = Frame::close(Some(1002), Some("Protocol Error"));
assert_eq!(close_frame.opcode, Opcode::Close);
}
#[test]
fn rfc_6455_example_masking_vector() {
use super::super::frame::apply_mask;
let mut payload = b"Hello".to_vec();
let mask_key = [0x37, 0xfa, 0x21, 0x3d];
apply_mask(&mut payload, mask_key);
let expected_masked = [0x7f, 0x9f, 0x4d, 0x51, 0x58];
assert_eq!(payload, expected_masked, "RFC 6455 test vector failed");
apply_mask(&mut payload, mask_key);
assert_eq!(payload, b"Hello");
}
#[test]
fn comprehensive_masking_conformance_validation() {
let mut client = FrameCodec::client();
let mut server = FrameCodec::server();
let client_frames = [
Frame::text("client text"),
Frame::binary(Vec::from(&b"client binary"[..])),
Frame::ping("client ping"),
Frame::pong("client pong"),
Frame::close(Some(1000), Some("client close")),
];
for frame in &client_frames {
let mut buf = BytesMut::new();
client.encode(frame.clone(), &mut buf).unwrap();
assert!(
has_mask_bit(&buf),
"❌ Client frame missing mask bit: {frame:?}"
);
assert!(
extract_mask_key(&buf).is_some(),
"❌ Client frame missing mask key: {frame:?}"
);
let mut decode_buf = BytesMut::from(buf.as_ref());
let decoded = server.decode(&mut decode_buf).unwrap().unwrap();
assert_eq!(
decoded.opcode, frame.opcode,
"❌ Opcode mismatch: {frame:?}"
);
assert_eq!(
decoded.payload, frame.payload,
"❌ Payload mismatch: {frame:?}"
);
}
let server_frames = [
Frame::text("server text"),
Frame::binary(Vec::from(&b"server binary"[..])),
Frame::ping("server ping"),
Frame::pong("server pong"),
Frame::close(Some(1000), Some("server close")),
];
for frame in &server_frames {
let mut buf = BytesMut::new();
server.encode(frame.clone(), &mut buf).unwrap();
assert!(
!has_mask_bit(&buf),
"❌ Server frame incorrectly masked: {frame:?}"
);
assert_eq!(
extract_mask_key(&buf),
None,
"❌ Server frame has mask key: {frame:?}"
);
let mut client_decoder = FrameCodec::client();
let mut decode_buf = BytesMut::from(buf.as_ref());
let decoded = client_decoder.decode(&mut decode_buf).unwrap().unwrap();
assert_eq!(
decoded.opcode, frame.opcode,
"❌ Opcode mismatch: {frame:?}"
);
assert_eq!(
decoded.payload, frame.payload,
"❌ Payload mismatch: {frame:?}"
);
}
}
}