Skip to main content

actix_ws/
session.rs

1use std::{
2    fmt,
3    future::poll_fn,
4    pin::Pin,
5    sync::{
6        atomic::{AtomicBool, Ordering},
7        Arc,
8    },
9    task::{Context, Poll},
10};
11
12use actix_http::ws::{CloseReason, Item, Message};
13use actix_web::web::Bytes;
14use bytestring::ByteString;
15use futures_sink::Sink;
16use tokio::sync::mpsc::Sender;
17use tokio_util::sync::PollSender;
18
19// RFC 6455: Control frames MUST have payload length <= 125 bytes.
20// Close payload is: 2-byte close code + optional UTF-8 reason, therefore the reason is <= 123 bytes.
21// ref. https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5
22const MAX_CONTROL_PAYLOAD_BYTES: usize = 125;
23const MAX_CLOSE_REASON_BYTES: usize = MAX_CONTROL_PAYLOAD_BYTES - 2;
24
25/// A handle into the websocket session.
26///
27/// This type can be used to send messages into the WebSocket.
28/// It also implements [`Sink<Message>`](futures_sink::Sink) for integration with sink-based APIs.
29#[derive(Clone)]
30pub struct Session {
31    inner: Option<PollSender<Message>>,
32    closed: Arc<AtomicBool>,
33}
34
35/// The error representing a closed websocket session
36#[derive(Debug)]
37pub struct Closed;
38
39impl fmt::Display for Closed {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        f.write_str("Session is closed")
42    }
43}
44
45impl std::error::Error for Closed {}
46
47impl Session {
48    pub(super) fn new(inner: Sender<Message>) -> Self {
49        Session {
50            inner: Some(PollSender::new(inner)),
51            closed: Arc::new(AtomicBool::new(false)),
52        }
53    }
54
55    fn pre_check(&mut self) {
56        if self.closed.load(Ordering::Relaxed) {
57            self.inner.take();
58        }
59    }
60
61    async fn send_message_inner(&mut self, msg: Message) -> Result<(), Closed> {
62        if let Some(inner) = self.inner.as_mut() {
63            poll_fn(|cx| Pin::new(&mut *inner).poll_ready(cx))
64                .await
65                .map_err(|_| Closed)?;
66            Pin::new(&mut *inner).start_send(msg).map_err(|_| Closed)?;
67            poll_fn(|cx| Pin::new(&mut *inner).poll_flush(cx))
68                .await
69                .map_err(|_| Closed)
70        } else {
71            Err(Closed)
72        }
73    }
74
75    async fn send_message(&mut self, msg: Message) -> Result<(), Closed> {
76        self.pre_check();
77        self.send_message_inner(msg).await
78    }
79
80    /// Sends text into the WebSocket.
81    ///
82    /// ```no_run
83    /// # use actix_ws::Session;
84    /// # async fn test(mut session: Session) {
85    /// if session.text("Some text").await.is_err() {
86    ///     // session closed
87    /// }
88    /// # }
89    /// ```
90    pub async fn text(&mut self, msg: impl Into<ByteString>) -> Result<(), Closed> {
91        self.send_message(Message::Text(msg.into())).await
92    }
93
94    /// Sends raw bytes into the WebSocket.
95    ///
96    /// ```no_run
97    /// # use actix_ws::Session;
98    /// # async fn test(mut session: Session) {
99    /// if session.binary(&b"some bytes"[..]).await.is_err() {
100    ///     // session closed
101    /// }
102    /// # }
103    /// ```
104    pub async fn binary(&mut self, msg: impl Into<Bytes>) -> Result<(), Closed> {
105        self.send_message(Message::Binary(msg.into())).await
106    }
107
108    /// Pings the client.
109    ///
110    /// For many applications, it will be important to send regular pings to keep track of if the
111    /// client has disconnected
112    ///
113    /// Ping payloads longer than 125 bytes are truncated to comply with RFC 6455 control frame
114    /// size limits.
115    ///
116    /// ```no_run
117    /// # use actix_ws::Session;
118    /// # async fn test(mut session: Session) {
119    /// if session.ping(b"").await.is_err() {
120    ///     // session is closed
121    /// }
122    /// # }
123    /// ```
124    pub async fn ping(&mut self, msg: &[u8]) -> Result<(), Closed> {
125        let msg = if msg.len() > MAX_CONTROL_PAYLOAD_BYTES {
126            &msg[..MAX_CONTROL_PAYLOAD_BYTES]
127        } else {
128            msg
129        };
130        self.send_message(Message::Ping(Bytes::copy_from_slice(msg)))
131            .await
132    }
133
134    /// Pongs the client.
135    ///
136    /// Pong payloads longer than 125 bytes are truncated to comply with RFC 6455 control frame
137    /// size limits.
138    ///
139    /// ```no_run
140    /// # use actix_ws::{Message, Session};
141    /// # async fn test(mut session: Session, msg: Message) {
142    /// match msg {
143    ///     Message::Ping(bytes) => {
144    ///         let _ = session.pong(&bytes).await;
145    ///     }
146    ///     _ => (),
147    /// }
148    /// # }
149    pub async fn pong(&mut self, msg: &[u8]) -> Result<(), Closed> {
150        let msg = if msg.len() > MAX_CONTROL_PAYLOAD_BYTES {
151            &msg[..MAX_CONTROL_PAYLOAD_BYTES]
152        } else {
153            msg
154        };
155        self.send_message(Message::Pong(Bytes::copy_from_slice(msg)))
156            .await
157    }
158
159    /// Manually controls sending continuations.
160    ///
161    /// Be wary of this method. Continuations represent multiple frames that, when combined, are
162    /// presented as a single message. They are useful when the entire contents of a message are
163    /// not available all at once. However, continuations MUST NOT be interrupted by other Text or
164    /// Binary messages. Control messages such as Ping, Pong, or Close are allowed to interrupt a
165    /// continuation.
166    ///
167    /// Continuations must be initialized with a First variant, and must be terminated by a Last
168    /// variant, with only Continue variants sent in between.
169    ///
170    /// ```no_run
171    /// # use actix_ws::{Item, Session};
172    /// # async fn test(mut session: Session) -> Result<(), Box<dyn std::error::Error>> {
173    /// session.continuation(Item::FirstText("Hello".into())).await?;
174    /// session.continuation(Item::Continue(b", World"[..].into())).await?;
175    /// session.continuation(Item::Last(b"!"[..].into())).await?;
176    /// # Ok(())
177    /// # }
178    /// ```
179    pub async fn continuation(&mut self, msg: Item) -> Result<(), Closed> {
180        self.send_message(Message::Continuation(msg)).await
181    }
182
183    /// Sends a close message, and consumes the session.
184    ///
185    /// All clones will return `Err(Closed)` if used after this call.
186    ///
187    /// Close reason descriptions longer than 123 bytes are truncated to comply with RFC 6455
188    /// control frame size limits.
189    ///
190    /// ```no_run
191    /// # use actix_ws::{Closed, Session};
192    /// # async fn test(mut session: Session) -> Result<(), Closed> {
193    /// session.close(None).await
194    /// # }
195    /// ```
196    pub async fn close(mut self, reason: Option<CloseReason>) -> Result<(), Closed> {
197        self.pre_check();
198
199        let mut reason = reason;
200
201        if let Some(reason) = reason.as_mut() {
202            if let Some(desc) = reason.description.as_mut() {
203                if desc.len() > MAX_CLOSE_REASON_BYTES {
204                    let mut end = MAX_CLOSE_REASON_BYTES;
205                    while end > 0 && !desc.is_char_boundary(end) {
206                        end -= 1;
207                    }
208                    desc.truncate(end);
209                }
210            }
211        }
212
213        if self.inner.is_some() {
214            self.closed.store(true, Ordering::Relaxed);
215            self.send_message_inner(Message::Close(reason)).await
216        } else {
217            Err(Closed)
218        }
219    }
220}
221
222impl Sink<Message> for Session {
223    type Error = Closed;
224
225    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226        self.pre_check();
227        if let Some(inner) = self.inner.as_mut() {
228            match Pin::new(inner).poll_ready(cx) {
229                Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
230                Poll::Ready(Err(_)) => Poll::Ready(Err(Closed)),
231                Poll::Pending => Poll::Pending,
232            }
233        } else {
234            Poll::Ready(Err(Closed))
235        }
236    }
237
238    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
239        self.pre_check();
240        if let Some(inner) = self.inner.as_mut() {
241            Pin::new(inner).start_send(item).map_err(|_| Closed)
242        } else {
243            Err(Closed)
244        }
245    }
246
247    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248        self.pre_check();
249        if let Some(inner) = self.inner.as_mut() {
250            match Pin::new(inner).poll_flush(cx) {
251                Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
252                Poll::Ready(Err(_)) => Poll::Ready(Err(Closed)),
253                Poll::Pending => Poll::Pending,
254            }
255        } else {
256            Poll::Ready(Err(Closed))
257        }
258    }
259
260    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
261        self.closed.store(true, Ordering::Relaxed);
262        if let Some(inner) = self.inner.as_mut() {
263            match Pin::new(inner).poll_close(cx) {
264                Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
265                Poll::Ready(Err(_)) => Poll::Ready(Err(Closed)),
266                Poll::Pending => Poll::Pending,
267            }
268        } else {
269            Poll::Ready(Ok(()))
270        }
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use actix_http::ws::Message;
277    use futures_util::SinkExt;
278
279    use super::Session;
280
281    #[tokio::test]
282    async fn session_implements_sink() {
283        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
284        let mut session = Session::new(tx);
285
286        session
287            .send(Message::Text("hello from sink".into()))
288            .await
289            .unwrap();
290
291        match rx.recv().await {
292            Some(Message::Text(msg)) => {
293                let text: &str = msg.as_ref();
294                assert_eq!(text, "hello from sink");
295            }
296            other => panic!("expected text frame, got: {other:?}"),
297        }
298    }
299
300    #[tokio::test]
301    async fn sink_close_closes_all_clones() {
302        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
303        let mut session = Session::new(tx);
304        let mut clone = session.clone();
305
306        SinkExt::close(&mut session).await.unwrap();
307        assert!(clone.text("should fail").await.is_err());
308
309        assert!(rx.recv().await.is_none());
310    }
311
312    #[tokio::test]
313    async fn close_sends_close_frame_and_closes_all_clones() {
314        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
315        let session = Session::new(tx);
316        let mut clone = session.clone();
317
318        session.close(None).await.unwrap();
319        assert!(clone.text("should fail").await.is_err());
320
321        match rx.recv().await {
322            Some(Message::Close(None)) => {}
323            other => panic!("expected close frame, got: {other:?}"),
324        }
325    }
326}