use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use fastwebsockets::FragmentCollectorRead;
use fastwebsockets::Frame as FastFrame;
use fastwebsockets::OpCode as FastOpCode;
use fastwebsockets::Role as FastRole;
use fastwebsockets::WebSocketError as FastWebSocketError;
use fastwebsockets::WebSocketWrite as FastWriteHalf;
use futures_util::sink::Sink;
use futures_util::stream::Stream;
use openwire_core::websocket::{
validate_close_frame, validate_outbound_engine_frame, BoxEngineSink, BoxEngineStream,
EngineFrame, Role, WebSocketChannel, WebSocketEngine, WebSocketEngineConfig,
WebSocketEngineError,
};
use openwire_core::{BoxConnection, BoxFuture, WireError, WireErrorKind};
use openwire_tokio::TokioIo;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::sync::Mutex;
#[derive(Clone, Default)]
pub struct FastWebSocketsEngine;
impl FastWebSocketsEngine {
pub fn new() -> Self {
Self
}
pub fn shared() -> Arc<Self> {
Arc::new(Self)
}
}
impl WebSocketEngine for FastWebSocketsEngine {
fn upgrade(
&self,
io: BoxConnection,
config: WebSocketEngineConfig,
) -> BoxFuture<Result<WebSocketChannel, WebSocketEngineError>> {
Box::pin(async move {
validate_config(&config)?;
let websocket =
fastwebsockets::WebSocket::after_handshake(TokioIo::new(io), FastRole::Client);
let (mut read, write) = websocket.split(tokio::io::split);
read.set_auto_close(false);
read.set_auto_pong(false);
read.set_max_message_size(config.max_message_size);
let send: BoxEngineSink = Box::pin(FastEngineSink::new(write));
let recv: BoxEngineStream = Box::pin(FastEngineStream::new(
FragmentCollectorRead::new(read),
config.max_message_size,
));
Ok(WebSocketChannel { send, recv })
})
}
}
fn validate_config(config: &WebSocketEngineConfig) -> Result<(), WebSocketEngineError> {
if config.role != Role::Client {
return Err(WebSocketEngineError::UnsupportedExtension(
"fastwebsockets engine only supports client role".into(),
));
}
if config
.extensions
.iter()
.any(|extension| !extension.is_empty())
{
return Err(WebSocketEngineError::UnsupportedExtension(
config.extensions.join(", "),
));
}
Ok(())
}
type BoxOpFuture = Pin<Box<dyn Future<Output = Result<(), WebSocketEngineError>> + Send>>;
type BoxReadFuture =
Pin<Box<dyn Future<Output = Option<Result<EngineFrame, WebSocketEngineError>>> + Send>>;
struct FastEngineSink<W> {
inner: Arc<Mutex<FastWriteHalf<W>>>,
buffered: Option<EngineFrame>,
write_fut: Option<BoxOpFuture>,
flush_fut: Option<BoxOpFuture>,
}
impl<W> FastEngineSink<W>
where
W: AsyncWrite + Unpin + Send + 'static,
{
fn new(inner: FastWriteHalf<W>) -> Self {
Self {
inner: Arc::new(Mutex::new(inner)),
buffered: None,
write_fut: None,
flush_fut: None,
}
}
fn poll_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebSocketEngineError>> {
if self.write_fut.is_none() {
if let Some(frame) = self.buffered.take() {
let inner = Arc::clone(&self.inner);
self.write_fut = Some(Box::pin(async move {
let mut writer = inner.lock_owned().await;
writer
.write_frame(engine_to_fast(frame))
.await
.map_err(map_error)
}));
}
}
if let Some(fut) = self.write_fut.as_mut() {
match fut.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
self.write_fut = None;
result?;
}
}
}
if let Some(fut) = self.flush_fut.as_mut() {
match fut.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
self.flush_fut = None;
result?;
}
}
}
Poll::Ready(Ok(()))
}
fn start_flush(&mut self) {
if self.flush_fut.is_some() {
return;
}
let inner = Arc::clone(&self.inner);
self.flush_fut = Some(Box::pin(async move {
let mut writer = inner.lock_owned().await;
writer.flush().await.map_err(map_error)
}));
}
}
impl<W> Sink<EngineFrame> for FastEngineSink<W>
where
W: AsyncWrite + Unpin + Send + 'static,
{
type Error = WebSocketEngineError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.as_mut().get_mut().poll_pending(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: EngineFrame) -> Result<(), Self::Error> {
let me = self.as_mut().get_mut();
if me.buffered.is_some() {
return Err(closed_sink_error("write already buffered"));
}
validate_outbound_engine_frame(&item)?;
me.buffered = Some(item);
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let me = self.as_mut().get_mut();
match me.poll_pending(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Ready(Ok(())) => {
me.start_flush();
me.poll_pending(cx)
}
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.as_mut().poll_flush(cx)
}
}
struct FastEngineStream<R> {
inner: Arc<Mutex<FragmentCollectorRead<R>>>,
read_fut: Option<BoxReadFuture>,
max_message_size: usize,
}
impl<R> FastEngineStream<R>
where
R: AsyncRead + Unpin + Send + 'static,
{
fn new(inner: FragmentCollectorRead<R>, max_message_size: usize) -> Self {
Self {
inner: Arc::new(Mutex::new(inner)),
read_fut: None,
max_message_size,
}
}
fn start_read(&mut self) {
if self.read_fut.is_some() {
return;
}
let inner = Arc::clone(&self.inner);
let max_message_size = self.max_message_size;
self.read_fut = Some(Box::pin(async move {
let mut reader = inner.lock_owned().await;
let mut noop_send = |_| async { Ok::<(), Infallible>(()) };
match reader.read_frame::<_, Infallible>(&mut noop_send).await {
Ok(frame) => Some(fast_to_engine(frame)),
Err(FastWebSocketError::ConnectionClosed) => None,
Err(error) => Some(Err(map_error_with_limit(error, max_message_size))),
}
}));
}
}
impl<R> Stream for FastEngineStream<R>
where
R: AsyncRead + Unpin + Send + 'static,
{
type Item = Result<EngineFrame, WebSocketEngineError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let me = self.as_mut().get_mut();
me.start_read();
let Some(fut) = me.read_fut.as_mut() else {
return Poll::Ready(None);
};
match fut.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
me.read_fut = None;
Poll::Ready(result)
}
}
}
}
fn engine_to_fast(frame: EngineFrame) -> FastFrame<'static> {
match frame {
EngineFrame::Text(text) => FastFrame::text(text.into_bytes().into()),
EngineFrame::Binary(bytes) => FastFrame::binary(bytes.to_vec().into()),
EngineFrame::Ping(bytes) => {
FastFrame::new(true, FastOpCode::Ping, None, bytes.to_vec().into())
}
EngineFrame::Pong(bytes) => FastFrame::pong(bytes.to_vec().into()),
EngineFrame::Close { code: 1005, reason } if reason.is_empty() => {
FastFrame::new(true, FastOpCode::Close, None, Vec::<u8>::new().into())
}
EngineFrame::Close { code, reason } => FastFrame::close(code, reason.as_bytes()),
}
}
fn fast_to_engine(frame: FastFrame<'_>) -> Result<EngineFrame, WebSocketEngineError> {
match frame.opcode {
FastOpCode::Text => {
let text = String::from_utf8(frame.payload.to_vec())
.map_err(|_| WebSocketEngineError::InvalidUtf8)?;
Ok(EngineFrame::Text(text))
}
FastOpCode::Binary => Ok(EngineFrame::Binary(Bytes::from(frame.payload.to_vec()))),
FastOpCode::Ping => Ok(EngineFrame::Ping(Bytes::from(frame.payload.to_vec()))),
FastOpCode::Pong => Ok(EngineFrame::Pong(Bytes::from(frame.payload.to_vec()))),
FastOpCode::Close => {
let (code, reason) = parse_close_payload(&frame.payload)?;
Ok(EngineFrame::Close { code, reason })
}
FastOpCode::Continuation => Err(WebSocketEngineError::InvalidFrame(
"fragment collector returned continuation frame".into(),
)),
}
}
fn parse_close_payload(payload: &[u8]) -> Result<(u16, String), WebSocketEngineError> {
if payload.is_empty() {
return Ok((1005, String::new()));
}
if payload.len() == 1 {
return Err(WebSocketEngineError::InvalidFrame(
"close payload of length 1".into(),
));
}
let code = u16::from_be_bytes([payload[0], payload[1]]);
let reason = std::str::from_utf8(&payload[2..])
.map_err(|_| WebSocketEngineError::InvalidUtf8)?
.to_string();
validate_close_frame(code, &reason)?;
Ok((code, reason))
}
fn map_error(error: FastWebSocketError) -> WebSocketEngineError {
map_error_with_limit(error, 0)
}
fn map_error_with_limit(
error: FastWebSocketError,
max_message_size: usize,
) -> WebSocketEngineError {
match error {
FastWebSocketError::IoError(io) => protocol_io_error("fastwebsockets IO error", io),
FastWebSocketError::InvalidUTF8 => WebSocketEngineError::InvalidUtf8,
FastWebSocketError::PingFrameTooLarge => WebSocketEngineError::PayloadTooLarge {
limit: 125,
received: 126,
},
FastWebSocketError::FrameTooLarge => WebSocketEngineError::PayloadTooLarge {
limit: max_message_size,
received: max_message_size.saturating_add(1),
},
other => WebSocketEngineError::InvalidFrame(other.to_string()),
}
}
fn protocol_io_error(message: &'static str, error: std::io::Error) -> WebSocketEngineError {
WebSocketEngineError::Io(WireError::with_source(
WireErrorKind::Protocol,
message,
error,
))
}
fn closed_sink_error(message: &'static str) -> WebSocketEngineError {
WebSocketEngineError::Io(WireError::new(WireErrorKind::Protocol, message))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_status_close_ack_maps_to_empty_fastwebsockets_close() {
let frame = engine_to_fast(EngineFrame::Close {
code: 1005,
reason: String::new(),
});
assert!(frame.fin);
assert_eq!(frame.opcode, FastOpCode::Close);
assert!(frame.payload.is_empty());
}
}