use bytes::BytesMut;
use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MaskMode {
Unmask,
Mask,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct FrameHeader {
pub fin_rsv_opcode: u8,
pub masked: bool,
pub mask_key: [u8; 4],
pub payload_len: u64,
pub header_len: usize,
}
pub(crate) fn parse_header(
buf: &[u8],
) -> io::Result<Option<FrameHeader>> {
if buf.len() < 2 {
return Ok(None);
}
let b0 = buf[0];
let b1 = buf[1];
let masked = b1 & 0x80 != 0;
let len7 = b1 & 0x7f;
let mut idx = 2usize;
let payload_len = match len7 {
126 => {
if buf.len() < idx + 2 {
return Ok(None);
}
let l = u16::from_be_bytes([buf[idx], buf[idx + 1]]);
idx += 2;
l as u64
}
127 => {
if buf.len() < idx + 8 {
return Ok(None);
}
let mut a = [0u8; 8];
a.copy_from_slice(&buf[idx..idx + 8]);
idx += 8;
let l = u64::from_be_bytes(a);
if l & 0x8000_0000_0000_0000 != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ws: 64-bit frame length has high bit set",
));
}
l
}
n => n as u64,
};
let mut mask_key = [0u8; 4];
if masked {
if buf.len() < idx + 4 {
return Ok(None);
}
mask_key.copy_from_slice(&buf[idx..idx + 4]);
idx += 4;
}
Ok(Some(FrameHeader {
fin_rsv_opcode: b0,
masked,
mask_key,
payload_len,
header_len: idx,
}))
}
pub(crate) fn emit_header(
out: &mut Vec<u8>,
fin_rsv_opcode: u8,
payload_len: u64,
mask_key: Option<[u8; 4]>,
) {
out.push(fin_rsv_opcode);
let mask_bit = if mask_key.is_some() { 0x80 } else { 0 };
if payload_len < 126 {
out.push(mask_bit | payload_len as u8);
} else if payload_len <= u16::MAX as u64 {
out.push(mask_bit | 126);
out.extend_from_slice(&(payload_len as u16).to_be_bytes());
} else {
out.push(mask_bit | 127);
out.extend_from_slice(&payload_len.to_be_bytes());
}
if let Some(k) = mask_key {
out.extend_from_slice(&k);
}
}
fn random_mask() -> [u8; 4] {
use rand_core::{OsRng, RngCore};
let mut k = [0u8; 4];
OsRng.fill_bytes(&mut k);
k
}
enum State {
Header,
Payload { remaining: u64, eff: [u8; 4], offset: u64 },
}
pub async fn translate_masking<R, W>(
reader: &mut R,
writer: &mut W,
mode: MaskMode,
) -> io::Result<()>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut acc = BytesMut::with_capacity(16 * 1024);
let mut read_buf = vec![0u8; 64 * 1024];
let mut state = State::Header;
loop {
match &mut state {
State::Header => match parse_header(&acc)? {
Some(h) => {
let out_key = match mode {
MaskMode::Mask => Some(random_mask()),
MaskMode::Unmask => None,
};
let in_key =
if h.masked { Some(h.mask_key) } else { None };
let mut eff = [0u8; 4];
for j in 0..4 {
eff[j] = in_key.map_or(0, |k| k[j])
^ out_key.map_or(0, |k| k[j]);
}
let mut hdr = Vec::with_capacity(14);
emit_header(
&mut hdr,
h.fin_rsv_opcode,
h.payload_len,
out_key,
);
writer.write_all(&hdr).await?;
let _ = acc.split_to(h.header_len);
state = State::Payload {
remaining: h.payload_len,
eff,
offset: 0,
};
}
None => {
let n = reader.read(&mut read_buf).await?;
if n == 0 {
if acc.is_empty() {
return Ok(());
}
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"ws: partial frame header at EOF",
));
}
acc.extend_from_slice(&read_buf[..n]);
}
},
State::Payload { remaining, eff, offset } => {
if *remaining == 0 {
state = State::Header;
continue;
}
if acc.is_empty() {
let n = reader.read(&mut read_buf).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"ws: truncated frame payload at EOF",
));
}
acc.extend_from_slice(&read_buf[..n]);
}
let take =
std::cmp::min(*remaining, acc.len() as u64) as usize;
let mut chunk = acc.split_to(take);
if *eff != [0u8; 4] {
for (i, b) in chunk.iter_mut().enumerate() {
let p = *offset + i as u64;
*b ^= eff[(p % 4) as usize];
}
}
writer.write_all(&chunk).await?;
*offset += take as u64;
*remaining -= take as u64;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_frame(
fin_rsv_opcode: u8,
payload: &[u8],
mask_key: Option<[u8; 4]>,
) -> Vec<u8> {
let mut out = Vec::new();
emit_header(
&mut out,
fin_rsv_opcode,
payload.len() as u64,
mask_key,
);
match mask_key {
Some(k) => out.extend(
payload
.iter()
.enumerate()
.map(|(i, b)| b ^ k[i % 4]),
),
None => out.extend_from_slice(payload),
}
out
}
async fn run_translate(
input: Vec<u8>,
mode: MaskMode,
) -> Vec<u8> {
let mut reader = std::io::Cursor::new(input);
let mut output: Vec<u8> = Vec::new();
translate_masking(&mut reader, &mut output, mode)
.await
.expect("translate clean");
output
}
#[test]
fn parse_header_short_form() {
let frame = build_frame(0x81, b"hello", None);
let h = parse_header(&frame).unwrap().unwrap();
assert_eq!(h.fin_rsv_opcode, 0x81);
assert!(!h.masked);
assert_eq!(h.payload_len, 5);
assert_eq!(h.header_len, 2);
}
#[test]
fn parse_header_needs_more_bytes() {
assert!(parse_header(&[0x81]).unwrap().is_none());
assert!(parse_header(&[0x81, 126, 0x00]).unwrap().is_none());
assert!(
parse_header(&[0x81, 0x82, 0xaa, 0xbb])
.unwrap()
.is_none()
);
}
#[test]
fn parse_header_126_extended_length() {
let payload = vec![0x5a; 200]; let frame = build_frame(0x82, &payload, None);
let h = parse_header(&frame).unwrap().unwrap();
assert_eq!(h.payload_len, 200);
assert_eq!(h.header_len, 4); }
#[test]
fn parse_header_127_extended_length() {
let mut frame = vec![0x82, 127];
frame.extend_from_slice(&70_000u64.to_be_bytes());
let h = parse_header(&frame).unwrap().unwrap();
assert_eq!(h.payload_len, 70_000);
assert_eq!(h.header_len, 10); }
#[test]
fn parse_header_127_masked() {
let mut frame = vec![0x82, 127 | 0x80];
frame.extend_from_slice(&70_000u64.to_be_bytes());
frame.extend_from_slice(&[1, 2, 3, 4]); let h = parse_header(&frame).unwrap().unwrap();
assert!(h.masked);
assert_eq!(h.mask_key, [1, 2, 3, 4]);
assert_eq!(h.header_len, 14); }
#[test]
fn parse_header_rejects_high_bit_64() {
let mut frame = vec![0x82, 127];
frame.extend_from_slice(&0x8000_0000_0000_0001u64.to_be_bytes());
assert!(parse_header(&frame).is_err());
}
#[tokio::test]
async fn unmask_strips_mask_and_recovers_payload() {
let mask = [0x12, 0x34, 0x56, 0x78];
let frame = build_frame(0x81, b"cross-proto-ping", Some(mask));
let out = run_translate(frame, MaskMode::Unmask).await;
let h = parse_header(&out).unwrap().unwrap();
assert_eq!(h.fin_rsv_opcode, 0x81);
assert!(!h.masked, "output must drop the mask bit");
assert_eq!(h.payload_len, 16);
assert_eq!(&out[h.header_len..], b"cross-proto-ping");
}
#[tokio::test]
async fn mask_adds_mask_and_payload_round_trips() {
let frame = build_frame(0x81, b"to-h1-backend", None);
let out = run_translate(frame, MaskMode::Mask).await;
let h = parse_header(&out).unwrap().unwrap();
assert!(h.masked, "output must set the mask bit");
assert_eq!(h.payload_len, 13);
let unmasked: Vec<u8> = out[h.header_len..]
.iter()
.enumerate()
.map(|(i, b)| b ^ h.mask_key[i % 4])
.collect();
assert_eq!(unmasked, b"to-h1-backend");
}
#[tokio::test]
async fn unmask_sub_4_byte_payload() {
let mask = [0xde, 0xad, 0xbe, 0xef];
let frame = build_frame(0x82, b"ab", Some(mask));
let out = run_translate(frame, MaskMode::Unmask).await;
let h = parse_header(&out).unwrap().unwrap();
assert!(!h.masked);
assert_eq!(&out[h.header_len..], b"ab");
}
#[tokio::test]
async fn unmask_control_frame_mid_stream() {
let mask = [1, 2, 3, 4];
let mut wire = build_frame(0x81, b"first", Some(mask));
wire.extend(build_frame(0x89, b"pong-me", Some(mask))); wire.extend(build_frame(0x81, b"third", Some(mask)));
let out = run_translate(wire, MaskMode::Unmask).await;
let mut off = 0;
let expect: &[(u8, &[u8])] = &[
(0x81, b"first"),
(0x89, b"pong-me"),
(0x81, b"third"),
];
for (opcode, payload) in expect {
let h = parse_header(&out[off..]).unwrap().unwrap();
assert_eq!(h.fin_rsv_opcode, *opcode);
assert!(!h.masked);
let start = off + h.header_len;
let end = start + h.payload_len as usize;
assert_eq!(&out[start..end], *payload);
off = end;
}
assert_eq!(off, out.len());
}
#[tokio::test]
async fn unmask_large_payload_spans_chunks() {
let mask = [0x11, 0x22, 0x33, 0x44];
let payload: Vec<u8> =
(0..200_000u32).map(|i| (i % 251) as u8).collect();
let frame = build_frame(0x82, &payload, Some(mask));
let out = run_translate(frame, MaskMode::Unmask).await;
let h = parse_header(&out).unwrap().unwrap();
assert!(!h.masked);
assert_eq!(h.payload_len, payload.len() as u64);
assert_eq!(&out[h.header_len..], &payload[..]);
}
#[tokio::test]
async fn zero_length_frame_translates() {
let frame = build_frame(0x88, b"", Some([9, 9, 9, 9]));
let out = run_translate(frame, MaskMode::Unmask).await;
let h = parse_header(&out).unwrap().unwrap();
assert_eq!(h.fin_rsv_opcode, 0x88);
assert!(!h.masked);
assert_eq!(h.payload_len, 0);
assert_eq!(out.len(), h.header_len);
}
#[tokio::test]
async fn partial_header_at_eof_errors() {
let mut reader = std::io::Cursor::new(vec![0x81u8]);
let mut out: Vec<u8> = Vec::new();
let err = translate_masking(
&mut reader,
&mut out,
MaskMode::Unmask,
)
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
}