#![cfg_attr(target_arch = "wasm32", allow(dead_code))]
use bytes::Bytes;
use crate::{close::CloseCode, WebSocketError};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum OpCode {
Continuation,
Text,
Binary,
Close,
Ping,
Pong,
}
impl OpCode {
pub fn is_control(&self) -> bool {
matches!(*self, OpCode::Close | OpCode::Ping | OpCode::Pong)
}
}
impl TryFrom<u8> for OpCode {
type Error = WebSocketError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0x0 => Ok(Self::Continuation),
0x1 => Ok(Self::Text),
0x2 => Ok(Self::Binary),
0x8 => Ok(Self::Close),
0x9 => Ok(Self::Ping),
0xA => Ok(Self::Pong),
_ => Err(WebSocketError::InvalidOpCode(value)),
}
}
}
impl From<OpCode> for u8 {
fn from(val: OpCode) -> Self {
match val {
OpCode::Continuation => 0x0,
OpCode::Text => 0x1,
OpCode::Binary => 0x2,
OpCode::Close => 0x8,
OpCode::Ping => 0x9,
OpCode::Pong => 0xA,
}
}
}
impl From<Frame> for (OpCode, Bytes) {
fn from(val: Frame) -> Self {
(val.opcode, val.payload)
}
}
impl<T> From<(OpCode, T)> for Frame
where
T: Into<Bytes>,
{
fn from((opcode, payload): (OpCode, T)) -> Self {
Self {
fin: true,
opcode,
mask: None,
payload: payload.into(),
is_compressed: false,
}
}
}
#[derive(Clone)]
pub struct Frame {
pub(crate) fin: bool,
pub(crate) opcode: OpCode,
pub(super) is_compressed: bool,
pub(super) mask: Option<[u8; 4]>,
pub(crate) payload: Bytes,
}
pub(crate) const MAX_HEAD_SIZE: usize = 16;
impl Frame {
pub fn text(payload: impl Into<Bytes>) -> Self {
Self {
fin: true,
opcode: OpCode::Text,
mask: None,
payload: payload.into(),
is_compressed: false,
}
}
pub fn binary(payload: impl Into<Bytes>) -> Self {
Self {
fin: true,
opcode: OpCode::Binary,
mask: None,
payload: payload.into(),
is_compressed: false,
}
}
pub fn ping(payload: impl Into<Bytes>) -> Self {
Self {
fin: true,
opcode: OpCode::Ping,
mask: None,
payload: payload.into(),
is_compressed: false,
}
}
pub fn pong(payload: impl Into<Bytes>) -> Self {
Self {
fin: true,
opcode: OpCode::Pong,
mask: None,
payload: payload.into(),
is_compressed: false,
}
}
pub fn continuation(payload: impl Into<Bytes>) -> Self {
Self {
fin: true,
opcode: OpCode::Continuation,
mask: None,
payload: payload.into(),
is_compressed: false,
}
}
pub fn with_fin(mut self, fin: bool) -> Self {
self.fin = fin;
self
}
pub fn close(code: CloseCode, reason: impl AsRef<[u8]>) -> Self {
let code16 = u16::from(code);
let reason: &[u8] = reason.as_ref();
let mut payload = Vec::with_capacity(2 + reason.len());
payload.extend_from_slice(&code16.to_be_bytes());
payload.extend_from_slice(reason);
Self {
fin: true,
opcode: OpCode::Close,
mask: None,
payload: payload.into(),
is_compressed: false,
}
}
pub(crate) fn into_fragments(self, partition: usize) -> impl Iterator<Item = Frame> {
struct Split {
index: usize,
max_size: usize,
frame: Option<Frame>,
}
impl Iterator for Split {
type Item = Frame;
fn next(&mut self) -> Option<Self::Item> {
let mut frame = self.frame.take()?;
if frame.payload.len() <= self.max_size {
if self.index != 0 {
frame.set_fin(true);
frame.opcode = OpCode::Continuation;
}
Some(frame)
} else {
let is_first = self.index == 0;
self.index += 1;
let chunk = frame.payload.split_to(self.max_size);
let opcode = if is_first {
frame.opcode
} else {
OpCode::Continuation
};
let mask = frame.mask;
self.frame = Some(frame);
Some(Frame::new(false, opcode, mask, chunk))
}
}
}
Split {
index: 0,
max_size: partition,
frame: Some(self),
}
}
pub(crate) fn close_raw<T: Into<Bytes>>(payload: T) -> Self {
Self {
fin: true,
opcode: OpCode::Close,
mask: None,
is_compressed: false,
payload: payload.into(),
}
}
pub(super) fn new(
fin: bool,
opcode: OpCode,
mask: Option<[u8; 4]>,
payload: impl Into<Bytes>,
) -> Self {
Self {
fin,
opcode,
mask,
payload: payload.into(),
is_compressed: false,
}
}
#[inline(always)]
pub fn opcode(&self) -> OpCode {
self.opcode
}
#[inline(always)]
pub fn payload(&self) -> &Bytes {
&self.payload
}
#[inline(always)]
pub fn payload_mut(&mut self) -> &mut Bytes {
&mut self.payload
}
#[inline(always)]
pub fn into_payload(self) -> Bytes {
self.payload
}
#[inline(always)]
pub fn into_parts(self) -> (OpCode, bool, Bytes) {
(self.opcode, self.fin, self.payload)
}
#[inline]
pub fn into_parts_str(&self) -> Result<(OpCode, &str), std::str::Utf8Error> {
let text = std::str::from_utf8(&self.payload)?;
Ok((self.opcode, text))
}
#[inline(always)]
pub fn is_fin(&self) -> bool {
self.fin
}
#[inline(always)]
pub fn set_fin(&mut self, fin: bool) {
self.fin = fin;
}
#[inline(always)]
pub fn set_mask(&mut self, mask: Option<[u8; 4]>) {
self.mask = mask;
}
#[inline(always)]
pub fn with_mask(mut self, mask: [u8; 4]) -> Self {
self.mask = Some(mask);
self
}
#[inline(always)]
pub fn set_random_mask(&mut self) {
self.mask = Some(rand::random());
}
#[inline(always)]
pub fn with_random_mask(mut self) -> Self {
self.mask = Some(rand::random());
self
}
#[inline]
pub fn as_str(&self) -> &str {
std::str::from_utf8(&self.payload).expect("frame payload is not valid UTF-8")
}
pub fn close_code(&self) -> Option<CloseCode> {
let code = CloseCode::from(u16::from_be_bytes(self.payload.get(0..2)?.try_into().ok()?));
Some(code)
}
pub fn close_reason(&self) -> Result<Option<&str>, WebSocketError> {
if self.payload.is_empty() {
return Ok(None);
}
let reason = self.payload.get(2..).ok_or(WebSocketError::InvalidUTF8)?;
std::str::from_utf8(reason)
.map(Some)
.map_err(|_| WebSocketError::InvalidUTF8)
}
#[inline(always)]
pub fn is_utf8(&self) -> bool {
std::str::from_utf8(&self.payload).is_ok()
}
#[inline]
pub(super) fn set_random_mask_if_not_set(&mut self) {
if self.mask.is_none() {
let mask: [u8; 4] = rand::random();
self.mask = Some(mask);
}
}
#[inline]
pub(super) fn write_head(&self, dst: &mut bytes::BytesMut) {
use bytes::BufMut;
let compression = u8::from(self.is_compressed);
let first_byte = (self.fin as u8) << 7 | compression << 6 | u8::from(self.opcode);
let len = self.payload.len();
if len < 126 {
dst.put_u8(first_byte);
dst.put_u8(len as u8 | if self.mask.is_some() { 0x80 } else { 0 });
} else if len < 65536 {
dst.put_u8(first_byte);
dst.put_u8(126 | if self.mask.is_some() { 0x80 } else { 0 });
dst.put_u16(len as u16);
} else {
dst.put_u8(first_byte);
dst.put_u8(127 | if self.mask.is_some() { 0x80 } else { 0 });
dst.put_u64(len as u64);
}
if let Some(mask) = self.mask {
dst.put_slice(&mask);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::close::CloseCode;
use bytes::{Bytes, BytesMut};
use wasm_bindgen_test::wasm_bindgen_test;
mod opcode_tests {
use super::*;
#[test]
#[wasm_bindgen_test]
fn test_is_control() {
assert!(OpCode::Close.is_control());
assert!(OpCode::Ping.is_control());
assert!(OpCode::Pong.is_control());
assert!(!OpCode::Continuation.is_control());
assert!(!OpCode::Text.is_control());
assert!(!OpCode::Binary.is_control());
}
#[test]
#[wasm_bindgen_test]
fn test_try_from_u8_valid() {
assert_eq!(OpCode::try_from(0x0).unwrap(), OpCode::Continuation);
assert_eq!(OpCode::try_from(0x1).unwrap(), OpCode::Text);
assert_eq!(OpCode::try_from(0x2).unwrap(), OpCode::Binary);
assert_eq!(OpCode::try_from(0x8).unwrap(), OpCode::Close);
assert_eq!(OpCode::try_from(0x9).unwrap(), OpCode::Ping);
assert_eq!(OpCode::try_from(0xA).unwrap(), OpCode::Pong);
}
#[test]
#[wasm_bindgen_test]
fn test_try_from_u8_invalid() {
for &code in &[0x3, 0x4, 0x5, 0x6, 0x7, 0xB, 0xC, 0xD, 0xE, 0xF] {
assert!(OpCode::try_from(code).is_err());
}
}
#[test]
#[wasm_bindgen_test]
fn test_from_opcode_to_u8() {
assert_eq!(u8::from(OpCode::Continuation), 0x0);
assert_eq!(u8::from(OpCode::Text), 0x1);
assert_eq!(u8::from(OpCode::Binary), 0x2);
assert_eq!(u8::from(OpCode::Close), 0x8);
assert_eq!(u8::from(OpCode::Ping), 0x9);
assert_eq!(u8::from(OpCode::Pong), 0xA);
}
}
mod frame_tests {
use super::*;
#[test]
#[wasm_bindgen_test]
fn test_frame_text() {
let text = "Hello, WebSocket!";
let frame = Frame::text(text);
assert_eq!(frame.opcode(), OpCode::Text);
assert_eq!(frame.payload().as_ref(), text.as_bytes());
assert!(frame.is_fin());
}
#[test]
#[wasm_bindgen_test]
fn test_frame_binary() {
let data = vec![0x01, 0x02, 0x03];
let frame = Frame::binary(data.clone());
assert_eq!(frame.opcode(), OpCode::Binary);
assert_eq!(frame.payload().as_ref(), &data[..]);
assert!(frame.is_fin());
}
#[test]
#[wasm_bindgen_test]
fn test_frame_close() {
let reason = "Normal closure";
let frame = Frame::close(CloseCode::Normal, reason);
assert_eq!(frame.opcode(), OpCode::Close);
assert!(frame.is_fin());
let mut expected_payload = Vec::new();
expected_payload.extend_from_slice(&1000u16.to_be_bytes());
expected_payload.extend_from_slice(reason.as_bytes());
assert_eq!(frame.payload().as_ref(), &expected_payload[..]);
assert_eq!(frame.close_code(), Some(CloseCode::Normal));
assert_eq!(frame.close_reason().unwrap(), Some(reason));
}
#[test]
#[wasm_bindgen_test]
fn test_frame_close_raw() {
let payload = vec![0x03, 0xE8]; let frame = Frame::close_raw(payload.clone());
assert_eq!(frame.opcode(), OpCode::Close);
assert_eq!(frame.payload().as_ref(), &payload[..]);
assert!(frame
.close_reason()
.is_ok_and(|reason| reason.is_some_and(|reason| reason.is_empty())));
}
#[test]
#[wasm_bindgen_test]
fn test_frame_empty_close() {
let frame = Frame::close_raw(vec![]);
assert_eq!(frame.opcode(), OpCode::Close);
assert!(frame.payload().is_empty());
assert!(frame.close_code().is_none());
assert!(frame.close_reason().is_ok_and(|reason| reason.is_none()));
}
#[test]
#[wasm_bindgen_test]
fn test_frame_ping() {
let payload = b"Ping payload";
let frame = Frame::ping(&payload[..]);
assert_eq!(frame.opcode(), OpCode::Ping);
assert_eq!(frame.payload().as_ref(), &payload[..]);
}
#[test]
#[wasm_bindgen_test]
fn test_frame_pong() {
let payload = b"Pong payload";
let frame = Frame::pong(&payload[..]);
assert_eq!(frame.opcode(), OpCode::Pong);
assert_eq!(frame.payload().as_ref(), &payload[..]);
}
#[test]
#[wasm_bindgen_test]
fn test_frame_continuation() {
let payload = b"continuation data";
let frame = Frame::continuation(&payload[..]);
assert_eq!(frame.opcode(), OpCode::Continuation);
assert_eq!(frame.payload().as_ref(), &payload[..]);
assert!(frame.is_fin());
}
#[test]
#[wasm_bindgen_test]
fn test_frame_with_fin() {
let frame = Frame::text("fragment").with_fin(false);
assert!(!frame.is_fin());
assert_eq!(frame.opcode(), OpCode::Text);
}
#[test]
#[wasm_bindgen_test]
fn test_frame_fragmentation() {
let first = Frame::text("Hello, ").with_fin(false);
let middle = Frame::continuation("World").with_fin(false);
let last = Frame::continuation("!");
assert!(!first.is_fin());
assert_eq!(first.opcode(), OpCode::Text);
assert!(!middle.is_fin());
assert_eq!(middle.opcode(), OpCode::Continuation);
assert!(last.is_fin());
assert_eq!(last.opcode(), OpCode::Continuation);
}
#[test]
#[wasm_bindgen_test]
fn test_frame_from_tuple() {
let frame = Frame::from((OpCode::Text, Bytes::from("Test")));
let (opcode, payload): (OpCode, Bytes) = frame.into();
assert_eq!(opcode, OpCode::Text);
assert_eq!(payload, Bytes::from("Test"));
}
#[test]
#[wasm_bindgen_test]
fn test_frame_from_tuple_bytes() {
let opcode = OpCode::Binary;
let payload = Bytes::from_static(b"\xDE\xAD\xBE\xEF");
let frame = Frame::from((opcode, payload.clone()));
assert_eq!(frame.opcode(), OpCode::Binary);
assert_eq!(frame.payload().as_ref(), payload.as_ref());
}
#[test]
#[wasm_bindgen_test]
fn test_frame_new() {
let payload = BytesMut::from("Test payload");
let frame = Frame::new(true, OpCode::Text, None, payload.clone());
assert!(frame.is_fin());
assert_eq!(frame.opcode(), OpCode::Text);
assert_eq!(frame.payload().as_ref(), payload.as_ref());
}
#[test]
#[wasm_bindgen_test]
fn test_frame_as_str() {
let frame = Frame::text("Hello, World!");
assert_eq!(frame.as_str(), "Hello, World!");
}
#[test]
#[wasm_bindgen_test]
fn test_frame_is_utf8() {
let valid_utf8 = Frame::text("Hello, 世界");
assert!(valid_utf8.is_utf8());
let invalid_utf8 = Frame::binary(vec![0xFF, 0xFE, 0xFD]);
assert!(!invalid_utf8.is_utf8());
}
}
}