blitz_ws/
client.rs

1//! Utilities to connect to a WebSocket as a client
2
3use std::{
4    io::{Read, Write},
5    net::{SocketAddr, TcpStream, ToSocketAddrs},
6    result::Result as StdResult,
7};
8
9use http::{request::Parts, HeaderName, Uri};
10
11use crate::{
12    error::{Error, Result, UrlError},
13    handshake::{
14        client::{generate_key, ClientHandshake, Request, Response},
15        core::HandshakeError,
16    },
17    protocol::{config::WebSocketConfig, websocket::WebSocket},
18    stream::{Mode, NoDelay, SimplifiedStream},
19};
20
21/// Connect to the given WebSocket in blocking mode.
22///
23/// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is
24/// equal to calling `connect()` function.
25///
26/// The URL may be either ws:// or wss://.
27/// To support wss:// URLs, you must activate the TLS feature on the crate level. Please refer to the
28/// project's [README][readme] for more information on available features.
29///
30/// This function "just works" for those who wants a simple blocking solution
31/// similar to `std::net::TcpStream`. If you want a non-blocking or other
32/// custom stream, call `client` instead.
33///
34/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
35/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
36/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
37///
38/// [readme]: https://github.com/risuleia/blitz/#features
39pub fn connect_with_config<Req: IntoClientRequest>(
40    req: Req,
41    config: Option<WebSocketConfig>,
42    max_redirects: u8,
43) -> Result<(WebSocket<SimplifiedStream<TcpStream>>, Response)> {
44    fn try_client_handshake(
45        request: Request,
46        config: Option<WebSocketConfig>,
47    ) -> Result<(WebSocket<SimplifiedStream<TcpStream>>, Response)> {
48        let uri = request.uri();
49        let mode = uri_mode(uri)?;
50
51        #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
52        if let Mode::Tls = mode {
53            return Err(Error::Url(UrlError::TlsFeatureNotEnabled));
54        }
55
56        let host = request.uri().host().ok_or(Error::Url(UrlError::MissingHost))?;
57        let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
58        let port = uri.port_u16().unwrap_or(match mode {
59            Mode::Plain => 80,
60            Mode::Tls => 443,
61        });
62        let addresses = (host, port).to_socket_addrs()?;
63
64        let mut stream = connect_to_some(addresses.as_slice(), request.uri())?;
65        NoDelay::set_nodelay(&mut stream, true)?;
66
67        #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
68        let client = client_with_config(request, SimplifiedStream::Plain(stream), config);
69
70        #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
71        let client = crate::tls::client_tls_with_config(request, stream, config, None);
72
73        client.map_err(|e| match e {
74            HandshakeError::Failure(f) => f,
75            HandshakeError::Interrupted(_) => panic!("Bug: blockign handshake not blocked"),
76        })
77    }
78
79    fn create_req(parts: &Parts, uri: &Uri) -> Request {
80        let mut builder =
81            Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
82
83        *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
84        builder.body(()).expect("Failed to create `Request`")
85    }
86
87    let (parts, _) = req.into_client_request()?.into_parts();
88    let mut uri = parts.uri.clone();
89
90    for attempt in 0..=max_redirects {
91        let request = create_req(&parts, &uri);
92
93        match try_client_handshake(request, config) {
94            Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
95                if let Some(location) = res.headers().get("Location") {
96                    uri = location.to_str()?.parse::<Uri>()?;
97                    continue;
98                } else {
99                    return Err(Error::Http(res));
100                }
101            }
102            other => return other,
103        }
104    }
105
106    panic!("Bug in redirect handler")
107}
108
109/// Connect to the given WebSocket in blocking mode.
110///
111/// The URL may be either ws:// or wss://.
112/// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on.
113///
114/// This function "just works" for those who wants a simple blocking solution
115/// similar to `std::net::TcpStream`. If you want a non-blocking or other
116/// custom stream, call `client` instead.
117///
118/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
119/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
120/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
121pub fn connect<Req: IntoClientRequest>(
122    req: Req,
123) -> Result<(WebSocket<SimplifiedStream<TcpStream>>, Response)> {
124    connect_with_config(req, None, 3)
125}
126
127fn connect_to_some(addresses: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
128    for address in addresses {
129        if let Ok(stream) = TcpStream::connect(address) {
130            return Ok(stream);
131        }
132    }
133
134    Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
135}
136
137/// Do the client handshake over the given stream given a web socket configuration. Passing `None`
138/// as configuration is equal to calling `client()` function.
139///
140/// Use this function if you need a nonblocking handshake support or if you
141/// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
142/// Any stream supporting `Read + Write` will do.
143pub fn client_with_config<Stream, Req>(
144    req: Req,
145    stream: Stream,
146    config: Option<WebSocketConfig>,
147) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
148where
149    Stream: Read + Write,
150    Req: IntoClientRequest,
151{
152    ClientHandshake::start(stream, req.into_client_request()?, config)?.handshake()
153}
154
155/// Do the client handshake over the given stream.
156///
157/// Use this function if you need a nonblocking handshake support or if you
158/// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
159/// Any stream supporting `Read + Write` will do.
160pub fn client<Stream, Req>(
161    req: Req,
162    stream: Stream,
163) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
164where
165    Stream: Read + Write,
166    Req: IntoClientRequest,
167{
168    client_with_config(req, stream, None)
169}
170
171/// Get the mode of the given URL.
172///
173/// This function may be used to ease the creation of custom TLS streams
174/// in non-blocking algorithms or for use with TLS libraries other than `native_tls` or `rustls`.
175pub fn uri_mode(uri: &Uri) -> Result<Mode> {
176    match uri.scheme_str() {
177        Some("ws") => Ok(Mode::Plain),
178        Some("wss") => Ok(Mode::Tls),
179        _ => Err(Error::Url(UrlError::UnsupportedScheme)),
180    }
181}
182
183/// Trait for converting various types into HTTP requests used for a client connection.
184///
185/// This trait is implemented by default for string slices, strings, `http::Uri` and
186/// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
187/// simply take your request and pass it as is further without altering any headers or URLs, so
188/// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
189/// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
190/// the proper `http::Request<()>` for you.
191pub trait IntoClientRequest {
192    /// Convert into a `Request` that can be used for a client connection.
193    fn into_client_request(self) -> Result<Request>;
194}
195
196impl IntoClientRequest for &str {
197    fn into_client_request(self) -> Result<Request> {
198        self.parse::<Uri>()?.into_client_request()
199    }
200}
201
202impl IntoClientRequest for &String {
203    fn into_client_request(self) -> Result<Request> {
204        <&str as IntoClientRequest>::into_client_request(self)
205    }
206}
207
208impl IntoClientRequest for String {
209    fn into_client_request(self) -> Result<Request> {
210        <&str as IntoClientRequest>::into_client_request(&self)
211    }
212}
213
214impl IntoClientRequest for &Uri {
215    fn into_client_request(self) -> Result<Request> {
216        self.clone().into_client_request()
217    }
218}
219
220impl IntoClientRequest for Uri {
221    fn into_client_request(self) -> Result<Request> {
222        let authority = self.authority().ok_or(Error::Url(UrlError::MissingHost))?.as_str();
223        let host = authority
224            .find('@')
225            .map(|index| authority.split_at(index + 1).1)
226            .unwrap_or_else(|| authority);
227
228        if host.is_empty() {
229            return Err(Error::Url(UrlError::EmptyHost));
230        }
231
232        let req = Request::builder()
233            .method("GET")
234            .header("Host", host)
235            .header("Connection", "Upgrade")
236            .header("Upgrade", "websocket")
237            .header("Sec-WebSocket-Version", "13")
238            .header("Sec-WebSocket-Key", generate_key())
239            .uri(self)
240            .body(())?;
241
242        Ok(req)
243    }
244}
245
246#[cfg(feature = "url")]
247impl IntoClientRequest for &url::Url {
248    fn into_client_request(self) -> Result<Request> {
249        self.as_str().into_client_request()
250    }
251}
252
253#[cfg(feature = "url")]
254impl IntoClientRequest for url::Url {
255    fn into_client_request(self) -> Result<Request> {
256        self.as_str().into_client_request()
257    }
258}
259
260impl IntoClientRequest for Request {
261    fn into_client_request(self) -> Result<Request> {
262        Ok(self)
263    }
264}
265
266impl IntoClientRequest for httparse::Request<'_, '_> {
267    fn into_client_request(self) -> Result<Request> {
268        use crate::handshake::headers::FromHttparse;
269        Request::from_httparse(self)
270    }
271}
272
273/// Builder for a custom [`IntoClientRequest`] with options to add
274/// custom additional headers and sub protocols.
275///
276/// # Example
277///
278/// ```rust no_run
279/// # use crate::*;
280/// use http::Uri;
281/// use blitz::{connect, ClientRequestBuilder};
282///
283/// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap();
284/// let token = "my_jwt_token";
285/// let builder = ClientRequestBuilder::new(uri)
286///     .with_header("Authorization", format!("Bearer {token}"))
287///     .with_subprotocol("my_sub_protocol");
288/// let socket = connect(builder).unwrap();
289/// ```
290#[derive(Debug, Clone)]
291pub struct ClientRequestBuilder {
292    uri: Uri,
293    /// Additional [`Request`] handshake headers
294    additional_headers: Vec<(String, String)>,
295    /// Handshake subprotocols
296    subprotocols: Vec<String>,
297}
298
299impl ClientRequestBuilder {
300    /// Initializes an empty request builder
301    #[must_use]
302    pub const fn new(uri: Uri) -> Self {
303        Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
304    }
305
306    /// Adds (`key`, `value`) as an additional header to the handshake request
307    pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
308    where
309        K: Into<String>,
310        V: Into<String>,
311    {
312        self.additional_headers.push((key.into(), value.into()));
313        self
314    }
315
316    /// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`)
317    pub fn with_subprotocol<P>(mut self, protocol: P) -> Self
318    where
319        P: Into<String>,
320    {
321        self.subprotocols.push(protocol.into());
322        self
323    }
324}
325
326impl IntoClientRequest for ClientRequestBuilder {
327    fn into_client_request(self) -> Result<Request> {
328        let mut req = self.uri.into_client_request()?;
329        let headers = req.headers_mut();
330
331        for (k, v) in self.additional_headers {
332            let key = HeaderName::try_from(k)?;
333            let value = v.parse()?;
334
335            headers.append(key, value);
336        }
337
338        if !self.subprotocols.is_empty() {
339            let protocols = self.subprotocols.join(", ").parse()?;
340            headers.append("Sec-WebSocket-Protocol", protocols);
341        }
342
343        Ok(req)
344    }
345}