Skip to main content

trillium_client/
websocket.rs

1//! Support for client-side WebSockets
2
3use crate::{Conn, WebSocketConfig, WebSocketConn};
4use std::{
5    borrow::Cow,
6    error::Error,
7    fmt::{self, Display},
8    ops::{Deref, DerefMut},
9};
10use trillium_http::{
11    KnownHeaderName::{
12        Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion,
13        Upgrade as UpgradeHeader,
14    },
15    Method, Status, Upgrade, Version,
16};
17pub use trillium_websockets::Message;
18use trillium_websockets::{Role, websocket_accept_hash, websocket_key};
19
20impl Conn {
21    fn set_websocket_upgrade_headers_h1(&mut self) {
22        let headers = self.request_headers_mut();
23        headers.try_insert(UpgradeHeader, "websocket");
24        headers.try_insert(Connection, "upgrade");
25        headers.try_insert(SecWebsocketVersion, "13");
26        headers.try_insert(SecWebsocketKey, websocket_key());
27    }
28
29    /// Attempt to transform this `Conn` into a [`WebSocketConn`].
30    ///
31    /// This is an *execution* method: calling it on a conn that has already been awaited
32    /// returns [`ErrorKind::AlreadyExecuted`]. Build the conn, then call this — don't await
33    /// it yourself first.
34    ///
35    /// Protocol selection follows the conn's [`http_version`][Conn::http_version] hint:
36    /// `Http2` and `Http3` use the extended-CONNECT bootstrap (RFC 8441 over h2, RFC 9220
37    /// over h3); the default uses an h1 `Upgrade` handshake (RFC 6455). If the peer is h2/h3
38    /// but doesn't advertise `SETTINGS_ENABLE_CONNECT_PROTOCOL`, the upgrade hard-errors —
39    /// there is no silent fallback to h1 from a non-capable peer.
40    pub async fn into_websocket(self) -> Result<WebSocketConn, WebSocketUpgradeError> {
41        self.into_websocket_with_config(WebSocketConfig::default())
42            .await
43    }
44
45    /// Like [`Conn::into_websocket`] but with a caller-supplied [`WebSocketConfig`].
46    pub async fn into_websocket_with_config(
47        self,
48        config: WebSocketConfig,
49    ) -> Result<WebSocketConn, WebSocketUpgradeError> {
50        if self.status().is_some() {
51            return Err(WebSocketUpgradeError::new(self, ErrorKind::AlreadyExecuted));
52        }
53
54        match self.http_version() {
55            Version::Http2 | Version::Http3 => self.into_websocket_extended_connect(config).await,
56            _ => self.into_websocket_h1(config).await,
57        }
58    }
59
60    async fn into_websocket_h1(
61        mut self,
62        config: WebSocketConfig,
63    ) -> Result<WebSocketConn, WebSocketUpgradeError> {
64        self.set_websocket_upgrade_headers_h1();
65        if let Err(e) = (&mut self).await {
66            return Err(WebSocketUpgradeError::new(self, e.into()));
67        }
68        let status = self.status().expect("Response did not include status");
69        if status != Status::SwitchingProtocols {
70            return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
71        }
72        let key = self
73            .request_headers()
74            .get_str(SecWebsocketKey)
75            .expect("Request did not include Sec-WebSocket-Key");
76        let accept_key = websocket_accept_hash(key);
77        if self.response_headers().get_str(SecWebsocketAccept) != Some(&accept_key) {
78            return Err(WebSocketUpgradeError::new(self, ErrorKind::InvalidAccept));
79        }
80        let peer_ip = self.peer_addr().map(|addr| addr.ip());
81        let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
82        conn.set_peer_ip(peer_ip);
83        Ok(conn)
84    }
85
86    async fn into_websocket_extended_connect(
87        mut self,
88        config: WebSocketConfig,
89    ) -> Result<WebSocketConn, WebSocketUpgradeError> {
90        // Extended CONNECT carries `Sec-WebSocket-Version: 13` and the optional
91        // `Sec-WebSocket-Protocol`, but skips the `Sec-WebSocket-Key` / `Sec-WebSocket-Accept`
92        // SHA1 dance — those are h1-only artifacts. The `Connection: upgrade` /
93        // `Upgrade: websocket` headers are likewise h1-only and would be stripped by
94        // `finalize_headers_h2` / `_h3` even if we set them.
95        self.request_headers_mut()
96            .try_insert(SecWebsocketVersion, "13");
97        self.set_method(Method::Connect);
98        self.protocol = Some(Cow::Borrowed("websocket"));
99
100        // The peer-capability gate (server must have advertised
101        // `SETTINGS_ENABLE_CONNECT_PROTOCOL` before the client may send a `:protocol`
102        // HEADERS) lives inside the h2 client send path, where it can park on the peer's
103        // first SETTINGS *before* putting any HEADERS on the wire. A "not supported"
104        // outcome surfaces here as `Error::ExtendedConnectUnsupported`.
105        if let Err(e) = (&mut self).await {
106            let kind = match e {
107                trillium_http::Error::ExtendedConnectUnsupported => {
108                    ErrorKind::ExtendedConnectUnsupported
109                }
110                other => other.into(),
111            };
112            return Err(WebSocketUpgradeError::new(self, kind));
113        }
114
115        let status = self.status().expect("Response did not include status");
116        if status != Status::Ok {
117            return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
118        }
119
120        let peer_ip = self.peer_addr().map(|addr| addr.ip());
121        let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
122        conn.set_peer_ip(peer_ip);
123        Ok(conn)
124    }
125}
126
127/// The kind of error that occurred when attempting a websocket upgrade
128#[derive(thiserror::Error, Debug)]
129#[non_exhaustive]
130pub enum ErrorKind {
131    /// an HTTP error attempting to make the request
132    #[error(transparent)]
133    Http(#[from] trillium_http::Error),
134
135    /// Response didn't have the expected status (101 Switching Protocols for h1, 200 OK for
136    /// h2/h3 extended CONNECT).
137    #[error("Unexpected response status {0} for websocket upgrade")]
138    Status(Status),
139
140    /// Response Sec-WebSocket-Accept was missing or invalid; generally a server bug
141    #[error("Response Sec-WebSocket-Accept was missing or invalid")]
142    InvalidAccept,
143
144    /// `into_websocket` was called on a `Conn` that had already been executed (its status is
145    /// already set). The websocket upgrade *is* the execution; build the conn and call
146    /// `into_websocket` directly without awaiting first.
147    #[error(
148        "Conn::into_websocket called after execution — build the conn and await into_websocket \
149         instead of awaiting the conn separately"
150    )]
151    AlreadyExecuted,
152
153    /// The h2 or h3 peer did not advertise `SETTINGS_ENABLE_CONNECT_PROTOCOL = 1`, so the
154    /// extended-CONNECT bootstrap (RFC 8441 over h2, RFC 9220 over h3) is not available on this
155    /// connection.
156    #[error("peer does not support extended CONNECT")]
157    ExtendedConnectUnsupported,
158}
159
160/// An attempted upgrade to a WebSocket failed.
161///
162/// You can transform this back into the Conn with [`From::from`]/[`Into::into`], if you need to
163/// look at the server response.
164#[derive(Debug)]
165pub struct WebSocketUpgradeError {
166    /// The kind of error that occurred
167    pub kind: ErrorKind,
168    conn: Box<Conn>,
169}
170
171impl WebSocketUpgradeError {
172    fn new(conn: Conn, kind: ErrorKind) -> Self {
173        let conn = Box::new(conn);
174        Self { conn, kind }
175    }
176}
177
178impl From<WebSocketUpgradeError> for Conn {
179    fn from(value: WebSocketUpgradeError) -> Self {
180        *value.conn
181    }
182}
183
184impl Deref for WebSocketUpgradeError {
185    type Target = Conn;
186
187    fn deref(&self) -> &Self::Target {
188        &self.conn
189    }
190}
191impl DerefMut for WebSocketUpgradeError {
192    fn deref_mut(&mut self) -> &mut Self::Target {
193        &mut self.conn
194    }
195}
196
197impl Error for WebSocketUpgradeError {}
198
199impl Display for WebSocketUpgradeError {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        self.kind.fmt(f)
202    }
203}