use nanorand::{Rng, WyRand};
use super::proto::{CloseCode, CloseReason, OpCode};
use super::{error::ProtocolError, mask::apply_mask};
use crate::util::{BufMut, BytePage, BytePages, Bytes, BytesMut};
#[derive(Debug)]
pub struct Parser;
impl Parser {
fn parse_metadata(
src: &[u8],
server: bool,
max_size: usize,
) -> Result<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError> {
let chunk_len = src.len();
let mut idx = 2;
if chunk_len < 2 {
return Ok(None);
}
let first = src[0];
let second = src[1];
let finished = first & 0x80 != 0;
let masked = second & 0x80 != 0;
if !masked && server {
return Err(ProtocolError::UnmaskedFrame);
} else if masked && !server {
return Err(ProtocolError::MaskedFrame);
}
let opcode = OpCode::from(first & 0x0F);
if let OpCode::Bad = opcode {
return Err(ProtocolError::InvalidOpcode(first & 0x0F));
}
let len = second & 0x7F;
let length = if len == 126 {
if chunk_len < 4 {
return Ok(None);
}
let len = usize::from(u16::from_be_bytes(
TryFrom::try_from(&src[idx..idx + 2]).unwrap(),
));
idx += 2;
len
} else if len == 127 {
if chunk_len < 10 {
return Ok(None);
}
let len = u64::from_be_bytes(TryFrom::try_from(&src[idx..idx + 8]).unwrap());
if len > max_size as u64 {
return Err(ProtocolError::Overflow);
}
idx += 8;
len as usize
} else {
len as usize
};
if length > max_size {
return Err(ProtocolError::Overflow);
}
let mask = if server {
if chunk_len < idx + 4 {
return Ok(None);
}
let mask = u32::from_le_bytes(TryFrom::try_from(&src[idx..idx + 4]).unwrap());
idx += 4;
Some(mask)
} else {
None
};
Ok(Some((idx, finished, opcode, length, mask)))
}
pub fn parse(
src: &mut BytesMut,
server: bool,
max_size: usize,
) -> Result<Option<(bool, OpCode, Option<Bytes>)>, ProtocolError> {
let Some((idx, finished, opcode, length, mask)) =
Parser::parse_metadata(src, server, max_size)?
else {
return Ok(None);
};
if src.len() < idx + length {
return Ok(None);
}
src.advance_to(idx);
if length == 0 {
return Ok(Some((finished, opcode, None)));
}
match opcode {
OpCode::Ping | OpCode::Pong if length > 125 => {
return Err(ProtocolError::InvalidLength(length));
}
OpCode::Close if length > 125 => {
log::trace!(
"Received close frame with payload length exceeding 125. Morphing to protocol close frame."
);
return Ok(Some((true, OpCode::Close, None)));
}
_ => (),
}
if let Some(mask) = mask {
apply_mask(&mut src[..length], mask);
}
Ok(Some((finished, opcode, Some(src.split_to(length)))))
}
pub fn parse_close_payload(payload: &[u8]) -> Option<CloseReason> {
if payload.len() >= 2 {
let raw_code = u16::from_be_bytes(TryFrom::try_from(&payload[..2]).unwrap());
let code = CloseCode::from(raw_code);
let description = if payload.len() > 2 {
Some(String::from_utf8_lossy(&payload[2..]).into())
} else {
None
};
Some(CloseReason { code, description })
} else {
None
}
}
pub fn write_message<B>(dst: &mut BytePages, pl: B, op: OpCode, fin: bool, mask: bool)
where
BytePage: From<B>,
{
let payload = BytePage::from(pl);
let one: u8 = if fin {
0x80 | Into::<u8>::into(op)
} else {
op.into()
};
let payload_len = payload.len();
let two = if mask { 0x80 } else { 0 };
if payload_len < 126 {
dst.extend_from_slice(&[one, two | payload_len as u8]);
} else if payload_len <= 65_535 {
dst.extend_from_slice(&[one, two | 0x007e]);
dst.put_u16(payload_len as u16);
} else {
dst.extend_from_slice(&[one, two | 127]);
dst.put_u64(payload_len as u64);
}
if mask {
let mask: u32 = WyRand::new().generate();
let mut buf = BytesMut::from(payload);
apply_mask(&mut buf, mask);
dst.put_u32_le(mask);
dst.append::<BytesMut>(buf);
} else {
dst.append::<BytePage>(payload);
}
}
#[inline]
pub fn write_close(dst: &mut BytePages, reason: Option<CloseReason>, mask: bool) {
let payload = match reason {
None => Bytes::new(),
Some(reason) => {
let mut payload = BytesMut::with_capacity(
reason.description.as_ref().map_or(0, String::len) + 2,
);
payload.put_u16(u16::from(reason.code));
if let Some(description) = reason.description {
payload.extend_from_slice(description.as_bytes());
}
payload.freeze()
}
};
Parser::write_message(dst, payload, OpCode::Close, true, mask);
}
}
#[cfg(test)]
mod tests {
use super::*;
struct F {
finished: bool,
opcode: OpCode,
payload: Bytes,
}
fn is_none(frm: &Result<Option<(bool, OpCode, Option<Bytes>)>, ProtocolError>) -> bool {
matches!(*frm, Ok(None))
}
fn extract(frm: Result<Option<(bool, OpCode, Option<Bytes>)>, ProtocolError>) -> F {
match frm {
Ok(Some((finished, opcode, payload))) => F {
finished,
opcode,
payload: payload.unwrap_or_else(Bytes::new),
},
_ => unreachable!("error"),
}
}
#[test]
fn test_parse() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
buf.extend(b"1");
let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload.as_ref(), &b"1"[..]);
}
#[test]
fn test_parse_length0() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]);
let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text);
assert!(frame.payload.is_empty());
}
#[test]
fn test_parse_length2() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
buf.extend(&[0u8, 4u8][..]);
buf.extend(b"1234");
let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
}
#[test]
fn test_parse_length4() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]);
buf.extend(b"1234");
let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
}
#[test]
fn test_parse_frame_mask() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]);
buf.extend(b"0001");
buf.extend(b"1");
assert!(Parser::parse(&mut buf, false, 1024).is_err());
let frame = extract(Parser::parse(&mut buf, true, 1024));
assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, Bytes::from(vec![1u8]));
}
#[test]
fn test_parse_frame_no_mask() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
buf.extend([1u8]);
assert!(Parser::parse(&mut buf, true, 1024).is_err());
let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, Bytes::from(vec![1u8]));
}
#[test]
fn test_parse_frame_max_size() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]);
buf.extend([1u8, 1u8]);
assert!(Parser::parse(&mut buf, true, 1).is_err());
if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) {
} else {
unreachable!("error");
}
}
#[test]
fn test_ping_frame() {
let mut buf = BytePages::default();
Parser::write_message(&mut buf, Bytes::from("data"), OpCode::Ping, true, false);
let mut v = vec![137u8, 4u8];
v.extend(b"data");
assert_eq!(&Bytes::from(buf)[..], &v[..]);
}
#[test]
fn test_pong_frame() {
let mut buf = BytePages::default();
Parser::write_message(&mut buf, Bytes::from("data"), OpCode::Pong, true, false);
let mut v = vec![138u8, 4u8];
v.extend(b"data");
assert_eq!(&Bytes::from(buf)[..], &v[..]);
}
#[test]
fn test_close_frame() {
let mut buf = BytePages::default();
let reason = (CloseCode::Normal, "data");
Parser::write_close(&mut buf, Some(reason.into()), false);
let mut v = vec![136u8, 6u8, 3u8, 232u8];
v.extend(b"data");
assert_eq!(&Bytes::from(buf)[..], &v[..]);
}
#[test]
fn test_empty_close_frame() {
let mut buf = BytePages::default();
Parser::write_close(&mut buf, None, false);
assert_eq!(&Bytes::from(buf)[..], &[0x88, 0x00]);
}
}