use super::frame::Role;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EncodeError {
ControlPayloadTooLarge(usize),
}
impl std::fmt::Display for EncodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ControlPayloadTooLarge(n) => {
write!(f, "control frame payload too large: {n} bytes (max 125)")
}
}
}
}
impl std::error::Error for EncodeError {}
pub struct FrameHeader {
bytes: [u8; 14],
len: u8,
}
impl FrameHeader {
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len as usize]
}
pub fn len(&self) -> usize {
self.len as usize
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
pub struct FrameWriter {
role: Role,
}
impl FrameWriter {
#[must_use]
pub fn new(role: Role) -> Self {
Self { role }
}
pub fn encode_text(&self, payload: &[u8], dst: &mut [u8]) -> usize {
self.encode(0x81, payload, dst) }
pub fn encode_binary(&self, payload: &[u8], dst: &mut [u8]) -> usize {
self.encode(0x82, payload, dst) }
pub fn encode_ping(&self, payload: &[u8], dst: &mut [u8]) -> Result<usize, EncodeError> {
if payload.len() > 125 {
return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
}
Ok(self.encode(0x89, payload, dst)) }
pub fn encode_pong(&self, payload: &[u8], dst: &mut [u8]) -> Result<usize, EncodeError> {
if payload.len() > 125 {
return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
}
Ok(self.encode(0x8A, payload, dst)) }
pub fn encode_close(
&self,
code: u16,
reason: &[u8],
dst: &mut [u8],
) -> Result<usize, EncodeError> {
let payload_len = 2 + reason.len();
if payload_len > 125 {
return Err(EncodeError::ControlPayloadTooLarge(payload_len));
}
let mut close_payload = [0u8; 125];
close_payload[..2].copy_from_slice(&code.to_be_bytes());
close_payload[2..payload_len].copy_from_slice(reason);
Ok(self.encode(0x88, &close_payload[..payload_len], dst))
}
#[must_use]
pub fn max_encoded_len(&self, payload_len: usize) -> usize {
let header = if payload_len <= 125 {
2
} else if payload_len <= 65535 {
4
} else {
10
};
let mask = if self.role == Role::Client { 4 } else { 0 };
header + mask + payload_len
}
pub fn encode_empty_close(&self, dst: &mut [u8]) -> usize {
self.encode(0x88, &[], dst) }
pub fn encode_close_code(
&self,
code: super::message::CloseCode,
reason: &str,
dst: &mut [u8],
) -> Result<usize, EncodeError> {
assert!(
code != super::message::CloseCode::NoStatus,
"CloseCode::NoStatus cannot be sent on the wire โ use encode_empty_close()"
);
self.encode_close(code.as_u16(), reason.as_bytes(), dst)
}
pub fn build_header(&self, byte0: u8, payload_len: usize) -> (FrameHeader, Option<[u8; 4]>) {
let mask_bit: u8 = if self.role == Role::Client { 0x80 } else { 0 };
let mut hdr = FrameHeader {
bytes: [0; 14],
len: 0,
};
hdr.bytes[0] = byte0;
hdr.len = 1;
if payload_len <= 125 {
hdr.bytes[1] = mask_bit | (payload_len as u8);
hdr.len = 2;
} else if payload_len <= 65535 {
hdr.bytes[1] = mask_bit | 0x7E;
hdr.bytes[2..4].copy_from_slice(&(payload_len as u16).to_be_bytes());
hdr.len = 4;
} else {
hdr.bytes[1] = mask_bit | 0x7F;
hdr.bytes[2..10].copy_from_slice(&(payload_len as u64).to_be_bytes());
hdr.len = 10;
}
let mask_key = if self.role == Role::Client {
let mask = generate_mask();
hdr.bytes[hdr.len as usize..hdr.len as usize + 4].copy_from_slice(&mask);
hdr.len += 4;
Some(mask)
} else {
None
};
(hdr, mask_key)
}
pub fn encode_text_into(&self, payload: &[u8], dst: &mut crate::buf::WriteBuf) {
self.encode_into(0x81, payload, dst);
}
pub fn encode_binary_into(&self, payload: &[u8], dst: &mut crate::buf::WriteBuf) {
self.encode_into(0x82, payload, dst);
}
pub fn encode_ping_into(
&self,
payload: &[u8],
dst: &mut crate::buf::WriteBuf,
) -> Result<(), EncodeError> {
if payload.len() > 125 {
return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
}
self.encode_into(0x89, payload, dst);
Ok(())
}
pub fn encode_pong_into(
&self,
payload: &[u8],
dst: &mut crate::buf::WriteBuf,
) -> Result<(), EncodeError> {
if payload.len() > 125 {
return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
}
self.encode_into(0x8A, payload, dst);
Ok(())
}
pub fn encode_close_into(
&self,
code: u16,
reason: &[u8],
dst: &mut crate::buf::WriteBuf,
) -> Result<(), EncodeError> {
let payload_len = 2 + reason.len();
if payload_len > 125 {
return Err(EncodeError::ControlPayloadTooLarge(payload_len));
}
dst.clear();
dst.append(&code.to_be_bytes());
dst.append(reason);
let (hdr, mask_key) = self.build_header(0x88, payload_len);
if let Some(mask) = mask_key {
super::mask::apply_mask(dst.data_mut(), mask);
}
dst.prepend(hdr.as_bytes());
Ok(())
}
fn encode_into(&self, byte0: u8, payload: &[u8], dst: &mut crate::buf::WriteBuf) {
dst.clear();
dst.append(payload);
let (hdr, mask_key) = self.build_header(byte0, payload.len());
if let Some(mask) = mask_key {
super::mask::apply_mask(dst.data_mut(), mask);
}
dst.prepend(hdr.as_bytes());
}
fn encode(&self, byte0: u8, payload: &[u8], dst: &mut [u8]) -> usize {
let mask_bit: u8 = if self.role == Role::Client { 0x80 } else { 0 };
let payload_len = payload.len();
let mut offset = 0;
dst[offset] = byte0;
offset += 1;
if payload_len <= 125 {
dst[offset] = mask_bit | (payload_len as u8);
offset += 1;
} else if payload_len <= 65535 {
dst[offset] = mask_bit | 0x7E;
offset += 1;
dst[offset..offset + 2].copy_from_slice(&(payload_len as u16).to_be_bytes());
offset += 2;
} else {
dst[offset] = mask_bit | 0x7F;
offset += 1;
dst[offset..offset + 8].copy_from_slice(&(payload_len as u64).to_be_bytes());
offset += 8;
}
if self.role == Role::Client {
let mask = generate_mask();
dst[offset..offset + 4].copy_from_slice(&mask);
offset += 4;
dst[offset..offset + payload_len].copy_from_slice(payload);
super::mask::apply_mask(&mut dst[offset..offset + payload_len], mask);
} else {
dst[offset..offset + payload_len].copy_from_slice(payload);
}
offset + payload_len
}
}
fn generate_mask() -> [u8; 4] {
let mut mask = [0u8; 4];
getrandom::fill(&mut mask).expect("OS randomness unavailable");
mask
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_text_server() {
let writer = FrameWriter::new(Role::Server);
let mut dst = vec![0u8; writer.max_encoded_len(5)];
let n = writer.encode_text(b"Hello", &mut dst);
assert_eq!(n, 7);
assert_eq!(dst[0], 0x81); assert_eq!(dst[1], 0x05); assert_eq!(&dst[2..7], b"Hello");
}
#[test]
fn encode_binary_server() {
let writer = FrameWriter::new(Role::Server);
let mut dst = vec![0u8; writer.max_encoded_len(4)];
let n = writer.encode_binary(&[0xDE, 0xAD, 0xBE, 0xEF], &mut dst);
assert_eq!(n, 6);
assert_eq!(dst[0], 0x82); assert_eq!(&dst[2..6], &[0xDE, 0xAD, 0xBE, 0xEF]);
}
#[test]
fn encode_close_server() {
let writer = FrameWriter::new(Role::Server);
let mut dst = vec![0u8; writer.max_encoded_len(9)];
let n = writer.encode_close(1000, b"goodbye", &mut dst).unwrap();
assert_eq!(dst[0], 0x88); assert_eq!(&dst[2..4], &1000u16.to_be_bytes());
assert_eq!(&dst[4..n], b"goodbye");
}
#[test]
fn encode_ping_server() {
let writer = FrameWriter::new(Role::Server);
let mut dst = vec![0u8; writer.max_encoded_len(4)];
let n = writer.encode_ping(b"ping", &mut dst).unwrap();
assert_eq!(dst[0], 0x89); assert_eq!(&dst[2..n], b"ping");
}
#[test]
fn encode_pong_server() {
let writer = FrameWriter::new(Role::Server);
let mut dst = vec![0u8; writer.max_encoded_len(4)];
let n = writer.encode_pong(b"pong", &mut dst).unwrap();
assert_eq!(dst[0], 0x8A); assert_eq!(&dst[2..n], b"pong");
}
#[test]
fn encode_client_is_masked() {
let writer = FrameWriter::new(Role::Client);
let mut dst = vec![0u8; writer.max_encoded_len(5)];
let n = writer.encode_text(b"Hello", &mut dst);
assert_eq!(n, 11); assert_eq!(dst[0], 0x81); assert_eq!(dst[1] & 0x80, 0x80); assert_eq!(dst[1] & 0x7F, 5); assert_ne!(&dst[6..11], b"Hello");
}
#[test]
fn encode_16bit_length() {
let writer = FrameWriter::new(Role::Server);
let payload = vec![0x42; 256];
let mut dst = vec![0u8; writer.max_encoded_len(256)];
let n = writer.encode_binary(&payload, &mut dst);
assert_eq!(n, 4 + 256); assert_eq!(dst[1] & 0x7F, 126); let len = u16::from_be_bytes([dst[2], dst[3]]);
assert_eq!(len, 256);
}
#[test]
fn max_encoded_len_small() {
let server = FrameWriter::new(Role::Server);
assert_eq!(server.max_encoded_len(0), 2);
assert_eq!(server.max_encoded_len(125), 2 + 125);
assert_eq!(server.max_encoded_len(126), 4 + 126);
let client = FrameWriter::new(Role::Client);
assert_eq!(client.max_encoded_len(0), 2 + 4);
assert_eq!(client.max_encoded_len(125), 2 + 4 + 125);
}
#[test]
fn round_trip_server() {
use crate::ws::{FrameReader, Message};
let writer = FrameWriter::new(Role::Server);
let mut dst = vec![0u8; writer.max_encoded_len(5)];
let n = writer.encode_text(b"Hello", &mut dst);
let mut reader = FrameReader::builder().role(Role::Client).build();
reader.read(&dst[..n]).unwrap();
assert!(matches!(
reader.next().unwrap().unwrap(),
Message::Text("Hello")
));
}
#[test]
fn round_trip_client() {
use crate::ws::{FrameReader, Message};
let writer = FrameWriter::new(Role::Client);
let mut dst = vec![0u8; writer.max_encoded_len(5)];
let n = writer.encode_text(b"Hello", &mut dst);
let mut reader = FrameReader::builder().role(Role::Server).build();
reader.read(&dst[..n]).unwrap();
assert!(matches!(
reader.next().unwrap().unwrap(),
Message::Text("Hello")
));
}
#[test]
fn encode_close_code_round_trip() {
use crate::ws::{CloseCode, FrameReader, Message};
let writer = FrameWriter::new(Role::Server);
let mut dst = vec![0u8; 64];
let n = writer
.encode_close_code(CloseCode::Normal, "goodbye", &mut dst)
.unwrap();
let mut reader = FrameReader::builder().role(Role::Client).build();
reader.read(&dst[..n]).unwrap();
match reader.next().unwrap().unwrap() {
Message::Close(cf) => {
assert_eq!(cf.code, CloseCode::Normal);
assert_eq!(cf.reason, "goodbye");
}
other => panic!("expected Close, got {other:?}"),
}
}
#[test]
fn ping_too_large_returns_err() {
let writer = FrameWriter::new(Role::Server);
let mut dst = vec![0u8; 256];
assert!(matches!(
writer.encode_ping(&[0; 126], &mut dst),
Err(super::EncodeError::ControlPayloadTooLarge(126))
));
}
}