use std::convert::TryInto;
use rand::RngCore;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::split::{WebSocketReadHalf, WebSocketWriteHalf};
use super::FrameType;
#[allow(unused_imports)] use super::WebSocket;
use crate::error::WebSocketError;
const U16_MAX_MINUS_ONE: usize = (u16::MAX - 1) as usize;
const U16_MAX: usize = u16::MAX as usize;
const U64_MAX_MINUS_ONE: usize = (u64::MAX - 1) as usize;
#[derive(Debug, Clone)]
pub enum Frame {
Text {
payload: String,
continuation: bool,
fin: bool,
},
Binary {
payload: Vec<u8>,
continuation: bool,
fin: bool,
},
Close {
payload: Option<(u16, String)>,
},
Ping {
payload: Option<Vec<u8>>,
},
Pong {
payload: Option<Vec<u8>>,
},
}
impl Frame {
pub fn text(payload: String) -> Self {
Self::Text {
payload,
continuation: false,
fin: true,
}
}
pub fn is_text(&self) -> bool {
self.as_text().is_some()
}
pub fn as_text(&self) -> Option<(&String, &bool, &bool)> {
match self {
Self::Text {
payload,
continuation,
fin,
} => Some((payload, continuation, fin)),
_ => None,
}
}
pub fn as_text_mut(&mut self) -> Option<(&mut String, &mut bool, &mut bool)> {
match self {
Self::Text {
payload,
continuation,
fin,
} => Some((payload, continuation, fin)),
_ => None,
}
}
pub fn into_text(self) -> Option<(String, bool, bool)> {
match self {
Self::Text {
payload,
continuation,
fin,
} => Some((payload, continuation, fin)),
_ => None,
}
}
pub fn binary(payload: Vec<u8>) -> Self {
Self::Binary {
payload,
continuation: false,
fin: true,
}
}
pub fn is_binary(&self) -> bool {
self.as_binary().is_some()
}
pub fn as_binary(&self) -> Option<(&Vec<u8>, &bool, &bool)> {
match self {
Self::Binary {
payload,
continuation,
fin,
} => Some((payload, continuation, fin)),
_ => None,
}
}
pub fn as_binary_mut(&mut self) -> Option<(&mut Vec<u8>, &mut bool, &mut bool)> {
match self {
Self::Binary {
payload,
continuation,
fin,
} => Some((payload, continuation, fin)),
_ => None,
}
}
pub fn into_binary(self) -> Option<(Vec<u8>, bool, bool)> {
match self {
Self::Binary {
payload,
continuation,
fin,
} => Some((payload, continuation, fin)),
_ => None,
}
}
pub fn close(payload: Option<(u16, String)>) -> Self {
Self::Close { payload }
}
pub fn is_close(&self) -> bool {
self.as_close().is_some()
}
pub fn as_close(&self) -> Option<&(u16, String)> {
match self {
Self::Close { payload } => payload.as_ref(),
_ => None,
}
}
pub fn as_close_mut(&mut self) -> Option<&mut (u16, String)> {
match self {
Self::Close { payload } => payload.as_mut(),
_ => None,
}
}
pub fn into_close(self) -> Option<(u16, String)> {
match self {
Self::Close { payload } => payload,
_ => None,
}
}
pub fn ping(payload: Option<Vec<u8>>) -> Self {
Self::Ping { payload }
}
pub fn is_ping(&self) -> bool {
self.as_ping().is_some()
}
pub fn as_ping(&self) -> Option<&Vec<u8>> {
match self {
Self::Ping { payload } => payload.as_ref(),
_ => None,
}
}
pub fn as_ping_mut(&mut self) -> Option<&mut Vec<u8>> {
match self {
Self::Ping { payload } => payload.as_mut(),
_ => None,
}
}
pub fn into_ping(self) -> Option<Vec<u8>> {
match self {
Self::Ping { payload } => payload,
_ => None,
}
}
pub fn pong(payload: Option<Vec<u8>>) -> Self {
Self::Pong { payload }
}
pub fn is_pong(&self) -> bool {
self.as_pong().is_some()
}
pub fn as_pong(&self) -> Option<&Vec<u8>> {
match self {
Self::Pong { payload } => payload.as_ref(),
_ => None,
}
}
pub fn as_pong_mut(&mut self) -> Option<&mut Vec<u8>> {
match self {
Self::Pong { payload } => payload.as_mut(),
_ => None,
}
}
pub fn into_pong(self) -> Option<Vec<u8>> {
match self {
Self::Pong { payload } => payload,
_ => None,
}
}
pub fn set_continuation(self, continuation: bool) -> Self {
match self {
Self::Text { payload, fin, .. } => Self::Text {
payload,
continuation,
fin,
},
Self::Binary { payload, fin, .. } => Self::Binary {
payload,
continuation,
fin,
},
_ => self,
}
}
pub fn set_fin(self, fin: bool) -> Self {
match self {
Self::Text {
payload,
continuation,
..
} => Self::Text {
payload,
continuation,
fin,
},
Self::Binary {
payload,
continuation,
..
} => Self::Binary {
payload,
continuation,
fin,
},
_ => self,
}
}
pub(super) async fn send(
self,
write_half: &mut WebSocketWriteHalf,
) -> Result<(), WebSocketError> {
let is_control = self.is_control();
let opcode = self.opcode();
let fin = self.fin();
let mut payload = match self {
Self::Text { payload, .. } => payload.into_bytes(),
Self::Binary { payload, .. } => payload,
Self::Close {
payload: Some((status_code, reason)),
} => {
let mut payload = status_code.to_be_bytes().to_vec();
payload.append(&mut reason.into_bytes());
payload
}
Self::Close { payload: None } => Vec::new(),
Self::Ping { payload } => payload.unwrap_or(Vec::new()),
Self::Pong { payload } => payload.unwrap_or(Vec::new()),
};
if is_control && payload.len() > 125 {
return Err(WebSocketError::ControlFrameTooLargeError);
}
let mut raw_frame = Vec::with_capacity(payload.len() + 14);
raw_frame.push(opcode + fin);
let mut payload_len_data = match payload.len() {
0..=125 => (payload.len() as u8).to_be_bytes().to_vec(),
126..=U16_MAX_MINUS_ONE => {
let mut payload_len_data = vec![126];
payload_len_data.extend_from_slice(&(payload.len() as u16).to_be_bytes());
payload_len_data
}
U16_MAX..=U64_MAX_MINUS_ONE => {
let mut payload_len_data = vec![127];
payload_len_data.extend_from_slice(&(payload.len() as u64).to_be_bytes());
payload_len_data
}
_ => return Err(WebSocketError::PayloadTooLargeError),
};
payload_len_data[0] += 0b10000000; raw_frame.append(&mut payload_len_data);
let mut masking_key = vec![0; 4];
write_half.rng.fill_bytes(&mut masking_key);
for (i, byte) in payload.iter_mut().enumerate() {
*byte = *byte ^ (masking_key[i % 4]);
}
raw_frame.append(&mut masking_key);
raw_frame.append(&mut payload);
write_half
.stream
.write_all(&raw_frame)
.await
.map_err(|e| WebSocketError::WriteError(e))?;
write_half
.stream
.flush()
.await
.map_err(|e| WebSocketError::WriteError(e))?;
Ok(())
}
fn is_control(&self) -> bool {
match self {
Self::Text { .. } => false,
Self::Binary { .. } => false,
Self::Close { .. } => true,
Self::Ping { .. } => true,
Self::Pong { .. } => true,
}
}
fn opcode(&self) -> u8 {
match self {
Self::Text { continuation, .. } => {
if *continuation {
0x0
} else {
0x1
}
}
Self::Binary { continuation, .. } => {
if *continuation {
0x0
} else {
0x2
}
}
Self::Close { .. } => 0x8,
Self::Ping { .. } => 0x9,
Self::Pong { .. } => 0xA,
}
}
fn fin(&self) -> u8 {
match self {
Self::Text { fin, .. } => (*fin as u8) << 7,
Self::Binary { fin, .. } => (*fin as u8) << 7,
Self::Close { .. } => 0b10000000,
Self::Ping { .. } => 0b10000000,
Self::Pong { .. } => 0b10000000,
}
}
pub(super) async fn read_from_websocket(
read_half: &mut WebSocketReadHalf,
) -> Result<Self, WebSocketError> {
let fin_and_opcode = read_half
.stream
.read_u8()
.await
.map_err(|e| WebSocketError::ReadError(e))?;
let fin: bool = fin_and_opcode & 0b10000000_u8 != 0;
let opcode = fin_and_opcode & 0b00001111_u8;
let mask_and_payload_len_first_byte = read_half
.stream
.read_u8()
.await
.map_err(|e| WebSocketError::ReadError(e))?;
let masked = mask_and_payload_len_first_byte & 0b10000000_u8 != 0;
if masked {
return Err(WebSocketError::ReceivedMaskedFrameError);
}
let payload_len_first_byte = mask_and_payload_len_first_byte & 0b01111111_u8;
let payload_len = match payload_len_first_byte {
0..=125 => payload_len_first_byte as usize,
126 => read_half
.stream
.read_u16()
.await
.map_err(|e| WebSocketError::ReadError(e))? as usize,
127 => read_half
.stream
.read_u64()
.await
.map_err(|e| WebSocketError::ReadError(e))? as usize,
_ => unreachable!(),
};
let mut payload = vec![0; payload_len];
read_half
.stream
.read_exact(&mut payload)
.await
.map_err(|e| WebSocketError::ReadError(e))?;
match opcode {
0x0 => match read_half.last_frame_type {
FrameType::Text => Ok(Self::Text {
payload: String::from_utf8(payload)
.map_err(|_e| WebSocketError::InvalidFrameError)?,
continuation: true,
fin,
}),
FrameType::Binary => Ok(Self::Binary {
payload,
continuation: true,
fin,
}),
FrameType::Control => Err(WebSocketError::InvalidFrameError),
},
0x1 => Ok(Self::Text {
payload: String::from_utf8(payload)
.map_err(|_e| WebSocketError::InvalidFrameError)?,
continuation: false,
fin,
}),
0x2 => Ok(Self::Binary {
payload,
continuation: false,
fin,
}),
0x3..=0x7 => Err(WebSocketError::InvalidFrameError),
0x8 if payload_len == 0 => Ok(Self::Close { payload: None }),
0x8 if payload_len < 2 => Err(WebSocketError::InvalidFrameError),
0x8 => {
let (status_code, reason) = payload.split_at(2);
let status_code = u16::from_be_bytes(
status_code
.try_into()
.map_err(|_e| WebSocketError::InvalidFrameError)?,
);
Ok(Self::Close {
payload: Some((
status_code,
String::from_utf8(reason.to_vec())
.map_err(|_e| WebSocketError::InvalidFrameError)?,
)),
})
}
0x9 if payload_len == 0 => Ok(Self::Ping { payload: None }),
0x9 => Ok(Self::Ping {
payload: Some(payload),
}),
0xA if payload_len == 0 => Ok(Self::Pong { payload: None }),
0xA => Ok(Self::Pong {
payload: Some(payload),
}),
0xB..=0xFF => Err(WebSocketError::InvalidFrameError),
}
}
}
impl From<String> for Frame {
fn from(s: String) -> Self {
Self::text(s)
}
}
impl From<Vec<u8>> for Frame {
fn from(v: Vec<u8>) -> Self {
Self::binary(v)
}
}