Skip to main content

actix_ws/
codec.rs

1//! Typed WebSocket messages via pluggable codecs.
2//!
3//! This module provides a small framework for doing that. Concrete codecs can be
4//! implemented by user code or enabled via crate features.
5//!
6//! # Feature Flags
7//!
8//! - `serde-json`: enables the `JsonCodec` type (requires `serde` + `serde_json`).
9
10use std::{
11    fmt,
12    future::poll_fn,
13    marker::PhantomData,
14    pin::Pin,
15    task::{Context, Poll},
16};
17
18use actix_http::ws::{CloseReason, ProtocolError};
19use actix_web::web::Bytes;
20use bytestring::ByteString;
21use futures_core::Stream;
22
23use crate::{AggregatedMessage, AggregatedMessageStream, Closed, MessageStream, Session};
24
25#[cfg(feature = "serde-json")]
26mod json;
27
28#[cfg(feature = "serde-json")]
29#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
30pub use self::json::JsonCodec;
31
32/// A codec that can translate between typed values and WebSocket messages.
33pub trait MessageCodec<T> {
34    /// Codec-specific error type.
35    type Error;
36
37    /// Encodes a value into a WebSocket text or binary message.
38    fn encode(&self, item: &T) -> Result<EncodedMessage, Self::Error>;
39
40    /// Decodes an incoming WebSocket message into a typed value or a control message.
41    fn decode(&self, msg: AggregatedMessage) -> Result<CodecMessage<T>, Self::Error>;
42}
43
44/// WebSocket messages that can be sent by a codec.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub enum EncodedMessage {
47    /// Text message.
48    Text(ByteString),
49
50    /// Binary message.
51    Binary(Bytes),
52}
53
54/// Typed message yielded by a [`CodecMessageStream`].
55#[derive(Debug)]
56pub enum CodecMessage<T> {
57    /// Successfully decoded application message.
58    Item(T),
59
60    /// Ping message.
61    Ping(Bytes),
62
63    /// Pong message.
64    Pong(Bytes),
65
66    /// Close message with optional reason.
67    Close(Option<CloseReason>),
68}
69
70/// Errors returned by [`CodecSession::send()`].
71#[derive(Debug)]
72pub enum CodecSendError<E> {
73    /// The session is closed.
74    Closed(Closed),
75
76    /// The codec failed to encode the outgoing value.
77    Codec(E),
78}
79
80impl<E> fmt::Display for CodecSendError<E>
81where
82    E: fmt::Display,
83{
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        match self {
86            CodecSendError::Closed(_) => f.write_str("session is closed"),
87            CodecSendError::Codec(err) => write!(f, "codec error: {err}"),
88        }
89    }
90}
91
92impl<E> std::error::Error for CodecSendError<E>
93where
94    E: std::error::Error + 'static,
95{
96    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
97        match self {
98            CodecSendError::Closed(err) => Some(err),
99            CodecSendError::Codec(err) => Some(err),
100        }
101    }
102}
103
104/// Errors returned by [`CodecMessageStream`].
105#[derive(Debug)]
106pub enum CodecStreamError<E> {
107    /// The WebSocket stream failed to decode frames.
108    Protocol(ProtocolError),
109
110    /// The codec failed to decode an application message.
111    Codec(E),
112}
113
114impl<E> fmt::Display for CodecStreamError<E>
115where
116    E: fmt::Display,
117{
118    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119        match self {
120            CodecStreamError::Protocol(err) => write!(f, "protocol error: {err}"),
121            CodecStreamError::Codec(err) => write!(f, "codec error: {err}"),
122        }
123    }
124}
125
126impl<E> std::error::Error for CodecStreamError<E>
127where
128    E: std::error::Error + 'static,
129{
130    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
131        match self {
132            CodecStreamError::Protocol(err) => Some(err),
133            CodecStreamError::Codec(err) => Some(err),
134        }
135    }
136}
137
138/// A [`Session`] wrapper that can send typed messages using a codec.
139pub struct CodecSession<T, C> {
140    session: Session,
141    codec: C,
142    _phantom: PhantomData<fn() -> T>,
143}
144
145impl<T, C> CodecSession<T, C>
146where
147    C: MessageCodec<T>,
148{
149    /// Constructs a new codec session wrapper.
150    pub fn new(session: Session, codec: C) -> Self {
151        Self {
152            session,
153            codec,
154            _phantom: PhantomData,
155        }
156    }
157
158    /// Returns a reference to the underlying session.
159    pub fn session(&self) -> &Session {
160        &self.session
161    }
162
163    /// Returns a mutable reference to the underlying session.
164    pub fn session_mut(&mut self) -> &mut Session {
165        &mut self.session
166    }
167
168    /// Returns a reference to the underlying codec.
169    pub fn codec(&self) -> &C {
170        &self.codec
171    }
172
173    /// Returns a mutable reference to the underlying codec.
174    pub fn codec_mut(&mut self) -> &mut C {
175        &mut self.codec
176    }
177
178    /// Consumes this wrapper and returns the underlying [`Session`].
179    pub fn into_inner(self) -> Session {
180        self.session
181    }
182
183    /// Encodes `item` and sends it as a WebSocket message.
184    ///
185    /// This method only sends text or binary frames. Use the underlying [`Session`] for control
186    /// frames (ping/pong/close).
187    pub async fn send(&mut self, item: &T) -> Result<(), CodecSendError<C::Error>> {
188        let msg = self.codec.encode(item).map_err(CodecSendError::Codec)?;
189
190        match msg {
191            EncodedMessage::Text(text) => self
192                .session
193                .text(text)
194                .await
195                .map_err(CodecSendError::Closed),
196
197            EncodedMessage::Binary(bin) => self
198                .session
199                .binary(bin)
200                .await
201                .map_err(CodecSendError::Closed),
202        }
203    }
204
205    /// Sends a close frame, consuming the codec session.
206    pub async fn close(self, reason: Option<CloseReason>) -> Result<(), Closed> {
207        self.session.close(reason).await
208    }
209}
210
211/// A [`Stream`] of typed messages decoded from an [`AggregatedMessageStream`].
212pub struct CodecMessageStream<T, C> {
213    stream: AggregatedMessageStream,
214    codec: C,
215    _phantom: PhantomData<fn() -> T>,
216}
217
218impl<T, C> CodecMessageStream<T, C>
219where
220    C: MessageCodec<T>,
221{
222    /// Constructs a new codec message stream wrapper.
223    pub fn new(stream: AggregatedMessageStream, codec: C) -> Self {
224        Self {
225            stream,
226            codec,
227            _phantom: PhantomData,
228        }
229    }
230
231    /// Returns a reference to the underlying codec.
232    pub fn codec(&self) -> &C {
233        &self.codec
234    }
235
236    /// Returns a mutable reference to the underlying codec.
237    pub fn codec_mut(&mut self) -> &mut C {
238        &mut self.codec
239    }
240
241    /// Consumes this wrapper and returns the underlying stream.
242    pub fn into_inner(self) -> AggregatedMessageStream {
243        self.stream
244    }
245
246    /// Waits for the next item from the codec message stream.
247    ///
248    /// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
249    #[must_use]
250    pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
251        // `CodecMessageStream` is not necessarily `Unpin` (depends on codec type) but it is safe
252        // to pin it for the duration of this future since it is borrowed for the await.
253        poll_fn(|cx| unsafe { Pin::new_unchecked(&mut *self) }.poll_next(cx)).await
254    }
255}
256
257impl<T, C> Stream for CodecMessageStream<T, C>
258where
259    C: MessageCodec<T>,
260{
261    type Item = Result<CodecMessage<T>, CodecStreamError<C::Error>>;
262
263    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
264        // SAFETY: We will not move out of any fields. `AggregatedMessageStream` is polled by
265        // pinning its field, and the codec is only accessed by reference.
266        let this = unsafe { self.get_unchecked_mut() };
267
268        let msg = match Pin::new(&mut this.stream).poll_next(cx) {
269            Poll::Ready(Some(Ok(msg))) => msg,
270            Poll::Ready(Some(Err(err))) => {
271                return Poll::Ready(Some(Err(CodecStreamError::Protocol(err))));
272            }
273            Poll::Ready(None) => return Poll::Ready(None),
274            Poll::Pending => return Poll::Pending,
275        };
276
277        match this.codec.decode(msg) {
278            Ok(item) => Poll::Ready(Some(Ok(item))),
279            Err(err) => Poll::Ready(Some(Err(CodecStreamError::Codec(err)))),
280        }
281    }
282}
283
284impl MessageStream {
285    /// Wraps this message stream with `codec`, aggregating continuation frames before decoding.
286    #[must_use]
287    pub fn with_codec<T, C>(self, codec: C) -> CodecMessageStream<T, C>
288    where
289        C: MessageCodec<T>,
290    {
291        self.aggregate_continuations().with_codec(codec)
292    }
293}
294
295impl AggregatedMessageStream {
296    /// Wraps this aggregated message stream with `codec`.
297    #[must_use]
298    pub fn with_codec<T, C>(self, codec: C) -> CodecMessageStream<T, C>
299    where
300        C: MessageCodec<T>,
301    {
302        CodecMessageStream::new(self, codec)
303    }
304}
305
306impl Session {
307    /// Wraps this session with `codec` so it can send typed messages.
308    #[must_use]
309    pub fn with_codec<T, C>(self, codec: C) -> CodecSession<T, C>
310    where
311        C: MessageCodec<T>,
312    {
313        CodecSession::new(self, codec)
314    }
315}
316
317#[cfg(all(test, feature = "serde-json"))]
318mod tests {
319    use actix_http::ws::Message;
320    use actix_web::web::Bytes;
321    use serde::{Deserialize, Serialize};
322
323    use super::{CodecMessage, EncodedMessage};
324    use crate::{codec::CodecStreamError, stream::tests::payload_pair, Session};
325
326    #[derive(Debug, Serialize, Deserialize, PartialEq)]
327    struct TestMsg {
328        a: u32,
329    }
330
331    #[tokio::test]
332    async fn json_session_encodes_text_frames_by_default() {
333        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
334        let session = Session::new(tx);
335
336        let mut session = session.with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
337        session.send(&TestMsg { a: 123 }).await.unwrap();
338
339        match rx.recv().await.unwrap() {
340            Message::Text(text) => {
341                let s: &str = text.as_ref();
342                assert_eq!(s, r#"{"a":123}"#);
343            }
344            other => panic!("expected text frame, got: {other:?}"),
345        }
346    }
347
348    #[tokio::test]
349    async fn json_session_can_encode_binary_frames() {
350        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
351        let session = Session::new(tx);
352
353        let mut session =
354            session.with_codec::<TestMsg, _>(crate::codec::JsonCodec::default().binary());
355        session.send(&TestMsg { a: 123 }).await.unwrap();
356
357        match rx.recv().await.unwrap() {
358            Message::Binary(bytes) => assert_eq!(bytes, Bytes::from_static(br#"{"a":123}"#)),
359            other => panic!("expected binary frame, got: {other:?}"),
360        }
361    }
362
363    #[tokio::test]
364    async fn json_stream_decodes_text_and_binary_frames() {
365        let (mut tx, rx) = payload_pair(8);
366        let mut stream = crate::MessageStream::new(rx)
367            .with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
368
369        tx.send(Message::Text(r#"{"a":1}"#.into())).await;
370        match stream.recv().await.unwrap().unwrap() {
371            CodecMessage::Item(TestMsg { a }) => assert_eq!(a, 1),
372            other => panic!("expected decoded item, got: {other:?}"),
373        }
374
375        tx.send(Message::Binary(Bytes::from_static(br#"{"a":2}"#)))
376            .await;
377        match stream.recv().await.unwrap().unwrap() {
378            CodecMessage::Item(TestMsg { a }) => assert_eq!(a, 2),
379            other => panic!("expected decoded item, got: {other:?}"),
380        }
381    }
382
383    #[tokio::test]
384    async fn json_stream_passes_through_control_frames() {
385        let (mut tx, rx) = payload_pair(8);
386        let mut stream = crate::MessageStream::new(rx)
387            .with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
388
389        tx.send(Message::Ping(Bytes::from_static(b"hi"))).await;
390        match stream.recv().await.unwrap().unwrap() {
391            CodecMessage::Ping(bytes) => assert_eq!(bytes, Bytes::from_static(b"hi")),
392            other => panic!("expected ping, got: {other:?}"),
393        }
394    }
395
396    #[tokio::test]
397    async fn json_stream_yields_codec_error_on_invalid_payload_and_continues() {
398        let (mut tx, rx) = payload_pair(8);
399        let mut stream = crate::MessageStream::new(rx)
400            .with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
401
402        tx.send(Message::Text("not json".into())).await;
403        match stream.recv().await.unwrap() {
404            Err(CodecStreamError::Codec(_)) => {}
405            other => panic!("expected codec error, got: {other:?}"),
406        }
407
408        tx.send(Message::Text(r#"{"a":9}"#.into())).await;
409        match stream.recv().await.unwrap().unwrap() {
410            CodecMessage::Item(TestMsg { a }) => assert_eq!(a, 9),
411            other => panic!("expected decoded item, got: {other:?}"),
412        }
413    }
414
415    #[test]
416    fn encoded_message_is_lightweight() {
417        let _ = EncodedMessage::Text("hello".into());
418        let _ = EncodedMessage::Binary(Bytes::from_static(b"hello"));
419    }
420}