rocket_ws_community/websocket.rs
1use std::io;
2
3use rocket::data::{IoHandler, IoStream};
4use rocket::futures::{self, future::BoxFuture, stream::SplitStream, SinkExt, StreamExt};
5use rocket::http::Status;
6use rocket::request::{FromRequest, Outcome, Request};
7use rocket::response::{self, Responder, Response};
8
9use crate::result::{Error, Result};
10use crate::stream::DuplexStream;
11use crate::{Config, Message};
12
13/// A request guard identifying WebSocket requests. Converts into a [`Channel`]
14/// or [`MessageStream`].
15///
16/// For example usage, see the [crate docs](crate#usage).
17///
18/// ## Details
19///
20/// This is the entrypoint to the library. Every WebSocket response _must_
21/// initiate via the `WebSocket` request guard. The guard identifies valid
22/// WebSocket connection requests and, if the request is valid, succeeds to be
23/// converted into a streaming WebSocket response via
24/// [`Stream!`](crate::Stream!), [`WebSocket::channel()`], or
25/// [`WebSocket::stream()`]. The connection can be configured via
26/// [`WebSocket::config()`]; see [`Config`] for details on configuring a
27/// connection.
28///
29/// ### Forwarding
30///
31/// If the incoming request is not a valid WebSocket request, the guard
32/// forwards with a status of `BadRequest`. The guard never fails.
33pub struct WebSocket {
34 config: Config,
35 key: String,
36}
37
38impl WebSocket {
39 /// Change the default connection configuration to `config`.
40 ///
41 /// # Example
42 ///
43 /// ```rust
44 /// # extern crate rocket_ws_community as rocket_ws;
45 /// # use rocket::get;
46 /// # use rocket_ws as ws;
47 /// #
48 /// #[get("/echo")]
49 /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
50 /// let ws = ws.config(ws::Config {
51 /// max_send_queue: Some(5),
52 /// ..Default::default()
53 /// });
54 ///
55 /// ws::Stream! { ws =>
56 /// for await message in ws {
57 /// yield message?;
58 /// }
59 /// }
60 /// }
61 /// ```
62 pub fn config(mut self, config: Config) -> Self {
63 self.config = config;
64 self
65 }
66
67 /// Create a read/write channel to the client and call `handler` with it.
68 ///
69 /// This method takes a `FnOnce`, `handler`, that consumes a read/write
70 /// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`]
71 /// for details on how to make use of the channel.
72 ///
73 /// The `handler` must return a `Box`ed and `Pin`ned future: calling
74 /// [`Box::pin()`] with a future does just this as is the preferred
75 /// mechanism to create a `Box<Pin<Future>>`. The future must return a
76 /// [`Result<()>`](crate::result::Result). The WebSocket connection is
77 /// closed successfully if the future returns `Ok` and with an error if
78 /// the future returns `Err`.
79 ///
80 /// # Lifetimes
81 ///
82 /// The `Channel` may borrow from the request. If it does, the lifetime
83 /// should be specified as something other than `'static`. Otherwise, the
84 /// `'static` lifetime should be used.
85 ///
86 /// # Example
87 ///
88 /// ```rust
89 /// # extern crate rocket_ws_community as rocket_ws;
90 /// # use rocket::get;
91 /// # use rocket_ws as ws;
92 /// use rocket::futures::{SinkExt, StreamExt};
93 ///
94 /// #[get("/hello/<name>")]
95 /// fn hello(ws: ws::WebSocket, name: &str) -> ws::Channel<'_> {
96 /// ws.channel(move |mut stream| Box::pin(async move {
97 /// let message = format!("Hello, {}!", name);
98 /// let _ = stream.send(message.into()).await;
99 /// Ok(())
100 /// }))
101 /// }
102 ///
103 /// #[get("/echo")]
104 /// fn echo(ws: ws::WebSocket) -> ws::Channel<'static> {
105 /// ws.channel(move |mut stream| Box::pin(async move {
106 /// while let Some(message) = stream.next().await {
107 /// let _ = stream.send(message?).await;
108 /// }
109 ///
110 /// Ok(())
111 /// }))
112 /// }
113 /// ```
114 pub fn channel<'r, F>(self, handler: F) -> Channel<'r>
115 where
116 F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r,
117 {
118 Channel {
119 ws: self,
120 handler: Box::new(handler),
121 }
122 }
123
124 /// Create a stream that consumes client [`Message`]s and emits its own.
125 ///
126 /// This method takes a `FnOnce` `stream` that consumes a read-only stream
127 /// and returns a stream of [`Message`]s. While the returned stream can be
128 /// constructed in any manner, the [`Stream!`](crate::Stream!) macro is the
129 /// preferred method. In any case, the stream must be `Send`.
130 ///
131 /// The returned stream must emit items of type `Result<Message>`. Items
132 /// that are `Ok(Message)` are sent to the client while items of type
133 /// `Err(Error)` result in the connection being closed and the remainder of
134 /// the stream discarded.
135 ///
136 /// # Example
137 ///
138 /// ```rust
139 /// # extern crate rocket_ws_community as rocket_ws;
140 /// # use rocket::get;
141 /// # use rocket_ws as ws;
142 ///
143 /// // Use `Stream!`, which internally calls `WebSocket::stream()`.
144 /// #[get("/echo?stream")]
145 /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
146 /// ws::Stream! { ws =>
147 /// for await message in ws {
148 /// yield message?;
149 /// }
150 /// }
151 /// }
152 ///
153 /// // Use a raw stream.
154 /// #[get("/echo?compose")]
155 /// fn echo_compose(ws: ws::WebSocket) -> ws::Stream!['static] {
156 /// ws.stream(|io| io)
157 /// }
158 /// ```
159 pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
160 where
161 F: FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r,
162 S: futures::Stream<Item = Result<Message>> + Send + 'r,
163 {
164 MessageStream {
165 ws: self,
166 handler: Box::new(stream),
167 }
168 }
169
170 /// Returns the server's fully computed and encoded WebSocket handshake
171 /// accept key.
172 ///
173 /// > The server takes the value of the `Sec-WebSocket-Key` sent in the
174 /// > handshake request, appends `258EAFA5-E914-47DA-95CA-C5AB0DC85B11`,
175 /// > SHA-1 of the new value, and is then base64 encoded.
176 /// >
177 /// > -- [`Sec-WebSocket-Accept`]
178 ///
179 /// This is the value returned via the [`Sec-WebSocket-Accept`] header
180 /// during the acceptance response.
181 ///
182 /// [`Sec-WebSocket-Accept`]:
183 /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Accept
184 ///
185 /// # Example
186 ///
187 /// ```rust
188 /// # extern crate rocket_ws_community as rocket_ws;
189 /// # use rocket::get;
190 /// # use rocket_ws as ws;
191 /// #
192 /// #[get("/echo")]
193 /// fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] {
194 /// let accept_key = ws.accept_key();
195 /// ws.stream(|io| io)
196 /// }
197 /// ```
198 pub fn accept_key(&self) -> &str {
199 &self.key
200 }
201}
202
203/// A streaming channel, returned by [`WebSocket::channel()`].
204///
205/// `Channel` has no methods or functionality beyond its trait implementations.
206pub struct Channel<'r> {
207 ws: WebSocket,
208 handler: Box<dyn FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
209}
210
211/// A [`Stream`](futures::Stream) of [`Message`]s, returned by
212/// [`WebSocket::stream()`], used via [`Stream!`].
213///
214/// This type should not be used directly. Instead, it is used via the
215/// [`Stream!`] macro, which expands to both the type itself and an expression
216/// which evaluates to this type. See [`Stream!`] for details.
217///
218/// [`Stream!`]: crate::Stream!
219// TODO: Get rid of this or `Channel` via a single `enum`.
220pub struct MessageStream<'r, S> {
221 ws: WebSocket,
222 handler: Box<dyn FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r>,
223}
224
225#[rocket::async_trait]
226impl<'r> FromRequest<'r> for WebSocket {
227 type Error = std::convert::Infallible;
228
229 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
230 use crate::tungstenite::handshake::derive_accept_key;
231 use rocket::http::uncased::eq;
232
233 let headers = req.headers();
234 let is_upgrade = headers
235 .get("Connection")
236 .any(|h| h.split(',').any(|v| eq(v.trim(), "upgrade")));
237
238 let is_ws = headers
239 .get("Upgrade")
240 .any(|h| h.split(',').any(|v| eq(v.trim(), "websocket")));
241
242 let is_13 = headers.get_one("Sec-WebSocket-Version") == Some("13");
243 let key = headers
244 .get_one("Sec-WebSocket-Key")
245 .map(|k| derive_accept_key(k.as_bytes()));
246 match key {
247 Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket {
248 key,
249 config: Config::default(),
250 }),
251 Some(_) | None => Outcome::Forward(Status::BadRequest),
252 }
253 }
254}
255
256impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
257 fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
258 Response::build()
259 .raw_header("Sec-Websocket-Version", "13")
260 .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
261 .upgrade("websocket", self)
262 .ok()
263 }
264}
265
266impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
267where
268 S: futures::Stream<Item = Result<Message>> + Send + 'o,
269{
270 fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
271 Response::build()
272 .raw_header("Sec-Websocket-Version", "13")
273 .raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
274 .upgrade("websocket", self)
275 .ok()
276 }
277}
278
279#[rocket::async_trait]
280impl IoHandler for Channel<'_> {
281 async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
282 let stream = DuplexStream::new(io, self.ws.config).await;
283 let result = (self.handler)(stream).await;
284 handle_result(result).map(|_| ())
285 }
286}
287
288#[rocket::async_trait]
289impl<'r, S> IoHandler for MessageStream<'r, S>
290where
291 S: futures::Stream<Item = Result<Message>> + Send + 'r,
292{
293 async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
294 let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
295 let stream = (self.handler)(source);
296 rocket::tokio::pin!(stream);
297 while let Some(msg) = stream.next().await {
298 let result = match msg {
299 Ok(msg) if msg.is_close() => return Ok(()),
300 Ok(msg) => sink.send(msg).await,
301 Err(e) => Err(e),
302 };
303
304 if !handle_result(result)? {
305 return Ok(());
306 }
307 }
308
309 Ok(())
310 }
311}
312
313/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing
314/// has terminated without error, and `Err(e)` if an error has occurred.
315fn handle_result(result: Result<()>) -> io::Result<bool> {
316 match result {
317 Ok(_) => Ok(true),
318 Err(Error::ConnectionClosed) => Ok(false),
319 Err(Error::Io(e)) => Err(e),
320 Err(e) => Err(io::Error::other(e)),
321 }
322}