Skip to main content

specter/websocket/
client.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use crate::url::Url;
5use tokio::sync::RwLock;
6use tokio::time::timeout as tokio_timeout;
7
8use crate::cookie::CookieJar;
9use crate::headers::Headers;
10use crate::request::IntoUrl;
11use crate::timeouts::Timeouts;
12use crate::transport::connector::{AlpnProtocol, BoringConnector};
13use crate::transport::h1_h2::Client;
14
15use super::handshake::{
16    build_handshake_request, map_websocket_url, perform_handshake, HandshakeTimeouts,
17};
18use super::{WebSocket, WebSocketConfig, WebSocketError, WebSocketResult};
19
20pub struct WebSocketBuilder<'a> {
21    parts: Option<WebSocketClientParts<'a>>,
22    url: Option<Url>,
23    headers: Headers,
24    subprotocols: Vec<String>,
25    config: WebSocketConfig,
26    timeouts: HandshakeTimeouts,
27    error: Option<WebSocketError>,
28}
29
30pub(crate) struct WebSocketClientParts<'a> {
31    pub(crate) connector: &'a BoringConnector,
32    pub(crate) insecure_connector: &'a BoringConnector,
33    pub(crate) default_headers: &'a Headers,
34    pub(crate) timeouts: &'a Timeouts,
35    pub(crate) cookie_store: Option<&'a Arc<RwLock<CookieJar>>>,
36    pub(crate) danger_accept_invalid_certs: bool,
37    pub(crate) localhost_allows_invalid_certs: bool,
38}
39
40impl<'a> WebSocketBuilder<'a> {
41    pub(crate) fn from_client_parts(parts: WebSocketClientParts<'a>, url: impl IntoUrl) -> Self {
42        let (url, error) = match url.into_url() {
43            Ok(url) => (Some(url), None),
44            Err(err) => (
45                None,
46                Some(WebSocketError::Protocol {
47                    url: "<invalid>".to_string(),
48                    message: err.to_string(),
49                }),
50            ),
51        };
52
53        Self {
54            timeouts: HandshakeTimeouts {
55                connect: parts.timeouts.connect,
56                handshake: parts.timeouts.ttfb,
57            },
58            parts: Some(parts),
59            url,
60            headers: Headers::new(),
61            subprotocols: Vec::new(),
62            config: WebSocketConfig::default(),
63            error,
64        }
65    }
66
67    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
68        self.headers.insert(name, value);
69        self
70    }
71
72    pub fn headers(mut self, headers: impl Into<Headers>) -> Self {
73        self.headers = headers.into();
74        self
75    }
76
77    pub fn subprotocol(mut self, value: impl Into<String>) -> Self {
78        self.subprotocols.push(value.into());
79        self
80    }
81
82    pub fn subprotocols<I, S>(mut self, values: I) -> Self
83    where
84        I: IntoIterator<Item = S>,
85        S: Into<String>,
86    {
87        self.subprotocols.extend(values.into_iter().map(Into::into));
88        self
89    }
90
91    pub fn max_message_size(mut self, bytes: usize) -> Self {
92        self.config.max_message_size = bytes;
93        self
94    }
95
96    pub fn max_frame_size(mut self, bytes: usize) -> Self {
97        self.config.max_frame_size = bytes;
98        self
99    }
100
101    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
102        self.timeouts.connect = Some(timeout);
103        self
104    }
105
106    pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
107        self.timeouts.handshake = Some(timeout);
108        self
109    }
110
111    pub fn read_timeout(mut self, timeout: Duration) -> Self {
112        self.config.read_timeout = Some(timeout);
113        self
114    }
115
116    pub fn write_timeout(mut self, timeout: Duration) -> Self {
117        self.config.write_timeout = Some(timeout);
118        self
119    }
120
121    pub async fn connect(self) -> WebSocketResult<WebSocket> {
122        if let Some(error) = self.error {
123            return Err(error);
124        }
125
126        let parts = self.parts.ok_or_else(|| WebSocketError::Protocol {
127            url: "<unknown>".to_string(),
128            message: "missing WebSocket client parts".to_string(),
129        })?;
130        let original_url = self.url.ok_or_else(|| WebSocketError::Protocol {
131            url: "<unknown>".to_string(),
132            message: "missing WebSocket URL".to_string(),
133        })?;
134        let ws_url = map_websocket_url(original_url)?;
135        let cookie_header = build_cookie_header(parts.cookie_store, &ws_url.http_equivalent).await;
136        let request = build_handshake_request(
137            ws_url.clone(),
138            parts.default_headers,
139            &self.headers,
140            &self.subprotocols,
141            cookie_header,
142        )?;
143
144        let connector = connector_for_url(&parts, &ws_url.uri);
145        let connect_fut = async {
146            if ws_url.secure {
147                connector.connect_h1_only(&ws_url.uri).await
148            } else {
149                connector.connect(&ws_url.uri).await
150            }
151        };
152        let stream = match self.timeouts.connect {
153            Some(duration) => tokio_timeout(duration, connect_fut)
154                .await
155                .map_err(|_| WebSocketError::Timeout {
156                    url: ws_url.original.to_string(),
157                    operation: format!("connect after {:?}", duration),
158                })?
159                .map_err(|err| WebSocketError::protocol(&ws_url.original, err.to_string()))?,
160            None => connect_fut
161                .await
162                .map_err(|err| WebSocketError::protocol(&ws_url.original, err.to_string()))?,
163        };
164
165        if ws_url.secure && matches!(stream.alpn_protocol(), AlpnProtocol::H2) {
166            return Err(WebSocketError::protocol(
167                &ws_url.original,
168                format!(
169                    "wss WebSocket negotiated h2 for {} despite HTTP/1.1-only ALPN",
170                    ws_url.original
171                ),
172            ));
173        }
174
175        let response = perform_handshake(
176            stream,
177            &request,
178            &self.subprotocols,
179            self.timeouts.handshake,
180        )
181        .await?;
182
183        store_cookies(
184            parts.cookie_store,
185            &response.headers,
186            &request.url.http_equivalent,
187        )
188        .await;
189
190        Ok(WebSocket::new(
191            response.stream,
192            request.url.original,
193            response.protocol,
194            self.config,
195            response.buffered,
196        ))
197    }
198}
199
200impl Client {
201    /// Integration shim for `transport::h1_h2`.
202    ///
203    /// `Client` fields are private to `transport::h1_h2`, so that module should expose
204    /// the public `Client::websocket(url)` method by calling this constructor with its
205    /// private fields. Keeping the actual builder in `src/websocket/client.rs` avoids
206    /// routing upgraded streams through the normal HTTP execute/pool path.
207    pub(crate) fn websocket_with_parts<'a>(
208        parts: WebSocketClientParts<'a>,
209        url: impl IntoUrl,
210    ) -> WebSocketBuilder<'a> {
211        WebSocketBuilder::from_client_parts(parts, url)
212    }
213}
214
215async fn build_cookie_header(
216    cookie_store: Option<&Arc<RwLock<CookieJar>>>,
217    http_equivalent_url: &Url,
218) -> Option<String> {
219    let jar = cookie_store?;
220    jar.read()
221        .await
222        .build_cookie_header(http_equivalent_url.as_str())
223}
224
225async fn store_cookies(
226    cookie_store: Option<&Arc<RwLock<CookieJar>>>,
227    headers: &Headers,
228    http_equivalent_url: &Url,
229) {
230    if let Some(jar) = cookie_store {
231        jar.write()
232            .await
233            .store_from_headers(headers, http_equivalent_url.as_str());
234    }
235}
236
237fn connector_for_url<'a>(
238    parts: &'a WebSocketClientParts<'a>,
239    uri: &http::Uri,
240) -> &'a BoringConnector {
241    if parts.danger_accept_invalid_certs {
242        return parts.insecure_connector;
243    }
244
245    if parts.localhost_allows_invalid_certs {
246        if let Some(host) = uri.host() {
247            if is_localhost(host) {
248                return parts.insecure_connector;
249            }
250        }
251    }
252
253    parts.connector
254}
255
256fn is_localhost(host: &str) -> bool {
257    host == "localhost" || host == "127.0.0.1" || host == "::1"
258}