use core::{fmt::Debug, ops::Deref, str::Utf8Error};
use futures::{Sink, SinkExt, Stream, StreamExt};
use rand_core::RngCore;
use crate::{
WebSocket, WebSocketCloseStatusCode, WebSocketOptions, WebSocketReceiveMessageType,
WebSocketSendMessageType, WebSocketSubProtocol, WebSocketType,
};
pub struct CloseMessage<'a> {
pub status_code: WebSocketCloseStatusCode,
pub reason: &'a [u8],
}
pub enum ReadResult<'a> {
Binary(&'a [u8]),
Text(&'a str),
Pong(&'a [u8]),
Ping(&'a [u8]),
Close(CloseMessage<'a>),
}
#[derive(Debug)]
pub enum FramerError<E> {
Io(E),
FrameTooLarge(usize),
Utf8(Utf8Error),
HttpHeader(httparse::Error),
WebSocket(crate::Error),
Disconnected,
RxBufferTooSmall(usize),
}
pub struct Framer<TRng, TWebSocketType>
where
TRng: RngCore,
TWebSocketType: WebSocketType,
{
websocket: WebSocket<TRng, TWebSocketType>,
frame_cursor: usize,
rx_remainder_len: usize,
}
impl<TRng> Framer<TRng, crate::Client>
where
TRng: RngCore,
{
pub async fn connect<'a, B, E>(
&mut self,
stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
buffer: &'a mut [u8],
websocket_options: &WebSocketOptions<'_>,
) -> Result<Option<WebSocketSubProtocol>, FramerError<E>>
where
B: AsRef<[u8]>,
{
let (tx_len, web_socket_key) = self
.websocket
.client_connect(websocket_options, buffer)
.map_err(FramerError::WebSocket)?;
let (tx_buf, rx_buf) = buffer.split_at_mut(tx_len);
stream.send(tx_buf).await.map_err(FramerError::Io)?;
stream.flush().await.map_err(FramerError::Io)?;
loop {
match stream.next().await {
Some(buf) => {
let buf = buf.map_err(FramerError::Io)?;
let buf = buf.as_ref();
match self.websocket.client_accept(&web_socket_key, buf) {
Ok((len, sub_protocol)) => {
let from = len;
let to = buf.len();
let remaining_len = to - from;
if remaining_len > 0 {
let rx_start = rx_buf.len() - remaining_len;
rx_buf[rx_start..].copy_from_slice(&buf[from..to]);
self.rx_remainder_len = remaining_len;
}
return Ok(sub_protocol);
}
Err(crate::Error::HttpHeaderIncomplete) => {
panic!("oh no");
}
Err(e) => {
return Err(FramerError::WebSocket(e));
}
}
}
None => return Err(FramerError::Disconnected),
}
}
}
}
impl<TRng, TWebSocketType> Framer<TRng, TWebSocketType>
where
TRng: RngCore,
TWebSocketType: WebSocketType,
{
pub fn new(websocket: WebSocket<TRng, TWebSocketType>) -> Self {
Self {
websocket,
frame_cursor: 0,
rx_remainder_len: 0,
}
}
pub fn encode<E>(
&mut self,
message_type: WebSocketSendMessageType,
end_of_message: bool,
from: &[u8],
to: &mut [u8],
) -> Result<usize, FramerError<E>> {
let len = self
.websocket
.write(message_type, end_of_message, from, to)
.map_err(FramerError::WebSocket)?;
Ok(len)
}
pub async fn write<'b, E>(
&mut self,
tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
tx_buf: &'b mut [u8],
message_type: WebSocketSendMessageType,
end_of_message: bool,
frame_buf: &[u8],
) -> Result<(), FramerError<E>>
where
E: Debug,
{
let len = self
.websocket
.write(message_type, end_of_message, frame_buf, tx_buf)
.map_err(FramerError::WebSocket)?;
tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
tx.flush().await.map_err(FramerError::Io)?;
Ok(())
}
pub async fn close<'b, E>(
&mut self,
tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
tx_buf: &'b mut [u8],
close_status: WebSocketCloseStatusCode,
status_description: Option<&str>,
) -> Result<(), FramerError<E>>
where
E: Debug,
{
let len = self
.websocket
.close(close_status, status_description, tx_buf)
.map_err(FramerError::WebSocket)?;
tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
tx.flush().await.map_err(FramerError::Io)?;
Ok(())
}
pub async fn read<'a, B: Deref<Target = [u8]>, E>(
&mut self,
stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
buffer: &'a mut [u8],
) -> Option<Result<ReadResult<'a>, FramerError<E>>>
where
E: Debug,
{
if self.rx_remainder_len == 0 {
match stream.next().await {
Some(Ok(input)) => {
if buffer.len() < input.len() {
return Some(Err(FramerError::RxBufferTooSmall(input.len())));
}
let rx_start = buffer.len() - input.len();
buffer[rx_start..].copy_from_slice(&input);
self.rx_remainder_len = input.len()
}
Some(Err(e)) => {
return Some(Err(FramerError::Io(e)));
}
None => return None,
}
}
let rx_start = buffer.len() - self.rx_remainder_len;
let (frame_buf, rx_buf) = buffer.split_at_mut(rx_start);
let ws_result = match self.websocket.read(rx_buf, frame_buf) {
Ok(ws_result) => ws_result,
Err(e) => return Some(Err(FramerError::WebSocket(e))),
};
self.rx_remainder_len -= ws_result.len_from;
match ws_result.message_type {
WebSocketReceiveMessageType::Binary => {
self.frame_cursor += ws_result.len_to;
if ws_result.end_of_message {
let range = 0..self.frame_cursor;
self.frame_cursor = 0;
return Some(Ok(ReadResult::Binary(&frame_buf[range])));
}
}
WebSocketReceiveMessageType::Text => {
self.frame_cursor += ws_result.len_to;
if ws_result.end_of_message {
let range = 0..self.frame_cursor;
self.frame_cursor = 0;
match core::str::from_utf8(&frame_buf[range]) {
Ok(text) => return Some(Ok(ReadResult::Text(text))),
Err(e) => return Some(Err(FramerError::Utf8(e))),
}
}
}
WebSocketReceiveMessageType::CloseMustReply => {
let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
match self.websocket.write(
WebSocketSendMessageType::CloseReply,
true,
&frame_buf[range.start..range.end],
tx_buf,
) {
Ok(len) => match stream.send(&tx_buf[..len]).await {
Ok(()) => {
self.frame_cursor = 0;
let status_code = ws_result
.close_status
.expect("close message must have code");
let reason = &frame_buf[range];
return Some(Ok(ReadResult::Close(CloseMessage {
status_code,
reason,
})));
}
Err(e) => return Some(Err(FramerError::Io(e))),
},
Err(e) => return Some(Err(FramerError::WebSocket(e))),
}
}
WebSocketReceiveMessageType::CloseCompleted => return None,
WebSocketReceiveMessageType::Pong => {
let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
return Some(Ok(ReadResult::Pong(&frame_buf[range])));
}
WebSocketReceiveMessageType::Ping => {
let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
match self.websocket.write(
WebSocketSendMessageType::Pong,
true,
&frame_buf[range.start..range.end],
tx_buf,
) {
Ok(len) => match stream.send(&tx_buf[..len]).await {
Ok(()) => {
return Some(Ok(ReadResult::Ping(&frame_buf[range])));
}
Err(e) => return Some(Err(FramerError::Io(e))),
},
Err(e) => return Some(Err(FramerError::WebSocket(e))),
}
}
}
None
}
}