rocket_ws_community/
websocket.rs

1use std::io;
2
3use rocket::data::{IoHandler, IoStream};
4use rocket::futures::{self, future::BoxFuture, stream::SplitStream, SinkExt, StreamExt};
5use rocket::http::Status;
6use rocket::request::{FromRequest, Outcome, Request};
7use rocket::response::{self, Responder, Response};
8
9use crate::result::{Error, Result};
10use crate::stream::DuplexStream;
11use crate::{Config, Message};
12
13/// A request guard identifying WebSocket requests. Converts into a [`Channel`]
14/// or [`MessageStream`].
15///
16/// For example usage, see the [crate docs](crate#usage).
17///
18/// ## Details
19///
20/// This is the entrypoint to the library. Every WebSocket response _must_
21/// initiate via the `WebSocket` request guard. The guard identifies valid
22/// WebSocket connection requests and, if the request is valid, succeeds to be
23/// converted into a streaming WebSocket response via
24/// [`Stream!`](crate::Stream!), [`WebSocket::channel()`], or
25/// [`WebSocket::stream()`]. The connection can be configured via
26/// [`WebSocket::config()`]; see [`Config`] for details on configuring a
27/// connection.
28///
29/// ### Forwarding
30///
31/// If the incoming request is not a valid WebSocket request, the guard
32/// forwards with a status of `BadRequest`. The guard never fails.
33pub struct WebSocket {
34    config: Config,
35    key: String,
36}
37
38impl WebSocket {
39    /// Change the default connection configuration to `config`.
40    ///
41    /// # Example
42    ///
43    /// ```rust
44    /// # extern crate rocket_ws_community as rocket_ws;
45    /// # use rocket::get;
46    /// # use rocket_ws as ws;
47    /// #
48    /// #[get("/echo")]
49    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
50    ///     let ws = ws.config(ws::Config {
51    ///         max_send_queue: Some(5),
52    ///         ..Default::default()
53    ///     });
54    ///
55    ///     ws::Stream! { ws =>
56    ///         for await message in ws {
57    ///             yield message?;
58    ///         }
59    ///     }
60    /// }
61    /// ```
62    pub fn config(mut self, config: Config) -> Self {
63        self.config = config;
64        self
65    }
66
67    /// Create a read/write channel to the client and call `handler` with it.
68    ///
69    /// This method takes a `FnOnce`, `handler`, that consumes a read/write
70    /// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`]
71    /// for details on how to make use of the channel.
72    ///
73    /// The `handler` must return a `Box`ed and `Pin`ned future: calling
74    /// [`Box::pin()`] with a future does just this as is the preferred
75    /// mechanism to create a `Box<Pin<Future>>`. The future must return a
76    /// [`Result<()>`](crate::result::Result). The WebSocket connection is
77    /// closed successfully if the future returns `Ok` and with an error if
78    /// the future returns `Err`.
79    ///
80    /// # Lifetimes
81    ///
82    /// The `Channel` may borrow from the request. If it does, the lifetime
83    /// should be specified as something other than `'static`. Otherwise, the
84    /// `'static` lifetime should be used.
85    ///
86    /// # Example
87    ///
88    /// ```rust
89    /// # extern crate rocket_ws_community as rocket_ws;
90    /// # use rocket::get;
91    /// # use rocket_ws as ws;
92    /// use rocket::futures::{SinkExt, StreamExt};
93    ///
94    /// #[get("/hello/<name>")]
95    /// fn hello(ws: ws::WebSocket, name: &str) -> ws::Channel<'_> {
96    ///     ws.channel(move |mut stream| Box::pin(async move {
97    ///         let message = format!("Hello, {}!", name);
98    ///         let _ = stream.send(message.into()).await;
99    ///         Ok(())
100    ///     }))
101    /// }
102    ///
103    /// #[get("/echo")]
104    /// fn echo(ws: ws::WebSocket) -> ws::Channel<'static> {
105    ///     ws.channel(move |mut stream| Box::pin(async move {
106    ///         while let Some(message) = stream.next().await {
107    ///             let _ = stream.send(message?).await;
108    ///         }
109    ///
110    ///         Ok(())
111    ///     }))
112    /// }
113    /// ```
114    pub fn channel<'r, F>(self, handler: F) -> Channel<'r>
115    where
116        F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r,
117    {
118        Channel {
119            ws: self,
120            handler: Box::new(handler),
121        }
122    }
123
124    /// Create a stream that consumes client [`Message`]s and emits its own.
125    ///
126    /// This method takes a `FnOnce` `stream` that consumes a read-only stream
127    /// and returns a stream of [`Message`]s. While the returned stream can be
128    /// constructed in any manner, the [`Stream!`](crate::Stream!) macro is the
129    /// preferred method. In any case, the stream must be `Send`.
130    ///
131    /// The returned stream must emit items of type `Result<Message>`. Items
132    /// that are `Ok(Message)` are sent to the client while items of type
133    /// `Err(Error)` result in the connection being closed and the remainder of
134    /// the stream discarded.
135    ///
136    /// # Example
137    ///
138    /// ```rust
139    /// # extern crate rocket_ws_community as rocket_ws;
140    /// # use rocket::get;
141    /// # use rocket_ws as ws;
142    ///
143    /// // Use `Stream!`, which internally calls `WebSocket::stream()`.
144    /// #[get("/echo?stream")]
145    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
146    ///     ws::Stream! { ws =>
147    ///         for await message in ws {
148    ///             yield message?;
149    ///         }
150    ///     }
151    /// }
152    ///
153    /// // Use a raw stream.
154    /// #[get("/echo?compose")]
155    /// fn echo_compose(ws: ws::WebSocket) -> ws::Stream!['static] {
156    ///     ws.stream(|io| io)
157    /// }
158    /// ```
159    pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
160    where
161        F: FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r,
162        S: futures::Stream<Item = Result<Message>> + Send + 'r,
163    {
164        MessageStream {
165            ws: self,
166            handler: Box::new(stream),
167        }
168    }
169
170    /// Returns the server's fully computed and encoded WebSocket handshake
171    /// accept key.
172    ///
173    /// > The server takes the value of the `Sec-WebSocket-Key` sent in the
174    /// > handshake request, appends `258EAFA5-E914-47DA-95CA-C5AB0DC85B11`,
175    /// > SHA-1 of the new value, and is then base64 encoded.
176    /// >
177    /// > -- [`Sec-WebSocket-Accept`]
178    ///
179    /// This is the value returned via the [`Sec-WebSocket-Accept`] header
180    /// during the acceptance response.
181    ///
182    /// [`Sec-WebSocket-Accept`]:
183    /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Accept
184    ///
185    /// # Example
186    ///
187    /// ```rust
188    /// # extern crate rocket_ws_community as rocket_ws;
189    /// # use rocket::get;
190    /// # use rocket_ws as ws;
191    /// #
192    /// #[get("/echo")]
193    /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
194    ///     let accept_key = ws.accept_key();
195    ///     ws.stream(|io| io)
196    /// }
197    /// ```
198    pub fn accept_key(&self) -> &str {
199        &self.key
200    }
201}
202
203/// A streaming channel, returned by [`WebSocket::channel()`].
204///
205/// `Channel` has no methods or functionality beyond its trait implementations.
206pub struct Channel<'r> {
207    ws: WebSocket,
208    handler: Box<dyn FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
209}
210
211/// A [`Stream`](futures::Stream) of [`Message`]s, returned by
212/// [`WebSocket::stream()`], used via [`Stream!`].
213///
214/// This type should not be used directly. Instead, it is used via the
215/// [`Stream!`] macro, which expands to both the type itself and an expression
216/// which evaluates to this type. See [`Stream!`] for details.
217///
218/// [`Stream!`]: crate::Stream!
219// TODO: Get rid of this or `Channel` via a single `enum`.
220pub struct MessageStream<'r, S> {
221    ws: WebSocket,
222    handler: Box<dyn FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r>,
223}
224
225#[rocket::async_trait]
226impl<'r> FromRequest<'r> for WebSocket {
227    type Error = std::convert::Infallible;
228
229    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
230        use crate::tungstenite::handshake::derive_accept_key;
231        use rocket::http::uncased::eq;
232
233        let headers = req.headers();
234        let is_upgrade = headers
235            .get("Connection")
236            .any(|h| h.split(',').any(|v| eq(v.trim(), "upgrade")));
237
238        let is_ws = headers
239            .get("Upgrade")
240            .any(|h| h.split(',').any(|v| eq(v.trim(), "websocket")));
241
242        let is_13 = headers.get_one("Sec-WebSocket-Version") == Some("13");
243        let key = headers
244            .get_one("Sec-WebSocket-Key")
245            .map(|k| derive_accept_key(k.as_bytes()));
246        match key {
247            Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket {
248                key,
249                config: Config::default(),
250            }),
251            Some(_) | None => Outcome::Forward(Status::BadRequest),
252        }
253    }
254}
255
256impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
257    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
258        Response::build()
259            .raw_header("Sec-Websocket-Version", "13")
260            .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
261            .upgrade("websocket", self)
262            .ok()
263    }
264}
265
266impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
267where
268    S: futures::Stream<Item = Result<Message>> + Send + 'o,
269{
270    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
271        Response::build()
272            .raw_header("Sec-Websocket-Version", "13")
273            .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
274            .upgrade("websocket", self)
275            .ok()
276    }
277}
278
279#[rocket::async_trait]
280impl IoHandler for Channel<'_> {
281    async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
282        let stream = DuplexStream::new(io, self.ws.config).await;
283        let result = (self.handler)(stream).await;
284        handle_result(result).map(|_| ())
285    }
286}
287
288#[rocket::async_trait]
289impl<'r, S> IoHandler for MessageStream<'r, S>
290where
291    S: futures::Stream<Item = Result<Message>> + Send + 'r,
292{
293    async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
294        let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
295        let stream = (self.handler)(source);
296        rocket::tokio::pin!(stream);
297        while let Some(msg) = stream.next().await {
298            let result = match msg {
299                Ok(msg) if msg.is_close() => return Ok(()),
300                Ok(msg) => sink.send(msg).await,
301                Err(e) => Err(e),
302            };
303
304            if !handle_result(result)? {
305                return Ok(());
306            }
307        }
308
309        Ok(())
310    }
311}
312
313/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing
314/// has terminated without error, and `Err(e)` if an error has occurred.
315fn handle_result(result: Result<()>) -> io::Result<bool> {
316    match result {
317        Ok(_) => Ok(true),
318        Err(Error::ConnectionClosed) => Ok(false),
319        Err(Error::Io(e)) => Err(e),
320        Err(e) => Err(io::Error::other(e)),
321    }
322}