trillium_websockets/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11
12/*!
13# A websocket trillium handler
14
15There are three primary ways to use this crate
16
17## With an async function that receives a [`WebSocketConn`](crate::WebSocketConn)
18
19This is the simplest way to use trillium websockets, but does not
20provide any of the affordances that implementing the
21[`WebSocketHandler`] trait does. It is best for very simple websockets
22or for usages that require moving the WebSocketConn elsewhere in an
23application. The WebSocketConn is fully owned at this point, and will
24disconnect when dropped, not when the async function passed to
25`websocket` completes.
26
27```
28use futures_lite::stream::StreamExt;
29use trillium_websockets::{Message, WebSocketConn, websocket};
30
31let handler = websocket(|mut conn: WebSocketConn| async move {
32    while let Some(Ok(Message::Text(input))) = conn.next().await {
33        conn.send_string(format!("received your message: {}", &input)).await;
34    }
35});
36# // tests at tests/tests.rs for example simplicity
37```
38
39
40## Implementing [`WebSocketHandler`](crate::WebSocketHandler)
41
42[`WebSocketHandler`] provides support for sending outbound messages as a
43stream, and simplifies common patterns like executing async code on
44received messages.
45
46## Using [`JsonWebSocketHandler`](crate::JsonWebSocketHandler)
47
48[`JsonWebSocketHandler`] provides a thin serialization and
49deserialization layer on top of [`WebSocketHandler`] for this common
50use case.  See the [`JsonWebSocketHandler`] documentation for example
51usage. In order to use this trait, the `json` cargo feature must be
52enabled.
53
54*/
55
56mod bidirectional_stream;
57mod websocket_connection;
58mod websocket_handler;
59
60use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
61use bidirectional_stream::{BidirectionalStream, Direction};
62use futures_lite::stream::StreamExt;
63use sha1::{Digest, Sha1};
64use std::{
65    net::IpAddr,
66    ops::{Deref, DerefMut},
67};
68use trillium::{
69    Conn, Handler,
70    KnownHeaderName::{
71        Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketProtocol, SecWebsocketVersion,
72        Upgrade as UpgradeHeader,
73    },
74    Status, Upgrade,
75};
76
77pub use async_tungstenite::{
78    self,
79    tungstenite::{
80        self,
81        protocol::{Role, WebSocketConfig},
82        Message,
83    },
84};
85pub use trillium::async_trait;
86pub use websocket_connection::WebSocketConn;
87pub use websocket_handler::WebSocketHandler;
88
89const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
90
91#[derive(thiserror::Error, Debug)]
92#[non_exhaustive]
93/// An Error type that represents all exceptional conditions that can be encoutered in the operation
94/// of this crate
95pub enum Error {
96    #[error(transparent)]
97    /// an error in the underlying websocket implementation
98    WebSocket(#[from] tungstenite::Error),
99
100    #[cfg(feature = "json")]
101    #[error(transparent)]
102    /// an error in json serialization or deserialization
103    Json(#[from] serde_json::Error),
104}
105
106/// a Result type for this crate
107pub type Result<T = Message> = std::result::Result<T, Error>;
108
109#[cfg(feature = "json")]
110mod json;
111
112#[cfg(feature = "json")]
113pub use json::{json_websocket, JsonHandler, JsonWebSocketHandler};
114
115/**
116The trillium handler.
117See crate-level docs for example usage.
118*/
119#[derive(Debug)]
120pub struct WebSocket<H> {
121    handler: H,
122    protocols: Vec<String>,
123    config: Option<WebSocketConfig>,
124    required: bool,
125}
126
127impl<H> Deref for WebSocket<H> {
128    type Target = H;
129
130    fn deref(&self) -> &Self::Target {
131        &self.handler
132    }
133}
134
135impl<H> DerefMut for WebSocket<H> {
136    fn deref_mut(&mut self) -> &mut Self::Target {
137        &mut self.handler
138    }
139}
140
141/**
142Builds a new trillium handler from the provided
143WebSocketHandler. Alias for [`WebSocket::new`]
144*/
145pub fn websocket<H>(websocket_handler: H) -> WebSocket<H>
146where
147    H: WebSocketHandler,
148{
149    WebSocket::new(websocket_handler)
150}
151
152impl<H> WebSocket<H>
153where
154    H: WebSocketHandler,
155{
156    /// Build a new WebSocket with an async handler function that
157    /// receives a [`WebSocketConn`]
158    pub fn new(handler: H) -> Self {
159        Self {
160            handler,
161            protocols: Default::default(),
162            config: None,
163            required: false,
164        }
165    }
166
167    /// `protocols` is a sequence of known protocols. On successful handshake,
168    /// the returned response headers contain the first protocol in this list
169    /// which the server also knows.
170    pub fn with_protocols(self, protocols: &[&str]) -> Self {
171        Self {
172            protocols: protocols.iter().map(ToString::to_string).collect(),
173            ..self
174        }
175    }
176
177    /// configure the websocket protocol
178    pub fn with_protocol_config(self, config: WebSocketConfig) -> Self {
179        Self {
180            config: Some(config),
181            ..self
182        }
183    }
184
185    /// configure this handler to halt and send back a [`426 Upgrade
186    /// Required`][Status::UpgradeRequired] if a websocket cannot be negotiated
187    pub fn required(mut self) -> Self {
188        self.required = true;
189        self
190    }
191}
192
193struct IsWebsocket;
194
195#[cfg(test)]
196mod tests;
197
198// this is a workaround for the fact that Upgrade is a public struct,
199// so adding peer_ip to that struct would be a breaking change. We
200// stash a copy in state for now.
201struct WebsocketPeerIp(Option<IpAddr>);
202
203#[async_trait]
204impl<H> Handler for WebSocket<H>
205where
206    H: WebSocketHandler,
207{
208    async fn run(&self, mut conn: Conn) -> Conn {
209        if !upgrade_requested(&conn) {
210            if self.required {
211                return conn.with_status(Status::UpgradeRequired).halt();
212            } else {
213                return conn;
214            }
215        }
216
217        let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
218
219        let Some(sec_websocket_key) = conn.request_headers().get_str(SecWebsocketKey) else {
220            return conn.with_status(Status::BadRequest).halt();
221        };
222        let sec_websocket_accept = websocket_accept_hash(sec_websocket_key);
223
224        let protocol = websocket_protocol(&conn, &self.protocols);
225
226        let headers = conn.response_headers_mut();
227
228        headers.extend([
229            (UpgradeHeader, "websocket"),
230            (Connection, "Upgrade"),
231            (SecWebsocketVersion, "13"),
232        ]);
233
234        headers.insert(SecWebsocketAccept, sec_websocket_accept);
235
236        if let Some(protocol) = protocol {
237            headers.insert(SecWebsocketProtocol, protocol);
238        }
239
240        conn.halt()
241            .with_state(websocket_peer_ip)
242            .with_state(IsWebsocket)
243            .with_status(Status::SwitchingProtocols)
244    }
245
246    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
247        upgrade.state().contains::<IsWebsocket>()
248    }
249
250    async fn upgrade(&self, mut upgrade: Upgrade) {
251        let peer_ip = upgrade.state.take::<WebsocketPeerIp>().and_then(|i| i.0);
252        let mut conn = WebSocketConn::new(upgrade, self.config, Role::Server).await;
253        conn.set_peer_ip(peer_ip);
254
255        let Some((mut conn, outbound)) = self.handler.connect(conn).await else {
256            return;
257        };
258
259        let inbound = conn.take_inbound_stream();
260
261        let mut stream = std::pin::pin!(BidirectionalStream { inbound, outbound });
262        while let Some(message) = stream.next().await {
263            match message {
264                Direction::Inbound(Ok(Message::Close(close_frame))) => {
265                    self.handler.disconnect(&mut conn, close_frame).await;
266                    break;
267                }
268
269                Direction::Inbound(Ok(message)) => {
270                    self.handler.inbound(message, &mut conn).await;
271                }
272
273                Direction::Outbound(message) => {
274                    if let Err(e) = self.handler.send(message, &mut conn).await {
275                        log::warn!("outbound websocket error: {:?}", e);
276                        break;
277                    }
278                }
279
280                _ => {
281                    self.handler.disconnect(&mut conn, None).await;
282                    break;
283                }
284            }
285        }
286
287        if let Some(err) = conn.close().await.err() {
288            log::warn!("websocket close error: {:?}", err);
289        };
290    }
291}
292
293fn websocket_protocol(conn: &Conn, protocols: &[String]) -> Option<String> {
294    conn.request_headers()
295        .get_str(SecWebsocketProtocol)
296        .and_then(|value| {
297            value
298                .split(',')
299                .map(str::trim)
300                .find(|req_p| protocols.iter().any(|x| x == req_p))
301                .map(|s| s.to_owned())
302        })
303}
304
305fn connection_is_upgrade(conn: &Conn) -> bool {
306    conn.request_headers()
307        .get_str(Connection)
308        .map(|connection| {
309            connection
310                .split(',')
311                .any(|c| c.trim().eq_ignore_ascii_case("upgrade"))
312        })
313        .unwrap_or(false)
314}
315
316fn upgrade_to_websocket(conn: &Conn) -> bool {
317    conn.request_headers()
318        .eq_ignore_ascii_case(UpgradeHeader, "websocket")
319}
320
321fn upgrade_requested(conn: &Conn) -> bool {
322    connection_is_upgrade(conn) && upgrade_to_websocket(conn)
323}
324
325/// Generate a random key suitable for Sec-WebSocket-Key
326pub fn websocket_key() -> String {
327    BASE64.encode(fastrand::u128(..).to_ne_bytes())
328}
329
330/// Generate the expected Sec-WebSocket-Accept hash from the Sec-WebSocket-Key
331pub fn websocket_accept_hash(websocket_key: &str) -> String {
332    let hash = Sha1::new()
333        .chain_update(websocket_key)
334        .chain_update(WEBSOCKET_GUID)
335        .finalize();
336    BASE64.encode(&hash[..])
337}