use std::{cell::Cell, cmp, io, task::Poll};
use crate::codec::{Decoder, Encoder};
use crate::io::{Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use crate::util::{BufMut, PoolRef, Ready};
use super::{CloseCode, CloseReason, Codec, Frame, Item, Message};
bitflags::bitflags! {
struct Flags: u8 {
const CLOSED = 0b0001;
const CONTINUATION = 0b0010;
const PROTO_ERR = 0b0100;
}
}
pub struct WsTransport {
pool: PoolRef,
codec: Codec,
flags: Cell<Flags>,
}
impl WsTransport {
pub fn create<F: Filter>(io: Io<F>, codec: Codec) -> Io<Layer<WsTransport, F>> {
let pool = io.memory_pool();
io.add_filter(WsTransport {
pool,
codec,
flags: Cell::new(Flags::empty()),
})
}
fn insert_flags(&self, flags: Flags) {
let mut f = self.flags.get();
f.insert(flags);
self.flags.set(f);
}
fn remove_flags(&self, flags: Flags) {
let mut f = self.flags.get();
f.remove(flags);
self.flags.set(f);
}
fn continuation_must_start(&self, err_message: &'static str) -> io::Result<()> {
if self.flags.get().contains(Flags::CONTINUATION) {
Ok(())
} else {
self.insert_flags(Flags::PROTO_ERR);
Err(io::Error::new(io::ErrorKind::Other, err_message))
}
}
}
impl FilterLayer for WsTransport {
#[inline]
fn shutdown(&self, buf: &WriteBuf<'_>) -> io::Result<Poll<()>> {
let flags = self.flags.get();
if !flags.contains(Flags::CLOSED) {
self.insert_flags(Flags::CLOSED);
let code = if flags.contains(Flags::PROTO_ERR) {
CloseCode::Protocol
} else {
CloseCode::Normal
};
let _ = buf.with_dst(|buf| {
self.codec.encode_vec(
Message::Close(Some(CloseReason {
code,
description: None,
})),
buf,
)
});
}
Ok(Poll::Ready(()))
}
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
if let Some(mut src) = buf.take_src() {
let mut dst = buf.take_dst();
let dst_len = dst.len();
loop {
self.pool.resize_read_buf(&mut dst);
let frame = if let Some(frame) =
self.codec.decode_vec(&mut src).map_err(|e| {
log::trace!("Failed to decode ws codec frames: {:?}", e);
self.insert_flags(Flags::PROTO_ERR);
io::Error::new(io::ErrorKind::Other, e)
})? {
frame
} else {
break;
};
match frame {
Frame::Binary(bin) => dst.extend_from_slice(&bin),
Frame::Continuation(Item::FirstBinary(bin)) => {
self.insert_flags(Flags::CONTINUATION);
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Continue(bin)) => {
self.continuation_must_start("Continuation frame is not started")?;
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Last(bin)) => {
self.continuation_must_start(
"Continuation frame is not started, last frame is received",
)?;
dst.extend_from_slice(&bin);
self.remove_flags(Flags::CONTINUATION);
}
Frame::Continuation(Item::FirstText(_)) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSocket Text continuation frames are not supported",
));
}
Frame::Text(_) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSockets Text frames are not supported",
));
}
Frame::Ping(msg) => {
let _ = buf.with_write_buf(|b| {
b.with_dst(|b| self.codec.encode_vec(Message::Pong(msg), b))
});
}
Frame::Pong(_) => (),
Frame::Close(_) => {
buf.want_shutdown();
break;
}
};
}
let nb = dst.len() - dst_len;
buf.set_dst(Some(dst));
buf.set_src(Some(src));
Ok(nb)
} else {
Ok(0)
}
}
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
if let Some(src) = buf.take_src() {
buf.with_dst(|dst| {
let (hw, lw) = self.pool.write_params().unpack();
let remaining = dst.remaining_mut();
if remaining < lw {
dst.reserve(cmp::max(hw, dst.len() + 12) - remaining);
}
let _ = self.codec.encode_vec(Message::Binary(src.freeze()), dst);
});
}
Ok(())
}
}
pub struct WsTransportFactory {
codec: Codec,
}
impl WsTransportFactory {
pub fn new(codec: Codec) -> Self {
Self { codec }
}
}
impl<F: Filter> FilterFactory<F> for WsTransportFactory {
type Filter = WsTransport;
type Error = io::Error;
type Future = Ready<Io<Layer<Self::Filter, F>>, Self::Error>;
fn create(self, io: Io<F>) -> Self::Future {
Ready::Ok(WsTransport::create(io, self.codec))
}
}