use crate::websocket::frame::{Frame, Opcode};
use crate::websocket::message::WebsocketMessage;
use std::collections::VecDeque;
use std::{io, mem};
use crate::stream::ConnectionStream;
use crate::tii_error::{RequestHeadParsingError, TiiError, TiiResult};
use crate::util::{unwrap_poison, unwrap_some};
use crate::{error_log, trace_log, warn_log};
use std::io::{Cursor, ErrorKind, Read, Write};
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::{Arc, Mutex};
use std::time::Duration;
#[derive(Debug)]
struct WebSocketGuard {
closed: AtomicBool,
write_mutex: Mutex<()>,
stream: Box<dyn ConnectionStream>,
}
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct WebsocketSender(Arc<WebSocketGuard>);
pub fn new_web_socket_stream(
connection: &dyn ConnectionStream,
) -> (WebsocketSender, WebsocketReceiver) {
let guard = Arc::new(WebSocketGuard {
closed: AtomicBool::new(false),
write_mutex: Mutex::new(()),
stream: connection.new_ref(),
});
let sender = WebsocketSender(guard.clone());
let receiver = WebsocketReceiver {
guard,
state: Vec::new(),
cursor: Default::default(),
unhandled_messages: Default::default(),
};
(sender, receiver)
}
impl WebsocketSender {
#[must_use]
pub fn is_closed(&self) -> bool {
self.0.closed.load(SeqCst)
}
pub fn send(&self, message: WebsocketMessage) -> TiiResult<()> {
match message {
WebsocketMessage::Text(txt) => self.text(txt),
WebsocketMessage::Binary(bin) => self.binary(bin),
WebsocketMessage::Ping => self.ping(),
WebsocketMessage::Pong => self.pong(),
}
}
pub fn close(&self) -> TiiResult<()> {
let _g = unwrap_poison(self.0.write_mutex.lock())?;
if self.0.closed.swap(true, SeqCst) {
return Ok(()); }
Frame::new(Opcode::Close, Vec::new()).write_to(self.0.stream.as_stream_write())
}
pub fn binary(&self, message: impl Into<Vec<u8>>) -> TiiResult<()> {
let _g = unwrap_poison(self.0.write_mutex.lock())?;
Frame::new(Opcode::Binary, message.into()).write_to(self.0.stream.as_stream_write())
}
pub fn text(&self, message: impl ToString) -> TiiResult<()> {
let _g = unwrap_poison(self.0.write_mutex.lock())?;
Frame::new(Opcode::Text, message.to_string().into_bytes())
.write_to(self.0.stream.as_stream_write())
}
pub fn ping(&self) -> TiiResult<()> {
let _g = unwrap_poison(self.0.write_mutex.lock())?;
Frame::new(Opcode::Ping, Vec::new()).write_to(self.0.stream.as_stream_write())
}
pub fn pong(&self) -> TiiResult<()> {
let _g = unwrap_poison(self.0.write_mutex.lock())?;
Frame::new(Opcode::Ping, Vec::new()).write_to(self.0.stream.as_stream_write())
}
pub fn peer_addr(&self) -> TiiResult<String> {
Ok(self.0.stream.peer_addr()?)
}
}
#[derive(Debug)]
pub struct WebsocketReceiver {
guard: Arc<WebSocketGuard>,
state: Vec<Frame>,
cursor: Cursor<Vec<u8>>,
unhandled_messages: VecDeque<WebsocketMessage>,
}
#[derive(Debug)]
pub enum ReadMessageTimeoutResult {
Message(WebsocketMessage),
Timeout,
Closed,
}
impl WebsocketReceiver {
pub fn close(&self) -> TiiResult<()> {
let _g = unwrap_poison(self.guard.write_mutex.lock())?;
if self.guard.closed.swap(true, SeqCst) {
return Ok(()); }
Frame::new(Opcode::Close, Vec::new()).write_to(self.guard.stream.as_stream_write())
}
pub fn unhandled(&mut self) -> Option<WebsocketMessage> {
self.unhandled_messages.pop_front()
}
pub fn read_message(&mut self) -> TiiResult<Option<WebsocketMessage>> {
if let Some(message) = self.unhandled_messages.pop_front() {
return Ok(Some(message));
}
self.read_next_frame()
}
pub fn read_message_timeout(
&mut self,
timeout: Option<Duration>,
) -> TiiResult<ReadMessageTimeoutResult> {
if let Some(message) = self.unhandled_messages.pop_front() {
return Ok(ReadMessageTimeoutResult::Message(message));
}
if self.guard.stream.available() == 0 {
if self.guard.closed.load(SeqCst) {
return Ok(ReadMessageTimeoutResult::Closed);
}
let old_timeout = self.guard.stream.get_read_timeout()?.as_ref().cloned();
if let Err(err) = self.guard.stream.set_read_timeout(timeout) {
self.guard.closed.store(true, SeqCst);
error_log!("WebsocketReceiver::read_message_timeout error setting timeout for 1st byte of next frame {}", &err);
return Err(TiiError::from(err));
}
let res = self.guard.stream.ensure_readable();
let res2 = self.guard.stream.set_read_timeout(old_timeout);
if let Err(err) = res2 {
self.guard.closed.store(true, SeqCst);
error_log!("WebsocketReceiver::read_message_timeout error setting timeout back to read timeout after waiting for 1st byte of next frame {}", &err);
return Err(TiiError::from(err));
}
if let Err(err) = res {
if matches!(err.kind(), ErrorKind::WouldBlock | ErrorKind::TimedOut) {
return Ok(ReadMessageTimeoutResult::Timeout);
}
self.guard.closed.store(true, SeqCst);
error_log!("WebsocketReceiver::read_message_timeout error while waiting for 1st byte of next frame {}", &err);
return Err(TiiError::from(err));
}
}
match self.read_next_frame() {
Ok(Some(message)) => Ok(ReadMessageTimeoutResult::Message(message)),
Ok(None) => Ok(ReadMessageTimeoutResult::Closed),
Err(err) => Err(err),
}
}
fn read_next_frame(&mut self) -> TiiResult<Option<WebsocketMessage>> {
if self.guard.closed.load(SeqCst) {
return Ok(None);
}
let as_read = self.guard.stream.as_stream_read();
while self.state.last().map(|f| !f.fin).unwrap_or(true) {
let frame = Frame::from_stream(as_read).inspect_err(|e| {
self.guard.closed.store(true, SeqCst);
error_log!("WebsocketReceiver::read_next_frame Frame::from_stream error: {}", e);
})?;
if frame.opcode == Opcode::Ping {
return Ok(Some(WebsocketMessage::Ping));
}
if frame.opcode == Opcode::Pong {
return Ok(Some(WebsocketMessage::Pong));
}
if frame.opcode == Opcode::Close {
self.guard.closed.store(true, SeqCst);
if self.state.is_empty() {
return Ok(None);
}
return Err(TiiError::RequestHeadParsing(
RequestHeadParsingError::WebSocketClosedDuringPendingMessage,
));
}
self.state.push(frame);
}
let frames = mem::take(&mut self.state);
let frame_type = unwrap_some(frames.first()).opcode;
let size = frames.iter().map(|f| f.payload.len()).sum();
let mut payload = Vec::with_capacity(size);
for (idx, frame) in frames.into_iter().enumerate() {
if idx != 0 && frame.opcode != Opcode::Continuation {
return Err(TiiError::RequestHeadParsing(
RequestHeadParsingError::UnexpectedWebSocketOpcode,
));
}
payload.extend_from_slice(frame.payload.as_slice());
}
match frame_type {
Opcode::Text => {
let payload = String::from_utf8(payload).map_err(|e| {
self.guard.closed.store(true, SeqCst);
TiiError::RequestHeadParsing(RequestHeadParsingError::WebSocketTextMessageIsNotUtf8(
e.into_bytes(),
))
})?;
Ok(Some(WebsocketMessage::Text(payload)))
}
Opcode::Binary => Ok(Some(WebsocketMessage::Binary(payload))),
_ => {
self.guard.closed.store(true, SeqCst);
Err(TiiError::RequestHeadParsing(RequestHeadParsingError::UnexpectedWebSocketOpcode))
}
}
}
}
impl Read for WebsocketReceiver {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop {
let cnt = self.cursor.read(buf)?;
if cnt != 0 {
return Ok(cnt);
}
return match self.read_next_frame() {
Ok(Some(message)) => match message.bytes() {
Some(bytes) => {
if bytes.len() <= buf.len() {
unwrap_some(buf.get_mut(..bytes.len())).copy_from_slice(bytes);
Ok(bytes.len())
} else {
self.cursor = Cursor::new(bytes.to_vec());
continue;
}
}
None => {
self.unhandled_messages.push_back(message);
continue;
}
},
Ok(None) => Ok(0),
Err(err) => {
return Err(err.into());
}
};
}
}
}
impl Write for WebsocketSender {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.0.closed.load(SeqCst) {
return Err(io::Error::from(ErrorKind::ConnectionReset));
}
Frame::write_unowned_payload_frame(self.0.stream.as_stream_write(), Opcode::Binary, buf)
.inspect_err(|e| {
self.0.closed.store(true, SeqCst);
error_log!("WebsocketSender::write error: {}", e);
})?;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl Drop for WebSocketGuard {
fn drop(&mut self) {
trace_log!("WebsocketReceiver::drop");
if self.closed.load(SeqCst) {
trace_log!("WebsocketReceiver::drop already closed");
return;
}
trace_log!("WebsocketReceiver::drop closing...");
if let Err(err) = Frame::new(Opcode::Close, Vec::new()).write_to(self.stream.as_stream_write())
{
warn_log!("WebsocketSender::drop error: {}", err);
}
trace_log!("WebsocketReceiver::drop closed.");
}
}