use std::{cell::Cell, io, task::Poll};
use crate::codec::{Decoder, Encoder};
use crate::io::{Filter, FilterBuf, FilterLayer, Io, Layer};
use crate::service::{Service, ServiceCtx};
use super::{CloseCode, CloseReason, Codec, Frame, Item, Message};
bitflags::bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct Flags: u8 {
const CLOSED = 0b0001;
const CONTINUATION = 0b0010;
const PROTO_ERR = 0b0100;
}
}
#[derive(Clone, Debug)]
pub struct WsTransport {
codec: Codec,
flags: Cell<Flags>,
}
impl WsTransport {
pub fn create<F: Filter>(io: Io<F>, codec: Codec) -> Io<Layer<WsTransport, F>> {
io.add_filter(WsTransport {
codec,
flags: Cell::new(Flags::empty()),
})
}
fn insert_flags(&self, flags: Flags) {
let mut f = self.flags.get();
f.insert(flags);
self.flags.set(f);
}
fn remove_flags(&self, flags: Flags) {
let mut f = self.flags.get();
f.remove(flags);
self.flags.set(f);
}
fn continuation_must_start(&self, err_message: &'static str) -> io::Result<()> {
if self.flags.get().contains(Flags::CONTINUATION) {
Ok(())
} else {
self.insert_flags(Flags::PROTO_ERR);
Err(io::Error::new(io::ErrorKind::InvalidData, err_message))
}
}
}
impl FilterLayer for WsTransport {
#[inline]
fn shutdown(&self, buf: &mut FilterBuf<'_>) -> io::Result<Poll<()>> {
let flags = self.flags.get();
if !flags.contains(Flags::CLOSED) {
self.insert_flags(Flags::CLOSED);
let code = if flags.contains(Flags::PROTO_ERR) {
CloseCode::Protocol
} else {
CloseCode::Normal
};
let _ = buf.with_write_buffers(|_, _, w_dst| {
self.codec.encodev(
Message::Close(Some(CloseReason {
code,
description: None,
})),
w_dst,
)
});
}
Ok(Poll::Ready(()))
}
fn process_read_buf(&self, buf: &mut FilterBuf<'_>) -> io::Result<()> {
buf.with_buffers(|io, r_src, r_dst, _, w_dst| {
if let Some(src) = r_src {
let mut dst = r_dst.take().unwrap_or_else(|| io.cfg().read_buf().get());
loop {
let Some(frame) = self.codec.decode(src).map_err(|e| {
log::trace!("Failed to decode ws codec frames: {e:?}");
self.insert_flags(Flags::PROTO_ERR);
io::Error::new(io::ErrorKind::InvalidData, e)
})?
else {
break;
};
match frame {
Frame::Binary(bin) => dst.extend_from_slice(&bin),
Frame::Continuation(Item::FirstBinary(bin)) => {
self.insert_flags(Flags::CONTINUATION);
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Continue(bin)) => {
self.continuation_must_start(
"Continuation frame is not started",
)?;
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Last(bin)) => {
self.continuation_must_start(
"Continuation frame is not started, last frame is received",
)?;
dst.extend_from_slice(&bin);
self.remove_flags(Flags::CONTINUATION);
}
Frame::Continuation(Item::FirstText(_)) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"WebSocket Text continuation frames are not supported",
));
}
Frame::Text(_) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"WebSockets Text frames are not supported",
));
}
Frame::Ping(msg) => {
let _ = self.codec.encodev(Message::Pong(msg), w_dst);
}
Frame::Pong(_) => (),
Frame::Close(_) => {
io.wants_shutdown();
break;
}
}
}
*r_dst = Some(dst);
}
Ok(())
})
}
fn process_write_buf(&self, buf: &mut FilterBuf<'_>) -> io::Result<()> {
buf.with_buffers(|_, _, _, w_src, w_dst| {
while let Some(page) = w_src.take() {
self.codec.encode_page(page, w_dst);
}
});
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct WsTransportService {
codec: Codec,
}
impl WsTransportService {
pub fn new(codec: Codec) -> Self {
Self { codec }
}
}
impl<F: Filter> Service<Io<F>> for WsTransportService {
type Response = Io<Layer<WsTransport, F>>;
type Error = io::Error;
async fn call(
&self,
io: Io<F>,
_: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
Ok(WsTransport::create(io, self.codec.clone()))
}
}