1use std::{
4 io::{Read, Write},
5 net::{SocketAddr, TcpStream, ToSocketAddrs},
6 result::Result as StdResult,
7};
8
9use http::{request::Parts, HeaderName, Uri};
10
11use crate::{
12 error::{Error, Result, UrlError},
13 handshake::{
14 client::{generate_key, ClientHandshake, Request, Response},
15 core::HandshakeError,
16 },
17 protocol::{config::WebSocketConfig, websocket::WebSocket},
18 stream::{Mode, NoDelay, SimplifiedStream},
19};
20
21pub fn connect_with_config<Req: IntoClientRequest>(
40 req: Req,
41 config: Option<WebSocketConfig>,
42 max_redirects: u8,
43) -> Result<(WebSocket<SimplifiedStream<TcpStream>>, Response)> {
44 fn try_client_handshake(
45 request: Request,
46 config: Option<WebSocketConfig>,
47 ) -> Result<(WebSocket<SimplifiedStream<TcpStream>>, Response)> {
48 let uri = request.uri();
49 let mode = uri_mode(uri)?;
50
51 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
52 if let Mode::Tls = mode {
53 return Err(Error::Url(UrlError::TlsFeatureNotEnabled));
54 }
55
56 let host = request.uri().host().ok_or(Error::Url(UrlError::MissingHost))?;
57 let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
58 let port = uri.port_u16().unwrap_or(match mode {
59 Mode::Plain => 80,
60 Mode::Tls => 443,
61 });
62 let addresses = (host, port).to_socket_addrs()?;
63
64 let mut stream = connect_to_some(addresses.as_slice(), request.uri())?;
65 NoDelay::set_nodelay(&mut stream, true)?;
66
67 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
68 let client = client_with_config(request, SimplifiedStream::Plain(stream), config);
69
70 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
71 let client = crate::tls::client_tls_with_config(request, stream, config, None);
72
73 client.map_err(|e| match e {
74 HandshakeError::Failure(f) => f,
75 HandshakeError::Interrupted(_) => panic!("Bug: blockign handshake not blocked"),
76 })
77 }
78
79 fn create_req(parts: &Parts, uri: &Uri) -> Request {
80 let mut builder =
81 Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
82
83 *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
84 builder.body(()).expect("Failed to create `Request`")
85 }
86
87 let (parts, _) = req.into_client_request()?.into_parts();
88 let mut uri = parts.uri.clone();
89
90 for attempt in 0..=max_redirects {
91 let request = create_req(&parts, &uri);
92
93 match try_client_handshake(request, config) {
94 Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
95 if let Some(location) = res.headers().get("Location") {
96 uri = location.to_str()?.parse::<Uri>()?;
97 continue;
98 } else {
99 return Err(Error::Http(res));
100 }
101 }
102 other => return other,
103 }
104 }
105
106 panic!("Bug in redirect handler")
107}
108
109pub fn connect<Req: IntoClientRequest>(
122 req: Req,
123) -> Result<(WebSocket<SimplifiedStream<TcpStream>>, Response)> {
124 connect_with_config(req, None, 3)
125}
126
127fn connect_to_some(addresses: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
128 for address in addresses {
129 if let Ok(stream) = TcpStream::connect(address) {
130 return Ok(stream);
131 }
132 }
133
134 Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
135}
136
137pub fn client_with_config<Stream, Req>(
144 req: Req,
145 stream: Stream,
146 config: Option<WebSocketConfig>,
147) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
148where
149 Stream: Read + Write,
150 Req: IntoClientRequest,
151{
152 ClientHandshake::start(stream, req.into_client_request()?, config)?.handshake()
153}
154
155pub fn client<Stream, Req>(
161 req: Req,
162 stream: Stream,
163) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
164where
165 Stream: Read + Write,
166 Req: IntoClientRequest,
167{
168 client_with_config(req, stream, None)
169}
170
171pub fn uri_mode(uri: &Uri) -> Result<Mode> {
176 match uri.scheme_str() {
177 Some("ws") => Ok(Mode::Plain),
178 Some("wss") => Ok(Mode::Tls),
179 _ => Err(Error::Url(UrlError::UnsupportedScheme)),
180 }
181}
182
183pub trait IntoClientRequest {
192 fn into_client_request(self) -> Result<Request>;
194}
195
196impl IntoClientRequest for &str {
197 fn into_client_request(self) -> Result<Request> {
198 self.parse::<Uri>()?.into_client_request()
199 }
200}
201
202impl IntoClientRequest for &String {
203 fn into_client_request(self) -> Result<Request> {
204 <&str as IntoClientRequest>::into_client_request(self)
205 }
206}
207
208impl IntoClientRequest for String {
209 fn into_client_request(self) -> Result<Request> {
210 <&str as IntoClientRequest>::into_client_request(&self)
211 }
212}
213
214impl IntoClientRequest for &Uri {
215 fn into_client_request(self) -> Result<Request> {
216 self.clone().into_client_request()
217 }
218}
219
220impl IntoClientRequest for Uri {
221 fn into_client_request(self) -> Result<Request> {
222 let authority = self.authority().ok_or(Error::Url(UrlError::MissingHost))?.as_str();
223 let host = authority
224 .find('@')
225 .map(|index| authority.split_at(index + 1).1)
226 .unwrap_or_else(|| authority);
227
228 if host.is_empty() {
229 return Err(Error::Url(UrlError::EmptyHost));
230 }
231
232 let req = Request::builder()
233 .method("GET")
234 .header("Host", host)
235 .header("Connection", "Upgrade")
236 .header("Upgrade", "websocket")
237 .header("Sec-WebSocket-Version", "13")
238 .header("Sec-WebSocket-Key", generate_key())
239 .uri(self)
240 .body(())?;
241
242 Ok(req)
243 }
244}
245
246#[cfg(feature = "url")]
247impl IntoClientRequest for &url::Url {
248 fn into_client_request(self) -> Result<Request> {
249 self.as_str().into_client_request()
250 }
251}
252
253#[cfg(feature = "url")]
254impl IntoClientRequest for url::Url {
255 fn into_client_request(self) -> Result<Request> {
256 self.as_str().into_client_request()
257 }
258}
259
260impl IntoClientRequest for Request {
261 fn into_client_request(self) -> Result<Request> {
262 Ok(self)
263 }
264}
265
266impl IntoClientRequest for httparse::Request<'_, '_> {
267 fn into_client_request(self) -> Result<Request> {
268 use crate::handshake::headers::FromHttparse;
269 Request::from_httparse(self)
270 }
271}
272
273#[derive(Debug, Clone)]
291pub struct ClientRequestBuilder {
292 uri: Uri,
293 additional_headers: Vec<(String, String)>,
295 subprotocols: Vec<String>,
297}
298
299impl ClientRequestBuilder {
300 #[must_use]
302 pub const fn new(uri: Uri) -> Self {
303 Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
304 }
305
306 pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
308 where
309 K: Into<String>,
310 V: Into<String>,
311 {
312 self.additional_headers.push((key.into(), value.into()));
313 self
314 }
315
316 pub fn with_subprotocol<P>(mut self, protocol: P) -> Self
318 where
319 P: Into<String>,
320 {
321 self.subprotocols.push(protocol.into());
322 self
323 }
324}
325
326impl IntoClientRequest for ClientRequestBuilder {
327 fn into_client_request(self) -> Result<Request> {
328 let mut req = self.uri.into_client_request()?;
329 let headers = req.headers_mut();
330
331 for (k, v) in self.additional_headers {
332 let key = HeaderName::try_from(k)?;
333 let value = v.parse()?;
334
335 headers.append(key, value);
336 }
337
338 if !self.subprotocols.is_empty() {
339 let protocols = self.subprotocols.join(", ").parse()?;
340 headers.append("Sec-WebSocket-Protocol", protocols);
341 }
342
343 Ok(req)
344 }
345}