use crate::codec::apply_mask;
use bytes::{BufMut, BytesMut};
use std::fmt::Debug;
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
#[repr(u8)]
pub enum OpCode {
Continue = 0,
Text = 1,
Binary = 2,
RNC3 = 3,
RNC4 = 4,
RNC5 = 5,
RNC6 = 6,
RNC7 = 7,
Close = 8,
Ping = 9,
Pong = 10,
RC11 = 11,
RC12 = 12,
RC13 = 13,
RC14 = 14,
RC15 = 15,
}
impl Default for OpCode {
fn default() -> Self {
Self::Text
}
}
impl OpCode {
pub fn as_u8(&self) -> u8 {
*self as u8
}
pub fn is_close(&self) -> bool {
matches!(self, Self::Close)
}
pub fn is_data(&self) -> bool {
matches!(self, Self::Text | Self::Binary | Self::Continue)
}
pub fn is_reserved(&self) -> bool {
matches!(self.as_u8(), 3..=5 | 11..=15)
}
}
#[inline]
pub(crate) fn parse_opcode(val: u8) -> OpCode {
unsafe { std::mem::transmute(val & 0b00001111) }
}
#[inline]
pub(crate) fn get_bit(source: &[u8], byte_idx: usize, bit_idx: u8) -> bool {
let mask = match bit_idx {
0 => 128,
1 => 64,
2 => 32,
3 => 16,
4 => 8,
5 => 4,
6 => 2,
7 => 1,
_ => unreachable!(),
};
unsafe { *source.get_unchecked(byte_idx) & mask == mask }
}
#[inline]
pub(crate) fn set_bit(source: &mut [u8], byte_idx: usize, bit_idx: u8, val: bool) {
if val {
let mask = match bit_idx {
0 => 128,
1 => 64,
2 => 32,
3 => 16,
4 => 8,
5 => 4,
6 => 2,
7 => 1,
_ => unreachable!(),
};
source[byte_idx] |= mask;
} else {
let mask = match bit_idx {
0 => 0b01111111,
1 => 0b10111111,
2 => 0b11011111,
3 => 0b11101111,
4 => 0b11110111,
5 => 0b11111011,
6 => 0b11111101,
7 => 0b11111110,
_ => unreachable!(),
};
source[byte_idx] &= mask;
}
}
macro_rules! impl_get {
() => {
#[inline]
fn get_bit(&self, byte_idx: usize, bit_idx: u8) -> bool {
get_bit(&self.0, byte_idx, bit_idx)
}
#[inline]
pub fn fin(&self) -> bool {
self.get_bit(0, 0)
}
#[inline]
pub fn rsv1(&self) -> bool {
self.get_bit(0, 1)
}
#[inline]
pub fn rsv2(&self) -> bool {
self.get_bit(0, 2)
}
#[inline]
pub fn rsv3(&self) -> bool {
self.get_bit(0, 3)
}
#[inline]
pub fn opcode(&self) -> OpCode {
parse_opcode(unsafe { *self.0.get_unchecked(0) })
}
#[inline]
pub fn masked(&self) -> bool {
self.get_bit(1, 0)
}
#[inline]
fn len_bytes(&self) -> usize {
let header = &self.0;
match header[1] {
0..=125 | 128..=253 => 1,
126 | 254 => 3,
127 | 255 => 9,
}
}
#[inline]
pub fn payload_len(&self) -> u64 {
let header = &self.0;
assert!(header.len() >= 1);
match header[1] {
len @ (0..=125 | 128..=253) => (len & 127) as u64,
126 | 254 => {
assert!(header.len() >= 4);
u16::from_be_bytes((&header[2..4]).try_into().unwrap()) as u64
}
127 | 255 => {
assert!(header.len() >= 10);
u64::from_be_bytes((&header[2..(8 + 2)]).try_into().unwrap())
}
}
}
#[inline]
pub fn masking_key(&self) -> Option<[u8; 4]> {
if self.masked() {
let len_occupied = self.len_bytes();
let mut arr = [0u8; 4];
arr.copy_from_slice(&self.0[(1 + len_occupied)..(5 + len_occupied)]);
Some(arr)
} else {
None
}
}
};
}
pub fn header_len(mask: bool, payload_len: u64) -> usize {
let mut header_len = 1;
if mask {
header_len += 4;
}
if payload_len <= 125 {
header_len += 1;
} else if payload_len <= 65535 {
header_len += 3;
} else {
header_len += 9;
}
header_len
}
#[inline]
const fn first_byte(fin: bool, rsv1: bool, rsv2: bool, rsv3: bool, opcode: OpCode) -> u8 {
let leading = match (fin, rsv1, rsv2, rsv3) {
(true, true, true, true) => 0b1111_0000,
(true, true, true, false) => 0b1110_0000,
(true, true, false, true) => 0b1101_0000,
(true, true, false, false) => 0b1100_0000,
(true, false, true, true) => 0b1011_0000,
(true, false, true, false) => 0b1010_0000,
(true, false, false, true) => 0b1001_0000,
(true, false, false, false) => 0b1000_0000,
(false, true, true, true) => 0b0111_0000,
(false, true, true, false) => 0b0110_0000,
(false, true, false, true) => 0b0101_0000,
(false, true, false, false) => 0b0100_0000,
(false, false, true, true) => 0b0011_0000,
(false, false, true, false) => 0b0010_0000,
(false, false, false, true) => 0b0001_0000,
(false, false, false, false) => 0b0000_0000,
};
leading | opcode as u8
}
#[allow(clippy::too_many_arguments)]
pub fn ctor_header<M: Into<Option<[u8; 4]>>>(
buf: &mut [u8],
fin: bool,
rsv1: bool,
rsv2: bool,
rsv3: bool,
mask_key: M,
opcode: OpCode,
payload_len: u64,
) -> &[u8] {
let mask = mask_key.into();
let mut header_len = 1;
if mask.is_some() {
header_len += 4;
}
if payload_len <= 125 {
buf[1] = payload_len as u8;
header_len += 1;
} else if payload_len <= 65535 {
buf[1] = 126;
buf[2..4].copy_from_slice(&(payload_len as u16).to_be_bytes());
header_len += 3;
} else {
buf[1] = 127;
buf[2..10].copy_from_slice(&payload_len.to_be_bytes());
header_len += 9;
}
buf[0] = first_byte(fin, rsv1, rsv2, rsv3, opcode);
if let Some(key) = mask {
set_bit(buf, 1, 0, true);
buf[(header_len - 4)..header_len].copy_from_slice(&key);
} else {
set_bit(buf, 1, 0, false);
}
&buf[..header_len]
}
#[test]
fn test_header() {
fn rand_mask() -> Option<[u8; 4]> {
fastrand::bool().then(|| fastrand::u32(0..u32::MAX).to_be_bytes())
}
fn rand_code() -> OpCode {
unsafe { std::mem::transmute(fastrand::u8(0..16)) }
}
let mut buf = [0u8; 14];
for _ in 0..1000 {
let fin = fastrand::bool();
let rsv1 = fastrand::bool();
let rsv2 = fastrand::bool();
let rsv3 = fastrand::bool();
let mask_key = rand_mask();
let opcode = rand_code();
let payload_len = fastrand::u64(0..u64::MAX);
let slice = ctor_header(
&mut buf,
fin,
rsv1,
rsv2,
rsv3,
mask_key,
opcode,
payload_len,
);
let header = Header::new(fin, rsv1, rsv2, rsv3, mask_key, opcode, payload_len);
assert_eq!(slice, &header.0.to_vec());
}
}
#[derive(Debug, Clone, Copy)]
pub struct SimplifiedHeader {
pub fin: bool,
pub rsv1: bool,
pub rsv2: bool,
pub rsv3: bool,
pub code: OpCode,
}
impl<'a> From<HeaderView<'a>> for SimplifiedHeader {
fn from(value: HeaderView<'a>) -> Self {
Self {
fin: value.fin(),
rsv1: value.rsv1(),
rsv2: value.rsv2(),
rsv3: value.rsv3(),
code: value.opcode(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct HeaderView<'a>(pub(crate) &'a [u8]);
impl<'a> HeaderView<'a> {
impl_get! {}
}
#[derive(Debug, Clone)]
pub struct Header(pub(crate) BytesMut);
impl Header {
impl_get! {}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
#[inline]
fn set_bit(&mut self, byte_idx: usize, bit_idx: u8, val: bool) {
set_bit(&mut self.0, byte_idx, bit_idx, val)
}
#[inline]
pub fn set_fin(&mut self, val: bool) {
self.set_bit(0, 0, val)
}
#[inline]
pub fn set_rsv1(&mut self, val: bool) {
self.set_bit(0, 1, val)
}
#[inline]
pub fn set_rsv2(&mut self, val: bool) {
self.set_bit(0, 2, val)
}
#[inline]
pub fn set_rsv3(&mut self, val: bool) {
self.set_bit(0, 3, val)
}
#[inline]
pub fn set_opcode(&mut self, code: OpCode) {
let header = &mut self.0;
let leading_bits = (header[0] >> 4) << 4;
header[0] = leading_bits | code.as_u8()
}
#[inline]
pub fn set_mask(&mut self, mask: bool) {
self.set_bit(1, 0, mask);
}
#[inline]
pub fn set_payload_len(&mut self, len: u64) {
let mask = self.masking_key();
let mask_len = mask.as_ref().map(|_| 4).unwrap_or_default();
let header = &mut self.0;
let mut leading_byte = header[1];
match len {
0..=125 => {
leading_byte &= 128;
header[1] = leading_byte | (len as u8);
let idx = 1 + 1;
header.resize(idx + mask_len, 0);
if let Some(mask) = mask {
header[idx..].copy_from_slice(&mask);
}
}
126..=65535 => {
leading_byte &= 128;
header[1] = leading_byte | 126;
let len_arr = (len as u16).to_be_bytes();
let idx = 1 + 3;
header.resize(idx + mask_len, 0);
header[2] = len_arr[0];
header[3] = len_arr[1];
if let Some(mask) = mask {
header[idx..].copy_from_slice(&mask);
}
}
_ => {
leading_byte &= 128;
header[1] = leading_byte | 127;
let len_arr = len.to_be_bytes();
let idx = 1 + 9;
header.resize(idx + mask_len, 0);
header[2..10].copy_from_slice(&len_arr[..8]);
if let Some(mask) = mask {
header[idx..].copy_from_slice(&mask);
}
}
}
}
pub fn raw(data: BytesMut) -> Self {
Self(data)
}
pub fn new<M: Into<Option<[u8; 4]>>>(
fin: bool,
rsv1: bool,
rsv2: bool,
rsv3: bool,
mask_key: M,
opcode: OpCode,
payload_len: u64,
) -> Self {
let mask = mask_key.into();
let len = header_len(mask.is_some(), payload_len);
assert!(len >= 2);
let mut buf = BytesMut::zeroed(len);
buf[0] = first_byte(fin, rsv1, rsv2, rsv3, opcode);
let mut header = Self(buf);
header.set_mask(mask.is_some());
header.set_payload_len(payload_len);
if let Some(mask) = mask {
header.0[(len - 4)..len].copy_from_slice(&mask);
}
header
}
}
#[derive(Debug, Clone)]
pub struct OwnedFrame {
pub(crate) header: Header,
pub(crate) payload: BytesMut,
}
impl OwnedFrame {
#[inline]
pub fn new(code: OpCode, mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
let header = Header::new(true, false, false, false, mask, code, data.len() as u64);
let mut payload = BytesMut::with_capacity(data.len());
payload.extend_from_slice(data);
if let Some(mask) = header.masking_key() {
apply_mask(&mut payload, mask);
}
Self { header, payload }
}
#[inline]
pub fn with_raw(header: Header, payload: BytesMut) -> Self {
Self { header, payload }
}
#[inline]
pub fn text_frame(mask: impl Into<Option<[u8; 4]>>, data: &str) -> Self {
Self::new(OpCode::Text, mask, data.as_bytes())
}
#[inline]
pub fn binary_frame(mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
Self::new(OpCode::Binary, mask, data)
}
#[inline]
pub fn ping_frame(mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
assert!(data.len() <= 125);
Self::new(OpCode::Ping, mask, data)
}
#[inline]
pub fn pong_frame(mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
assert!(data.len() <= 125);
Self::new(OpCode::Pong, mask, data)
}
#[inline]
pub fn close_frame(
mask: impl Into<Option<[u8; 4]>>,
code: impl Into<Option<u16>>,
data: &[u8],
) -> Self {
assert!(data.len() <= 123);
let code = code.into();
assert!(code.is_some() || data.is_empty());
let mut payload = BytesMut::with_capacity(2 + data.len());
if let Some(code) = code {
payload.put_u16(code);
payload.extend_from_slice(data);
}
Self::new(OpCode::Close, mask, &payload)
}
#[inline]
pub fn unmask(&mut self) -> Option<[u8; 4]> {
if let Some(mask) = self.header.masking_key() {
apply_mask(&mut self.payload, mask);
self.header.set_mask(false);
self.header.0.truncate(self.header.0.len() - 4);
Some(mask)
} else {
None
}
}
pub fn mask(&mut self, mask: [u8; 4]) {
self.unmask();
self.header.set_mask(true);
self.header.0.extend_from_slice(&mask);
apply_mask(&mut self.payload, mask);
}
pub fn extend_from_slice(&mut self, data: &[u8]) {
if let Some(mask) = self.unmask() {
self.payload.extend_from_slice(data);
self.header.set_payload_len(self.payload.len() as u64);
self.mask(mask);
} else {
self.payload.extend_from_slice(data);
self.header.set_payload_len(self.payload.len() as u64);
}
}
#[inline]
pub fn header(&self) -> &Header {
&self.header
}
#[inline]
pub fn header_mut(&mut self) -> &mut Header {
&mut self.header
}
#[inline]
pub fn payload(&self) -> &BytesMut {
&self.payload
}
#[inline]
pub fn parts(self) -> (Header, BytesMut) {
(self.header, self.payload)
}
}