blitz_ws/
tls.rs

1//! Connection helper
2
3use std::io::{Read, Write};
4
5#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
6use crate::error::{Error, UrlError};
7use crate::{
8    client::{client_with_config, uri_mode, IntoClientRequest},
9    error::Result,
10    handshake::{
11        client::{ClientHandshake, Response},
12        core::HandshakeError,
13    },
14    protocol::{config::WebSocketConfig, websocket::WebSocket},
15    stream::SimplifiedStream,
16};
17
18/// A connector that can be used when establishing connections, allowing to control whether
19/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
20/// `Plain` variant.
21#[non_exhaustive]
22#[allow(missing_debug_implementations)]
23pub enum Connector {
24    /// Plain (non-TLS) connector.
25    Plain,
26
27    /// `native-tls` TLS connector.
28    #[cfg(feature = "native-tls")]
29    NativeTls(native_tls_crate::TlsConnector),
30
31    /// `rustls` TLS connector
32    #[cfg(feature = "__rustls-tls")]
33    Rustls(std::sync::Arc<rustls::ClientConfig>),
34}
35
36mod encryption {
37    #[cfg(feature = "native-tls")]
38    pub mod native_tls {
39        use crate::{
40            error::{Error, Result, TlsError},
41            stream::{Mode, SimplifiedStream},
42        };
43        use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
44        use std::io::{Read, Write};
45
46        pub fn wrap_stream<S>(
47            socket: S,
48            domain: &str,
49            mode: Mode,
50            tls_connection: Option<TlsConnector>,
51        ) -> Result<SimplifiedStream<S>>
52        where
53            S: Read + Write,
54        {
55            match mode {
56                Mode::Plain => Ok(SimplifiedStream::Plain(socket)),
57                Mode::Tls => {
58                    let try_connector = tls_connection.map_or_else(TlsConnector::new, Ok);
59                    let connector = try_connector.map_err(TlsError::Native)?;
60                    let connected = connector.connect(domain, socket);
61
62                    match connected {
63                        Err(e) => match e {
64                            TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
65                            TlsHandshakeError::WouldBlock(_) => {
66                                panic!("Bug: TLS handshake not blocked")
67                            }
68                        },
69                        Ok(s) => Ok(SimplifiedStream::NativeTls(s)),
70                    }
71                }
72            }
73        }
74    }
75
76    #[cfg(feature = "__rustls-tls")]
77    pub mod rustls {
78        use crate::{
79            error::{Result, TlsError},
80            stream::{Mode, SimplifiedStream},
81        };
82        use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
83        use rustls_pki_types::ServerName;
84        use std::{
85            io::{Read, Write},
86            sync::Arc,
87        };
88
89        pub fn wrap_stream<S>(
90            socket: S,
91            domain: &str,
92            mode: Mode,
93            tls_connector: Option<Arc<ClientConfig>>,
94        ) -> Result<SimplifiedStream<S>>
95        where
96            S: Read + Write,
97        {
98            match mode {
99                Mode::Plain => Ok(SimplifiedStream::Plain(socket)),
100                Mode::Tls => {
101                    let config = match tls_connector {
102                        Some(config) => config,
103                        None => {
104                            #[allow(unused_mut)]
105                            let mut root_store = RootCertStore::empty();
106
107                            #[cfg(feature = "rustls-tls-native-roots")]
108                            {
109                                let rustls_native_certs::CertificateResult {
110                                    certs, errors, ..
111                                } = rustls_native_certs::load_native_certs();
112
113                                // #[cfg(not(feature = "rustls-tls-webpki-roots"))]
114                                if certs.is_empty() {
115                                    return Err(std::io::Error::new(
116                                        std::io::ErrorKind::NotFound,
117                                        format!("No native root CA certificates found (errors: {errors:?})")
118                                    ).into());
119                                }
120
121                                // let total = certs.len();
122                                // let (num_added, num_ignored) = root_store.add_parsable_certificates(certs);
123                            }
124
125                            #[cfg(feature = "rustls-tls-webpki-roots")]
126                            {
127                                root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
128                            }
129
130                            Arc::new(
131                                ClientConfig::builder()
132                                    .with_root_certificates(root_store)
133                                    .with_no_client_auth(),
134                            )
135                        }
136                    };
137
138                    let domain = ServerName::try_from(domain)
139                        .map_err(|_| TlsError::InvalidDnsName)?
140                        .to_owned();
141
142                    let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
143                    let stream = StreamOwned::new(client, socket);
144
145                    Ok(SimplifiedStream::Rustls(stream))
146                }
147            }
148        }
149    }
150
151    pub mod plain {
152        use crate::{
153            error::{Error, Result, UrlError},
154            stream::{Mode, SimplifiedStream},
155        };
156        use std::io::{Read, Write};
157
158        pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<SimplifiedStream<S>>
159        where
160            S: Read + Write,
161        {
162            match mode {
163                Mode::Plain => Ok(SimplifiedStream::Plain(socket)),
164                Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
165            }
166        }
167    }
168}
169
170type TlsErrorHandshake<S> = HandshakeError<ClientHandshake<SimplifiedStream<S>>>;
171
172/// Creates a WebSocket handshake from a request and a stream,
173/// upgrading the stream to TLS if required.
174pub fn client_tls<R, S>(
175    request: R,
176    stream: S,
177) -> Result<(WebSocket<SimplifiedStream<S>>, Response), TlsErrorHandshake<S>>
178where
179    R: IntoClientRequest,
180    S: Read + Write,
181{
182    client_tls_with_config(request, stream, None, None)
183}
184
185/// The same as [`client_tls()`] but one can specify a websocket configuration,
186/// and an optional connector. If no connector is specified, a default one will
187/// be created.
188///
189/// Please refer to [`client_tls()`] for more details.
190pub fn client_tls_with_config<R, S>(
191    request: R,
192    stream: S,
193    config: Option<WebSocketConfig>,
194    connector: Option<Connector>,
195) -> Result<(WebSocket<SimplifiedStream<S>>, Response), TlsErrorHandshake<S>>
196where
197    R: IntoClientRequest,
198    S: Read + Write,
199{
200    let request = request.into_client_request()?;
201
202    #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
203    let domain = match request.uri().host() {
204        Some(d) => Ok(d.to_string()),
205        None => Err(Error::Url(UrlError::MissingHost)),
206    }?;
207
208    let mode = uri_mode(request.uri())?;
209
210    let stream = match connector {
211        Some(conn) => match conn {
212            #[cfg(feature = "native-tls")]
213            Connector::NativeTls(conn) => {
214                self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
215            }
216
217            #[cfg(feature = "__rustls-tls")]
218            Connector::Rustls(conn) => {
219                self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
220            }
221
222            Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
223        },
224        None => {
225            #[cfg(feature = "native-tls")]
226            {
227                self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
228            }
229            #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
230            {
231                self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
232            }
233            #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
234            {
235                self::encryption::plain::wrap_stream(stream, mode)
236            }
237        }
238    }?;
239
240    client_with_config(request, stream, config)
241}