use crate::{
WebSocket, WebSocketCloseStatusCode, WebSocketOptions, WebSocketReceiveMessageType,
WebSocketSendMessageType, WebSocketState, WebSocketType,
};
use core::{cmp::min, str::Utf8Error};
use rand_core::RngCore;
#[cfg(feature = "std")]
use std::io::{Read, Write};
#[cfg(feature = "std")]
use std::io::Error as IoError;
#[cfg(not(feature = "std"))]
#[derive(PartialEq, Debug)]
pub enum IoError {
Read,
Write,
}
#[cfg(not(feature = "std"))]
pub trait Read {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError>;
}
#[cfg(not(feature = "std"))]
pub trait Write {
fn write_all(&mut self, buf: &[u8]) -> Result<(), IoError>;
}
#[derive(Debug)]
pub enum FramerError {
Io(IoError),
FrameTooLarge(usize),
Utf8(Utf8Error),
WebSocket(crate::Error),
}
impl From<IoError> for FramerError {
fn from(err: IoError) -> Self {
FramerError::Io(err)
}
}
impl From<Utf8Error> for FramerError {
fn from(err: Utf8Error) -> Self {
FramerError::Utf8(err)
}
}
impl From<crate::Error> for FramerError {
fn from(err: crate::Error) -> Self {
FramerError::WebSocket(err)
}
}
pub struct Framer<'a, TRng, TWebSocketType>
where
TRng: RngCore,
TWebSocketType: WebSocketType,
{
read_buf: &'a mut [u8],
write_buf: &'a mut [u8],
read_cursor: usize,
frame_cursor: usize,
read_len: usize,
websocket: &'a mut WebSocket<TRng, TWebSocketType>,
}
impl<'a, TRng> Framer<'a, TRng, crate::Client>
where
TRng: RngCore,
{
pub fn connect<TStream>(
&mut self,
stream: &mut TStream,
websocket_options: &WebSocketOptions,
) -> Result<(), FramerError>
where
TStream: Read + Write,
{
let (len, web_socket_key) = self
.websocket
.client_connect(&websocket_options, &mut self.write_buf)?;
stream.write_all(&self.write_buf[..len])?;
let received_size = stream.read(&mut self.read_buf)?;
self.websocket
.client_accept(&web_socket_key, &self.read_buf[..received_size])?;
Ok(())
}
}
impl<'a, TRng, TWebSocketType> Framer<'a, TRng, TWebSocketType>
where
TRng: RngCore,
TWebSocketType: WebSocketType,
{
pub fn new(
read_buf: &'a mut [u8],
write_buf: &'a mut [u8],
websocket: &'a mut WebSocket<TRng, TWebSocketType>,
) -> Self {
Self {
read_buf,
write_buf,
read_cursor: 0,
frame_cursor: 0,
read_len: 0,
websocket,
}
}
pub fn state(&self) -> WebSocketState {
self.websocket.state
}
pub fn close(
&mut self,
stream: &mut impl Write,
close_status: WebSocketCloseStatusCode,
status_description: Option<&str>,
) -> Result<(), FramerError> {
let len = self
.websocket
.close(close_status, status_description, self.write_buf)?;
stream.write_all(&self.write_buf[..len])?;
Ok(())
}
pub fn write(
&mut self,
stream: &mut impl Write,
message_type: WebSocketSendMessageType,
end_of_message: bool,
frame_buf: &[u8],
) -> Result<(), FramerError> {
let len = self
.websocket
.write(message_type, end_of_message, frame_buf, self.write_buf)?;
stream.write_all(&self.write_buf[..len])?;
Ok(())
}
pub fn read_text<'b, TStream>(
&mut self,
stream: &mut TStream,
frame_buf: &'b mut [u8],
) -> Result<Option<&'b str>, FramerError>
where
TStream: Read + Write,
{
if let Some(frame) = self.next(stream, frame_buf, WebSocketReceiveMessageType::Text)? {
Ok(Some(core::str::from_utf8(frame)?))
} else {
Ok(None)
}
}
pub fn read_binary<'b, TStream>(
&mut self,
stream: &mut TStream,
frame_buf: &'b mut [u8],
) -> Result<Option<&'b [u8]>, FramerError>
where
TStream: Read + Write,
{
self.next(stream, frame_buf, WebSocketReceiveMessageType::Binary)
}
fn next<'b, TStream>(
&mut self,
stream: &mut TStream,
frame_buf: &'b mut [u8],
message_type: WebSocketReceiveMessageType,
) -> Result<Option<&'b [u8]>, FramerError>
where
TStream: Read + Write,
{
loop {
if self.read_cursor == 0 || self.read_cursor == self.read_len {
self.read_len = stream.read(self.read_buf)?;
self.read_cursor = 0;
}
if self.read_len == 0 {
return Ok(None);
}
loop {
if self.read_cursor == self.read_len {
break;
}
if self.frame_cursor == frame_buf.len() {
return Err(FramerError::FrameTooLarge(frame_buf.len()));
}
let ws_result = self.websocket.read(
&self.read_buf[self.read_cursor..self.read_len],
&mut frame_buf[self.frame_cursor..],
)?;
self.read_cursor += ws_result.len_from;
match ws_result.message_type {
x if x == message_type => {
self.frame_cursor += ws_result.len_to;
if ws_result.end_of_message {
let frame = &frame_buf[..self.frame_cursor];
self.frame_cursor = 0;
return Ok(Some(frame));
}
}
WebSocketReceiveMessageType::CloseMustReply => {
self.send_back(
stream,
frame_buf,
ws_result.len_to,
WebSocketSendMessageType::CloseReply,
)?;
return Ok(None);
}
WebSocketReceiveMessageType::CloseCompleted => return Ok(None),
WebSocketReceiveMessageType::Ping => {
self.send_back(
stream,
frame_buf,
ws_result.len_to,
WebSocketSendMessageType::Pong,
)?;
}
_ => {}
}
}
}
}
fn send_back(
&mut self,
stream: &mut impl Write,
frame_buf: &'_ mut [u8],
len_to: usize,
send_message_type: WebSocketSendMessageType,
) -> Result<(), FramerError> {
let payload_len = min(self.write_buf.len(), len_to);
let from = &frame_buf[self.frame_cursor..self.frame_cursor + payload_len];
let len = self
.websocket
.write(send_message_type, true, from, &mut self.write_buf)?;
stream.write_all(&self.write_buf[..len])?;
Ok(())
}
}