use core::{
fmt,
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
use alloc::sync::{Arc, Weak};
use std::{error, io, sync::Mutex};
use bytes::{Bytes, BytesMut};
use futures_core::stream::Stream;
use pin_project_lite::pin_project;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use super::{
codec::{Codec, Message},
error::ProtocolError,
proto::CloseReason,
};
pin_project! {
pub struct RequestStream<S> {
#[pin]
stream: S,
buf: BytesMut,
codec: Codec,
}
}
impl<S, T, E> RequestStream<S>
where
S: Stream<Item = Result<T, E>>,
T: AsRef<[u8]>,
{
pub fn new(stream: S) -> Self {
Self::with_codec(stream, Codec::new())
}
pub fn with_codec(stream: S, codec: Codec) -> Self {
Self {
stream,
buf: BytesMut::new(),
codec,
}
}
#[inline]
pub fn inner_mut(&mut self) -> &mut S {
&mut self.stream
}
#[inline]
pub fn codec_mut(&mut self) -> &mut Codec {
&mut self.codec
}
pub fn response_stream(&self) -> (ResponseStream, ResponseSender) {
let codec = self.codec.duplicate();
let cap = codec.capacity();
let (tx, rx) = channel(cap);
(ResponseStream(rx), ResponseSender::new(tx, codec))
}
}
pub enum WsError<E> {
Protocol(ProtocolError),
Stream(E),
}
impl<E> fmt::Debug for WsError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
Self::Stream(..) => f.write_str("Input Stream error"),
}
}
}
impl<E> fmt::Display for WsError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
Self::Stream(..) => f.write_str("Input Stream error"),
}
}
}
impl<E> error::Error for WsError<E> {}
impl<E> From<ProtocolError> for WsError<E> {
fn from(e: ProtocolError) -> Self {
Self::Protocol(e)
}
}
impl<S, T, E> Stream for RequestStream<S>
where
S: Stream<Item = Result<T, E>>,
T: AsRef<[u8]>,
{
type Item = Result<Message, WsError<E>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match this.codec.decode(this.buf)? {
Some(msg) => return Poll::Ready(Some(Ok(msg))),
None => match ready!(this.stream.as_mut().poll_next(cx)) {
Some(res) => {
let item = res.map_err(WsError::Stream)?;
this.buf.extend_from_slice(item.as_ref())
}
None => return Poll::Ready(Some(Err(WsError::Protocol(ProtocolError::UnexpectedEof)))),
},
}
}
}
}
pub struct ResponseStream(Receiver<Item>);
type Item = io::Result<Bytes>;
impl Stream for ResponseStream {
type Item = Item;
#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().0.poll_recv(cx)
}
}
#[derive(Debug)]
pub struct ResponseSender {
inner: Arc<_ResponseSender>,
}
impl ResponseSender {
fn new(tx: Sender<Item>, codec: Codec) -> Self {
let buf = BytesMut::with_capacity(codec.max_size());
Self {
inner: Arc::new(_ResponseSender {
encoder: Mutex::new(Encoder { codec, buf }),
tx,
}),
}
}
pub fn downgrade(&self) -> ResponseWeakSender {
ResponseWeakSender {
inner: Arc::downgrade(&self.inner),
}
}
#[inline]
pub async fn text(&self, txt: impl Into<Bytes>) -> Result<(), ProtocolError> {
let bytes = txt.into();
core::str::from_utf8(&bytes).map_err(|_| ProtocolError::BadOpCode)?;
self.send(Message::Text(bytes)).await
}
#[inline]
pub fn binary(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
self.send(Message::Binary(bin.into()))
}
#[inline]
pub fn continuation(&self, item: super::codec::Item) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
self.send(Message::Continuation(item))
}
#[inline]
pub fn ping(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
self.send(Message::Ping(bin.into()))
}
#[inline]
pub fn pong(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
self.send(Message::Pong(bin.into()))
}
pub async fn close(&mut self, reason: Option<impl Into<CloseReason>>) -> Result<(), ProtocolError> {
self.send(Message::Close(reason.map(Into::into))).await
}
#[inline]
fn send(&self, msg: Message) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
self.inner.send(msg)
}
}
#[derive(Debug)]
pub struct ResponseWeakSender {
inner: Weak<_ResponseSender>,
}
impl ResponseWeakSender {
pub fn upgrade(&self) -> Option<ResponseSender> {
self.inner.upgrade().and_then(|inner| {
let closed = inner.encoder.lock().unwrap().codec.send_closed();
(!closed).then(|| ResponseSender { inner })
})
}
}
#[derive(Debug)]
struct _ResponseSender {
encoder: Mutex<Encoder>,
tx: Sender<Item>,
}
#[derive(Debug)]
struct Encoder {
codec: Codec,
buf: BytesMut,
}
impl _ResponseSender {
async fn send(&self, msg: Message) -> Result<(), ProtocolError> {
let permit = self.tx.reserve().await.map_err(|_| ProtocolError::UnexpectedEof)?;
let buf = {
let mut encoder = self.encoder.lock().unwrap();
let Encoder { codec, buf } = &mut *encoder;
codec.encode(msg, buf)?;
buf.split().freeze()
};
permit.send(Ok(buf));
Ok(())
}
}