1use std::{
11 fmt,
12 future::poll_fn,
13 marker::PhantomData,
14 pin::Pin,
15 task::{Context, Poll},
16};
17
18use actix_http::ws::{CloseReason, ProtocolError};
19use actix_web::web::Bytes;
20use bytestring::ByteString;
21use futures_core::Stream;
22
23use crate::{AggregatedMessage, AggregatedMessageStream, Closed, MessageStream, Session};
24
25#[cfg(feature = "serde-json")]
26mod json;
27
28#[cfg(feature = "serde-json")]
29#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
30pub use self::json::JsonCodec;
31
32pub trait MessageCodec<T> {
34 type Error;
36
37 fn encode(&self, item: &T) -> Result<EncodedMessage, Self::Error>;
39
40 fn decode(&self, msg: AggregatedMessage) -> Result<CodecMessage<T>, Self::Error>;
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
46pub enum EncodedMessage {
47 Text(ByteString),
49
50 Binary(Bytes),
52}
53
54#[derive(Debug)]
56pub enum CodecMessage<T> {
57 Item(T),
59
60 Ping(Bytes),
62
63 Pong(Bytes),
65
66 Close(Option<CloseReason>),
68}
69
70#[derive(Debug)]
72pub enum CodecSendError<E> {
73 Closed(Closed),
75
76 Codec(E),
78}
79
80impl<E> fmt::Display for CodecSendError<E>
81where
82 E: fmt::Display,
83{
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 match self {
86 CodecSendError::Closed(_) => f.write_str("session is closed"),
87 CodecSendError::Codec(err) => write!(f, "codec error: {err}"),
88 }
89 }
90}
91
92impl<E> std::error::Error for CodecSendError<E>
93where
94 E: std::error::Error + 'static,
95{
96 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
97 match self {
98 CodecSendError::Closed(err) => Some(err),
99 CodecSendError::Codec(err) => Some(err),
100 }
101 }
102}
103
104#[derive(Debug)]
106pub enum CodecStreamError<E> {
107 Protocol(ProtocolError),
109
110 Codec(E),
112}
113
114impl<E> fmt::Display for CodecStreamError<E>
115where
116 E: fmt::Display,
117{
118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119 match self {
120 CodecStreamError::Protocol(err) => write!(f, "protocol error: {err}"),
121 CodecStreamError::Codec(err) => write!(f, "codec error: {err}"),
122 }
123 }
124}
125
126impl<E> std::error::Error for CodecStreamError<E>
127where
128 E: std::error::Error + 'static,
129{
130 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
131 match self {
132 CodecStreamError::Protocol(err) => Some(err),
133 CodecStreamError::Codec(err) => Some(err),
134 }
135 }
136}
137
138pub struct CodecSession<T, C> {
140 session: Session,
141 codec: C,
142 _phantom: PhantomData<fn() -> T>,
143}
144
145impl<T, C> CodecSession<T, C>
146where
147 C: MessageCodec<T>,
148{
149 pub fn new(session: Session, codec: C) -> Self {
151 Self {
152 session,
153 codec,
154 _phantom: PhantomData,
155 }
156 }
157
158 pub fn session(&self) -> &Session {
160 &self.session
161 }
162
163 pub fn session_mut(&mut self) -> &mut Session {
165 &mut self.session
166 }
167
168 pub fn codec(&self) -> &C {
170 &self.codec
171 }
172
173 pub fn codec_mut(&mut self) -> &mut C {
175 &mut self.codec
176 }
177
178 pub fn into_inner(self) -> Session {
180 self.session
181 }
182
183 pub async fn send(&mut self, item: &T) -> Result<(), CodecSendError<C::Error>> {
188 let msg = self.codec.encode(item).map_err(CodecSendError::Codec)?;
189
190 match msg {
191 EncodedMessage::Text(text) => self
192 .session
193 .text(text)
194 .await
195 .map_err(CodecSendError::Closed),
196
197 EncodedMessage::Binary(bin) => self
198 .session
199 .binary(bin)
200 .await
201 .map_err(CodecSendError::Closed),
202 }
203 }
204
205 pub async fn close(self, reason: Option<CloseReason>) -> Result<(), Closed> {
207 self.session.close(reason).await
208 }
209}
210
211pub struct CodecMessageStream<T, C> {
213 stream: AggregatedMessageStream,
214 codec: C,
215 _phantom: PhantomData<fn() -> T>,
216}
217
218impl<T, C> CodecMessageStream<T, C>
219where
220 C: MessageCodec<T>,
221{
222 pub fn new(stream: AggregatedMessageStream, codec: C) -> Self {
224 Self {
225 stream,
226 codec,
227 _phantom: PhantomData,
228 }
229 }
230
231 pub fn codec(&self) -> &C {
233 &self.codec
234 }
235
236 pub fn codec_mut(&mut self) -> &mut C {
238 &mut self.codec
239 }
240
241 pub fn into_inner(self) -> AggregatedMessageStream {
243 self.stream
244 }
245
246 #[must_use]
250 pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
251 poll_fn(|cx| unsafe { Pin::new_unchecked(&mut *self) }.poll_next(cx)).await
254 }
255}
256
257impl<T, C> Stream for CodecMessageStream<T, C>
258where
259 C: MessageCodec<T>,
260{
261 type Item = Result<CodecMessage<T>, CodecStreamError<C::Error>>;
262
263 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
264 let this = unsafe { self.get_unchecked_mut() };
267
268 let msg = match Pin::new(&mut this.stream).poll_next(cx) {
269 Poll::Ready(Some(Ok(msg))) => msg,
270 Poll::Ready(Some(Err(err))) => {
271 return Poll::Ready(Some(Err(CodecStreamError::Protocol(err))));
272 }
273 Poll::Ready(None) => return Poll::Ready(None),
274 Poll::Pending => return Poll::Pending,
275 };
276
277 match this.codec.decode(msg) {
278 Ok(item) => Poll::Ready(Some(Ok(item))),
279 Err(err) => Poll::Ready(Some(Err(CodecStreamError::Codec(err)))),
280 }
281 }
282}
283
284impl MessageStream {
285 #[must_use]
287 pub fn with_codec<T, C>(self, codec: C) -> CodecMessageStream<T, C>
288 where
289 C: MessageCodec<T>,
290 {
291 self.aggregate_continuations().with_codec(codec)
292 }
293}
294
295impl AggregatedMessageStream {
296 #[must_use]
298 pub fn with_codec<T, C>(self, codec: C) -> CodecMessageStream<T, C>
299 where
300 C: MessageCodec<T>,
301 {
302 CodecMessageStream::new(self, codec)
303 }
304}
305
306impl Session {
307 #[must_use]
309 pub fn with_codec<T, C>(self, codec: C) -> CodecSession<T, C>
310 where
311 C: MessageCodec<T>,
312 {
313 CodecSession::new(self, codec)
314 }
315}
316
317#[cfg(all(test, feature = "serde-json"))]
318mod tests {
319 use actix_http::ws::Message;
320 use actix_web::web::Bytes;
321 use serde::{Deserialize, Serialize};
322
323 use super::{CodecMessage, EncodedMessage};
324 use crate::{codec::CodecStreamError, stream::tests::payload_pair, Session};
325
326 #[derive(Debug, Serialize, Deserialize, PartialEq)]
327 struct TestMsg {
328 a: u32,
329 }
330
331 #[tokio::test]
332 async fn json_session_encodes_text_frames_by_default() {
333 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
334 let session = Session::new(tx);
335
336 let mut session = session.with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
337 session.send(&TestMsg { a: 123 }).await.unwrap();
338
339 match rx.recv().await.unwrap() {
340 Message::Text(text) => {
341 let s: &str = text.as_ref();
342 assert_eq!(s, r#"{"a":123}"#);
343 }
344 other => panic!("expected text frame, got: {other:?}"),
345 }
346 }
347
348 #[tokio::test]
349 async fn json_session_can_encode_binary_frames() {
350 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
351 let session = Session::new(tx);
352
353 let mut session =
354 session.with_codec::<TestMsg, _>(crate::codec::JsonCodec::default().binary());
355 session.send(&TestMsg { a: 123 }).await.unwrap();
356
357 match rx.recv().await.unwrap() {
358 Message::Binary(bytes) => assert_eq!(bytes, Bytes::from_static(br#"{"a":123}"#)),
359 other => panic!("expected binary frame, got: {other:?}"),
360 }
361 }
362
363 #[tokio::test]
364 async fn json_stream_decodes_text_and_binary_frames() {
365 let (mut tx, rx) = payload_pair(8);
366 let mut stream = crate::MessageStream::new(rx)
367 .with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
368
369 tx.send(Message::Text(r#"{"a":1}"#.into())).await;
370 match stream.recv().await.unwrap().unwrap() {
371 CodecMessage::Item(TestMsg { a }) => assert_eq!(a, 1),
372 other => panic!("expected decoded item, got: {other:?}"),
373 }
374
375 tx.send(Message::Binary(Bytes::from_static(br#"{"a":2}"#)))
376 .await;
377 match stream.recv().await.unwrap().unwrap() {
378 CodecMessage::Item(TestMsg { a }) => assert_eq!(a, 2),
379 other => panic!("expected decoded item, got: {other:?}"),
380 }
381 }
382
383 #[tokio::test]
384 async fn json_stream_passes_through_control_frames() {
385 let (mut tx, rx) = payload_pair(8);
386 let mut stream = crate::MessageStream::new(rx)
387 .with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
388
389 tx.send(Message::Ping(Bytes::from_static(b"hi"))).await;
390 match stream.recv().await.unwrap().unwrap() {
391 CodecMessage::Ping(bytes) => assert_eq!(bytes, Bytes::from_static(b"hi")),
392 other => panic!("expected ping, got: {other:?}"),
393 }
394 }
395
396 #[tokio::test]
397 async fn json_stream_yields_codec_error_on_invalid_payload_and_continues() {
398 let (mut tx, rx) = payload_pair(8);
399 let mut stream = crate::MessageStream::new(rx)
400 .with_codec::<TestMsg, _>(crate::codec::JsonCodec::default());
401
402 tx.send(Message::Text("not json".into())).await;
403 match stream.recv().await.unwrap() {
404 Err(CodecStreamError::Codec(_)) => {}
405 other => panic!("expected codec error, got: {other:?}"),
406 }
407
408 tx.send(Message::Text(r#"{"a":9}"#.into())).await;
409 match stream.recv().await.unwrap().unwrap() {
410 CodecMessage::Item(TestMsg { a }) => assert_eq!(a, 9),
411 other => panic!("expected decoded item, got: {other:?}"),
412 }
413 }
414
415 #[test]
416 fn encoded_message_is_lightweight() {
417 let _ = EncodedMessage::Text("hello".into());
418 let _ = EncodedMessage::Binary(Bytes::from_static(b"hello"));
419 }
420}