use smol::io::{AsyncRead, AsyncReadExt};
use std::convert::TryFrom;
use crate::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpCode {
Continuation,
Text,
Binary,
Close,
Ping,
Pong,
}
impl TryFrom<u8> for OpCode {
type Error = Error;
fn try_from(value: u8) -> Result<Self> {
match value {
0 => Ok(OpCode::Continuation),
1 => Ok(OpCode::Text),
2 => Ok(OpCode::Binary),
8 => Ok(OpCode::Close),
9 => Ok(OpCode::Ping),
10 => Ok(OpCode::Pong),
_ => Err(Error::Protocol(format!("Invalid OpCode: {}", value))),
}
}
}
#[derive(Debug)]
pub struct Frame {
pub fin: bool,
pub rsv1: bool,
pub rsv2: bool,
pub rsv3: bool,
pub opcode: OpCode,
pub mask: Option<[u8; 4]>,
pub payload: Vec<u8>,
}
impl Frame {
pub fn new(opcode: OpCode, payload: Vec<u8>) -> Self {
Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode,
mask: None,
payload,
}
}
pub fn close(status_code: Option<u16>) -> Self {
let payload = status_code.map(|code| code.to_be_bytes().to_vec()).unwrap_or_default();
Frame::new(OpCode::Close, payload)
}
pub fn is_close(&self) -> bool {
self.opcode == OpCode::Close
}
pub fn is_masked(&self) -> bool {
self.mask.is_some()
}
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Self> {
let mut buf = [0u8; 2];
reader.read_exact(&mut buf).await?;
let first_byte = buf[0];
let second_byte = buf[1];
let fin = (first_byte & 0x80) != 0;
let rsv1 = (first_byte & 0x40) != 0;
let rsv2 = (first_byte & 0x20) != 0;
let rsv3 = (first_byte & 0x10) != 0;
let opcode = OpCode::try_from(first_byte & 0x0F)?;
let masked = (second_byte & 0x80) != 0;
let mut payload_len = (second_byte & 0x7F) as u64;
if payload_len == 126 {
let mut buf = [0u8; 2];
reader.read_exact(&mut buf).await?;
payload_len = u16::from_be_bytes(buf) as u64;
} else if payload_len == 127 {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf).await?;
payload_len = u64::from_be_bytes(buf);
}
let mask = if masked {
let mut mask_bytes = [0u8; 4];
reader.read_exact(&mut mask_bytes).await?;
Some(mask_bytes)
} else {
None
};
let mut payload = vec![0u8; payload_len as usize];
reader.read_exact(&mut payload).await?;
if let Some(mask) = mask {
for (i, byte) in payload.iter_mut().enumerate() {
*byte ^= mask[i % 4];
}
}
Ok(Frame {
fin,
rsv1,
rsv2,
rsv3,
opcode,
mask,
payload,
})
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
let mut first_byte = 0;
if self.fin {
first_byte |= 0x80;
}
if self.rsv1 {
first_byte |= 0x40;
}
if self.rsv2 {
first_byte |= 0x20;
}
if self.rsv3 {
first_byte |= 0x10;
}
first_byte |= self.opcode as u8;
bytes.push(first_byte);
let mut second_byte = 0;
if self.mask.is_some() {
second_byte |= 0x80;
}
let payload_len = self.payload.len();
if payload_len < 126 {
second_byte |= payload_len as u8;
bytes.push(second_byte);
} else if payload_len < 65536 {
second_byte |= 126;
bytes.push(second_byte);
bytes.extend_from_slice(&(payload_len as u16).to_be_bytes());
} else {
second_byte |= 127;
bytes.push(second_byte);
bytes.extend_from_slice(&(payload_len as u64).to_be_bytes());
}
if let Some(mask) = self.mask {
bytes.extend_from_slice(&mask);
}
bytes.extend_from_slice(&self.payload);
bytes
}
}