use std::{
fmt,
future::poll_fn,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll},
};
use actix_http::ws::{CloseReason, Item, Message};
use actix_web::web::Bytes;
use bytestring::ByteString;
use futures_sink::Sink;
use tokio::sync::mpsc::Sender;
use tokio_util::sync::PollSender;
const MAX_CONTROL_PAYLOAD_BYTES: usize = 125;
const MAX_CLOSE_REASON_BYTES: usize = MAX_CONTROL_PAYLOAD_BYTES - 2;
#[derive(Clone)]
pub struct Session {
inner: Option<PollSender<Message>>,
closed: Arc<AtomicBool>,
}
#[derive(Debug)]
pub struct Closed;
impl fmt::Display for Closed {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Session is closed")
}
}
impl std::error::Error for Closed {}
impl Session {
pub(super) fn new(inner: Sender<Message>) -> Self {
Session {
inner: Some(PollSender::new(inner)),
closed: Arc::new(AtomicBool::new(false)),
}
}
fn pre_check(&mut self) {
if self.closed.load(Ordering::Relaxed) {
self.inner.take();
}
}
async fn send_message_inner(&mut self, msg: Message) -> Result<(), Closed> {
if let Some(inner) = self.inner.as_mut() {
poll_fn(|cx| Pin::new(&mut *inner).poll_ready(cx))
.await
.map_err(|_| Closed)?;
Pin::new(&mut *inner).start_send(msg).map_err(|_| Closed)?;
poll_fn(|cx| Pin::new(&mut *inner).poll_flush(cx))
.await
.map_err(|_| Closed)
} else {
Err(Closed)
}
}
async fn send_message(&mut self, msg: Message) -> Result<(), Closed> {
self.pre_check();
self.send_message_inner(msg).await
}
pub async fn text(&mut self, msg: impl Into<ByteString>) -> Result<(), Closed> {
self.send_message(Message::Text(msg.into())).await
}
pub async fn binary(&mut self, msg: impl Into<Bytes>) -> Result<(), Closed> {
self.send_message(Message::Binary(msg.into())).await
}
pub async fn ping(&mut self, msg: &[u8]) -> Result<(), Closed> {
let msg = if msg.len() > MAX_CONTROL_PAYLOAD_BYTES {
&msg[..MAX_CONTROL_PAYLOAD_BYTES]
} else {
msg
};
self.send_message(Message::Ping(Bytes::copy_from_slice(msg)))
.await
}
pub async fn pong(&mut self, msg: &[u8]) -> Result<(), Closed> {
let msg = if msg.len() > MAX_CONTROL_PAYLOAD_BYTES {
&msg[..MAX_CONTROL_PAYLOAD_BYTES]
} else {
msg
};
self.send_message(Message::Pong(Bytes::copy_from_slice(msg)))
.await
}
pub async fn continuation(&mut self, msg: Item) -> Result<(), Closed> {
self.send_message(Message::Continuation(msg)).await
}
pub async fn close(mut self, reason: Option<CloseReason>) -> Result<(), Closed> {
self.pre_check();
let mut reason = reason;
if let Some(reason) = reason.as_mut() {
if let Some(desc) = reason.description.as_mut() {
if desc.len() > MAX_CLOSE_REASON_BYTES {
let mut end = MAX_CLOSE_REASON_BYTES;
while end > 0 && !desc.is_char_boundary(end) {
end -= 1;
}
desc.truncate(end);
}
}
}
if self.inner.is_some() {
self.closed.store(true, Ordering::Relaxed);
self.send_message_inner(Message::Close(reason)).await
} else {
Err(Closed)
}
}
}
impl Sink<Message> for Session {
type Error = Closed;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.pre_check();
if let Some(inner) = self.inner.as_mut() {
match Pin::new(inner).poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(_)) => Poll::Ready(Err(Closed)),
Poll::Pending => Poll::Pending,
}
} else {
Poll::Ready(Err(Closed))
}
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.pre_check();
if let Some(inner) = self.inner.as_mut() {
Pin::new(inner).start_send(item).map_err(|_| Closed)
} else {
Err(Closed)
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.pre_check();
if let Some(inner) = self.inner.as_mut() {
match Pin::new(inner).poll_flush(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(_)) => Poll::Ready(Err(Closed)),
Poll::Pending => Poll::Pending,
}
} else {
Poll::Ready(Err(Closed))
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.closed.store(true, Ordering::Relaxed);
if let Some(inner) = self.inner.as_mut() {
match Pin::new(inner).poll_close(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(_)) => Poll::Ready(Err(Closed)),
Poll::Pending => Poll::Pending,
}
} else {
Poll::Ready(Ok(()))
}
}
}
#[cfg(test)]
mod tests {
use actix_http::ws::Message;
use futures_util::SinkExt;
use super::Session;
#[tokio::test]
async fn session_implements_sink() {
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
let mut session = Session::new(tx);
session
.send(Message::Text("hello from sink".into()))
.await
.unwrap();
match rx.recv().await {
Some(Message::Text(msg)) => {
let text: &str = msg.as_ref();
assert_eq!(text, "hello from sink");
}
other => panic!("expected text frame, got: {other:?}"),
}
}
#[tokio::test]
async fn sink_close_closes_all_clones() {
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
let mut session = Session::new(tx);
let mut clone = session.clone();
SinkExt::close(&mut session).await.unwrap();
assert!(clone.text("should fail").await.is_err());
assert!(rx.recv().await.is_none());
}
#[tokio::test]
async fn close_sends_close_frame_and_closes_all_clones() {
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
let session = Session::new(tx);
let mut clone = session.clone();
session.close(None).await.unwrap();
assert!(clone.text("should fail").await.is_err());
match rx.recv().await {
Some(Message::Close(None)) => {}
other => panic!("expected close frame, got: {other:?}"),
}
}
}