use bytes::{BufMut, Bytes, BytesMut};
use oxihttp_core::OxiHttpError;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const MAX_PAYLOAD_LEN: u64 = 64 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Opcode {
Continuation = 0x0,
Text = 0x1,
Binary = 0x2,
Close = 0x8,
Ping = 0x9,
Pong = 0xA,
}
impl Opcode {
pub fn from_u8(v: u8) -> Option<Self> {
match v {
0x0 => Some(Self::Continuation),
0x1 => Some(Self::Text),
0x2 => Some(Self::Binary),
0x8 => Some(Self::Close),
0x9 => Some(Self::Ping),
0xA => Some(Self::Pong),
_ => None,
}
}
pub fn is_control(self) -> bool {
matches!(self, Opcode::Close | Opcode::Ping | Opcode::Pong)
}
}
#[derive(Debug, Clone)]
pub struct Frame {
pub fin: bool,
pub opcode: Opcode,
pub payload: Bytes,
}
pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Frame, OxiHttpError> {
let mut header = [0u8; 2];
reader
.read_exact(&mut header)
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: read header: {e}")))?;
let fin = (header[0] & 0x80) != 0;
let rsv = header[0] & 0x70;
let opcode_byte = header[0] & 0x0F;
let masked = (header[1] & 0x80) != 0;
let len_byte = (header[1] & 0x7F) as usize;
if rsv != 0 {
return Err(OxiHttpError::Body(
"WebSocket: reserved bits set without extension".into(),
));
}
let opcode = Opcode::from_u8(opcode_byte)
.ok_or_else(|| OxiHttpError::Body(format!("WebSocket: unknown opcode {opcode_byte:#x}")))?;
if opcode.is_control() && (!fin || len_byte > 125) {
return Err(OxiHttpError::Body(
"WebSocket: illegal control frame (fragmented or oversized)".into(),
));
}
let payload_len: u64 = match len_byte {
0..=125 => len_byte as u64,
126 => {
let mut b = [0u8; 2];
reader
.read_exact(&mut b)
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: read ext len16: {e}")))?;
u16::from_be_bytes(b) as u64
}
127 => {
let mut b = [0u8; 8];
reader
.read_exact(&mut b)
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: read ext len64: {e}")))?;
u64::from_be_bytes(b)
}
_ => unreachable!("len_byte masked to 7 bits"),
};
if payload_len > MAX_PAYLOAD_LEN {
return Err(OxiHttpError::Body(format!(
"WebSocket: payload too large ({payload_len} bytes, max {MAX_PAYLOAD_LEN})"
)));
}
let mask = if masked {
let mut key = [0u8; 4];
reader
.read_exact(&mut key)
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: read mask key: {e}")))?;
Some(key)
} else {
None
};
let mut payload = vec![0u8; payload_len as usize];
reader
.read_exact(&mut payload)
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: read payload: {e}")))?;
if let Some(key) = mask {
for (i, byte) in payload.iter_mut().enumerate() {
*byte ^= key[i % 4];
}
}
Ok(Frame {
fin,
opcode,
payload: Bytes::from(payload),
})
}
pub async fn write_frame<W: AsyncWrite + Unpin>(
writer: &mut W,
opcode: Opcode,
payload: &[u8],
fin: bool,
) -> Result<(), OxiHttpError> {
let mut header = BytesMut::with_capacity(10);
let first_byte = if fin {
0x80 | (opcode as u8)
} else {
opcode as u8
};
header.put_u8(first_byte);
let len = payload.len();
if len <= 125 {
header.put_u8(len as u8);
} else if len <= 0xFFFF {
header.put_u8(126);
header.put_u16(len as u16);
} else {
header.put_u8(127);
header.put_u64(len as u64);
}
writer
.write_all(&header)
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: write header: {e}")))?;
writer
.write_all(payload)
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: write payload: {e}")))?;
writer
.flush()
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: flush: {e}")))?;
Ok(())
}
pub async fn write_frame_masked<W: AsyncWrite + Unpin>(
writer: &mut W,
opcode: Opcode,
payload: &[u8],
fin: bool,
mask_key: [u8; 4],
) -> Result<(), OxiHttpError> {
let mut header = BytesMut::with_capacity(14);
let first_byte = if fin {
0x80 | (opcode as u8)
} else {
opcode as u8
};
header.put_u8(first_byte);
let len = payload.len();
if len <= 125 {
header.put_u8(0x80 | len as u8);
} else if len <= 0xFFFF {
header.put_u8(0x80 | 126);
header.put_u16(len as u16);
} else {
header.put_u8(0x80 | 127);
header.put_u64(len as u64);
}
header.put_slice(&mask_key);
writer
.write_all(&header)
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: write masked header: {e}")))?;
let masked_payload: Vec<u8> = payload
.iter()
.enumerate()
.map(|(i, &b)| b ^ mask_key[i % 4])
.collect();
writer
.write_all(&masked_payload)
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: write masked payload: {e}")))?;
writer
.flush()
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket: flush masked: {e}")))?;
Ok(())
}