gel_stream/common/
rustls.rs

1use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
2use rustls::client::WebPkiServerVerifier;
3use rustls::server::{Acceptor, WebPkiClientVerifier};
4use rustls::{
5    ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, ServerConfig,
6    SignatureScheme,
7};
8use rustls_pki_types::{
9    CertificateDer, CertificateRevocationListDer, DnsName, ServerName, UnixTime,
10};
11use rustls_platform_verifier::Verifier;
12use rustls_tokio_stream::TlsStream;
13use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
14
15use super::tokio_stream::TokioStream;
16use crate::{
17    AsHandle, LocalAddress, PeerCred, RemoteAddress, ResolvedTarget, RewindStream, SslError,
18    SslVersion, Stream, StreamMetadata, TlsClientCertVerify, TlsDriver, TlsHandshake,
19    TlsServerParameterProvider, TlsServerParameters, Transport,
20};
21use crate::{TlsCert, TlsParameters, TlsServerCertVerify};
22use std::borrow::Cow;
23use std::mem::MaybeUninit;
24use std::net::{IpAddr, Ipv4Addr};
25use std::sync::Arc;
26
27#[derive(Default)]
28pub struct RustlsDriver;
29
30impl TlsDriver for RustlsDriver {
31    type Stream = TlsStream;
32    type ClientParams = ClientConnection;
33    type ServerParams = Arc<ServerConfig>;
34    const DRIVER_NAME: &'static str = "rustls";
35
36    fn init_client(
37        params: &TlsParameters,
38        name: Option<ServerName>,
39    ) -> Result<Self::ClientParams, SslError> {
40        let _ = ::rustls::crypto::ring::default_provider().install_default();
41
42        let TlsParameters {
43            server_cert_verify,
44            root_cert,
45            cert,
46            key,
47            crl,
48            min_protocol_version: _,
49            max_protocol_version: _,
50            alpn,
51            enable_keylog,
52            sni_override,
53        } = params;
54
55        let verifier = make_verifier(server_cert_verify, root_cert, crl.clone())?;
56
57        let config = ClientConfig::builder()
58            .dangerous()
59            .with_custom_certificate_verifier(verifier);
60
61        // Load client certificate and key if provided
62        let mut config = if let (Some(cert), Some(key)) = (cert, key) {
63            config
64                .with_client_auth_cert(vec![cert.clone()], key.clone_key())
65                .map_err(|_| {
66                    std::io::Error::new(
67                        std::io::ErrorKind::InvalidInput,
68                        "Failed to set client auth cert",
69                    )
70                })?
71        } else {
72            config.with_no_client_auth()
73        };
74
75        // Configure ALPN if provided
76        config.alpn_protocols = alpn.as_vec_vec();
77
78        // Configure keylog if provided
79        if *enable_keylog {
80            config.key_log = Arc::new(rustls::KeyLogFile::new());
81        }
82
83        let name = if let Some(sni_override) = sni_override {
84            ServerName::try_from(sni_override.to_string())?
85        } else if let Some(name) = name {
86            name.to_owned()
87        } else {
88            config.enable_sni = false;
89            ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from_bits(0)).into())
90        };
91
92        Ok(ClientConnection::new(Arc::new(config), name)?)
93    }
94
95    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
96        let builder = match &params.client_cert_verify {
97            TlsClientCertVerify::Ignore => ServerConfig::builder().with_no_client_auth(),
98            TlsClientCertVerify::Optional(certs) => {
99                let mut roots = RootCertStore::empty();
100                roots.add_parsable_certificates(
101                    certs.iter().map(|c| CertificateDer::from_slice(c.as_ref())),
102                );
103                ServerConfig::builder().with_client_cert_verifier(
104                    WebPkiClientVerifier::builder(roots.into())
105                        .allow_unauthenticated()
106                        .build()?,
107                )
108            }
109            TlsClientCertVerify::Validate(certs) => {
110                let mut roots = RootCertStore::empty();
111                roots.add_parsable_certificates(
112                    certs.iter().map(|c| CertificateDer::from_slice(c.as_ref())),
113                );
114                ServerConfig::builder()
115                    .with_client_cert_verifier(WebPkiClientVerifier::builder(roots.into()).build()?)
116            }
117        };
118
119        let mut config = builder.with_single_cert(
120            vec![params.server_certificate.cert.clone()],
121            params.server_certificate.key.clone_key(),
122        )?;
123
124        config.alpn_protocols = params.alpn.as_vec_vec();
125
126        Ok(Arc::new(config))
127    }
128
129    async fn upgrade_client<S: Stream>(
130        params: Self::ClientParams,
131        stream: S,
132    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
133        // Note that we only support Tokio TcpStream for rustls.
134        let stream = stream
135            .downcast::<TokioStream>()
136            .map_err(|_| crate::SslError::SslUnsupported)?;
137        let TokioStream::Tcp(stream) = stream else {
138            return Err(crate::SslError::SslUnsupported);
139        };
140
141        let mut stream = TlsStream::new_client_side(stream, params, None);
142        match stream.handshake().await {
143            Ok(handshake) => {
144                let cert = stream
145                    .connection()
146                    .and_then(|c| c.peer_certificates())
147                    .and_then(|c| c.first().map(|cert| cert.to_owned()));
148                let version = stream.connection().and_then(|c| c.protocol_version());
149                Ok((
150                    stream,
151                    TlsHandshake {
152                        alpn: handshake.alpn.map(|alpn| Cow::Owned(alpn.to_vec())),
153                        sni: handshake.sni.and_then(|s| DnsName::try_from(s).ok()),
154                        cert,
155                        version: match version {
156                            Some(rustls::ProtocolVersion::TLSv1_0) => Some(SslVersion::Tls1),
157                            Some(rustls::ProtocolVersion::TLSv1_1) => Some(SslVersion::Tls1_1),
158                            Some(rustls::ProtocolVersion::TLSv1_2) => Some(SslVersion::Tls1_2),
159                            Some(rustls::ProtocolVersion::TLSv1_3) => Some(SslVersion::Tls1_3),
160                            _ => None,
161                        },
162                    },
163                ))
164            }
165            Err(e) => {
166                let kind = e.kind();
167                if let Some(e2) = e.into_inner() {
168                    match e2.downcast::<::rustls::Error>() {
169                        Ok(e) => Err(crate::SslError::RustlsError(*e)),
170                        Err(e) => Err(std::io::Error::new(kind, e).into()),
171                    }
172                } else {
173                    Err(std::io::Error::from(kind).into())
174                }
175            }
176        }
177    }
178
179    async fn upgrade_server<S: Stream>(
180        params: TlsServerParameterProvider,
181        stream: S,
182    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
183        let (stream, mut acceptor) = match stream.downcast::<RewindStream<TokioStream>>() {
184            Ok(stream) => {
185                let (stream, buffer) = stream.into_inner();
186                let mut acceptor = Acceptor::default();
187                acceptor.read_tls(&mut buffer.as_slice())?;
188                (stream, acceptor)
189            }
190            Err(stream) => {
191                let Ok(stream) = stream.downcast::<TokioStream>() else {
192                    return Err(crate::SslError::SslUnsupported);
193                };
194                (stream, Acceptor::default())
195            }
196        };
197
198        let TokioStream::Tcp(mut stream) = stream else {
199            return Err(crate::SslError::SslUnsupported);
200        };
201
202        let mut buf = [MaybeUninit::uninit(); 1024];
203        let accepted = loop {
204            match acceptor.accept() {
205                Ok(Some(accept)) => break accept,
206                Ok(None) => {
207                    let mut buf = ReadBuf::uninit(&mut buf);
208                    stream.read_buf(&mut buf).await?;
209                    acceptor.read_tls(&mut buf.filled())?;
210                }
211                Err((e, mut b)) => {
212                    let mut buf = [0_u8; 1024];
213                    loop {
214                        let w = b.write(&mut buf.as_mut_slice())?;
215                        if w == 0 {
216                            break;
217                        }
218                        stream.write_all(&buf[..w]).await?;
219                    }
220                    return Err(e.into());
221                }
222            }
223        };
224
225        let hello = accepted.client_hello();
226        let server_name = hello
227            .server_name()
228            .and_then(|name| DnsName::try_from(name).ok());
229
230        let params = params.lookup(server_name, &stream);
231        let config = RustlsDriver::init_server(&params)
232            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
233        let conn = match accepted.into_connection(config) {
234            Ok(conn) => conn,
235            Err((e, mut b)) => {
236                let mut buf = [0_u8; 1024];
237                loop {
238                    let w = b.write(&mut buf.as_mut_slice())?;
239                    if w == 0 {
240                        break;
241                    }
242                    stream.write_all(&buf[..w]).await?;
243                }
244                return Err(e.into());
245            }
246        };
247        let mut stream = TlsStream::new_server_side_from(stream, conn, None);
248
249        match stream.handshake().await {
250            Ok(handshake) => {
251                let cert = stream
252                    .connection()
253                    .and_then(|c| c.peer_certificates())
254                    .and_then(|c| c.first().map(|cert| cert.to_owned()));
255                let version = stream.connection().and_then(|c| c.protocol_version());
256                Ok((
257                    stream,
258                    TlsHandshake {
259                        alpn: handshake.alpn.map(|alpn| Cow::Owned(alpn.to_vec())),
260                        sni: handshake
261                            .sni
262                            .and_then(|s| DnsName::try_from(s.to_string()).ok()),
263                        cert,
264                        version: match version {
265                            Some(rustls::ProtocolVersion::TLSv1_0) => Some(SslVersion::Tls1),
266                            Some(rustls::ProtocolVersion::TLSv1_1) => Some(SslVersion::Tls1_1),
267                            Some(rustls::ProtocolVersion::TLSv1_2) => Some(SslVersion::Tls1_2),
268                            Some(rustls::ProtocolVersion::TLSv1_3) => Some(SslVersion::Tls1_3),
269                            _ => None,
270                        },
271                    },
272                ))
273            }
274            Err(e) => {
275                let kind = e.kind();
276                if let Some(e2) = e.into_inner() {
277                    match e2.downcast::<::rustls::Error>() {
278                        Ok(e) => Err(crate::SslError::RustlsError(*e)),
279                        Err(e) => Err(std::io::Error::new(kind, e).into()),
280                    }
281                } else {
282                    Err(std::io::Error::from(kind).into())
283                }
284            }
285        }
286    }
287
288    fn unclean_shutdown(this: Self::Stream) -> Result<(), Self::Stream> {
289        // Skip the shutdown logic by tearing this down into its parts.
290        this.try_into_inner().map(drop)
291    }
292}
293
294fn make_roots(
295    root_certs: &[CertificateDer<'static>],
296    webpki: bool,
297) -> Result<RootCertStore, crate::SslError> {
298    let mut roots = RootCertStore::empty();
299    if webpki {
300        let webpki_roots = webpki_roots::TLS_SERVER_ROOTS;
301        roots.extend(webpki_roots.iter().cloned());
302    }
303    let (loaded, ignored) = roots.add_parsable_certificates(root_certs.iter().cloned());
304    if !root_certs.is_empty() && (loaded == 0 || ignored > 0) {
305        return Err(
306            std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid certificate").into(),
307        );
308    }
309    Ok(roots)
310}
311
312fn make_verifier(
313    server_cert_verify: &TlsServerCertVerify,
314    root_cert: &TlsCert,
315    crls: Vec<CertificateRevocationListDer<'static>>,
316) -> Result<Arc<dyn ServerCertVerifier>, crate::SslError> {
317    if *server_cert_verify == TlsServerCertVerify::Insecure {
318        return Ok(Arc::new(NullVerifier));
319    }
320
321    if matches!(
322        root_cert,
323        TlsCert::Webpki | TlsCert::WebpkiPlus(_) | TlsCert::Custom(_)
324    ) {
325        let roots = match root_cert {
326            TlsCert::Webpki => make_roots(&[], true),
327            TlsCert::Custom(roots) => make_roots(roots, false),
328            TlsCert::WebpkiPlus(roots) => make_roots(roots, true),
329            _ => unreachable!(),
330        }?;
331
332        let verifier = WebPkiServerVerifier::builder(Arc::new(roots))
333            .with_crls(crls)
334            .build()?;
335        if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
336            return Ok(Arc::new(IgnoreHostnameVerifier::new(verifier)));
337        }
338        return Ok(verifier);
339    }
340
341    // We need to work around macOS returning `certificate is not standards compliant: -67901`
342    // when using the system verifier.
343    let verifier: Arc<dyn ServerCertVerifier> = if let TlsCert::SystemPlus(roots) = root_cert {
344        let roots = make_roots(roots, false)?;
345        let v1 = WebPkiServerVerifier::builder(Arc::new(roots))
346            .with_crls(crls)
347            .build()?;
348        let v2 = Arc::new(Verifier::new());
349        Arc::new(ChainingVerifier::new(v1, v2))
350    } else {
351        Arc::new(ErrorFilteringVerifier::new(Arc::new(Verifier::new())))
352    };
353
354    let verifier: Arc<dyn ServerCertVerifier> =
355        if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
356            Arc::new(IgnoreHostnameVerifier::new(verifier))
357        } else {
358            verifier
359        };
360
361    Ok(verifier)
362}
363
364#[derive(Debug)]
365struct IgnoreHostnameVerifier {
366    verifier: Arc<dyn ServerCertVerifier>,
367}
368
369impl IgnoreHostnameVerifier {
370    fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
371        Self { verifier }
372    }
373}
374
375impl ServerCertVerifier for IgnoreHostnameVerifier {
376    fn verify_server_cert(
377        &self,
378        end_entity: &CertificateDer<'_>,
379        intermediates: &[CertificateDer<'_>],
380        server_name: &ServerName,
381        ocsp_response: &[u8],
382        now: UnixTime,
383    ) -> Result<ServerCertVerified, rustls::Error> {
384        match self.verifier.verify_server_cert(
385            end_entity,
386            intermediates,
387            server_name,
388            ocsp_response,
389            now,
390        ) {
391            Ok(res) => Ok(res),
392            // This works because the name check is the last step in the verify process
393            Err(rustls::Error::InvalidCertificate(
394                rustls::CertificateError::NotValidForName
395                | rustls::CertificateError::NotValidForNameContext { .. },
396            )) => Ok(ServerCertVerified::assertion()),
397            Err(e) => Err(e),
398        }
399    }
400
401    fn verify_tls12_signature(
402        &self,
403        message: &[u8],
404        cert: &CertificateDer<'_>,
405        dss: &DigitallySignedStruct,
406    ) -> Result<HandshakeSignatureValid, rustls::Error> {
407        self.verifier.verify_tls12_signature(message, cert, dss)
408    }
409
410    fn verify_tls13_signature(
411        &self,
412        message: &[u8],
413        cert: &CertificateDer<'_>,
414        dss: &DigitallySignedStruct,
415    ) -> Result<HandshakeSignatureValid, rustls::Error> {
416        self.verifier.verify_tls13_signature(message, cert, dss)
417    }
418
419    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
420        self.verifier.supported_verify_schemes()
421    }
422}
423
424#[derive(Debug)]
425struct ChainingVerifier {
426    verifier1: Arc<dyn ServerCertVerifier>,
427    verifier2: Arc<dyn ServerCertVerifier>,
428}
429
430impl ChainingVerifier {
431    fn new(verifier1: Arc<dyn ServerCertVerifier>, verifier2: Arc<dyn ServerCertVerifier>) -> Self {
432        Self {
433            verifier1,
434            verifier2,
435        }
436    }
437}
438
439impl ServerCertVerifier for ChainingVerifier {
440    fn verify_server_cert(
441        &self,
442        end_entity: &CertificateDer<'_>,
443        intermediates: &[CertificateDer<'_>],
444        server_name: &ServerName,
445        ocsp_response: &[u8],
446        now: UnixTime,
447    ) -> Result<ServerCertVerified, rustls::Error> {
448        let res = self.verifier1.verify_server_cert(
449            end_entity,
450            intermediates,
451            server_name,
452            ocsp_response,
453            now,
454        );
455        if let Ok(res) = res {
456            return Ok(res);
457        }
458
459        let res2 = self.verifier2.verify_server_cert(
460            end_entity,
461            intermediates,
462            server_name,
463            ocsp_response,
464            now,
465        );
466        if let Ok(res) = res2 {
467            return Ok(res);
468        }
469
470        res
471    }
472
473    fn verify_tls12_signature(
474        &self,
475        message: &[u8],
476        cert: &CertificateDer<'_>,
477        dss: &DigitallySignedStruct,
478    ) -> Result<HandshakeSignatureValid, rustls::Error> {
479        let res = self.verifier1.verify_tls12_signature(message, cert, dss);
480        if let Ok(res) = res {
481            return Ok(res);
482        }
483
484        let res2 = self.verifier2.verify_tls12_signature(message, cert, dss);
485        if let Ok(res) = res2 {
486            return Ok(res);
487        }
488
489        res
490    }
491
492    fn verify_tls13_signature(
493        &self,
494        message: &[u8],
495        cert: &CertificateDer<'_>,
496        dss: &DigitallySignedStruct,
497    ) -> Result<HandshakeSignatureValid, rustls::Error> {
498        let res = self.verifier1.verify_tls13_signature(message, cert, dss);
499        if let Ok(res) = res {
500            return Ok(res);
501        }
502
503        let res2 = self.verifier2.verify_tls13_signature(message, cert, dss);
504        if let Ok(res) = res2 {
505            return Ok(res);
506        }
507
508        res
509    }
510
511    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
512        self.verifier1.supported_verify_schemes()
513    }
514}
515
516#[derive(Debug)]
517struct NullVerifier;
518
519impl ServerCertVerifier for NullVerifier {
520    fn verify_server_cert(
521        &self,
522        _end_entity: &CertificateDer<'_>,
523        _intermediates: &[CertificateDer<'_>],
524        _server_name: &ServerName,
525        _ocsp_response: &[u8],
526        _now: UnixTime,
527    ) -> Result<ServerCertVerified, rustls::Error> {
528        Ok(ServerCertVerified::assertion())
529    }
530
531    fn verify_tls12_signature(
532        &self,
533        _message: &[u8],
534        _cert: &CertificateDer<'_>,
535        _dss: &DigitallySignedStruct,
536    ) -> Result<HandshakeSignatureValid, rustls::Error> {
537        Ok(HandshakeSignatureValid::assertion())
538    }
539
540    fn verify_tls13_signature(
541        &self,
542        _message: &[u8],
543        _cert: &CertificateDer<'_>,
544        _dss: &DigitallySignedStruct,
545    ) -> Result<HandshakeSignatureValid, rustls::Error> {
546        Ok(HandshakeSignatureValid::assertion())
547    }
548
549    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
550        use SignatureScheme::*;
551        vec![
552            RSA_PKCS1_SHA1,
553            ECDSA_SHA1_Legacy,
554            RSA_PKCS1_SHA256,
555            ECDSA_NISTP256_SHA256,
556            RSA_PKCS1_SHA384,
557            ECDSA_NISTP384_SHA384,
558            RSA_PKCS1_SHA512,
559            ECDSA_NISTP521_SHA512,
560            RSA_PSS_SHA256,
561            RSA_PSS_SHA384,
562            RSA_PSS_SHA512,
563            ED25519,
564            ED448,
565        ]
566    }
567}
568
569#[derive(Debug)]
570struct ErrorFilteringVerifier {
571    verifier: Arc<dyn ServerCertVerifier>,
572}
573
574impl ErrorFilteringVerifier {
575    fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
576        Self { verifier }
577    }
578
579    fn filter_err<T>(res: Result<T, rustls::Error>) -> Result<T, rustls::Error> {
580        match res {
581            Ok(res) => Ok(res),
582            // On macOS, the system verifier returns `certificate is not
583            // standards compliant: -67901` for self-signed certificates that
584            // have too long of a validity period. It's probably better if we
585            // eventually have the WebPki verifier handle certs as a fallback to
586            // ensure a better error is returned.
587            #[cfg(target_vendor = "apple")]
588            Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(e)))
589                if e.to_string().contains("-67901") =>
590            {
591                Err(rustls::Error::InvalidCertificate(
592                    rustls::CertificateError::UnknownIssuer,
593                ))
594            }
595            Err(e) => Err(e),
596        }
597    }
598}
599
600impl ServerCertVerifier for ErrorFilteringVerifier {
601    fn verify_server_cert(
602        &self,
603        end_entity: &CertificateDer<'_>,
604        intermediates: &[CertificateDer<'_>],
605        server_name: &ServerName,
606        ocsp_response: &[u8],
607        now: UnixTime,
608    ) -> Result<ServerCertVerified, rustls::Error> {
609        Self::filter_err(self.verifier.verify_server_cert(
610            end_entity,
611            intermediates,
612            server_name,
613            ocsp_response,
614            now,
615        ))
616    }
617
618    fn verify_tls12_signature(
619        &self,
620        message: &[u8],
621        cert: &CertificateDer<'_>,
622        dss: &DigitallySignedStruct,
623    ) -> Result<HandshakeSignatureValid, rustls::Error> {
624        Self::filter_err(self.verifier.verify_tls12_signature(message, cert, dss))
625    }
626
627    fn verify_tls13_signature(
628        &self,
629        message: &[u8],
630        cert: &CertificateDer<'_>,
631        dss: &DigitallySignedStruct,
632    ) -> Result<HandshakeSignatureValid, rustls::Error> {
633        Self::filter_err(self.verifier.verify_tls13_signature(message, cert, dss))
634    }
635
636    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
637        self.verifier.supported_verify_schemes()
638    }
639}
640
641impl LocalAddress for TlsStream {
642    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
643        self.local_addr().map(|addr| ResolvedTarget::from(addr))
644    }
645}
646
647impl RemoteAddress for TlsStream {
648    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
649        self.peer_addr().map(|addr| ResolvedTarget::from(addr))
650    }
651}
652
653impl PeerCred for TlsStream {
654    #[cfg(all(unix, feature = "tokio"))]
655    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
656        Err(std::io::Error::new(
657            std::io::ErrorKind::Unsupported,
658            "TCP streams do not support peer credentials",
659        ))
660    }
661}
662
663impl StreamMetadata for TlsStream {
664    fn transport(&self) -> Transport {
665        Transport::Tcp
666    }
667}
668
669impl AsHandle for TlsStream {
670    #[cfg(windows)]
671    fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
672        std::os::windows::io::AsSocket::as_socket(self.tcp_stream().unwrap())
673    }
674
675    #[cfg(unix)]
676    fn as_fd(&self) -> std::os::fd::BorrowedFd {
677        std::os::fd::AsFd::as_fd(self.tcp_stream().unwrap())
678    }
679}