1use futures::FutureExt;
2use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
3use rustls::client::WebPkiServerVerifier;
4use rustls::server::{Acceptor, ClientHello, WebPkiClientVerifier};
5use rustls::{
6 ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, ServerConfig,
7 SignatureScheme,
8};
9use rustls_pki_types::{
10 CertificateDer, CertificateRevocationListDer, DnsName, ServerName, UnixTime,
11};
12use rustls_platform_verifier::Verifier;
13use rustls_tokio_stream::TlsStream;
14
15use super::tokio_stream::TokioStream;
16use crate::{
17 RewindStream, SslError, SslVersion, Stream, TlsClientCertVerify, TlsDriver, TlsHandshake,
18 TlsServerParameterProvider, TlsServerParameters,
19};
20use crate::{TlsCert, TlsParameters, TlsServerCertVerify};
21use std::borrow::Cow;
22use std::net::{IpAddr, Ipv4Addr};
23use std::sync::Arc;
24
25#[derive(Default)]
26pub struct RustlsDriver;
27
28impl TlsDriver for RustlsDriver {
29 type Stream = TlsStream;
30 type ClientParams = ClientConnection;
31 type ServerParams = Arc<ServerConfig>;
32 const DRIVER_NAME: &'static str = "rustls";
33
34 fn init_client(
35 params: &TlsParameters,
36 name: Option<ServerName>,
37 ) -> Result<Self::ClientParams, SslError> {
38 let _ = ::rustls::crypto::ring::default_provider().install_default();
39
40 let TlsParameters {
41 server_cert_verify,
42 root_cert,
43 cert,
44 key,
45 crl,
46 min_protocol_version: _,
47 max_protocol_version: _,
48 alpn,
49 enable_keylog,
50 sni_override,
51 } = params;
52
53 let verifier = make_verifier(server_cert_verify, root_cert, crl.clone())?;
54
55 let config = ClientConfig::builder()
56 .dangerous()
57 .with_custom_certificate_verifier(verifier);
58
59 let mut config = if let (Some(cert), Some(key)) = (cert, key) {
61 config
62 .with_client_auth_cert(vec![cert.clone()], key.clone_key())
63 .map_err(|_| {
64 std::io::Error::new(
65 std::io::ErrorKind::InvalidInput,
66 "Failed to set client auth cert",
67 )
68 })?
69 } else {
70 config.with_no_client_auth()
71 };
72
73 config.alpn_protocols = alpn.as_vec_vec();
75
76 if *enable_keylog {
78 config.key_log = Arc::new(rustls::KeyLogFile::new());
79 }
80
81 let name = if let Some(sni_override) = sni_override {
82 ServerName::try_from(sni_override.to_string())?
83 } else if let Some(name) = name {
84 name.to_owned()
85 } else {
86 config.enable_sni = false;
87 ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from_bits(0)).into())
88 };
89
90 Ok(ClientConnection::new(Arc::new(config), name)?)
91 }
92
93 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
94 let builder = match ¶ms.client_cert_verify {
95 TlsClientCertVerify::Ignore => ServerConfig::builder().with_no_client_auth(),
96 TlsClientCertVerify::Optional(certs) => {
97 let mut roots = RootCertStore::empty();
98 roots.add_parsable_certificates(
99 certs.iter().map(|c| CertificateDer::from_slice(c.as_ref())),
100 );
101 ServerConfig::builder().with_client_cert_verifier(
102 WebPkiClientVerifier::builder(roots.into())
103 .allow_unauthenticated()
104 .build()?,
105 )
106 }
107 TlsClientCertVerify::Validate(certs) => {
108 let mut roots = RootCertStore::empty();
109 roots.add_parsable_certificates(
110 certs.iter().map(|c| CertificateDer::from_slice(c.as_ref())),
111 );
112 ServerConfig::builder()
113 .with_client_cert_verifier(WebPkiClientVerifier::builder(roots.into()).build()?)
114 }
115 };
116
117 let mut config = builder.with_single_cert(
118 vec![params.server_certificate.cert.clone()],
119 params.server_certificate.key.clone_key(),
120 )?;
121
122 config.alpn_protocols = params.alpn.as_vec_vec();
123
124 Ok(Arc::new(config))
125 }
126
127 async fn upgrade_client<S: Stream>(
128 params: Self::ClientParams,
129 stream: S,
130 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
131 let stream = stream
133 .downcast::<TokioStream>()
134 .map_err(|_| crate::SslError::SslUnsupportedByClient)?;
135 let TokioStream::Tcp(stream) = stream else {
136 return Err(crate::SslError::SslUnsupportedByClient);
137 };
138
139 let mut stream = TlsStream::new_client_side(stream, params, None);
140 match stream.handshake().await {
141 Ok(handshake) => {
142 let cert = stream
143 .connection()
144 .and_then(|c| c.peer_certificates())
145 .and_then(|c| c.first().map(|cert| cert.to_owned()));
146 let version = stream.connection().and_then(|c| c.protocol_version());
147 Ok((
148 stream,
149 TlsHandshake {
150 alpn: handshake.alpn.map(|alpn| Cow::Owned(alpn.to_vec())),
151 sni: handshake.sni.map(|sni| Cow::Owned(sni.to_string())),
152 cert,
153 version: match version {
154 Some(rustls::ProtocolVersion::TLSv1_0) => Some(SslVersion::Tls1),
155 Some(rustls::ProtocolVersion::TLSv1_1) => Some(SslVersion::Tls1_1),
156 Some(rustls::ProtocolVersion::TLSv1_2) => Some(SslVersion::Tls1_2),
157 Some(rustls::ProtocolVersion::TLSv1_3) => Some(SslVersion::Tls1_3),
158 _ => None,
159 },
160 },
161 ))
162 }
163 Err(e) => {
164 let kind = e.kind();
165 if let Some(e2) = e.into_inner() {
166 match e2.downcast::<::rustls::Error>() {
167 Ok(e) => Err(crate::SslError::RustlsError(*e)),
168 Err(e) => Err(std::io::Error::new(kind, e).into()),
169 }
170 } else {
171 Err(std::io::Error::from(kind).into())
172 }
173 }
174 }
175 }
176
177 async fn upgrade_server<S: Stream>(
178 params: TlsServerParameterProvider,
179 stream: S,
180 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
181 let stream = stream
182 .downcast::<RewindStream<TokioStream>>()
183 .map_err(|_| crate::SslError::SslUnsupportedByClient)?;
184 let (stream, buffer) = stream.into_inner();
185 let TokioStream::Tcp(stream) = stream else {
186 return Err(crate::SslError::SslUnsupportedByClient);
187 };
188
189 let mut acceptor = Acceptor::default();
190 acceptor.read_tls(&mut buffer.as_slice())?;
191 let server_config_provider = Arc::new(move |client_hello: ClientHello| {
192 let params = params.clone();
193 let server_name = client_hello
194 .server_name()
195 .map(|name| ServerName::DnsName(DnsName::try_from(name.to_string()).unwrap()));
196 async move {
197 let params = params.lookup(server_name);
198 let config = RustlsDriver::init_server(¶ms)
199 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
200 Ok::<_, std::io::Error>(config)
201 }
202 .boxed()
203 });
204 let mut stream = TlsStream::new_server_side_from_acceptor(
205 acceptor,
206 stream,
207 server_config_provider,
208 None,
209 );
210
211 match stream.handshake().await {
212 Ok(handshake) => {
213 let cert = stream
214 .connection()
215 .and_then(|c| c.peer_certificates())
216 .and_then(|c| c.first().map(|cert| cert.to_owned()));
217 let version = stream.connection().and_then(|c| c.protocol_version());
218 Ok((
219 stream,
220 TlsHandshake {
221 alpn: handshake.alpn.map(|alpn| Cow::Owned(alpn.to_vec())),
222 sni: handshake.sni.map(|sni| Cow::Owned(sni.to_string())),
223 cert,
224 version: match version {
225 Some(rustls::ProtocolVersion::TLSv1_0) => Some(SslVersion::Tls1),
226 Some(rustls::ProtocolVersion::TLSv1_1) => Some(SslVersion::Tls1_1),
227 Some(rustls::ProtocolVersion::TLSv1_2) => Some(SslVersion::Tls1_2),
228 Some(rustls::ProtocolVersion::TLSv1_3) => Some(SslVersion::Tls1_3),
229 _ => None,
230 },
231 },
232 ))
233 }
234 Err(e) => {
235 let kind = e.kind();
236 if let Some(e2) = e.into_inner() {
237 match e2.downcast::<::rustls::Error>() {
238 Ok(e) => Err(crate::SslError::RustlsError(*e)),
239 Err(e) => Err(std::io::Error::new(kind, e).into()),
240 }
241 } else {
242 Err(std::io::Error::from(kind).into())
243 }
244 }
245 }
246 }
247
248 fn unclean_shutdown(this: Self::Stream) -> Result<(), Self::Stream> {
249 this.try_into_inner().map(drop)
251 }
252}
253
254fn make_roots(
255 root_certs: &[CertificateDer<'static>],
256 webpki: bool,
257) -> Result<RootCertStore, crate::SslError> {
258 let mut roots = RootCertStore::empty();
259 if webpki {
260 let webpki_roots = webpki_roots::TLS_SERVER_ROOTS;
261 roots.extend(webpki_roots.iter().cloned());
262 }
263 let (loaded, ignored) = roots.add_parsable_certificates(root_certs.iter().cloned());
264 if !root_certs.is_empty() && (loaded == 0 || ignored > 0) {
265 return Err(
266 std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid certificate").into(),
267 );
268 }
269 Ok(roots)
270}
271
272fn make_verifier(
273 server_cert_verify: &TlsServerCertVerify,
274 root_cert: &TlsCert,
275 crls: Vec<CertificateRevocationListDer<'static>>,
276) -> Result<Arc<dyn ServerCertVerifier>, crate::SslError> {
277 if *server_cert_verify == TlsServerCertVerify::Insecure {
278 return Ok(Arc::new(NullVerifier));
279 }
280
281 if matches!(
282 root_cert,
283 TlsCert::Webpki | TlsCert::WebpkiPlus(_) | TlsCert::Custom(_)
284 ) {
285 let roots = match root_cert {
286 TlsCert::Webpki => make_roots(&[], true),
287 TlsCert::Custom(roots) => make_roots(roots, false),
288 TlsCert::WebpkiPlus(roots) => make_roots(roots, true),
289 _ => unreachable!(),
290 }?;
291
292 let verifier = WebPkiServerVerifier::builder(Arc::new(roots))
293 .with_crls(crls)
294 .build()?;
295 if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
296 return Ok(Arc::new(IgnoreHostnameVerifier::new(verifier)));
297 }
298 return Ok(verifier);
299 }
300
301 let verifier: Arc<dyn ServerCertVerifier> = if let TlsCert::SystemPlus(roots) = root_cert {
304 let roots = make_roots(roots, false)?;
305 let v1 = WebPkiServerVerifier::builder(Arc::new(roots))
306 .with_crls(crls)
307 .build()?;
308 let v2 = Arc::new(Verifier::new());
309 Arc::new(ChainingVerifier::new(v1, v2))
310 } else {
311 Arc::new(ErrorFilteringVerifier::new(Arc::new(Verifier::new())))
312 };
313
314 let verifier: Arc<dyn ServerCertVerifier> =
315 if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
316 Arc::new(IgnoreHostnameVerifier::new(verifier))
317 } else {
318 verifier
319 };
320
321 Ok(verifier)
322}
323
324#[derive(Debug)]
325struct IgnoreHostnameVerifier {
326 verifier: Arc<dyn ServerCertVerifier>,
327}
328
329impl IgnoreHostnameVerifier {
330 fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
331 Self { verifier }
332 }
333}
334
335impl ServerCertVerifier for IgnoreHostnameVerifier {
336 fn verify_server_cert(
337 &self,
338 end_entity: &CertificateDer<'_>,
339 intermediates: &[CertificateDer<'_>],
340 server_name: &ServerName,
341 ocsp_response: &[u8],
342 now: UnixTime,
343 ) -> Result<ServerCertVerified, rustls::Error> {
344 match self.verifier.verify_server_cert(
345 end_entity,
346 intermediates,
347 server_name,
348 ocsp_response,
349 now,
350 ) {
351 Ok(res) => Ok(res),
352 Err(rustls::Error::InvalidCertificate(
354 rustls::CertificateError::NotValidForName
355 | rustls::CertificateError::NotValidForNameContext { .. },
356 )) => Ok(ServerCertVerified::assertion()),
357 Err(e) => Err(e),
358 }
359 }
360
361 fn verify_tls12_signature(
362 &self,
363 message: &[u8],
364 cert: &CertificateDer<'_>,
365 dss: &DigitallySignedStruct,
366 ) -> Result<HandshakeSignatureValid, rustls::Error> {
367 self.verifier.verify_tls12_signature(message, cert, dss)
368 }
369
370 fn verify_tls13_signature(
371 &self,
372 message: &[u8],
373 cert: &CertificateDer<'_>,
374 dss: &DigitallySignedStruct,
375 ) -> Result<HandshakeSignatureValid, rustls::Error> {
376 self.verifier.verify_tls13_signature(message, cert, dss)
377 }
378
379 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
380 self.verifier.supported_verify_schemes()
381 }
382}
383
384#[derive(Debug)]
385struct ChainingVerifier {
386 verifier1: Arc<dyn ServerCertVerifier>,
387 verifier2: Arc<dyn ServerCertVerifier>,
388}
389
390impl ChainingVerifier {
391 fn new(verifier1: Arc<dyn ServerCertVerifier>, verifier2: Arc<dyn ServerCertVerifier>) -> Self {
392 Self {
393 verifier1,
394 verifier2,
395 }
396 }
397}
398
399impl ServerCertVerifier for ChainingVerifier {
400 fn verify_server_cert(
401 &self,
402 end_entity: &CertificateDer<'_>,
403 intermediates: &[CertificateDer<'_>],
404 server_name: &ServerName,
405 ocsp_response: &[u8],
406 now: UnixTime,
407 ) -> Result<ServerCertVerified, rustls::Error> {
408 let res = self.verifier1.verify_server_cert(
409 end_entity,
410 intermediates,
411 server_name,
412 ocsp_response,
413 now,
414 );
415 if let Ok(res) = res {
416 return Ok(res);
417 }
418
419 let res2 = self.verifier2.verify_server_cert(
420 end_entity,
421 intermediates,
422 server_name,
423 ocsp_response,
424 now,
425 );
426 if let Ok(res) = res2 {
427 return Ok(res);
428 }
429
430 res
431 }
432
433 fn verify_tls12_signature(
434 &self,
435 message: &[u8],
436 cert: &CertificateDer<'_>,
437 dss: &DigitallySignedStruct,
438 ) -> Result<HandshakeSignatureValid, rustls::Error> {
439 let res = self.verifier1.verify_tls12_signature(message, cert, dss);
440 if let Ok(res) = res {
441 return Ok(res);
442 }
443
444 let res2 = self.verifier2.verify_tls12_signature(message, cert, dss);
445 if let Ok(res) = res2 {
446 return Ok(res);
447 }
448
449 res
450 }
451
452 fn verify_tls13_signature(
453 &self,
454 message: &[u8],
455 cert: &CertificateDer<'_>,
456 dss: &DigitallySignedStruct,
457 ) -> Result<HandshakeSignatureValid, rustls::Error> {
458 let res = self.verifier1.verify_tls13_signature(message, cert, dss);
459 if let Ok(res) = res {
460 return Ok(res);
461 }
462
463 let res2 = self.verifier2.verify_tls13_signature(message, cert, dss);
464 if let Ok(res) = res2 {
465 return Ok(res);
466 }
467
468 res
469 }
470
471 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
472 self.verifier1.supported_verify_schemes()
473 }
474}
475
476#[derive(Debug)]
477struct NullVerifier;
478
479impl ServerCertVerifier for NullVerifier {
480 fn verify_server_cert(
481 &self,
482 _end_entity: &CertificateDer<'_>,
483 _intermediates: &[CertificateDer<'_>],
484 _server_name: &ServerName,
485 _ocsp_response: &[u8],
486 _now: UnixTime,
487 ) -> Result<ServerCertVerified, rustls::Error> {
488 Ok(ServerCertVerified::assertion())
489 }
490
491 fn verify_tls12_signature(
492 &self,
493 _message: &[u8],
494 _cert: &CertificateDer<'_>,
495 _dss: &DigitallySignedStruct,
496 ) -> Result<HandshakeSignatureValid, rustls::Error> {
497 Ok(HandshakeSignatureValid::assertion())
498 }
499
500 fn verify_tls13_signature(
501 &self,
502 _message: &[u8],
503 _cert: &CertificateDer<'_>,
504 _dss: &DigitallySignedStruct,
505 ) -> Result<HandshakeSignatureValid, rustls::Error> {
506 Ok(HandshakeSignatureValid::assertion())
507 }
508
509 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
510 use SignatureScheme::*;
511 vec![
512 RSA_PKCS1_SHA1,
513 ECDSA_SHA1_Legacy,
514 RSA_PKCS1_SHA256,
515 ECDSA_NISTP256_SHA256,
516 RSA_PKCS1_SHA384,
517 ECDSA_NISTP384_SHA384,
518 RSA_PKCS1_SHA512,
519 ECDSA_NISTP521_SHA512,
520 RSA_PSS_SHA256,
521 RSA_PSS_SHA384,
522 RSA_PSS_SHA512,
523 ED25519,
524 ED448,
525 ]
526 }
527}
528
529#[derive(Debug)]
530struct ErrorFilteringVerifier {
531 verifier: Arc<dyn ServerCertVerifier>,
532}
533
534impl ErrorFilteringVerifier {
535 fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
536 Self { verifier }
537 }
538
539 fn filter_err<T>(res: Result<T, rustls::Error>) -> Result<T, rustls::Error> {
540 match res {
541 Ok(res) => Ok(res),
542 #[cfg(target_vendor = "apple")]
548 Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(e)))
549 if e.to_string().contains("-67901") =>
550 {
551 Err(rustls::Error::InvalidCertificate(
552 rustls::CertificateError::UnknownIssuer,
553 ))
554 }
555 Err(e) => Err(e),
556 }
557 }
558}
559
560impl ServerCertVerifier for ErrorFilteringVerifier {
561 fn verify_server_cert(
562 &self,
563 end_entity: &CertificateDer<'_>,
564 intermediates: &[CertificateDer<'_>],
565 server_name: &ServerName,
566 ocsp_response: &[u8],
567 now: UnixTime,
568 ) -> Result<ServerCertVerified, rustls::Error> {
569 Self::filter_err(self.verifier.verify_server_cert(
570 end_entity,
571 intermediates,
572 server_name,
573 ocsp_response,
574 now,
575 ))
576 }
577
578 fn verify_tls12_signature(
579 &self,
580 message: &[u8],
581 cert: &CertificateDer<'_>,
582 dss: &DigitallySignedStruct,
583 ) -> Result<HandshakeSignatureValid, rustls::Error> {
584 Self::filter_err(self.verifier.verify_tls12_signature(message, cert, dss))
585 }
586
587 fn verify_tls13_signature(
588 &self,
589 message: &[u8],
590 cert: &CertificateDer<'_>,
591 dss: &DigitallySignedStruct,
592 ) -> Result<HandshakeSignatureValid, rustls::Error> {
593 Self::filter_err(self.verifier.verify_tls13_signature(message, cert, dss))
594 }
595
596 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
597 self.verifier.supported_verify_schemes()
598 }
599}