use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::io::{self, ErrorKind};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum OpCode {
Continue = 0x0,
Text = 0x1,
Binary = 0x2,
Close = 0x8,
Ping = 0x9,
Pong = 0xA,
}
impl OpCode {
pub fn from_u8(byte: u8) -> Result<Self, io::Error> {
match byte & 0x0F {
0x0 => Ok(OpCode::Continue),
0x1 => Ok(OpCode::Text),
0x2 => Ok(OpCode::Binary),
0x8 => Ok(OpCode::Close),
0x9 => Ok(OpCode::Ping),
0xA => Ok(OpCode::Pong),
_ => Err(io::Error::new(ErrorKind::InvalidData, "invalid opcode")),
}
}
pub fn is_control(&self) -> bool {
matches!(self, OpCode::Close | OpCode::Ping | OpCode::Pong)
}
}
#[derive(Debug, Clone)]
pub struct Frame {
pub fin: bool,
pub opcode: OpCode,
pub mask: Option<[u8; 4]>,
pub payload: Bytes,
}
impl Frame {
pub fn text(data: impl Into<String>) -> Self {
Self {
fin: true,
opcode: OpCode::Text,
mask: None,
payload: Bytes::from(data.into().into_bytes()),
}
}
pub fn binary(data: impl Into<Bytes>) -> Self {
Self {
fin: true,
opcode: OpCode::Binary,
mask: None,
payload: data.into(),
}
}
pub fn close(code: Option<u16>, reason: Option<&str>) -> Self {
let mut payload = BytesMut::new();
if let Some(code) = code {
payload.put_u16(code);
if let Some(reason) = reason {
payload.put_slice(reason.as_bytes());
}
}
Self {
fin: true,
opcode: OpCode::Close,
mask: None,
payload: payload.freeze(),
}
}
pub fn ping(data: impl Into<Bytes>) -> Self {
Self {
fin: true,
opcode: OpCode::Ping,
mask: None,
payload: data.into(),
}
}
pub fn pong(data: impl Into<Bytes>) -> Self {
Self {
fin: true,
opcode: OpCode::Pong,
mask: None,
payload: data.into(),
}
}
pub fn parse(buf: &mut BytesMut) -> Result<Option<Self>, io::Error> {
Self::parse_with_limits(buf, None)
}
pub fn parse_with_limits(
buf: &mut BytesMut,
max_frame_size: Option<usize>,
) -> Result<Option<Self>, io::Error> {
if buf.len() < 2 {
return Ok(None); }
let first = buf[0];
let fin = (first & 0x80) != 0;
let opcode = OpCode::from_u8(first)?;
let second = buf[1];
let masked = (second & 0x80) != 0;
let mut payload_len = (second & 0x7F) as u64;
let mut header_len = 2;
if payload_len == 126 {
if buf.len() < 4 {
return Ok(None); }
payload_len = u16::from_be_bytes([buf[2], buf[3]]) as u64;
header_len += 2;
} else if payload_len == 127 {
if buf.len() < 10 {
return Ok(None); }
payload_len = u64::from_be_bytes([
buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
]);
header_len += 8;
}
if let Some(max_size) = max_frame_size {
if payload_len > max_size as u64 {
return Err(io::Error::new(
ErrorKind::InvalidData,
format!("Frame size {} exceeds maximum {}", payload_len, max_size),
));
}
}
let mask = if masked {
if buf.len() < header_len + 4 {
return Ok(None);
}
let mask_bytes = [
buf[header_len],
buf[header_len + 1],
buf[header_len + 2],
buf[header_len + 3],
];
header_len += 4;
Some(mask_bytes)
} else {
None
};
let total_len = header_len + payload_len as usize;
if buf.len() < total_len {
return Ok(None);
}
buf.advance(header_len);
let mut payload = buf.split_to(payload_len as usize);
if let Some(mask_key) = mask {
for (i, byte) in payload.iter_mut().enumerate() {
*byte ^= mask_key[i % 4];
}
}
Ok(Some(Frame {
fin,
opcode,
mask,
payload: payload.freeze(),
}))
}
pub fn encode(&self) -> Bytes {
let payload_len = self.payload.len();
let mut buf = BytesMut::new();
let mut first = self.opcode as u8;
if self.fin {
first |= 0x80;
}
buf.put_u8(first);
let mut second = 0u8;
if self.mask.is_some() {
second |= 0x80;
}
if payload_len < 126 {
second |= payload_len as u8;
buf.put_u8(second);
} else if payload_len <= u16::MAX as usize {
second |= 126;
buf.put_u8(second);
buf.put_u16(payload_len as u16);
} else {
second |= 127;
buf.put_u8(second);
buf.put_u64(payload_len as u64);
}
if let Some(mask_key) = self.mask {
buf.put_slice(&mask_key);
}
if let Some(mask_key) = self.mask {
let mut masked = self.payload.to_vec();
for (i, byte) in masked.iter_mut().enumerate() {
*byte ^= mask_key[i % 4];
}
buf.put_slice(&masked);
} else {
buf.put_slice(&self.payload);
}
buf.freeze()
}
}
#[derive(Debug, Clone)]
pub enum Message {
Text(String),
Binary(Bytes),
Ping(Bytes),
Pong(Bytes),
Close(Option<CloseFrame>),
}
#[derive(Debug, Clone)]
pub struct CloseFrame {
pub code: u16,
pub reason: String,
}
impl Message {
pub fn from_frame(frame: Frame) -> Result<Self, io::Error> {
Self::from_frame_with_limit(frame, None)
}
pub fn from_frame_with_limit(
frame: Frame,
max_message_size: Option<usize>,
) -> Result<Self, io::Error> {
if let Some(max_size) = max_message_size {
if frame.payload.len() > max_size {
return Err(io::Error::new(
ErrorKind::InvalidData,
format!(
"Message size {} exceeds maximum {}",
frame.payload.len(),
max_size
),
));
}
}
match frame.opcode {
OpCode::Text => {
let text = String::from_utf8(frame.payload.to_vec())
.map_err(|_| io::Error::new(ErrorKind::InvalidData, "invalid UTF-8"))?;
Ok(Message::Text(text))
}
OpCode::Binary => Ok(Message::Binary(frame.payload)),
OpCode::Ping => Ok(Message::Ping(frame.payload)),
OpCode::Pong => Ok(Message::Pong(frame.payload)),
OpCode::Close => {
if frame.payload.len() >= 2 {
let mut buf = frame.payload.clone();
let code = buf.get_u16();
let reason = if buf.has_remaining() {
String::from_utf8(buf.to_vec()).unwrap_or_default()
} else {
String::new()
};
Ok(Message::Close(Some(CloseFrame { code, reason })))
} else {
Ok(Message::Close(None))
}
}
OpCode::Continue => Err(io::Error::new(
ErrorKind::InvalidData,
"unexpected continuation frame",
)),
}
}
pub fn to_frame(&self) -> Frame {
match self {
Message::Text(text) => Frame::text(text.clone()),
Message::Binary(data) => Frame::binary(data.clone()),
Message::Ping(data) => Frame::ping(data.clone()),
Message::Pong(data) => Frame::pong(data.clone()),
Message::Close(close_frame) => {
if let Some(cf) = close_frame {
Frame::close(Some(cf.code), Some(&cf.reason))
} else {
Frame::close(None, None)
}
}
}
}
pub fn to_fragmented_frames(&self, max_chunk_size: usize) -> Vec<Frame> {
let (opcode, payload) = match self {
Message::Text(text) => (OpCode::Text, Bytes::from(text.as_bytes().to_vec())),
Message::Binary(data) => (OpCode::Binary, data.clone()),
_ => return vec![self.to_frame()],
};
if payload.len() <= max_chunk_size {
return vec![self.to_frame()];
}
let mut frames = Vec::new();
let mut offset = 0;
while offset < payload.len() {
let end = (offset + max_chunk_size).min(payload.len());
let chunk = payload.slice(offset..end);
let is_first = offset == 0;
let is_last = end >= payload.len();
let frame = Frame {
fin: is_last,
opcode: if is_first { opcode } else { OpCode::Continue },
mask: None,
payload: chunk,
};
frames.push(frame);
offset = end;
}
frames
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_text_encode_decode() {
let frame = Frame::text("Hello, WebSocket!");
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = Frame::parse(&mut buf).unwrap().unwrap();
assert!(decoded.fin);
assert_eq!(decoded.opcode, OpCode::Text);
assert_eq!(decoded.payload, Bytes::from("Hello, WebSocket!"));
}
#[test]
fn test_frame_binary_encode_decode() {
let data = vec![1, 2, 3, 4, 5];
let frame = Frame::binary(Bytes::from(data.clone()));
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = Frame::parse(&mut buf).unwrap().unwrap();
assert_eq!(decoded.opcode, OpCode::Binary);
assert_eq!(decoded.payload, Bytes::from(data));
}
#[test]
fn test_frame_close_encode_decode() {
let frame = Frame::close(Some(1000), Some("Normal closure"));
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = Frame::parse(&mut buf).unwrap().unwrap();
assert_eq!(decoded.opcode, OpCode::Close);
assert!(decoded.payload.len() >= 2);
}
#[test]
fn test_frame_masking() {
let mut frame = Frame::text("Test");
frame.mask = Some([1, 2, 3, 4]);
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = Frame::parse(&mut buf).unwrap().unwrap();
assert_eq!(decoded.payload, Bytes::from("Test"));
}
#[test]
fn test_message_from_frame() {
let frame = Frame::text("Hello");
let message = Message::from_frame(frame).unwrap();
match message {
Message::Text(text) => assert_eq!(text, "Hello"),
_ => panic!("Expected text message"),
}
}
#[test]
fn test_extended_payload_length_16bit() {
let payload = vec![0u8; 1000];
let frame = Frame::binary(Bytes::from(payload.clone()));
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = Frame::parse(&mut buf).unwrap().unwrap();
assert_eq!(decoded.payload.len(), 1000);
assert_eq!(decoded.opcode, OpCode::Binary);
}
#[test]
fn test_extended_payload_length_64bit() {
let payload = vec![0u8; 70000];
let frame = Frame::binary(Bytes::from(payload.clone()));
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = Frame::parse(&mut buf).unwrap().unwrap();
assert_eq!(decoded.payload.len(), 70000);
assert_eq!(decoded.opcode, OpCode::Binary);
}
#[test]
fn test_partial_frame_parsing() {
let frame = Frame::text("Hello, WebSocket!");
let encoded = frame.encode();
let mut buf = BytesMut::from(&encoded[..1]);
assert!(Frame::parse(&mut buf).unwrap().is_none());
buf = BytesMut::from(&encoded[..3]);
assert!(Frame::parse(&mut buf).unwrap().is_none());
buf = BytesMut::from(encoded.as_ref());
assert!(Frame::parse(&mut buf).unwrap().is_some());
}
#[test]
fn test_empty_payload() {
let frame = Frame::text("");
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = Frame::parse(&mut buf).unwrap().unwrap();
assert_eq!(decoded.payload.len(), 0);
assert_eq!(decoded.opcode, OpCode::Text);
}
#[test]
fn test_ping_pong_frames() {
let ping_data = Bytes::from("ping");
let ping = Frame::ping(ping_data.clone());
assert_eq!(ping.opcode, OpCode::Ping);
assert_eq!(ping.payload, ping_data);
let pong = Frame::pong(ping_data.clone());
assert_eq!(pong.opcode, OpCode::Pong);
assert_eq!(pong.payload, ping_data);
}
#[test]
fn test_close_frame_with_reason() {
let frame = Frame::close(Some(1000), Some("Normal closure"));
let message = Message::from_frame(frame).unwrap();
match message {
Message::Close(Some(close_frame)) => {
assert_eq!(close_frame.code, 1000);
assert_eq!(close_frame.reason, "Normal closure");
}
_ => panic!("Expected close message with frame"),
}
}
#[test]
fn test_close_frame_without_reason() {
let frame = Frame::close(Some(1001), None);
let message = Message::from_frame(frame).unwrap();
match message {
Message::Close(Some(close_frame)) => {
assert_eq!(close_frame.code, 1001);
assert_eq!(close_frame.reason, "");
}
_ => panic!("Expected close message"),
}
}
#[test]
fn test_close_frame_empty() {
let frame = Frame::close(None, None);
let message = Message::from_frame(frame).unwrap();
match message {
Message::Close(None) => {}
_ => panic!("Expected empty close message"),
}
}
#[test]
fn test_invalid_utf8_in_text_frame() {
let mut frame = Frame::text("test");
frame.payload = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
let result = Message::from_frame(frame);
assert!(result.is_err());
}
#[test]
fn test_message_to_frame_round_trip() {
let messages = vec![
Message::Text("Hello".to_string()),
Message::Binary(Bytes::from(vec![1, 2, 3])),
Message::Ping(Bytes::from("ping")),
Message::Pong(Bytes::from("pong")),
Message::Close(Some(CloseFrame {
code: 1000,
reason: "bye".to_string(),
})),
];
for msg in messages {
let frame = msg.to_frame();
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = Frame::parse(&mut buf).unwrap().unwrap();
let msg_back = Message::from_frame(decoded).unwrap();
assert_eq!(
std::mem::discriminant(&msg),
std::mem::discriminant(&msg_back)
);
}
}
#[test]
fn test_opcode_is_control() {
assert!(OpCode::Close.is_control());
assert!(OpCode::Ping.is_control());
assert!(OpCode::Pong.is_control());
assert!(!OpCode::Text.is_control());
assert!(!OpCode::Binary.is_control());
assert!(!OpCode::Continue.is_control());
}
#[test]
fn test_invalid_opcode() {
let result = OpCode::from_u8(0xFF);
assert!(result.is_err());
}
}