gel_stream/common/
rustls.rs

1use futures::FutureExt;
2use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
3use rustls::client::WebPkiServerVerifier;
4use rustls::server::{Acceptor, ClientHello, WebPkiClientVerifier};
5use rustls::{
6    ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, ServerConfig,
7    SignatureScheme,
8};
9use rustls_pki_types::{
10    CertificateDer, CertificateRevocationListDer, DnsName, ServerName, UnixTime,
11};
12use rustls_platform_verifier::Verifier;
13use rustls_tokio_stream::TlsStream;
14
15use super::tokio_stream::TokioStream;
16use crate::{
17    RewindStream, SslError, SslVersion, Stream, TlsClientCertVerify, TlsDriver, TlsHandshake,
18    TlsServerParameterProvider, TlsServerParameters,
19};
20use crate::{TlsCert, TlsParameters, TlsServerCertVerify};
21use std::borrow::Cow;
22use std::net::{IpAddr, Ipv4Addr};
23use std::sync::Arc;
24
25#[derive(Default)]
26pub struct RustlsDriver;
27
28impl TlsDriver for RustlsDriver {
29    type Stream = TlsStream;
30    type ClientParams = ClientConnection;
31    type ServerParams = Arc<ServerConfig>;
32    const DRIVER_NAME: &'static str = "rustls";
33
34    fn init_client(
35        params: &TlsParameters,
36        name: Option<ServerName>,
37    ) -> Result<Self::ClientParams, SslError> {
38        let _ = ::rustls::crypto::ring::default_provider().install_default();
39
40        let TlsParameters {
41            server_cert_verify,
42            root_cert,
43            cert,
44            key,
45            crl,
46            min_protocol_version: _,
47            max_protocol_version: _,
48            alpn,
49            enable_keylog,
50            sni_override,
51        } = params;
52
53        let verifier = make_verifier(server_cert_verify, root_cert, crl.clone())?;
54
55        let config = ClientConfig::builder()
56            .dangerous()
57            .with_custom_certificate_verifier(verifier);
58
59        // Load client certificate and key if provided
60        let mut config = if let (Some(cert), Some(key)) = (cert, key) {
61            config
62                .with_client_auth_cert(vec![cert.clone()], key.clone_key())
63                .map_err(|_| {
64                    std::io::Error::new(
65                        std::io::ErrorKind::InvalidInput,
66                        "Failed to set client auth cert",
67                    )
68                })?
69        } else {
70            config.with_no_client_auth()
71        };
72
73        // Configure ALPN if provided
74        config.alpn_protocols = alpn.as_vec_vec();
75
76        // Configure keylog if provided
77        if *enable_keylog {
78            config.key_log = Arc::new(rustls::KeyLogFile::new());
79        }
80
81        let name = if let Some(sni_override) = sni_override {
82            ServerName::try_from(sni_override.to_string())?
83        } else if let Some(name) = name {
84            name.to_owned()
85        } else {
86            config.enable_sni = false;
87            ServerName::IpAddress(IpAddr::V4(Ipv4Addr::from_bits(0)).into())
88        };
89
90        Ok(ClientConnection::new(Arc::new(config), name)?)
91    }
92
93    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
94        let builder = match &params.client_cert_verify {
95            TlsClientCertVerify::Ignore => ServerConfig::builder().with_no_client_auth(),
96            TlsClientCertVerify::Optional(certs) => {
97                let mut roots = RootCertStore::empty();
98                roots.add_parsable_certificates(
99                    certs.iter().map(|c| CertificateDer::from_slice(c.as_ref())),
100                );
101                ServerConfig::builder().with_client_cert_verifier(
102                    WebPkiClientVerifier::builder(roots.into())
103                        .allow_unauthenticated()
104                        .build()?,
105                )
106            }
107            TlsClientCertVerify::Validate(certs) => {
108                let mut roots = RootCertStore::empty();
109                roots.add_parsable_certificates(
110                    certs.iter().map(|c| CertificateDer::from_slice(c.as_ref())),
111                );
112                ServerConfig::builder()
113                    .with_client_cert_verifier(WebPkiClientVerifier::builder(roots.into()).build()?)
114            }
115        };
116
117        let mut config = builder.with_single_cert(
118            vec![params.server_certificate.cert.clone()],
119            params.server_certificate.key.clone_key(),
120        )?;
121
122        config.alpn_protocols = params.alpn.as_vec_vec();
123
124        Ok(Arc::new(config))
125    }
126
127    async fn upgrade_client<S: Stream>(
128        params: Self::ClientParams,
129        stream: S,
130    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
131        // Note that we only support Tokio TcpStream for rustls.
132        let stream = stream
133            .downcast::<TokioStream>()
134            .map_err(|_| crate::SslError::SslUnsupportedByClient)?;
135        let TokioStream::Tcp(stream) = stream else {
136            return Err(crate::SslError::SslUnsupportedByClient);
137        };
138
139        let mut stream = TlsStream::new_client_side(stream, params, None);
140        match stream.handshake().await {
141            Ok(handshake) => {
142                let cert = stream
143                    .connection()
144                    .and_then(|c| c.peer_certificates())
145                    .and_then(|c| c.first().map(|cert| cert.to_owned()));
146                let version = stream.connection().and_then(|c| c.protocol_version());
147                Ok((
148                    stream,
149                    TlsHandshake {
150                        alpn: handshake.alpn.map(|alpn| Cow::Owned(alpn.to_vec())),
151                        sni: handshake.sni.map(|sni| Cow::Owned(sni.to_string())),
152                        cert,
153                        version: match version {
154                            Some(rustls::ProtocolVersion::TLSv1_0) => Some(SslVersion::Tls1),
155                            Some(rustls::ProtocolVersion::TLSv1_1) => Some(SslVersion::Tls1_1),
156                            Some(rustls::ProtocolVersion::TLSv1_2) => Some(SslVersion::Tls1_2),
157                            Some(rustls::ProtocolVersion::TLSv1_3) => Some(SslVersion::Tls1_3),
158                            _ => None,
159                        },
160                    },
161                ))
162            }
163            Err(e) => {
164                let kind = e.kind();
165                if let Some(e2) = e.into_inner() {
166                    match e2.downcast::<::rustls::Error>() {
167                        Ok(e) => Err(crate::SslError::RustlsError(*e)),
168                        Err(e) => Err(std::io::Error::new(kind, e).into()),
169                    }
170                } else {
171                    Err(std::io::Error::from(kind).into())
172                }
173            }
174        }
175    }
176
177    async fn upgrade_server<S: Stream>(
178        params: TlsServerParameterProvider,
179        stream: S,
180    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
181        let stream = stream
182            .downcast::<RewindStream<TokioStream>>()
183            .map_err(|_| crate::SslError::SslUnsupportedByClient)?;
184        let (stream, buffer) = stream.into_inner();
185        let TokioStream::Tcp(stream) = stream else {
186            return Err(crate::SslError::SslUnsupportedByClient);
187        };
188
189        let mut acceptor = Acceptor::default();
190        acceptor.read_tls(&mut buffer.as_slice())?;
191        let server_config_provider = Arc::new(move |client_hello: ClientHello| {
192            let params = params.clone();
193            let server_name = client_hello
194                .server_name()
195                .map(|name| ServerName::DnsName(DnsName::try_from(name.to_string()).unwrap()));
196            async move {
197                let params = params.lookup(server_name);
198                let config = RustlsDriver::init_server(&params)
199                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
200                Ok::<_, std::io::Error>(config)
201            }
202            .boxed()
203        });
204        let mut stream = TlsStream::new_server_side_from_acceptor(
205            acceptor,
206            stream,
207            server_config_provider,
208            None,
209        );
210
211        match stream.handshake().await {
212            Ok(handshake) => {
213                let cert = stream
214                    .connection()
215                    .and_then(|c| c.peer_certificates())
216                    .and_then(|c| c.first().map(|cert| cert.to_owned()));
217                let version = stream.connection().and_then(|c| c.protocol_version());
218                Ok((
219                    stream,
220                    TlsHandshake {
221                        alpn: handshake.alpn.map(|alpn| Cow::Owned(alpn.to_vec())),
222                        sni: handshake.sni.map(|sni| Cow::Owned(sni.to_string())),
223                        cert,
224                        version: match version {
225                            Some(rustls::ProtocolVersion::TLSv1_0) => Some(SslVersion::Tls1),
226                            Some(rustls::ProtocolVersion::TLSv1_1) => Some(SslVersion::Tls1_1),
227                            Some(rustls::ProtocolVersion::TLSv1_2) => Some(SslVersion::Tls1_2),
228                            Some(rustls::ProtocolVersion::TLSv1_3) => Some(SslVersion::Tls1_3),
229                            _ => None,
230                        },
231                    },
232                ))
233            }
234            Err(e) => {
235                let kind = e.kind();
236                if let Some(e2) = e.into_inner() {
237                    match e2.downcast::<::rustls::Error>() {
238                        Ok(e) => Err(crate::SslError::RustlsError(*e)),
239                        Err(e) => Err(std::io::Error::new(kind, e).into()),
240                    }
241                } else {
242                    Err(std::io::Error::from(kind).into())
243                }
244            }
245        }
246    }
247
248    fn unclean_shutdown(this: Self::Stream) -> Result<(), Self::Stream> {
249        // Skip the shutdown logic by tearing this down into its parts.
250        this.try_into_inner().map(drop)
251    }
252}
253
254fn make_roots(
255    root_certs: &[CertificateDer<'static>],
256    webpki: bool,
257) -> Result<RootCertStore, crate::SslError> {
258    let mut roots = RootCertStore::empty();
259    if webpki {
260        let webpki_roots = webpki_roots::TLS_SERVER_ROOTS;
261        roots.extend(webpki_roots.iter().cloned());
262    }
263    let (loaded, ignored) = roots.add_parsable_certificates(root_certs.iter().cloned());
264    if !root_certs.is_empty() && (loaded == 0 || ignored > 0) {
265        return Err(
266            std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid certificate").into(),
267        );
268    }
269    Ok(roots)
270}
271
272fn make_verifier(
273    server_cert_verify: &TlsServerCertVerify,
274    root_cert: &TlsCert,
275    crls: Vec<CertificateRevocationListDer<'static>>,
276) -> Result<Arc<dyn ServerCertVerifier>, crate::SslError> {
277    if *server_cert_verify == TlsServerCertVerify::Insecure {
278        return Ok(Arc::new(NullVerifier));
279    }
280
281    if matches!(
282        root_cert,
283        TlsCert::Webpki | TlsCert::WebpkiPlus(_) | TlsCert::Custom(_)
284    ) {
285        let roots = match root_cert {
286            TlsCert::Webpki => make_roots(&[], true),
287            TlsCert::Custom(roots) => make_roots(roots, false),
288            TlsCert::WebpkiPlus(roots) => make_roots(roots, true),
289            _ => unreachable!(),
290        }?;
291
292        let verifier = WebPkiServerVerifier::builder(Arc::new(roots))
293            .with_crls(crls)
294            .build()?;
295        if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
296            return Ok(Arc::new(IgnoreHostnameVerifier::new(verifier)));
297        }
298        return Ok(verifier);
299    }
300
301    // We need to work around macOS returning `certificate is not standards compliant: -67901`
302    // when using the system verifier.
303    let verifier: Arc<dyn ServerCertVerifier> = if let TlsCert::SystemPlus(roots) = root_cert {
304        let roots = make_roots(roots, false)?;
305        let v1 = WebPkiServerVerifier::builder(Arc::new(roots))
306            .with_crls(crls)
307            .build()?;
308        let v2 = Arc::new(Verifier::new());
309        Arc::new(ChainingVerifier::new(v1, v2))
310    } else {
311        Arc::new(ErrorFilteringVerifier::new(Arc::new(Verifier::new())))
312    };
313
314    let verifier: Arc<dyn ServerCertVerifier> =
315        if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
316            Arc::new(IgnoreHostnameVerifier::new(verifier))
317        } else {
318            verifier
319        };
320
321    Ok(verifier)
322}
323
324#[derive(Debug)]
325struct IgnoreHostnameVerifier {
326    verifier: Arc<dyn ServerCertVerifier>,
327}
328
329impl IgnoreHostnameVerifier {
330    fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
331        Self { verifier }
332    }
333}
334
335impl ServerCertVerifier for IgnoreHostnameVerifier {
336    fn verify_server_cert(
337        &self,
338        end_entity: &CertificateDer<'_>,
339        intermediates: &[CertificateDer<'_>],
340        server_name: &ServerName,
341        ocsp_response: &[u8],
342        now: UnixTime,
343    ) -> Result<ServerCertVerified, rustls::Error> {
344        match self.verifier.verify_server_cert(
345            end_entity,
346            intermediates,
347            server_name,
348            ocsp_response,
349            now,
350        ) {
351            Ok(res) => Ok(res),
352            // This works because the name check is the last step in the verify process
353            Err(rustls::Error::InvalidCertificate(
354                rustls::CertificateError::NotValidForName
355                | rustls::CertificateError::NotValidForNameContext { .. },
356            )) => Ok(ServerCertVerified::assertion()),
357            Err(e) => Err(e),
358        }
359    }
360
361    fn verify_tls12_signature(
362        &self,
363        message: &[u8],
364        cert: &CertificateDer<'_>,
365        dss: &DigitallySignedStruct,
366    ) -> Result<HandshakeSignatureValid, rustls::Error> {
367        self.verifier.verify_tls12_signature(message, cert, dss)
368    }
369
370    fn verify_tls13_signature(
371        &self,
372        message: &[u8],
373        cert: &CertificateDer<'_>,
374        dss: &DigitallySignedStruct,
375    ) -> Result<HandshakeSignatureValid, rustls::Error> {
376        self.verifier.verify_tls13_signature(message, cert, dss)
377    }
378
379    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
380        self.verifier.supported_verify_schemes()
381    }
382}
383
384#[derive(Debug)]
385struct ChainingVerifier {
386    verifier1: Arc<dyn ServerCertVerifier>,
387    verifier2: Arc<dyn ServerCertVerifier>,
388}
389
390impl ChainingVerifier {
391    fn new(verifier1: Arc<dyn ServerCertVerifier>, verifier2: Arc<dyn ServerCertVerifier>) -> Self {
392        Self {
393            verifier1,
394            verifier2,
395        }
396    }
397}
398
399impl ServerCertVerifier for ChainingVerifier {
400    fn verify_server_cert(
401        &self,
402        end_entity: &CertificateDer<'_>,
403        intermediates: &[CertificateDer<'_>],
404        server_name: &ServerName,
405        ocsp_response: &[u8],
406        now: UnixTime,
407    ) -> Result<ServerCertVerified, rustls::Error> {
408        let res = self.verifier1.verify_server_cert(
409            end_entity,
410            intermediates,
411            server_name,
412            ocsp_response,
413            now,
414        );
415        if let Ok(res) = res {
416            return Ok(res);
417        }
418
419        let res2 = self.verifier2.verify_server_cert(
420            end_entity,
421            intermediates,
422            server_name,
423            ocsp_response,
424            now,
425        );
426        if let Ok(res) = res2 {
427            return Ok(res);
428        }
429
430        res
431    }
432
433    fn verify_tls12_signature(
434        &self,
435        message: &[u8],
436        cert: &CertificateDer<'_>,
437        dss: &DigitallySignedStruct,
438    ) -> Result<HandshakeSignatureValid, rustls::Error> {
439        let res = self.verifier1.verify_tls12_signature(message, cert, dss);
440        if let Ok(res) = res {
441            return Ok(res);
442        }
443
444        let res2 = self.verifier2.verify_tls12_signature(message, cert, dss);
445        if let Ok(res) = res2 {
446            return Ok(res);
447        }
448
449        res
450    }
451
452    fn verify_tls13_signature(
453        &self,
454        message: &[u8],
455        cert: &CertificateDer<'_>,
456        dss: &DigitallySignedStruct,
457    ) -> Result<HandshakeSignatureValid, rustls::Error> {
458        let res = self.verifier1.verify_tls13_signature(message, cert, dss);
459        if let Ok(res) = res {
460            return Ok(res);
461        }
462
463        let res2 = self.verifier2.verify_tls13_signature(message, cert, dss);
464        if let Ok(res) = res2 {
465            return Ok(res);
466        }
467
468        res
469    }
470
471    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
472        self.verifier1.supported_verify_schemes()
473    }
474}
475
476#[derive(Debug)]
477struct NullVerifier;
478
479impl ServerCertVerifier for NullVerifier {
480    fn verify_server_cert(
481        &self,
482        _end_entity: &CertificateDer<'_>,
483        _intermediates: &[CertificateDer<'_>],
484        _server_name: &ServerName,
485        _ocsp_response: &[u8],
486        _now: UnixTime,
487    ) -> Result<ServerCertVerified, rustls::Error> {
488        Ok(ServerCertVerified::assertion())
489    }
490
491    fn verify_tls12_signature(
492        &self,
493        _message: &[u8],
494        _cert: &CertificateDer<'_>,
495        _dss: &DigitallySignedStruct,
496    ) -> Result<HandshakeSignatureValid, rustls::Error> {
497        Ok(HandshakeSignatureValid::assertion())
498    }
499
500    fn verify_tls13_signature(
501        &self,
502        _message: &[u8],
503        _cert: &CertificateDer<'_>,
504        _dss: &DigitallySignedStruct,
505    ) -> Result<HandshakeSignatureValid, rustls::Error> {
506        Ok(HandshakeSignatureValid::assertion())
507    }
508
509    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
510        use SignatureScheme::*;
511        vec![
512            RSA_PKCS1_SHA1,
513            ECDSA_SHA1_Legacy,
514            RSA_PKCS1_SHA256,
515            ECDSA_NISTP256_SHA256,
516            RSA_PKCS1_SHA384,
517            ECDSA_NISTP384_SHA384,
518            RSA_PKCS1_SHA512,
519            ECDSA_NISTP521_SHA512,
520            RSA_PSS_SHA256,
521            RSA_PSS_SHA384,
522            RSA_PSS_SHA512,
523            ED25519,
524            ED448,
525        ]
526    }
527}
528
529#[derive(Debug)]
530struct ErrorFilteringVerifier {
531    verifier: Arc<dyn ServerCertVerifier>,
532}
533
534impl ErrorFilteringVerifier {
535    fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
536        Self { verifier }
537    }
538
539    fn filter_err<T>(res: Result<T, rustls::Error>) -> Result<T, rustls::Error> {
540        match res {
541            Ok(res) => Ok(res),
542            // On macOS, the system verifier returns `certificate is not
543            // standards compliant: -67901` for self-signed certificates that
544            // have too long of a validity period. It's probably better if we
545            // eventually have the WebPki verifier handle certs as a fallback to
546            // ensure a better error is returned.
547            #[cfg(target_vendor = "apple")]
548            Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(e)))
549                if e.to_string().contains("-67901") =>
550            {
551                Err(rustls::Error::InvalidCertificate(
552                    rustls::CertificateError::UnknownIssuer,
553                ))
554            }
555            Err(e) => Err(e),
556        }
557    }
558}
559
560impl ServerCertVerifier for ErrorFilteringVerifier {
561    fn verify_server_cert(
562        &self,
563        end_entity: &CertificateDer<'_>,
564        intermediates: &[CertificateDer<'_>],
565        server_name: &ServerName,
566        ocsp_response: &[u8],
567        now: UnixTime,
568    ) -> Result<ServerCertVerified, rustls::Error> {
569        Self::filter_err(self.verifier.verify_server_cert(
570            end_entity,
571            intermediates,
572            server_name,
573            ocsp_response,
574            now,
575        ))
576    }
577
578    fn verify_tls12_signature(
579        &self,
580        message: &[u8],
581        cert: &CertificateDer<'_>,
582        dss: &DigitallySignedStruct,
583    ) -> Result<HandshakeSignatureValid, rustls::Error> {
584        Self::filter_err(self.verifier.verify_tls12_signature(message, cert, dss))
585    }
586
587    fn verify_tls13_signature(
588        &self,
589        message: &[u8],
590        cert: &CertificateDer<'_>,
591        dss: &DigitallySignedStruct,
592    ) -> Result<HandshakeSignatureValid, rustls::Error> {
593        Self::filter_err(self.verifier.verify_tls13_signature(message, cert, dss))
594    }
595
596    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
597        self.verifier.supported_verify_schemes()
598    }
599}