1#[cfg(feature = "use-rustls-no-provider")]
2use tokio_rustls::TlsConnector as RustlsConnector;
3#[cfg(feature = "use-rustls-no-provider")]
4use tokio_rustls::rustls::{
5 self, ClientConfig, RootCertStore,
6 pki_types::{CertificateDer, InvalidDnsNameError, PrivateKeyDer, ServerName, pem::PemObject},
7};
8
9#[cfg(feature = "use-rustls-no-provider")]
10use std::convert::TryFrom;
11#[cfg(feature = "use-rustls-no-provider")]
12use std::sync::Arc;
13
14use crate::{AsyncReadWrite, TlsConfiguration};
15
16#[cfg(feature = "use-native-tls")]
17use tokio_native_tls::TlsConnector as NativeTlsConnector;
18
19#[cfg(feature = "use-native-tls")]
20use tokio_native_tls::native_tls::{Error as NativeTlsError, Identity};
21
22use std::io;
23use std::net::AddrParseError;
24
25#[derive(Debug, thiserror::Error)]
26pub enum Error {
27 #[error("Addr")]
29 Addr(#[from] AddrParseError),
30 #[error("I/O: {0}")]
32 Io(#[from] io::Error),
33 #[cfg(feature = "use-rustls-no-provider")]
34 #[error("Web Pki: {0}")]
36 WebPki(#[from] webpki::Error),
37 #[cfg(feature = "use-rustls-no-provider")]
39 #[error("DNS name")]
40 DNSName(#[from] InvalidDnsNameError),
41 #[cfg(feature = "use-rustls-no-provider")]
42 #[error("TLS error: {0}")]
44 TLS(#[from] rustls::Error),
45 #[cfg(feature = "use-rustls-no-provider")]
46 #[error("PEM parsing error: {0}")]
47 Pem(#[from] rustls::pki_types::pem::Error),
48 #[cfg(feature = "use-rustls-no-provider")]
49 #[error("No valid CA certificate provided")]
51 NoValidCertInChain,
52 #[cfg(feature = "use-rustls-no-provider")]
53 #[error("No valid certificate for client authentication in chain")]
55 NoValidClientCertInChain,
56 #[cfg(feature = "use-rustls-no-provider")]
57 #[error("No valid key in chain")]
59 NoValidKeyInChain,
60 #[cfg(feature = "use-native-tls")]
61 #[error("Native TLS error {0}")]
62 NativeTls(#[from] NativeTlsError),
63}
64
65#[cfg(feature = "use-rustls-no-provider")]
66pub fn rustls_connector(tls_config: &TlsConfiguration) -> Result<RustlsConnector, Error> {
67 let config = match tls_config {
68 TlsConfiguration::Simple {
69 ca,
70 alpn,
71 client_auth,
72 } => {
73 let mut root_cert_store = RootCertStore::empty();
75 let certs = CertificateDer::pem_slice_iter(ca).collect::<Result<Vec<_>, _>>()?;
76
77 root_cert_store.add_parsable_certificates(certs);
78
79 if root_cert_store.is_empty() {
80 return Err(Error::NoValidCertInChain);
81 }
82
83 let config = ClientConfig::builder().with_root_certificates(root_cert_store);
84
85 let mut config = if let Some(client) = client_auth.as_ref() {
87 let certs =
88 CertificateDer::pem_slice_iter(&client.0).collect::<Result<Vec<_>, _>>()?;
89 if certs.is_empty() {
90 return Err(Error::NoValidClientCertInChain);
91 }
92
93 let key = match PrivateKeyDer::from_pem_slice(&client.1) {
94 Ok(key) => key,
95 Err(rustls::pki_types::pem::Error::NoItemsFound) => {
96 return Err(Error::NoValidKeyInChain);
97 }
98 Err(err) => return Err(Error::Pem(err)),
99 };
100
101 config.with_client_auth_cert(certs, key)?
102 } else {
103 config.with_no_client_auth()
104 };
105
106 if let Some(alpn) = alpn.as_ref() {
108 config.alpn_protocols.extend_from_slice(alpn);
109 }
110
111 Arc::new(config)
112 }
113 TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(),
114 #[allow(unreachable_patterns)]
115 _ => unreachable!("This cannot be called for other TLS backends than Rustls"),
116 };
117
118 Ok(RustlsConnector::from(config))
119}
120
121#[cfg(feature = "use-native-tls")]
122pub async fn native_tls_connector(
123 tls_config: &TlsConfiguration,
124) -> Result<NativeTlsConnector, Error> {
125 let connector = match tls_config {
126 TlsConfiguration::SimpleNative { ca, client_auth } => {
127 let cert = native_tls::Certificate::from_pem(ca)?;
128
129 let mut connector_builder = native_tls::TlsConnector::builder();
130 connector_builder.add_root_certificate(cert);
131
132 if let Some((der, password)) = client_auth {
133 let identity = Identity::from_pkcs12(der, password)?;
134 connector_builder.identity(identity);
135 }
136
137 connector_builder.build()?
138 }
139 TlsConfiguration::Native => native_tls::TlsConnector::new()?,
140 TlsConfiguration::NativeConnector(connector) => connector.to_owned(),
141 #[allow(unreachable_patterns)]
142 _ => unreachable!("This cannot be called for other TLS backends than Native TLS"),
143 };
144
145 Ok(connector.into())
146}
147
148#[cfg(all(
151 feature = "websocket",
152 feature = "use-native-tls",
153 not(feature = "use-rustls-no-provider")
154))]
155pub fn websocket_tls_connector(
156 tls_config: &TlsConfiguration,
157) -> Result<tokio_native_tls::TlsConnector, Error> {
158 match tls_config {
159 TlsConfiguration::Native
160 | TlsConfiguration::NativeConnector(_)
161 | TlsConfiguration::SimpleNative { .. } => {
162 let connector = match tls_config {
164 TlsConfiguration::SimpleNative { ca, client_auth } => {
165 let cert = native_tls::Certificate::from_pem(ca)?;
166 let mut connector_builder = native_tls::TlsConnector::builder();
167 connector_builder.add_root_certificate(cert);
168
169 if let Some((der, password)) = client_auth {
170 let identity = Identity::from_pkcs12(der, password)?;
171 connector_builder.identity(identity);
172 }
173
174 connector_builder.build()?
175 }
176 TlsConfiguration::Native => native_tls::TlsConnector::new()?,
177 TlsConfiguration::NativeConnector(connector) => connector.to_owned(),
178 };
180 Ok(connector.into())
181 }
182 #[allow(unreachable_patterns)]
183 _ => panic!("Unknown or not enabled TLS backend configuration"),
184 }
185}
186
187#[cfg(all(
190 feature = "websocket",
191 feature = "use-rustls-no-provider",
192 not(feature = "use-native-tls")
193))]
194pub fn websocket_tls_connector(
195 tls_config: &TlsConfiguration,
196) -> Result<tokio_rustls::TlsConnector, Error> {
197 match tls_config {
198 TlsConfiguration::Simple { .. } | TlsConfiguration::Rustls(_) => {
199 let connector = rustls_connector(tls_config)?;
200 Ok(connector)
201 }
202 #[allow(unreachable_patterns)]
203 _ => panic!("Unknown or not enabled TLS backend configuration"),
204 }
205}
206
207pub async fn tls_connect(
219 addr: &str,
220 _port: u16,
221 tls_config: &TlsConfiguration,
222 tcp: Box<dyn AsyncReadWrite>,
223) -> Result<Box<dyn AsyncReadWrite>, Error> {
224 let tls: Box<dyn AsyncReadWrite> = match tls_config {
225 #[cfg(feature = "use-rustls-no-provider")]
226 TlsConfiguration::Simple { .. } | TlsConfiguration::Rustls(_) => {
227 let connector = rustls_connector(tls_config)?;
228 let domain = ServerName::try_from(addr)?.to_owned();
229 Box::new(connector.connect(domain, tcp).await?)
230 }
231 #[cfg(feature = "use-native-tls")]
232 TlsConfiguration::Native
233 | TlsConfiguration::NativeConnector(_)
234 | TlsConfiguration::SimpleNative { .. } => {
235 let connector = native_tls_connector(tls_config).await?;
236 Box::new(connector.connect(addr, tcp).await?)
237 }
238 #[allow(unreachable_patterns)]
239 _ => panic!("Unknown or not enabled TLS backend configuration"),
240 };
241 Ok(tls)
242}