use std::{
fmt,
future::poll_fn,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use actix_http::ws::{CloseReason, ProtocolError};
use actix_web::web::Bytes;
use bytestring::ByteString;
use futures_core::Stream;
use crate::{AggregatedMessage, AggregatedMessageStream, Closed, MessageStream, Session};
#[cfg(feature = "serde-json")]
mod json;
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub use self::json::JsonCodec;
pub trait MessageCodec<T> {
type Error;
fn encode(&self, item: &T) -> Result<EncodedMessage, Self::Error>;
fn decode(&self, msg: AggregatedMessage) -> Result<CodecMessage<T>, Self::Error>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EncodedMessage {
Text(ByteString),
Binary(Bytes),
}
#[derive(Debug)]
pub enum CodecMessage<T> {
Item(T),
Ping(Bytes),
Pong(Bytes),
Close(Option<CloseReason>),
}
#[derive(Debug)]
pub enum CodecSendError<E> {
Closed(Closed),
Codec(E),
}
impl<E> fmt::Display for CodecSendError<E>
where
E: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CodecSendError::Closed(_) => f.write_str("session is closed"),
CodecSendError::Codec(err) => write!(f, "codec error: {err}"),
}
}
}
impl<E> std::error::Error for CodecSendError<E>
where
E: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CodecSendError::Closed(err) => Some(err),
CodecSendError::Codec(err) => Some(err),
}
}
}
#[derive(Debug)]
pub enum CodecStreamError<E> {
Protocol(ProtocolError),
Codec(E),
}
impl<E> fmt::Display for CodecStreamError<E>
where
E: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CodecStreamError::Protocol(err) => write!(f, "protocol error: {err}"),
CodecStreamError::Codec(err) => write!(f, "codec error: {err}"),
}
}
}
impl<E> std::error::Error for CodecStreamError<E>
where
E: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CodecStreamError::Protocol(err) => Some(err),
CodecStreamError::Codec(err) => Some(err),
}
}
}
pub struct CodecSession<T, C> {
session: Session,
codec: C,
_phantom: PhantomData<fn() -> T>,
}
impl<T, C> CodecSession<T, C>
where
C: MessageCodec<T>,
{
pub fn new(session: Session, codec: C) -> Self {
Self {
session,
codec,
_phantom: PhantomData,
}
}
pub fn session(&self) -> &Session {
&self.session
}
pub fn session_mut(&mut self) -> &mut Session {
&mut self.session
}
pub fn codec(&self) -> &C {
&self.codec
}
pub fn codec_mut(&mut self) -> &mut C {
&mut self.codec
}
pub fn into_inner(self) -> Session {
self.session
}
pub async fn send(&mut self, item: &T) -> Result<(), CodecSendError<C::Error>> {
let msg = self.codec.encode(item).map_err(CodecSendError::Codec)?;
match msg {
EncodedMessage::Text(text) => self
.session
.text(text)
.await
.map_err(CodecSendError::Closed),
EncodedMessage::Binary(bin) => self
.session
.binary(bin)
.await
.map_err(CodecSendError::Closed),
}
}
pub async fn close(self, reason: Option<CloseReason>) -> Result<(), Closed> {
self.session.close(reason).await
}
}
pub struct CodecMessageStream<T, C> {
stream: AggregatedMessageStream,
codec: C,
_phantom: PhantomData<fn() -> T>,
}
impl<T, C> CodecMessageStream<T, C>
where
C: MessageCodec<T>,
{
pub fn new(stream: AggregatedMessageStream, codec: C) -> Self {
Self {
stream,
codec,
_phantom: PhantomData,
}
}
pub fn codec(&self) -> &C {
&self.codec
}
pub fn codec_mut(&mut self) -> &mut C {
&mut self.codec
}
pub fn into_inner(self) -> AggregatedMessageStream {
self.stream
}
#[must_use]
pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
poll_fn(|cx| unsafe { Pin::new_unchecked(&mut *self) }.poll_next(cx)).await
}
}
impl<T, C> Stream for CodecMessageStream<T, C>
where
C: MessageCodec<T>,
{
type Item = Result<CodecMessage<T>, CodecStreamError<C::Error>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let msg = match Pin::new(&mut this.stream).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => msg,
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Some(Err(CodecStreamError::Protocol(err))));
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
};
match this.codec.decode(msg) {
Ok(item) => Poll::Ready(Some(Ok(item))),
Err(err) => Poll::Ready(Some(Err(CodecStreamError::Codec(err)))),
}
}
}
impl MessageStream {
#[must_use]
pub fn with_codec<T, C>(self, codec: C) -> CodecMessageStream<T, C>
where
C: MessageCodec<T>,
{
self.aggregate_continuations().with_codec(codec)
}
}
impl AggregatedMessageStream {
#[must_use]
pub fn with_codec<T, C>(self, codec: C) -> CodecMessageStream<T, C>
where
C: MessageCodec<T>,
{
CodecMessageStream::new(self, codec)
}
}
impl Session {
#[must_use]
pub fn with_codec<T, C>(self, codec: C) -> CodecSession<T, C>
where
C: MessageCodec<T>,
{
CodecSession::new(self, codec)
}
}
#[cfg(all(test, feature = "serde-json"))]
mod tests {
use actix_http::ws::Message;
use actix_web::web::Bytes;
use serde::{Deserialize, Serialize};
use super::{CodecMessage, EncodedMessage};
use crate::{codec::CodecStreamError, stream::tests::payload_pair, Session};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestMsg {
a: u32,
}
#[tokio::test]
async fn json_session_encodes_text_frames_by_default() {
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let session = Session::new(tx);
let mut session = session.with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
session.send(&TestMsg { a: 123 }).await.unwrap();
match rx.recv().await.unwrap() {
Message::Text(text) => {
let s: &str = text.as_ref();
assert_eq!(s, r#"{"a":123}"#);
}
other => panic!("expected text frame, got: {other:?}"),
}
}
#[tokio::test]
async fn json_session_can_encode_binary_frames() {
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let session = Session::new(tx);
let mut session =
session.with_codec::<TestMsg, _>(crate::codec::JsonCodec::default().binary());
session.send(&TestMsg { a: 123 }).await.unwrap();
match rx.recv().await.unwrap() {
Message::Binary(bytes) => assert_eq!(bytes, Bytes::from_static(br#"{"a":123}"#)),
other => panic!("expected binary frame, got: {other:?}"),
}
}
#[tokio::test]
async fn json_stream_decodes_text_and_binary_frames() {
let (mut tx, rx) = payload_pair(8);
let mut stream = crate::MessageStream::new(rx)
.with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
tx.send(Message::Text(r#"{"a":1}"#.into())).await;
match stream.recv().await.unwrap().unwrap() {
CodecMessage::Item(TestMsg { a }) => assert_eq!(a, 1),
other => panic!("expected decoded item, got: {other:?}"),
}
tx.send(Message::Binary(Bytes::from_static(br#"{"a":2}"#)))
.await;
match stream.recv().await.unwrap().unwrap() {
CodecMessage::Item(TestMsg { a }) => assert_eq!(a, 2),
other => panic!("expected decoded item, got: {other:?}"),
}
}
#[tokio::test]
async fn json_stream_passes_through_control_frames() {
let (mut tx, rx) = payload_pair(8);
let mut stream = crate::MessageStream::new(rx)
.with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
tx.send(Message::Ping(Bytes::from_static(b"hi"))).await;
match stream.recv().await.unwrap().unwrap() {
CodecMessage::Ping(bytes) => assert_eq!(bytes, Bytes::from_static(b"hi")),
other => panic!("expected ping, got: {other:?}"),
}
}
#[tokio::test]
async fn json_stream_yields_codec_error_on_invalid_payload_and_continues() {
let (mut tx, rx) = payload_pair(8);
let mut stream = crate::MessageStream::new(rx)
.with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
tx.send(Message::Text("not json".into())).await;
match stream.recv().await.unwrap() {
Err(CodecStreamError::Codec(_)) => {}
other => panic!("expected codec error, got: {other:?}"),
}
tx.send(Message::Text(r#"{"a":9}"#.into())).await;
match stream.recv().await.unwrap().unwrap() {
CodecMessage::Item(TestMsg { a }) => assert_eq!(a, 9),
other => panic!("expected decoded item, got: {other:?}"),
}
}
#[test]
fn encoded_message_is_lightweight() {
let _ = EncodedMessage::Text("hello".into());
let _ = EncodedMessage::Binary(Bytes::from_static(b"hello"));
}
}