Skip to main content

rumqttc_core/
tls.rs

1#[cfg(feature = "use-rustls-no-provider")]
2use tokio_rustls::TlsConnector as RustlsConnector;
3#[cfg(feature = "use-rustls-no-provider")]
4use tokio_rustls::rustls::{
5    self, ClientConfig, RootCertStore,
6    pki_types::{CertificateDer, InvalidDnsNameError, PrivateKeyDer, ServerName, pem::PemObject},
7};
8
9#[cfg(feature = "use-rustls-no-provider")]
10use std::convert::TryFrom;
11#[cfg(feature = "use-rustls-no-provider")]
12use std::sync::Arc;
13
14use crate::{AsyncReadWrite, TlsConfiguration};
15
16#[cfg(feature = "use-native-tls")]
17use tokio_native_tls::TlsConnector as NativeTlsConnector;
18
19#[cfg(feature = "use-native-tls")]
20use tokio_native_tls::native_tls::{Error as NativeTlsError, Identity};
21
22use std::io;
23use std::net::AddrParseError;
24
25#[derive(Debug, thiserror::Error)]
26pub enum Error {
27    /// Error parsing IP address
28    #[error("Addr")]
29    Addr(#[from] AddrParseError),
30    /// I/O related error
31    #[error("I/O: {0}")]
32    Io(#[from] io::Error),
33    #[cfg(feature = "use-rustls-no-provider")]
34    /// Certificate/Name validation error
35    #[error("Web Pki: {0}")]
36    WebPki(#[from] webpki::Error),
37    /// Invalid DNS name
38    #[cfg(feature = "use-rustls-no-provider")]
39    #[error("DNS name")]
40    DNSName(#[from] InvalidDnsNameError),
41    #[cfg(feature = "use-rustls-no-provider")]
42    /// Error from rustls module
43    #[error("TLS error: {0}")]
44    TLS(#[from] rustls::Error),
45    #[cfg(feature = "use-rustls-no-provider")]
46    #[error("PEM parsing error: {0}")]
47    Pem(#[from] rustls::pki_types::pem::Error),
48    #[cfg(feature = "use-rustls-no-provider")]
49    /// No valid CA cert found
50    #[error("No valid CA certificate provided")]
51    NoValidCertInChain,
52    #[cfg(feature = "use-rustls-no-provider")]
53    /// No valid client cert found
54    #[error("No valid certificate for client authentication in chain")]
55    NoValidClientCertInChain,
56    #[cfg(feature = "use-rustls-no-provider")]
57    /// No valid key found
58    #[error("No valid key in chain")]
59    NoValidKeyInChain,
60    #[cfg(feature = "use-native-tls")]
61    #[error("Native TLS error {0}")]
62    NativeTls(#[from] NativeTlsError),
63}
64
65#[cfg(feature = "use-rustls-no-provider")]
66pub fn rustls_connector(tls_config: &TlsConfiguration) -> Result<RustlsConnector, Error> {
67    let config = match tls_config {
68        TlsConfiguration::Simple {
69            ca,
70            alpn,
71            client_auth,
72        } => {
73            // Add ca to root store if the connection is TLS
74            let mut root_cert_store = RootCertStore::empty();
75            let certs = CertificateDer::pem_slice_iter(ca).collect::<Result<Vec<_>, _>>()?;
76
77            root_cert_store.add_parsable_certificates(certs);
78
79            if root_cert_store.is_empty() {
80                return Err(Error::NoValidCertInChain);
81            }
82
83            let config = ClientConfig::builder().with_root_certificates(root_cert_store);
84
85            // Add der encoded client cert and key
86            let mut config = if let Some(client) = client_auth.as_ref() {
87                let certs =
88                    CertificateDer::pem_slice_iter(&client.0).collect::<Result<Vec<_>, _>>()?;
89                if certs.is_empty() {
90                    return Err(Error::NoValidClientCertInChain);
91                }
92
93                let key = match PrivateKeyDer::from_pem_slice(&client.1) {
94                    Ok(key) => key,
95                    Err(rustls::pki_types::pem::Error::NoItemsFound) => {
96                        return Err(Error::NoValidKeyInChain);
97                    }
98                    Err(err) => return Err(Error::Pem(err)),
99                };
100
101                config.with_client_auth_cert(certs, key)?
102            } else {
103                config.with_no_client_auth()
104            };
105
106            // Set ALPN
107            if let Some(alpn) = alpn.as_ref() {
108                config.alpn_protocols.extend_from_slice(alpn);
109            }
110
111            Arc::new(config)
112        }
113        TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(),
114        #[allow(unreachable_patterns)]
115        _ => unreachable!("This cannot be called for other TLS backends than Rustls"),
116    };
117
118    Ok(RustlsConnector::from(config))
119}
120
121#[cfg(feature = "use-native-tls")]
122pub async fn native_tls_connector(
123    tls_config: &TlsConfiguration,
124) -> Result<NativeTlsConnector, Error> {
125    let connector = match tls_config {
126        TlsConfiguration::SimpleNative { ca, client_auth } => {
127            let cert = native_tls::Certificate::from_pem(ca)?;
128
129            let mut connector_builder = native_tls::TlsConnector::builder();
130            connector_builder.add_root_certificate(cert);
131
132            if let Some((der, password)) = client_auth {
133                let identity = Identity::from_pkcs12(der, password)?;
134                connector_builder.identity(identity);
135            }
136
137            connector_builder.build()?
138        }
139        TlsConfiguration::Native => native_tls::TlsConnector::new()?,
140        TlsConfiguration::NativeConnector(connector) => connector.to_owned(),
141        #[allow(unreachable_patterns)]
142        _ => unreachable!("This cannot be called for other TLS backends than Native TLS"),
143    };
144
145    Ok(connector.into())
146}
147
148/// Returns the appropriate TLS connector for websocket connections based on the
149/// TLS configuration.
150#[cfg(all(
151    feature = "websocket",
152    feature = "use-native-tls",
153    not(feature = "use-rustls-no-provider")
154))]
155pub fn websocket_tls_connector(
156    tls_config: &TlsConfiguration,
157) -> Result<tokio_native_tls::TlsConnector, Error> {
158    match tls_config {
159        TlsConfiguration::Native
160        | TlsConfiguration::NativeConnector(_)
161        | TlsConfiguration::SimpleNative { .. } => {
162            // For native-tls, we need to use the sync connector for websockets
163            let connector = match tls_config {
164                TlsConfiguration::SimpleNative { ca, client_auth } => {
165                    let cert = native_tls::Certificate::from_pem(ca)?;
166                    let mut connector_builder = native_tls::TlsConnector::builder();
167                    connector_builder.add_root_certificate(cert);
168
169                    if let Some((der, password)) = client_auth {
170                        let identity = Identity::from_pkcs12(der, password)?;
171                        connector_builder.identity(identity);
172                    }
173
174                    connector_builder.build()?
175                }
176                TlsConfiguration::Native => native_tls::TlsConnector::new()?,
177                TlsConfiguration::NativeConnector(connector) => connector.to_owned(),
178                // No need for catch-all: we're inside a match arm that only matches native-tls variants
179            };
180            Ok(connector.into())
181        }
182        #[allow(unreachable_patterns)]
183        _ => panic!("Unknown or not enabled TLS backend configuration"),
184    }
185}
186
187/// Returns the appropriate TLS connector for websocket connections based on the
188/// TLS configuration.
189#[cfg(all(
190    feature = "websocket",
191    feature = "use-rustls-no-provider",
192    not(feature = "use-native-tls")
193))]
194pub fn websocket_tls_connector(
195    tls_config: &TlsConfiguration,
196) -> Result<tokio_rustls::TlsConnector, Error> {
197    match tls_config {
198        TlsConfiguration::Simple { .. } | TlsConfiguration::Rustls(_) => {
199            let connector = rustls_connector(tls_config)?;
200            Ok(connector)
201        }
202        #[allow(unreachable_patterns)]
203        _ => panic!("Unknown or not enabled TLS backend configuration"),
204    }
205}
206
207/// Establishes a TLS stream on top of an already connected transport.
208///
209/// # Errors
210///
211/// Returns any TLS configuration, server-name validation, or handshake error
212/// produced by the selected backend.
213///
214/// # Panics
215///
216/// Panics only if the build enables no backend matching `tls_config`, which
217/// indicates an invalid internal configuration.
218pub async fn tls_connect(
219    addr: &str,
220    _port: u16,
221    tls_config: &TlsConfiguration,
222    tcp: Box<dyn AsyncReadWrite>,
223) -> Result<Box<dyn AsyncReadWrite>, Error> {
224    let tls: Box<dyn AsyncReadWrite> = match tls_config {
225        #[cfg(feature = "use-rustls-no-provider")]
226        TlsConfiguration::Simple { .. } | TlsConfiguration::Rustls(_) => {
227            let connector = rustls_connector(tls_config)?;
228            let domain = ServerName::try_from(addr)?.to_owned();
229            Box::new(connector.connect(domain, tcp).await?)
230        }
231        #[cfg(feature = "use-native-tls")]
232        TlsConfiguration::Native
233        | TlsConfiguration::NativeConnector(_)
234        | TlsConfiguration::SimpleNative { .. } => {
235            let connector = native_tls_connector(tls_config).await?;
236            Box::new(connector.connect(addr, tcp).await?)
237        }
238        #[allow(unreachable_patterns)]
239        _ => panic!("Unknown or not enabled TLS backend configuration"),
240    };
241    Ok(tls)
242}