pub struct Frame {
pub fin: u8,
pub op_code: u8,
pub payload: Vec<u8>,
}
pub mod reader {
use std::sync::Arc;
use crate::core::stream::Stream;
use crate::core::websocket::frame::Frame;
use crate::racoon_debug;
pub async fn read_frame(stream: Arc<Stream>, max_payload_size: u64) -> std::io::Result<Frame> {
let mut buffer = vec![];
while buffer.len() < 2 {
let chunk = stream.read_chunk().await?;
buffer.extend(chunk);
}
let first_byte = buffer[0];
let fin = fin_bit_to_u8(&first_byte);
let op_code = opcode_bit_to_u8(&first_byte);
let second_byte = buffer[1];
let mask_bit = bit_mask_to_u8(&second_byte);
let payload_length = payload_length_to_u8(&second_byte);
buffer.drain(0..2);
let actual_payload_length: u64;
if payload_length < 126 {
actual_payload_length = payload_length as u64;
} else if payload_length == 126 {
while buffer.len() < 2 {
let chunk = stream.read_chunk().await?;
buffer.extend(chunk);
}
actual_payload_length = payload_length_to_u16(&buffer[..2])? as u64;
buffer.drain(0..2);
} else {
while buffer.len() < 8 {
let chunk = stream.read_chunk().await?;
buffer.extend(chunk);
}
actual_payload_length = payload_length_to_u64(&buffer[..8])?;
buffer.drain(0..8);
}
let masking_key: Option<Vec<u8>>;
if mask_bit == 1 {
if buffer.len() < 4 {
let chunk = stream.read_chunk().await?;
buffer.extend(chunk);
}
let key = (&buffer[..4]).to_owned();
masking_key = Some(key);
buffer.drain(0..4);
racoon_debug!("Websocket masking key: {:?}.", &masking_key);
} else {
racoon_debug!("Websocket masking disabled.");
masking_key = None;
}
if actual_payload_length > max_payload_size {
return Err(std::io::Error::other(
"Payload length is more than the maximum allowed size.",
));
}
while buffer.len() < actual_payload_length as usize {
let chunk = stream.read_chunk().await?;
buffer.extend(chunk);
}
let extra_read: Vec<u8> = buffer.drain(actual_payload_length as usize..).collect();
if let Some(masking_key) = masking_key {
for i in 0..buffer.len() {
let masking_byte_index = i % 4;
buffer[i] = buffer[i] ^ &masking_key[masking_byte_index];
}
}
if extra_read.len() > 0 {
let _ = stream.restore_payload(&extra_read).await;
}
Ok(Frame {
fin,
op_code,
payload: buffer,
})
}
fn fin_bit_to_u8(byte: &u8) -> u8 {
byte >> 7
}
fn opcode_bit_to_u8(byte: &u8) -> u8 {
byte & 0b00001111
}
fn bit_mask_to_u8(byte: &u8) -> u8 {
byte >> 7
}
fn payload_length_to_u8(byte: &u8) -> u8 {
byte & 0b01111111
}
fn payload_length_to_u16(bytes: &[u8]) -> std::io::Result<u16> {
if bytes.len() != 2 {
return Err(std::io::Error::other(format!(
"Failed to convert payload length to u64. Bytes of size 2 is expected. But found: {}",
bytes.len()
)));
}
let mut tmp_bytes = [0; 2];
tmp_bytes.copy_from_slice(bytes);
Ok(u16::from_be_bytes(tmp_bytes))
}
fn payload_length_to_u64(bytes: &[u8]) -> std::io::Result<u64> {
if bytes.len() != 8 {
return Err(std::io::Error::other(format!(
"Failed to convert payload length to u64. Bytes of size 8 is expected. But found: {}",
bytes.len()
)));
}
let mut tmp_bytes = [0; 8];
tmp_bytes.copy_from_slice(bytes);
Ok(u64::from_be_bytes(tmp_bytes))
}
#[cfg(test)]
pub mod test {
use std::sync::Arc;
use crate::core::stream::{AbstractStream, TestStreamWrapper};
use crate::core::websocket::frame::{builder, Frame};
#[tokio::test]
async fn test_read_single_frame() {
let frame = Frame {
fin: 1,
op_code: 1,
payload: "Hello World".as_bytes().to_vec(),
};
let frame_bytes = builder::build(&frame);
let test_stream_wrapper = TestStreamWrapper::new(frame_bytes, 1024);
let stream: Arc<Box<dyn AbstractStream + 'static>> =
Arc::new(Box::new(test_stream_wrapper));
let result = super::read_frame(stream, 500).await;
assert_eq!(true, result.is_ok());
let decoded_frame = result.unwrap();
assert_eq!(frame.fin, decoded_frame.fin);
assert_eq!(frame.op_code, decoded_frame.op_code);
assert_eq!(frame.payload, decoded_frame.payload);
}
#[tokio::test]
async fn test_read_multiple_frames() {
let frame = Frame {
fin: 1,
op_code: 1,
payload: "Hello World".as_bytes().to_vec(),
};
let text_frame_bytes = builder::build_opt(&frame, true);
let frame2 = Frame {
fin: 1,
op_code: 9,
payload: "PING".as_bytes().to_vec(),
};
let ping_frame_bytes = builder::build_opt(&frame2, true);
let mut multiple_frame_bytes = text_frame_bytes;
multiple_frame_bytes.extend(&ping_frame_bytes);
let test_stream_wrapper = TestStreamWrapper::new(multiple_frame_bytes, 1024);
let stream: Arc<Box<dyn AbstractStream + 'static>> =
Arc::new(Box::new(test_stream_wrapper));
let result1 = super::read_frame(stream.clone(), 500).await;
assert_eq!(true, result1.is_ok());
let decoded_frame = result1.unwrap();
assert_eq!(frame.fin, decoded_frame.fin);
assert_eq!(frame.op_code, decoded_frame.op_code);
assert_eq!(frame.payload, decoded_frame.payload);
let result2 = super::read_frame(stream, 500).await;
assert_eq!(true, result2.is_ok());
let decoded_frame2 = result2.unwrap();
assert_eq!(frame2.fin, decoded_frame2.fin);
assert_eq!(frame2.op_code, decoded_frame2.op_code);
assert_eq!(frame2.payload, decoded_frame2.payload);
}
}
}
pub mod builder {
use crate::core::websocket::frame::Frame;
pub fn build_opt(frame: &Frame, mask: bool) -> Vec<u8> {
let mut buffer: Vec<u8> = vec![];
let fin_byte = frame.fin << 7;
let opcode_byte = frame.op_code;
let first_byte = fin_byte | opcode_byte;
buffer.push(first_byte);
let actual_payload_length = frame.payload.len();
if actual_payload_length < 126 {
let mut second_byte = actual_payload_length as u8;
if mask {
second_byte = second_byte | 0b10000000;
}
buffer.push(second_byte);
} else if actual_payload_length < (2_usize.pow(16)) {
buffer.push(126);
let length_bytes: [u8; 2] = (actual_payload_length as u16).to_be_bytes();
buffer.extend_from_slice(&length_bytes);
} else {
buffer.push(127);
let length_bytes: [u8; 8] = (actual_payload_length as u64).to_be_bytes();
buffer.extend_from_slice(&length_bytes);
}
let mut payload = frame.payload.clone();
if mask {
let mask_bytes:[u8; 4] = rand::random();
buffer.extend_from_slice(&mask_bytes);
for i in 0..frame.payload.len() {
let mask_index = i % 4;
payload[i] = (frame.payload[i] as usize ^ mask_bytes[mask_index] as usize) as u8;
}
}
buffer.extend_from_slice(&payload);
buffer
}
pub fn build(frame: &Frame) -> Vec<u8> {
build_opt(frame, false)
}
#[cfg(test)]
pub mod test {
use std::sync::Arc;
use crate::core::stream::{AbstractStream, TestStreamWrapper};
use crate::core::websocket::frame::reader::read_frame;
use crate::core::websocket::frame::Frame;
use super::build_opt;
#[tokio::test]
async fn test_frame_build_server() {
let frame = Frame {
fin: 0,
op_code: 1,
payload: "Hello World".as_bytes().to_vec(),
};
let frame_bytes = build_opt(&frame, true);
let test_stream_wrapper = TestStreamWrapper::new(frame_bytes, 1024);
let stream: Arc<Box<dyn AbstractStream + 'static>> =
Arc::new(Box::new(test_stream_wrapper));
let reader = read_frame(stream, 1000).await;
assert_eq!(true, reader.is_ok());
let frame = reader.unwrap();
assert_eq!(frame.fin, 0);
assert_eq!(frame.op_code, 1);
assert_eq!(frame.payload, "Hello World".as_bytes().to_vec());
}
}
}