use crate::error::Error;
use crate::websocket_opcode::Opcode;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Frame {
pub fin: bool,
pub rsv1: bool,
pub rsv2: bool,
pub rsv3: bool,
pub opcode: Opcode,
pub payload: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DecodedFrame {
pub frame: Frame,
pub masked: bool,
}
impl Frame {
pub fn new(opcode: Opcode, payload: Vec<u8>) -> Self {
Self {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode,
payload,
}
}
pub fn text(payload: &str) -> Self {
Self::new(Opcode::Text, payload.as_bytes().to_vec())
}
pub fn binary(payload: Vec<u8>) -> Self {
Self::new(Opcode::Binary, payload)
}
pub fn ping(payload: Vec<u8>) -> Result<Self, Error> {
if payload.len() > 125 {
return Err(Error::invalid_input(format!(
"ping payload exceeds 125 bytes: {} bytes",
payload.len()
)));
}
Ok(Self::new(Opcode::Ping, payload))
}
pub fn pong(payload: Vec<u8>) -> Result<Self, Error> {
if payload.len() > 125 {
return Err(Error::invalid_input(format!(
"pong payload exceeds 125 bytes: {} bytes",
payload.len()
)));
}
Ok(Self::new(Opcode::Pong, payload))
}
pub fn close(code: Option<u16>, reason: &str) -> Result<Self, Error> {
let payload = match code {
Some(c) => {
if reason.len() > 123 {
return Err(Error::invalid_input(format!(
"close reason exceeds 123 bytes: {} bytes",
reason.len()
)));
}
let mut p = Vec::with_capacity(2 + reason.len());
p.extend_from_slice(&c.to_be_bytes());
p.extend_from_slice(reason.as_bytes());
p
}
None => Vec::new(),
};
Ok(Self::new(Opcode::Close, payload))
}
pub fn encode(&self, masking_key: [u8; 4]) -> Vec<u8> {
self.encode_internal(true, masking_key)
}
#[allow(dead_code)]
pub fn encode_unmasked(&self) -> Vec<u8> {
self.encode_internal(false, [0; 4])
}
fn encode_internal(&self, masked: bool, masking_key: [u8; 4]) -> Vec<u8> {
let payload_len = self.payload.len();
let header_size =
2 + if payload_len >= 65536 {
8
} else if payload_len >= 126 {
2
} else {
0
} + if masked { 4 } else { 0 };
let mut buf = Vec::with_capacity(header_size + payload_len);
let byte0 = (if self.fin { 0x80 } else { 0 })
| (if self.rsv1 { 0x40 } else { 0 })
| (if self.rsv2 { 0x20 } else { 0 })
| (if self.rsv3 { 0x10 } else { 0 })
| self.opcode.as_u8();
buf.push(byte0);
let mask_bit = if masked { 0x80 } else { 0 };
if payload_len >= 65536 {
buf.push(mask_bit | 127);
buf.extend_from_slice(&(payload_len as u64).to_be_bytes());
} else if payload_len >= 126 {
buf.push(mask_bit | 126);
buf.extend_from_slice(&(payload_len as u16).to_be_bytes());
} else {
buf.push(mask_bit | payload_len as u8);
}
if masked {
buf.extend_from_slice(&masking_key);
}
if masked {
for (i, byte) in self.payload.iter().enumerate() {
buf.push(byte ^ masking_key[i % 4]);
}
} else {
buf.extend_from_slice(&self.payload);
}
buf
}
}
pub struct FrameDecoder {
buf: Vec<u8>,
}
impl FrameDecoder {
pub fn new() -> Self {
Self { buf: Vec::new() }
}
pub fn feed(&mut self, data: &[u8]) {
self.buf.extend_from_slice(data);
}
pub fn decode(&mut self) -> Result<Option<Frame>, Error> {
self.decode_with_info()
.map(|opt| opt.map(|decoded| decoded.frame))
}
pub fn decode_with_info(&mut self) -> Result<Option<DecodedFrame>, Error> {
if self.buf.len() < 2 {
return Ok(None);
}
let byte0 = self.buf[0];
let byte1 = self.buf[1];
let fin = (byte0 & 0x80) != 0;
let rsv1 = (byte0 & 0x40) != 0;
let rsv2 = (byte0 & 0x20) != 0;
let rsv3 = (byte0 & 0x10) != 0;
let opcode_value = byte0 & 0x0F;
let opcode = Opcode::from_u8(opcode_value)
.ok_or_else(|| Error::protocol_violation(format!("unknown opcode: {opcode_value}")))?;
let masked = (byte1 & 0x80) != 0;
let payload_len_7 = byte1 & 0x7F;
let (payload_len, header_len): (usize, usize) = match payload_len_7 {
127 => {
if self.buf.len() < 10 {
return Ok(None);
}
if self.buf[2] & 0x80 != 0 {
return Err(Error::protocol_violation(
"64-bit payload length MSB must be 0",
));
}
let len = u64::from_be_bytes([
self.buf[2],
self.buf[3],
self.buf[4],
self.buf[5],
self.buf[6],
self.buf[7],
self.buf[8],
self.buf[9],
]);
let len = usize::try_from(len)
.map_err(|_| Error::protocol_violation("payload length too large"))?;
if len <= 65535 {
return Err(Error::protocol_violation(
"64-bit payload length must be > 65535 (non-minimal encoding)",
));
}
(len, 10)
}
126 => {
if self.buf.len() < 4 {
return Ok(None);
}
let len = u16::from_be_bytes([self.buf[2], self.buf[3]]) as usize;
if len < 126 {
return Err(Error::protocol_violation(
"16-bit payload length must be >= 126 (non-minimal encoding)",
));
}
(len, 4)
}
_ => (payload_len_7 as usize, 2),
};
let masking_key_len = if masked { 4 } else { 0 };
let total_len = header_len
.checked_add(masking_key_len)
.and_then(|len| len.checked_add(payload_len))
.ok_or_else(|| Error::protocol_violation("payload length too large"))?;
if self.buf.len() < total_len {
return Ok(None);
}
let masking_key = if masked {
[
self.buf[header_len],
self.buf[header_len + 1],
self.buf[header_len + 2],
self.buf[header_len + 3],
]
} else {
[0; 4]
};
let payload_start = header_len + masking_key_len;
let mut payload = self.buf[payload_start..payload_start + payload_len].to_vec();
if masked {
for (i, byte) in payload.iter_mut().enumerate() {
*byte ^= masking_key[i % 4];
}
}
if opcode.is_control() {
if !fin {
return Err(Error::protocol_violation(
"control frame must not be fragmented",
));
}
if payload_len > 125 {
return Err(Error::protocol_violation("control frame payload too large"));
}
}
self.buf.drain(..total_len);
Ok(Some(DecodedFrame {
frame: Frame {
fin,
rsv1,
rsv2,
rsv3,
opcode,
payload,
},
masked,
}))
}
pub fn clear(&mut self) {
self.buf.clear();
}
#[allow(dead_code)]
pub fn buffer_len(&self) -> usize {
self.buf.len()
}
}
impl Default for FrameDecoder {
fn default() -> Self {
Self::new()
}
}