1use std::sync::Arc;
2use std::time::Duration;
3
4use crate::url::Url;
5use tokio::sync::RwLock;
6use tokio::time::timeout as tokio_timeout;
7
8use crate::cookie::CookieJar;
9use crate::headers::Headers;
10use crate::request::IntoUrl;
11use crate::timeouts::Timeouts;
12use crate::transport::connector::{AlpnProtocol, BoringConnector};
13use crate::transport::h1_h2::Client;
14
15use super::handshake::{
16 build_handshake_request, map_websocket_url, perform_handshake, HandshakeTimeouts,
17};
18use super::{WebSocket, WebSocketConfig, WebSocketError, WebSocketResult};
19
20pub struct WebSocketBuilder<'a> {
21 parts: Option<WebSocketClientParts<'a>>,
22 url: Option<Url>,
23 headers: Headers,
24 subprotocols: Vec<String>,
25 config: WebSocketConfig,
26 timeouts: HandshakeTimeouts,
27 error: Option<WebSocketError>,
28}
29
30pub(crate) struct WebSocketClientParts<'a> {
31 pub(crate) connector: &'a BoringConnector,
32 pub(crate) insecure_connector: &'a BoringConnector,
33 pub(crate) default_headers: &'a Headers,
34 pub(crate) timeouts: &'a Timeouts,
35 pub(crate) cookie_store: Option<&'a Arc<RwLock<CookieJar>>>,
36 pub(crate) danger_accept_invalid_certs: bool,
37 pub(crate) localhost_allows_invalid_certs: bool,
38}
39
40impl<'a> WebSocketBuilder<'a> {
41 pub(crate) fn from_client_parts(parts: WebSocketClientParts<'a>, url: impl IntoUrl) -> Self {
42 let (url, error) = match url.into_url() {
43 Ok(url) => (Some(url), None),
44 Err(err) => (
45 None,
46 Some(WebSocketError::Protocol {
47 url: "<invalid>".to_string(),
48 message: err.to_string(),
49 }),
50 ),
51 };
52
53 Self {
54 timeouts: HandshakeTimeouts {
55 connect: parts.timeouts.connect,
56 handshake: parts.timeouts.ttfb,
57 },
58 parts: Some(parts),
59 url,
60 headers: Headers::new(),
61 subprotocols: Vec::new(),
62 config: WebSocketConfig::default(),
63 error,
64 }
65 }
66
67 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
68 self.headers.insert(name, value);
69 self
70 }
71
72 pub fn headers(mut self, headers: impl Into<Headers>) -> Self {
73 self.headers = headers.into();
74 self
75 }
76
77 pub fn subprotocol(mut self, value: impl Into<String>) -> Self {
78 self.subprotocols.push(value.into());
79 self
80 }
81
82 pub fn subprotocols<I, S>(mut self, values: I) -> Self
83 where
84 I: IntoIterator<Item = S>,
85 S: Into<String>,
86 {
87 self.subprotocols.extend(values.into_iter().map(Into::into));
88 self
89 }
90
91 pub fn max_message_size(mut self, bytes: usize) -> Self {
92 self.config.max_message_size = bytes;
93 self
94 }
95
96 pub fn max_frame_size(mut self, bytes: usize) -> Self {
97 self.config.max_frame_size = bytes;
98 self
99 }
100
101 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
102 self.timeouts.connect = Some(timeout);
103 self
104 }
105
106 pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
107 self.timeouts.handshake = Some(timeout);
108 self
109 }
110
111 pub fn read_timeout(mut self, timeout: Duration) -> Self {
112 self.config.read_timeout = Some(timeout);
113 self
114 }
115
116 pub fn write_timeout(mut self, timeout: Duration) -> Self {
117 self.config.write_timeout = Some(timeout);
118 self
119 }
120
121 pub async fn connect(self) -> WebSocketResult<WebSocket> {
122 if let Some(error) = self.error {
123 return Err(error);
124 }
125
126 let parts = self.parts.ok_or_else(|| WebSocketError::Protocol {
127 url: "<unknown>".to_string(),
128 message: "missing WebSocket client parts".to_string(),
129 })?;
130 let original_url = self.url.ok_or_else(|| WebSocketError::Protocol {
131 url: "<unknown>".to_string(),
132 message: "missing WebSocket URL".to_string(),
133 })?;
134 let ws_url = map_websocket_url(original_url)?;
135 let cookie_header = build_cookie_header(parts.cookie_store, &ws_url.http_equivalent).await;
136 let request = build_handshake_request(
137 ws_url.clone(),
138 parts.default_headers,
139 &self.headers,
140 &self.subprotocols,
141 cookie_header,
142 )?;
143
144 let connector = connector_for_url(&parts, &ws_url.uri);
145 let connect_fut = async {
146 if ws_url.secure {
147 connector.connect_h1_only(&ws_url.uri).await
148 } else {
149 connector.connect(&ws_url.uri).await
150 }
151 };
152 let stream = match self.timeouts.connect {
153 Some(duration) => tokio_timeout(duration, connect_fut)
154 .await
155 .map_err(|_| WebSocketError::Timeout {
156 url: ws_url.original.to_string(),
157 operation: format!("connect after {:?}", duration),
158 })?
159 .map_err(|err| WebSocketError::protocol(&ws_url.original, err.to_string()))?,
160 None => connect_fut
161 .await
162 .map_err(|err| WebSocketError::protocol(&ws_url.original, err.to_string()))?,
163 };
164
165 if ws_url.secure && matches!(stream.alpn_protocol(), AlpnProtocol::H2) {
166 return Err(WebSocketError::protocol(
167 &ws_url.original,
168 format!(
169 "wss WebSocket negotiated h2 for {} despite HTTP/1.1-only ALPN",
170 ws_url.original
171 ),
172 ));
173 }
174
175 let response = perform_handshake(
176 stream,
177 &request,
178 &self.subprotocols,
179 self.timeouts.handshake,
180 )
181 .await?;
182
183 store_cookies(
184 parts.cookie_store,
185 &response.headers,
186 &request.url.http_equivalent,
187 )
188 .await;
189
190 Ok(WebSocket::new(
191 response.stream,
192 request.url.original,
193 response.protocol,
194 self.config,
195 response.buffered,
196 ))
197 }
198}
199
200impl Client {
201 pub(crate) fn websocket_with_parts<'a>(
208 parts: WebSocketClientParts<'a>,
209 url: impl IntoUrl,
210 ) -> WebSocketBuilder<'a> {
211 WebSocketBuilder::from_client_parts(parts, url)
212 }
213}
214
215async fn build_cookie_header(
216 cookie_store: Option<&Arc<RwLock<CookieJar>>>,
217 http_equivalent_url: &Url,
218) -> Option<String> {
219 let jar = cookie_store?;
220 jar.read()
221 .await
222 .build_cookie_header(http_equivalent_url.as_str())
223}
224
225async fn store_cookies(
226 cookie_store: Option<&Arc<RwLock<CookieJar>>>,
227 headers: &Headers,
228 http_equivalent_url: &Url,
229) {
230 if let Some(jar) = cookie_store {
231 jar.write()
232 .await
233 .store_from_headers(headers, http_equivalent_url.as_str());
234 }
235}
236
237fn connector_for_url<'a>(
238 parts: &'a WebSocketClientParts<'a>,
239 uri: &http::Uri,
240) -> &'a BoringConnector {
241 if parts.danger_accept_invalid_certs {
242 return parts.insecure_connector;
243 }
244
245 if parts.localhost_allows_invalid_certs {
246 if let Some(host) = uri.host() {
247 if is_localhost(host) {
248 return parts.insecure_connector;
249 }
250 }
251 }
252
253 parts.connector
254}
255
256fn is_localhost(host: &str) -> bool {
257 host == "localhost" || host == "127.0.0.1" || host == "::1"
258}