#![cfg(feature="__io__")]
use crate::io::{AsyncRead, AsyncWrite};
use crate::sync::RwLock;
use crate::{Config, Message};
use std::{sync::Arc, io::Error};
pub trait UnderlyingConnection: AsyncRead + AsyncWrite + Unpin + 'static {}
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> UnderlyingConnection for T {}
pub struct Connection<C: UnderlyingConnection> {
__closed__: Arc<RwLock<bool>>,
conn: Arc<std::cell::UnsafeCell<C>>,
config: Config,
n_buffered: usize,
}
#[inline(always)]
async fn read_closed(__closed__: &RwLock<bool>) -> bool {
*__closed__.read().await
}
#[inline(always)]
async fn set_closed(__closed__: &RwLock<bool>) {
*__closed__.write().await = true
}
const ALREADY_CLOSED_MESSAGE: &str = "\n\
|--------------------------------------------\n\
| WebSocket connection is already closed! |\n\
| |\n\
| Maybe you spawned tasks using connection |\n\
| and NOT waiting the tasks to finish? |\n\
| |\n\
| This is NOT supported because it may |\n\
| cause resource leak due to something like |\n\
| an infinite loop or a dead lock in the |\n\
| WebSocket handler. |\n\
| If you're doing it, please wait |\n\
| (e.g. join, select, await, ...) the tasks |\n\
| in the handler! |\n\
--------------------------------------------|\n\
";
macro_rules! underlying {
($this:expr) => {async {
let _: &mut Connection<_> = $this;
let conn = (!read_closed(&$this.__closed__).await).then(|| {
unsafe {&mut *$this.conn.get()}
});
underlying!(@@checked conn)
}};
(unless $__closed__:ident, $conn:ident) => {async {
let _: &mut _ = $conn;
let conn = (!read_closed(&$__closed__).await).then_some($conn);
underlying!(@@checked conn)
}};
(@@checked $maybe_conn:expr) => {{
let _: Option<&mut _> = $maybe_conn;
$maybe_conn.ok_or_else(|| {
eprintln!("{ALREADY_CLOSED_MESSAGE}");
::std::io::Error::new(
::std::io::ErrorKind::ConnectionReset,
"WebSocket connection is already closed"
)
})
}};
}
#[inline(always)]
async fn to_checked_parts<C: UnderlyingConnection>(connection: &mut Connection<C>) -> Result<(&mut C, &Config, &mut usize), Error> {
let conn = underlying!(connection).await?;
return Ok((conn, &connection.config, &mut connection.n_buffered))
}
#[inline]
pub(super) async fn send(
message: Message,
conn: &mut (impl AsyncWrite + Unpin),
config: &Config,
n_buffered: &mut usize,
) -> Result<(), Error> {
message.write(conn, config).await?;
flush(conn, n_buffered).await?;
Ok(())
}
#[inline]
pub(super) async fn write(
message: Message,
conn: &mut (impl AsyncWrite + Unpin),
config: &Config,
n_buffered: &mut usize,
) -> Result<usize, Error> {
let n = message.write(conn, config).await?;
*n_buffered += n;
if *n_buffered > config.write_buffer_size {
if *n_buffered > config.max_write_buffer_size {
panic!("Buffered messages is larger than `max_write_buffer_size`");
} else {
flush(conn, n_buffered).await?
}
}
Ok(n)
}
#[inline]
pub(super) async fn flush(
conn: &mut (impl AsyncWrite + Unpin),
n_buffered: &mut usize,
) -> Result<(), Error> {
conn.flush().await
.map(|_| *n_buffered = 0)
}
const _: () = {
unsafe impl<C: UnderlyingConnection> Send for Connection<C> {}
unsafe impl<C: UnderlyingConnection> Sync for Connection<C> {}
impl<C: UnderlyingConnection + std::fmt::Debug> std::fmt::Debug for Connection<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocket Connection")
.field("underlying", &unsafe {&*self.conn.get()})
.field("config", &self.config)
.field("n_buffered", &self.n_buffered)
.finish()
}
}
};
pub struct Closer<C: UnderlyingConnection>(Connection<C>);
use crate::{CloseCode, CloseFrame};
impl<C: UnderlyingConnection> Closer<C> {
pub async fn send_close_if_not_closed(self) {
self.send_close_if_not_closed_with(CloseFrame {
code: CloseCode::Normal,
reason: None
}).await
}
pub async fn send_close_if_not_closed_with(mut self, frame: CloseFrame) {
#[cfg(debug_assertions)] {
if Arc::strong_count(&self.0.__closed__) != 1 {
eprintln!("\n\
Unexpected state of WebSocket closer found!\n\
\n\
First use `Connection` in a `Handler`,\n\
and next use `Closer` to ensure to \n\
send a close message to client.\n\
")
}
}
if !self.0.is_closed().await {
if let Err(e) = self.0.send(Message::Close(Some(frame))).await {
eprintln!("failed to send a close message: {e}")
}
}
}
}
impl<C: UnderlyingConnection> Connection<C> {
pub fn new(conn: C, config: Config) -> (Self, Closer<C>) {
let conn = Arc::new(std::cell::UnsafeCell::new(conn));
let __closed__ = Arc::new(RwLock::new(false));
(
Self { conn: conn.clone(), __closed__: __closed__.clone(), config: config.clone(), n_buffered: 0 },
Closer(Connection { conn, __closed__, config, n_buffered: 0 })
)
}
pub async fn is_closed(&self) -> bool {
read_closed(&self.__closed__).await
}
pub(crate) async fn close(&mut self) {
set_closed(&self.__closed__).await
}
}
impl<C: UnderlyingConnection> Connection<C> {
#[inline]
pub async fn recv(&mut self) -> Result<Option<Message>, Error> {
let (conn, config, _) = to_checked_parts(self).await?;
match Message::read_from(conn, config).await? {
Some(Message::Ping(payload)) => {
self.send(Message::Pong(payload.clone())).await?;
Ok(None)
}
other => Ok(other)
}
}
#[inline]
pub async fn send(&mut self, message: impl Into<Message>) -> Result<(), Error> {
let message = message.into();
let (conn, config, n_buffered) = to_checked_parts(self).await?;
let closing = matches!(message, Message::Close(_));
send(message, conn, config, n_buffered).await?;
if closing {self.close().await}
Ok(())
}
pub async fn write(&mut self, message: impl Into<Message>) -> Result<usize, Error> {
let message = message.into();
let (conn, config, n_buffered) = to_checked_parts(self).await?;
let closing = matches!(message, Message::Close(_));
let n = write(message, conn, config, n_buffered).await?;
if closing {self.close().await}
Ok(n)
}
pub async fn flush(&mut self) -> Result<(), Error> {
let (conn, _, n_buffered) = to_checked_parts(self).await?;
flush(conn, n_buffered).await
}
}
pub mod split {
use super::*;
pub trait Splitable<'split>: AsyncRead + AsyncWrite + Unpin + Sized {
type ReadHalf: AsyncRead + Unpin;
type WriteHalf: AsyncWrite + Unpin;
fn split(&'split mut self) -> (Self::ReadHalf, Self::WriteHalf);
}
impl<C: UnderlyingConnection> Connection<C>
where
C: for<'s> Splitable<'s>,
{
pub fn split(self) -> (
ReadHalf<<C as Splitable<'static>>::ReadHalf>,
WriteHalf<<C as Splitable<'static>>::WriteHalf>,
) {
if *self.__closed__.try_read().expect(ALREADY_CLOSED_MESSAGE) {
panic!("{ALREADY_CLOSED_MESSAGE}")
}
let conn = unsafe {&mut *self.conn.get()};
let (r, w) = conn.split();
let __closed__ = Arc::new(RwLock::new(false));
(
ReadHalf {
__closed__: __closed__.clone(),
conn: r,
config: self.config.clone()
},
WriteHalf {
__closed__,
conn: w,
config: self.config,
n_buffered: self.n_buffered
},
)
}
}
#[cfg(feature="io_futures")]
const _: () = {
impl<'split, T: AsyncRead + AsyncWrite + Unpin + 'split> Splitable<'split> for T {
type ReadHalf = futures_util::io::ReadHalf<&'split mut T>;
type WriteHalf = futures_util::io::WriteHalf<&'split mut T>;
fn split(&'split mut self) -> (Self::ReadHalf, Self::WriteHalf) {
AsyncRead::split(self)
}
}
};
#[cfg(feature="io_tokio")]
const _: () = {
impl<'split, T: AsyncRead + AsyncWrite + Unpin + 'split> Splitable<'split> for T {
type ReadHalf = TokioIoReadHalf<'split, T>;
type WriteHalf = TokioIoWriteHalf<'split, T>;
fn split(&'split mut self) -> (Self::ReadHalf, Self::WriteHalf) {
let (r, w) = futures_util::lock::BiLock::new(self);
(TokioIoReadHalf(r), TokioIoWriteHalf(w))
}
}
pub struct TokioIoReadHalf<'split, T>(futures_util::lock::BiLock<&'split mut T>);
pub struct TokioIoWriteHalf<'split, T>(futures_util::lock::BiLock<&'split mut T>);
fn lock_and_then<T, U, E>(
lock: &futures_util::lock::BiLock<T>,
cx: &mut std::task::Context<'_>,
f: impl FnOnce(std::pin::Pin<&mut T>, &mut std::task::Context<'_>) -> std::task::Poll<Result<U, E>>
) -> std::task::Poll<Result<U, E>> {
let mut l = futures_util::ready!(lock.poll_lock(cx));
f(l.as_pin_mut(), cx)
}
impl<'split, T: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for TokioIoReadHalf<'split, T> {
#[inline]
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>
) -> std::task::Poll<std::io::Result<()>> {
lock_and_then(&self.0, cx, |l, cx| l.poll_read(cx, buf))
}
}
impl<'split, T: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for TokioIoWriteHalf<'split, T> {
#[inline]
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8]
) -> std::task::Poll<std::io::Result<usize>> {
lock_and_then(&self.0, cx, |l, cx| l.poll_write(cx, buf))
}
#[inline]
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>
) -> std::task::Poll<std::io::Result<()>> {
lock_and_then(&self.0, cx, |l, cx| l.poll_flush(cx))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>
) -> std::task::Poll<std::io::Result<()>> {
lock_and_then(&self.0, cx, |l, cx| l.poll_shutdown(cx))
}
}
};
pub struct ReadHalf<C: AsyncRead + Unpin> {
__closed__: Arc<RwLock<bool>>,
conn: C,
config: Config,
}
impl<C: AsyncRead + Unpin> ReadHalf<C> {
#[inline]
pub async fn recv(&mut self) -> Result<Option<Message>, Error> {
let Self { __closed__, conn, config } = self;
let conn = underlying!(unless __closed__, conn).await?;
Message::read_from(conn, config).await
}
}
pub struct WriteHalf<C: AsyncWrite + Unpin> {
__closed__: Arc<RwLock<bool>>,
conn: C,
config: Config,
n_buffered: usize,
}
impl<C: AsyncWrite + Unpin> WriteHalf<C> {
#[inline]
pub async fn send(&mut self, message: impl Into<Message>) -> Result<(), Error> {
let message = message.into();
let Self { __closed__, conn, config, n_buffered } = self;
let conn = underlying!(unless __closed__, conn).await?;
let closing = matches!(message, Message::Close(_));
send(message, conn, config, n_buffered).await?;
if closing {set_closed(__closed__).await}
Ok(())
}
pub async fn write(&mut self, message: impl Into<Message>) -> Result<usize, Error> {
let message = message.into();
let Self { __closed__, conn, config, n_buffered } = self;
let conn = underlying!(unless __closed__, conn).await?;
let closing = matches!(message, Message::Close(_));
let n = write(message, conn, config, n_buffered).await?;
if closing {set_closed(__closed__).await}
Ok(n)
}
pub async fn flush(&mut self) -> Result<(), Error> {
let Self { __closed__, conn, n_buffered, config:_ } = self;
let conn = underlying!(unless __closed__, conn).await?;
flush(conn, n_buffered).await
}
}
}