1use std::{
2 fmt,
3 future::poll_fn,
4 pin::Pin,
5 sync::{
6 atomic::{AtomicBool, Ordering},
7 Arc,
8 },
9 task::{Context, Poll},
10};
11
12use actix_http::ws::{CloseReason, Item, Message};
13use actix_web::web::Bytes;
14use bytestring::ByteString;
15use futures_sink::Sink;
16use tokio::sync::mpsc::Sender;
17use tokio_util::sync::PollSender;
18
19const MAX_CONTROL_PAYLOAD_BYTES: usize = 125;
23const MAX_CLOSE_REASON_BYTES: usize = MAX_CONTROL_PAYLOAD_BYTES - 2;
24
25#[derive(Clone)]
30pub struct Session {
31 inner: Option<PollSender<Message>>,
32 closed: Arc<AtomicBool>,
33}
34
35#[derive(Debug)]
37pub struct Closed;
38
39impl fmt::Display for Closed {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 f.write_str("Session is closed")
42 }
43}
44
45impl std::error::Error for Closed {}
46
47impl Session {
48 pub(super) fn new(inner: Sender<Message>) -> Self {
49 Session {
50 inner: Some(PollSender::new(inner)),
51 closed: Arc::new(AtomicBool::new(false)),
52 }
53 }
54
55 fn pre_check(&mut self) {
56 if self.closed.load(Ordering::Relaxed) {
57 self.inner.take();
58 }
59 }
60
61 async fn send_message_inner(&mut self, msg: Message) -> Result<(), Closed> {
62 if let Some(inner) = self.inner.as_mut() {
63 poll_fn(|cx| Pin::new(&mut *inner).poll_ready(cx))
64 .await
65 .map_err(|_| Closed)?;
66 Pin::new(&mut *inner).start_send(msg).map_err(|_| Closed)?;
67 poll_fn(|cx| Pin::new(&mut *inner).poll_flush(cx))
68 .await
69 .map_err(|_| Closed)
70 } else {
71 Err(Closed)
72 }
73 }
74
75 async fn send_message(&mut self, msg: Message) -> Result<(), Closed> {
76 self.pre_check();
77 self.send_message_inner(msg).await
78 }
79
80 pub async fn text(&mut self, msg: impl Into<ByteString>) -> Result<(), Closed> {
91 self.send_message(Message::Text(msg.into())).await
92 }
93
94 pub async fn binary(&mut self, msg: impl Into<Bytes>) -> Result<(), Closed> {
105 self.send_message(Message::Binary(msg.into())).await
106 }
107
108 pub async fn ping(&mut self, msg: &[u8]) -> Result<(), Closed> {
125 let msg = if msg.len() > MAX_CONTROL_PAYLOAD_BYTES {
126 &msg[..MAX_CONTROL_PAYLOAD_BYTES]
127 } else {
128 msg
129 };
130 self.send_message(Message::Ping(Bytes::copy_from_slice(msg)))
131 .await
132 }
133
134 pub async fn pong(&mut self, msg: &[u8]) -> Result<(), Closed> {
150 let msg = if msg.len() > MAX_CONTROL_PAYLOAD_BYTES {
151 &msg[..MAX_CONTROL_PAYLOAD_BYTES]
152 } else {
153 msg
154 };
155 self.send_message(Message::Pong(Bytes::copy_from_slice(msg)))
156 .await
157 }
158
159 pub async fn continuation(&mut self, msg: Item) -> Result<(), Closed> {
180 self.send_message(Message::Continuation(msg)).await
181 }
182
183 pub async fn close(mut self, reason: Option<CloseReason>) -> Result<(), Closed> {
197 self.pre_check();
198
199 let mut reason = reason;
200
201 if let Some(reason) = reason.as_mut() {
202 if let Some(desc) = reason.description.as_mut() {
203 if desc.len() > MAX_CLOSE_REASON_BYTES {
204 let mut end = MAX_CLOSE_REASON_BYTES;
205 while end > 0 && !desc.is_char_boundary(end) {
206 end -= 1;
207 }
208 desc.truncate(end);
209 }
210 }
211 }
212
213 if self.inner.is_some() {
214 self.closed.store(true, Ordering::Relaxed);
215 self.send_message_inner(Message::Close(reason)).await
216 } else {
217 Err(Closed)
218 }
219 }
220}
221
222impl Sink<Message> for Session {
223 type Error = Closed;
224
225 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226 self.pre_check();
227 if let Some(inner) = self.inner.as_mut() {
228 match Pin::new(inner).poll_ready(cx) {
229 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
230 Poll::Ready(Err(_)) => Poll::Ready(Err(Closed)),
231 Poll::Pending => Poll::Pending,
232 }
233 } else {
234 Poll::Ready(Err(Closed))
235 }
236 }
237
238 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
239 self.pre_check();
240 if let Some(inner) = self.inner.as_mut() {
241 Pin::new(inner).start_send(item).map_err(|_| Closed)
242 } else {
243 Err(Closed)
244 }
245 }
246
247 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248 self.pre_check();
249 if let Some(inner) = self.inner.as_mut() {
250 match Pin::new(inner).poll_flush(cx) {
251 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
252 Poll::Ready(Err(_)) => Poll::Ready(Err(Closed)),
253 Poll::Pending => Poll::Pending,
254 }
255 } else {
256 Poll::Ready(Err(Closed))
257 }
258 }
259
260 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
261 self.closed.store(true, Ordering::Relaxed);
262 if let Some(inner) = self.inner.as_mut() {
263 match Pin::new(inner).poll_close(cx) {
264 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
265 Poll::Ready(Err(_)) => Poll::Ready(Err(Closed)),
266 Poll::Pending => Poll::Pending,
267 }
268 } else {
269 Poll::Ready(Ok(()))
270 }
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use actix_http::ws::Message;
277 use futures_util::SinkExt;
278
279 use super::Session;
280
281 #[tokio::test]
282 async fn session_implements_sink() {
283 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
284 let mut session = Session::new(tx);
285
286 session
287 .send(Message::Text("hello from sink".into()))
288 .await
289 .unwrap();
290
291 match rx.recv().await {
292 Some(Message::Text(msg)) => {
293 let text: &str = msg.as_ref();
294 assert_eq!(text, "hello from sink");
295 }
296 other => panic!("expected text frame, got: {other:?}"),
297 }
298 }
299
300 #[tokio::test]
301 async fn sink_close_closes_all_clones() {
302 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
303 let mut session = Session::new(tx);
304 let mut clone = session.clone();
305
306 SinkExt::close(&mut session).await.unwrap();
307 assert!(clone.text("should fail").await.is_err());
308
309 assert!(rx.recv().await.is_none());
310 }
311
312 #[tokio::test]
313 async fn close_sends_close_frame_and_closes_all_clones() {
314 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
315 let session = Session::new(tx);
316 let mut clone = session.clone();
317
318 session.close(None).await.unwrap();
319 assert!(clone.text("should fail").await.is_err());
320
321 match rx.recv().await {
322 Some(Message::Close(None)) => {}
323 other => panic!("expected close frame, got: {other:?}"),
324 }
325 }
326}