1use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
2use rustls::client::WebPkiServerVerifier;
3use rustls::server::{Acceptor, WebPkiClientVerifier};
4use rustls::{
5 ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, ServerConfig,
6 SignatureScheme,
7};
8use rustls_pki_types::{
9 CertificateDer, CertificateRevocationListDer, DnsName, ServerName, UnixTime,
10};
11use rustls_platform_verifier::Verifier;
12use rustls_tokio_stream::TlsStream;
13use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
14
15use super::tokio_stream::TokioStream;
16use crate::{
17 AsHandle, LocalAddress, PeerCred, RemoteAddress, ResolvedTarget, RewindStream, SslError,
18 SslVersion, Stream, StreamMetadata, TlsClientCertVerify, TlsDriver, TlsHandshake,
19 TlsServerParameterProvider, TlsServerParameters, Transport,
20};
21use crate::{TlsCert, TlsParameters, TlsServerCertVerify};
22use std::borrow::Cow;
23use std::mem::MaybeUninit;
24use std::net::{IpAddr, Ipv4Addr};
25use std::sync::Arc;
26
27#[derive(Default)]
28pub struct RustlsDriver;
29
30impl TlsDriver for RustlsDriver {
31 type Stream = TlsStream;
32 type ClientParams = ClientConnection;
33 type ServerParams = Arc<ServerConfig>;
34 const DRIVER_NAME: &'static str = "rustls";
35
36 fn init_client(
37 params: &TlsParameters,
38 name: Option<ServerName>,
39 ) -> Result<Self::ClientParams, SslError> {
40 let _ = ::rustls::crypto::ring::default_provider().install_default();
41
42 let TlsParameters {
43 server_cert_verify,
44 root_cert,
45 cert,
46 key,
47 crl,
48 min_protocol_version: _,
49 max_protocol_version: _,
50 alpn,
51 enable_keylog,
52 sni_override,
53 } = params;
54
55 let verifier = make_verifier(server_cert_verify, root_cert, crl.clone())?;
56
57 let config = ClientConfig::builder()
58 .dangerous()
59 .with_custom_certificate_verifier(verifier);
60
61 let mut config = if let (Some(cert), Some(key)) = (cert, key) {
63 config
64 .with_client_auth_cert(vec![cert.clone()], key.clone_key())
65 .map_err(|_| {
66 std::io::Error::new(
67 std::io::ErrorKind::InvalidInput,
68 "Failed to set client auth cert",
69 )
70 })?
71 } else {
72 config.with_no_client_auth()
73 };
74
75 config.alpn_protocols = alpn.as_vec_vec();
77
78 if *enable_keylog {
80 config.key_log = Arc::new(rustls::KeyLogFile::new());
81 }
82
83 let name = if let Some(sni_override) = sni_override {
84 ServerName::try_from(sni_override.to_string())?
85 } else if let Some(name) = name {
86 name.to_owned()
87 } else {
88 config.enable_sni = false;
89 ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from_bits(0)).into())
90 };
91
92 Ok(ClientConnection::new(Arc::new(config), name)?)
93 }
94
95 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
96 let builder = match ¶ms.client_cert_verify {
97 TlsClientCertVerify::Ignore => ServerConfig::builder().with_no_client_auth(),
98 TlsClientCertVerify::Optional(certs) => {
99 let mut roots = RootCertStore::empty();
100 roots.add_parsable_certificates(
101 certs.iter().map(|c| CertificateDer::from_slice(c.as_ref())),
102 );
103 ServerConfig::builder().with_client_cert_verifier(
104 WebPkiClientVerifier::builder(roots.into())
105 .allow_unauthenticated()
106 .build()?,
107 )
108 }
109 TlsClientCertVerify::Validate(certs) => {
110 let mut roots = RootCertStore::empty();
111 roots.add_parsable_certificates(
112 certs.iter().map(|c| CertificateDer::from_slice(c.as_ref())),
113 );
114 ServerConfig::builder()
115 .with_client_cert_verifier(WebPkiClientVerifier::builder(roots.into()).build()?)
116 }
117 };
118
119 let mut config = builder.with_single_cert(
120 vec![params.server_certificate.cert.clone()],
121 params.server_certificate.key.clone_key(),
122 )?;
123
124 config.alpn_protocols = params.alpn.as_vec_vec();
125
126 Ok(Arc::new(config))
127 }
128
129 async fn upgrade_client<S: Stream>(
130 params: Self::ClientParams,
131 stream: S,
132 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
133 let stream = stream
135 .downcast::<TokioStream>()
136 .map_err(|_| crate::SslError::SslUnsupported)?;
137 let TokioStream::Tcp(stream) = stream else {
138 return Err(crate::SslError::SslUnsupported);
139 };
140
141 let mut stream = TlsStream::new_client_side(stream, params, None);
142 match stream.handshake().await {
143 Ok(handshake) => {
144 let cert = stream
145 .connection()
146 .and_then(|c| c.peer_certificates())
147 .and_then(|c| c.first().map(|cert| cert.to_owned()));
148 let version = stream.connection().and_then(|c| c.protocol_version());
149 Ok((
150 stream,
151 TlsHandshake {
152 alpn: handshake.alpn.map(|alpn| Cow::Owned(alpn.to_vec())),
153 sni: handshake.sni.and_then(|s| DnsName::try_from(s).ok()),
154 cert,
155 version: match version {
156 Some(rustls::ProtocolVersion::TLSv1_0) => Some(SslVersion::Tls1),
157 Some(rustls::ProtocolVersion::TLSv1_1) => Some(SslVersion::Tls1_1),
158 Some(rustls::ProtocolVersion::TLSv1_2) => Some(SslVersion::Tls1_2),
159 Some(rustls::ProtocolVersion::TLSv1_3) => Some(SslVersion::Tls1_3),
160 _ => None,
161 },
162 },
163 ))
164 }
165 Err(e) => {
166 let kind = e.kind();
167 if let Some(e2) = e.into_inner() {
168 match e2.downcast::<::rustls::Error>() {
169 Ok(e) => Err(crate::SslError::RustlsError(*e)),
170 Err(e) => Err(std::io::Error::new(kind, e).into()),
171 }
172 } else {
173 Err(std::io::Error::from(kind).into())
174 }
175 }
176 }
177 }
178
179 async fn upgrade_server<S: Stream>(
180 params: TlsServerParameterProvider,
181 stream: S,
182 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
183 let (stream, mut acceptor) = match stream.downcast::<RewindStream<TokioStream>>() {
184 Ok(stream) => {
185 let (stream, buffer) = stream.into_inner();
186 let mut acceptor = Acceptor::default();
187 acceptor.read_tls(&mut buffer.as_slice())?;
188 (stream, acceptor)
189 }
190 Err(stream) => {
191 let Ok(stream) = stream.downcast::<TokioStream>() else {
192 return Err(crate::SslError::SslUnsupported);
193 };
194 (stream, Acceptor::default())
195 }
196 };
197
198 let TokioStream::Tcp(mut stream) = stream else {
199 return Err(crate::SslError::SslUnsupported);
200 };
201
202 let mut buf = [MaybeUninit::uninit(); 1024];
203 let accepted = loop {
204 match acceptor.accept() {
205 Ok(Some(accept)) => break accept,
206 Ok(None) => {
207 let mut buf = ReadBuf::uninit(&mut buf);
208 stream.read_buf(&mut buf).await?;
209 acceptor.read_tls(&mut buf.filled())?;
210 }
211 Err((e, mut b)) => {
212 let mut buf = [0_u8; 1024];
213 loop {
214 let w = b.write(&mut buf.as_mut_slice())?;
215 if w == 0 {
216 break;
217 }
218 stream.write_all(&buf[..w]).await?;
219 }
220 return Err(e.into());
221 }
222 }
223 };
224
225 let hello = accepted.client_hello();
226 let server_name = hello
227 .server_name()
228 .and_then(|name| DnsName::try_from(name).ok());
229
230 let params = params.lookup(server_name, &stream);
231 let config = RustlsDriver::init_server(¶ms)
232 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
233 let conn = match accepted.into_connection(config) {
234 Ok(conn) => conn,
235 Err((e, mut b)) => {
236 let mut buf = [0_u8; 1024];
237 loop {
238 let w = b.write(&mut buf.as_mut_slice())?;
239 if w == 0 {
240 break;
241 }
242 stream.write_all(&buf[..w]).await?;
243 }
244 return Err(e.into());
245 }
246 };
247 let mut stream = TlsStream::new_server_side_from(stream, conn, None);
248
249 match stream.handshake().await {
250 Ok(handshake) => {
251 let cert = stream
252 .connection()
253 .and_then(|c| c.peer_certificates())
254 .and_then(|c| c.first().map(|cert| cert.to_owned()));
255 let version = stream.connection().and_then(|c| c.protocol_version());
256 Ok((
257 stream,
258 TlsHandshake {
259 alpn: handshake.alpn.map(|alpn| Cow::Owned(alpn.to_vec())),
260 sni: handshake
261 .sni
262 .and_then(|s| DnsName::try_from(s.to_string()).ok()),
263 cert,
264 version: match version {
265 Some(rustls::ProtocolVersion::TLSv1_0) => Some(SslVersion::Tls1),
266 Some(rustls::ProtocolVersion::TLSv1_1) => Some(SslVersion::Tls1_1),
267 Some(rustls::ProtocolVersion::TLSv1_2) => Some(SslVersion::Tls1_2),
268 Some(rustls::ProtocolVersion::TLSv1_3) => Some(SslVersion::Tls1_3),
269 _ => None,
270 },
271 },
272 ))
273 }
274 Err(e) => {
275 let kind = e.kind();
276 if let Some(e2) = e.into_inner() {
277 match e2.downcast::<::rustls::Error>() {
278 Ok(e) => Err(crate::SslError::RustlsError(*e)),
279 Err(e) => Err(std::io::Error::new(kind, e).into()),
280 }
281 } else {
282 Err(std::io::Error::from(kind).into())
283 }
284 }
285 }
286 }
287
288 fn unclean_shutdown(this: Self::Stream) -> Result<(), Self::Stream> {
289 this.try_into_inner().map(drop)
291 }
292}
293
294fn make_roots(
295 root_certs: &[CertificateDer<'static>],
296 webpki: bool,
297) -> Result<RootCertStore, crate::SslError> {
298 let mut roots = RootCertStore::empty();
299 if webpki {
300 let webpki_roots = webpki_roots::TLS_SERVER_ROOTS;
301 roots.extend(webpki_roots.iter().cloned());
302 }
303 let (loaded, ignored) = roots.add_parsable_certificates(root_certs.iter().cloned());
304 if !root_certs.is_empty() && (loaded == 0 || ignored > 0) {
305 return Err(
306 std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid certificate").into(),
307 );
308 }
309 Ok(roots)
310}
311
312fn make_verifier(
313 server_cert_verify: &TlsServerCertVerify,
314 root_cert: &TlsCert,
315 crls: Vec<CertificateRevocationListDer<'static>>,
316) -> Result<Arc<dyn ServerCertVerifier>, crate::SslError> {
317 if *server_cert_verify == TlsServerCertVerify::Insecure {
318 return Ok(Arc::new(NullVerifier));
319 }
320
321 if matches!(
322 root_cert,
323 TlsCert::Webpki | TlsCert::WebpkiPlus(_) | TlsCert::Custom(_)
324 ) {
325 let roots = match root_cert {
326 TlsCert::Webpki => make_roots(&[], true),
327 TlsCert::Custom(roots) => make_roots(roots, false),
328 TlsCert::WebpkiPlus(roots) => make_roots(roots, true),
329 _ => unreachable!(),
330 }?;
331
332 let verifier = WebPkiServerVerifier::builder(Arc::new(roots))
333 .with_crls(crls)
334 .build()?;
335 if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
336 return Ok(Arc::new(IgnoreHostnameVerifier::new(verifier)));
337 }
338 return Ok(verifier);
339 }
340
341 let verifier: Arc<dyn ServerCertVerifier> = if let TlsCert::SystemPlus(roots) = root_cert {
344 let roots = make_roots(roots, false)?;
345 let v1 = WebPkiServerVerifier::builder(Arc::new(roots))
346 .with_crls(crls)
347 .build()?;
348 let v2 = Arc::new(Verifier::new());
349 Arc::new(ChainingVerifier::new(v1, v2))
350 } else {
351 Arc::new(ErrorFilteringVerifier::new(Arc::new(Verifier::new())))
352 };
353
354 let verifier: Arc<dyn ServerCertVerifier> =
355 if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
356 Arc::new(IgnoreHostnameVerifier::new(verifier))
357 } else {
358 verifier
359 };
360
361 Ok(verifier)
362}
363
364#[derive(Debug)]
365struct IgnoreHostnameVerifier {
366 verifier: Arc<dyn ServerCertVerifier>,
367}
368
369impl IgnoreHostnameVerifier {
370 fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
371 Self { verifier }
372 }
373}
374
375impl ServerCertVerifier for IgnoreHostnameVerifier {
376 fn verify_server_cert(
377 &self,
378 end_entity: &CertificateDer<'_>,
379 intermediates: &[CertificateDer<'_>],
380 server_name: &ServerName,
381 ocsp_response: &[u8],
382 now: UnixTime,
383 ) -> Result<ServerCertVerified, rustls::Error> {
384 match self.verifier.verify_server_cert(
385 end_entity,
386 intermediates,
387 server_name,
388 ocsp_response,
389 now,
390 ) {
391 Ok(res) => Ok(res),
392 Err(rustls::Error::InvalidCertificate(
394 rustls::CertificateError::NotValidForName
395 | rustls::CertificateError::NotValidForNameContext { .. },
396 )) => Ok(ServerCertVerified::assertion()),
397 Err(e) => Err(e),
398 }
399 }
400
401 fn verify_tls12_signature(
402 &self,
403 message: &[u8],
404 cert: &CertificateDer<'_>,
405 dss: &DigitallySignedStruct,
406 ) -> Result<HandshakeSignatureValid, rustls::Error> {
407 self.verifier.verify_tls12_signature(message, cert, dss)
408 }
409
410 fn verify_tls13_signature(
411 &self,
412 message: &[u8],
413 cert: &CertificateDer<'_>,
414 dss: &DigitallySignedStruct,
415 ) -> Result<HandshakeSignatureValid, rustls::Error> {
416 self.verifier.verify_tls13_signature(message, cert, dss)
417 }
418
419 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
420 self.verifier.supported_verify_schemes()
421 }
422}
423
424#[derive(Debug)]
425struct ChainingVerifier {
426 verifier1: Arc<dyn ServerCertVerifier>,
427 verifier2: Arc<dyn ServerCertVerifier>,
428}
429
430impl ChainingVerifier {
431 fn new(verifier1: Arc<dyn ServerCertVerifier>, verifier2: Arc<dyn ServerCertVerifier>) -> Self {
432 Self {
433 verifier1,
434 verifier2,
435 }
436 }
437}
438
439impl ServerCertVerifier for ChainingVerifier {
440 fn verify_server_cert(
441 &self,
442 end_entity: &CertificateDer<'_>,
443 intermediates: &[CertificateDer<'_>],
444 server_name: &ServerName,
445 ocsp_response: &[u8],
446 now: UnixTime,
447 ) -> Result<ServerCertVerified, rustls::Error> {
448 let res = self.verifier1.verify_server_cert(
449 end_entity,
450 intermediates,
451 server_name,
452 ocsp_response,
453 now,
454 );
455 if let Ok(res) = res {
456 return Ok(res);
457 }
458
459 let res2 = self.verifier2.verify_server_cert(
460 end_entity,
461 intermediates,
462 server_name,
463 ocsp_response,
464 now,
465 );
466 if let Ok(res) = res2 {
467 return Ok(res);
468 }
469
470 res
471 }
472
473 fn verify_tls12_signature(
474 &self,
475 message: &[u8],
476 cert: &CertificateDer<'_>,
477 dss: &DigitallySignedStruct,
478 ) -> Result<HandshakeSignatureValid, rustls::Error> {
479 let res = self.verifier1.verify_tls12_signature(message, cert, dss);
480 if let Ok(res) = res {
481 return Ok(res);
482 }
483
484 let res2 = self.verifier2.verify_tls12_signature(message, cert, dss);
485 if let Ok(res) = res2 {
486 return Ok(res);
487 }
488
489 res
490 }
491
492 fn verify_tls13_signature(
493 &self,
494 message: &[u8],
495 cert: &CertificateDer<'_>,
496 dss: &DigitallySignedStruct,
497 ) -> Result<HandshakeSignatureValid, rustls::Error> {
498 let res = self.verifier1.verify_tls13_signature(message, cert, dss);
499 if let Ok(res) = res {
500 return Ok(res);
501 }
502
503 let res2 = self.verifier2.verify_tls13_signature(message, cert, dss);
504 if let Ok(res) = res2 {
505 return Ok(res);
506 }
507
508 res
509 }
510
511 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
512 self.verifier1.supported_verify_schemes()
513 }
514}
515
516#[derive(Debug)]
517struct NullVerifier;
518
519impl ServerCertVerifier for NullVerifier {
520 fn verify_server_cert(
521 &self,
522 _end_entity: &CertificateDer<'_>,
523 _intermediates: &[CertificateDer<'_>],
524 _server_name: &ServerName,
525 _ocsp_response: &[u8],
526 _now: UnixTime,
527 ) -> Result<ServerCertVerified, rustls::Error> {
528 Ok(ServerCertVerified::assertion())
529 }
530
531 fn verify_tls12_signature(
532 &self,
533 _message: &[u8],
534 _cert: &CertificateDer<'_>,
535 _dss: &DigitallySignedStruct,
536 ) -> Result<HandshakeSignatureValid, rustls::Error> {
537 Ok(HandshakeSignatureValid::assertion())
538 }
539
540 fn verify_tls13_signature(
541 &self,
542 _message: &[u8],
543 _cert: &CertificateDer<'_>,
544 _dss: &DigitallySignedStruct,
545 ) -> Result<HandshakeSignatureValid, rustls::Error> {
546 Ok(HandshakeSignatureValid::assertion())
547 }
548
549 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
550 use SignatureScheme::*;
551 vec![
552 RSA_PKCS1_SHA1,
553 ECDSA_SHA1_Legacy,
554 RSA_PKCS1_SHA256,
555 ECDSA_NISTP256_SHA256,
556 RSA_PKCS1_SHA384,
557 ECDSA_NISTP384_SHA384,
558 RSA_PKCS1_SHA512,
559 ECDSA_NISTP521_SHA512,
560 RSA_PSS_SHA256,
561 RSA_PSS_SHA384,
562 RSA_PSS_SHA512,
563 ED25519,
564 ED448,
565 ]
566 }
567}
568
569#[derive(Debug)]
570struct ErrorFilteringVerifier {
571 verifier: Arc<dyn ServerCertVerifier>,
572}
573
574impl ErrorFilteringVerifier {
575 fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
576 Self { verifier }
577 }
578
579 fn filter_err<T>(res: Result<T, rustls::Error>) -> Result<T, rustls::Error> {
580 match res {
581 Ok(res) => Ok(res),
582 #[cfg(target_vendor = "apple")]
588 Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(e)))
589 if e.to_string().contains("-67901") =>
590 {
591 Err(rustls::Error::InvalidCertificate(
592 rustls::CertificateError::UnknownIssuer,
593 ))
594 }
595 Err(e) => Err(e),
596 }
597 }
598}
599
600impl ServerCertVerifier for ErrorFilteringVerifier {
601 fn verify_server_cert(
602 &self,
603 end_entity: &CertificateDer<'_>,
604 intermediates: &[CertificateDer<'_>],
605 server_name: &ServerName,
606 ocsp_response: &[u8],
607 now: UnixTime,
608 ) -> Result<ServerCertVerified, rustls::Error> {
609 Self::filter_err(self.verifier.verify_server_cert(
610 end_entity,
611 intermediates,
612 server_name,
613 ocsp_response,
614 now,
615 ))
616 }
617
618 fn verify_tls12_signature(
619 &self,
620 message: &[u8],
621 cert: &CertificateDer<'_>,
622 dss: &DigitallySignedStruct,
623 ) -> Result<HandshakeSignatureValid, rustls::Error> {
624 Self::filter_err(self.verifier.verify_tls12_signature(message, cert, dss))
625 }
626
627 fn verify_tls13_signature(
628 &self,
629 message: &[u8],
630 cert: &CertificateDer<'_>,
631 dss: &DigitallySignedStruct,
632 ) -> Result<HandshakeSignatureValid, rustls::Error> {
633 Self::filter_err(self.verifier.verify_tls13_signature(message, cert, dss))
634 }
635
636 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
637 self.verifier.supported_verify_schemes()
638 }
639}
640
641impl LocalAddress for TlsStream {
642 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
643 self.local_addr().map(|addr| ResolvedTarget::from(addr))
644 }
645}
646
647impl RemoteAddress for TlsStream {
648 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
649 self.peer_addr().map(|addr| ResolvedTarget::from(addr))
650 }
651}
652
653impl PeerCred for TlsStream {
654 #[cfg(all(unix, feature = "tokio"))]
655 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
656 Err(std::io::Error::new(
657 std::io::ErrorKind::Unsupported,
658 "TCP streams do not support peer credentials",
659 ))
660 }
661}
662
663impl StreamMetadata for TlsStream {
664 fn transport(&self) -> Transport {
665 Transport::Tcp
666 }
667}
668
669impl AsHandle for TlsStream {
670 #[cfg(windows)]
671 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
672 std::os::windows::io::AsSocket::as_socket(self.tcp_stream().unwrap())
673 }
674
675 #[cfg(unix)]
676 fn as_fd(&self) -> std::os::fd::BorrowedFd {
677 std::os::fd::AsFd::as_fd(self.tcp_stream().unwrap())
678 }
679}