gel_stream/common/
openssl.rs

1use openssl::{
2    ssl::{
3        AlpnError, ClientHelloResponse, NameType, SniError, Ssl, SslAcceptor, SslContextBuilder,
4        SslMethod, SslOptions, SslRef, SslVerifyMode,
5    },
6    x509::{verify::X509VerifyFlags, X509VerifyResult},
7};
8use rustls_pki_types::{CertificateDer, ServerName};
9use std::{
10    borrow::Cow,
11    io::IoSlice,
12    pin::Pin,
13    sync::{Arc, Mutex, MutexGuard, OnceLock},
14    task::{ready, Poll},
15};
16use tokio::{
17    io::{AsyncRead, AsyncWrite, ReadBuf},
18    net::TcpStream,
19};
20
21use crate::{
22    RewindStream, SslError, SslVersion, Stream, TlsCert, TlsClientCertVerify, TlsDriver,
23    TlsHandshake, TlsParameters, TlsServerCertVerify, TlsServerParameterProvider,
24    TlsServerParameters,
25};
26
27use super::tokio_stream::TokioStream;
28
29#[derive(Debug, Clone, Default)]
30struct HandshakeData {
31    server_alpn: Option<Vec<u8>>,
32    handshake: TlsHandshake,
33}
34
35impl HandshakeData {
36    fn from_ssl(ssl: &SslRef) -> Option<MutexGuard<Self>> {
37        let mutex = ssl.ex_data(get_ssl_ex_data_index())?;
38        mutex.lock().ok()
39    }
40}
41
42static SSL_EX_DATA_INDEX: OnceLock<openssl::ex_data::Index<Ssl, Arc<Mutex<HandshakeData>>>> =
43    OnceLock::new();
44
45fn get_ssl_ex_data_index() -> openssl::ex_data::Index<Ssl, Arc<Mutex<HandshakeData>>> {
46    *SSL_EX_DATA_INDEX
47        .get_or_init(|| Ssl::new_ex_index().expect("Failed to create SSL ex_data index"))
48}
49
50#[derive(Default)]
51
52pub struct OpensslDriver;
53
54pub struct TlsStream(tokio_openssl::SslStream<TcpStream>);
55
56impl AsyncRead for TlsStream {
57    #[inline(always)]
58    fn poll_read(
59        mut self: Pin<&mut Self>,
60        cx: &mut std::task::Context<'_>,
61        buf: &mut ReadBuf<'_>,
62    ) -> std::task::Poll<std::io::Result<()>> {
63        Pin::new(&mut self.0).poll_read(cx, buf)
64    }
65}
66
67impl AsyncWrite for TlsStream {
68    #[inline(always)]
69    fn poll_write(
70        mut self: Pin<&mut Self>,
71        cx: &mut std::task::Context<'_>,
72        buf: &[u8],
73    ) -> std::task::Poll<std::io::Result<usize>> {
74        Pin::new(&mut self.0).poll_write(cx, buf)
75    }
76
77    #[inline(always)]
78    fn poll_write_vectored(
79        mut self: Pin<&mut Self>,
80        cx: &mut std::task::Context<'_>,
81        bufs: &[IoSlice<'_>],
82    ) -> std::task::Poll<std::io::Result<usize>> {
83        Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
84    }
85
86    #[inline(always)]
87    fn is_write_vectored(&self) -> bool {
88        self.0.is_write_vectored()
89    }
90
91    #[inline(always)]
92    fn poll_flush(
93        mut self: Pin<&mut Self>,
94        cx: &mut std::task::Context<'_>,
95    ) -> std::task::Poll<std::io::Result<()>> {
96        Pin::new(&mut self.0).poll_flush(cx)
97    }
98
99    #[inline(always)]
100    fn poll_shutdown(
101        mut self: Pin<&mut Self>,
102        cx: &mut std::task::Context<'_>,
103    ) -> std::task::Poll<std::io::Result<()>> {
104        let res = ready!(Pin::new(&mut self.0).poll_shutdown(cx));
105        if let Err(e) = &res {
106            // Swallow NotConnected errors here
107            if e.kind() == std::io::ErrorKind::NotConnected {
108                return Poll::Ready(Ok(()));
109            }
110
111            // Treat OpenSSL syscall errors during shutdown as graceful
112            if let Some(ssl_err) = e
113                .get_ref()
114                .and_then(|e| e.downcast_ref::<openssl::ssl::Error>())
115            {
116                if ssl_err.code() == openssl::ssl::ErrorCode::SYSCALL {
117                    return Poll::Ready(Ok(()));
118                }
119            }
120        }
121        Poll::Ready(res)
122    }
123}
124
125/// Cache for the WebPKI roots
126static WEBPKI_ROOTS: OnceLock<Vec<openssl::x509::X509>> = OnceLock::new();
127
128impl TlsDriver for OpensslDriver {
129    type Stream = TlsStream;
130    type ClientParams = openssl::ssl::Ssl;
131    type ServerParams = openssl::ssl::SslContext;
132    const DRIVER_NAME: &'static str = "openssl";
133
134    fn init_client(
135        params: &TlsParameters,
136        name: Option<ServerName>,
137    ) -> Result<Self::ClientParams, SslError> {
138        let TlsParameters {
139            server_cert_verify,
140            root_cert,
141            cert,
142            key,
143            crl,
144            min_protocol_version,
145            max_protocol_version,
146            alpn,
147            sni_override,
148            enable_keylog,
149        } = params;
150
151        // let mut ssl = SslConnector::builder(SslMethod::tls_client())?;
152        let mut ssl = SslContextBuilder::new(SslMethod::tls_client())?;
153
154        // Clear SSL_OP_IGNORE_UNEXPECTED_EOF
155        ssl.clear_options(SslOptions::from_bits_retain(1 << 7));
156
157        // Load additional root certs
158        match root_cert {
159            TlsCert::Custom(root) | TlsCert::SystemPlus(root) | TlsCert::WebpkiPlus(root) => {
160                for root in root {
161                    let root = openssl::x509::X509::from_der(root.as_ref())?;
162                    ssl.cert_store_mut().add_cert(root)?;
163                }
164            }
165            _ => {}
166        }
167
168        match root_cert {
169            TlsCert::Webpki | TlsCert::WebpkiPlus(_) => {
170                let webpki_roots = WEBPKI_ROOTS.get_or_init(|| {
171                    let webpki_roots = webpki_root_certs::TLS_SERVER_ROOT_CERTS;
172                    let mut roots = Vec::new();
173                    for root in webpki_roots {
174                        // Don't expect the roots to fail to load
175                        if let Ok(root) = openssl::x509::X509::from_der(root.as_ref()) {
176                            roots.push(root);
177                        }
178                    }
179                    roots
180                });
181                for root in webpki_roots {
182                    ssl.cert_store_mut().add_cert(root.clone())?;
183                }
184            }
185            _ => {}
186        }
187
188        // Load CA certificates from system for System/SystemPlus
189        if matches!(root_cert, TlsCert::SystemPlus(_) | TlsCert::System) {
190            // DANGER! Don't use the environment variable setter functions!
191            let probe = openssl_probe::probe();
192            ssl.load_verify_locations(probe.cert_file.as_deref(), probe.cert_dir.as_deref())?;
193        }
194
195        // Configure hostname verification
196        match server_cert_verify {
197            TlsServerCertVerify::Insecure => {
198                ssl.set_verify(SslVerifyMode::NONE);
199            }
200            TlsServerCertVerify::IgnoreHostname => {
201                ssl.set_verify(SslVerifyMode::PEER);
202            }
203            TlsServerCertVerify::VerifyFull => {
204                ssl.set_verify(SslVerifyMode::PEER);
205                if let Some(hostname) = sni_override {
206                    ssl.verify_param_mut().set_host(hostname)?;
207                } else if let Some(ServerName::DnsName(hostname)) = &name {
208                    ssl.verify_param_mut().set_host(hostname.as_ref())?;
209                } else if let Some(ServerName::IpAddress(ip)) = &name {
210                    ssl.verify_param_mut().set_ip((*ip).into())?;
211                }
212            }
213        }
214
215        // Load CRL
216        if !crl.is_empty() {
217            // The openssl crate doesn't yet have add_crl, so we need to use the raw FFI
218            use foreign_types::ForeignTypeRef;
219            let ptr = ssl.cert_store_mut().as_ptr();
220
221            extern "C" {
222                pub fn X509_STORE_add_crl(
223                    store: *mut openssl_sys::X509_STORE,
224                    x: *mut openssl_sys::X509_CRL,
225                ) -> openssl_sys::c_int;
226            }
227
228            for crl in crl {
229                let crl = openssl::x509::X509Crl::from_der(crl.as_ref())?;
230                let crl_ptr = crl.as_ptr();
231                let res = unsafe { X509_STORE_add_crl(ptr, crl_ptr) };
232                if res != 1 {
233                    return Err(std::io::Error::new(
234                        std::io::ErrorKind::Other,
235                        "Failed to add CRL to store",
236                    )
237                    .into());
238                }
239            }
240
241            ssl.verify_param_mut()
242                .set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
243            ssl.cert_store_mut()
244                .set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
245        }
246
247        // Load certificate chain and private key
248        if let (Some(cert), Some(key)) = (cert.as_ref(), key.as_ref()) {
249            let builder = openssl::x509::X509::from_der(cert.as_ref())?;
250            ssl.set_certificate(&builder)?;
251            let builder = openssl::pkey::PKey::private_key_from_der(key.secret_der())?;
252            ssl.set_private_key(&builder)?;
253        }
254
255        ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?;
256        ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?;
257
258        // Configure key log filename
259        if *enable_keylog {
260            if let Ok(path) = std::env::var("SSLKEYLOGFILE") {
261                ssl.set_keylog_callback(move |_ssl, msg| {
262                    let Ok(mut file) = std::fs::OpenOptions::new().append(true).open(&path) else {
263                        return;
264                    };
265                    let _ = std::io::Write::write_all(&mut file, msg.as_bytes());
266                });
267            }
268        }
269
270        let mut ssl = openssl::ssl::Ssl::new(&ssl.build())?;
271        ssl.set_connect_state();
272
273        // Set hostname if it's not an IP address
274        if let Some(hostname) = sni_override {
275            ssl.set_hostname(hostname)?;
276        } else if let Some(ServerName::DnsName(hostname)) = &name {
277            ssl.set_hostname(hostname.as_ref())?;
278        }
279
280        if !alpn.is_empty() {
281            ssl.set_alpn_protos(&alpn.as_bytes())?;
282        }
283
284        Ok(ssl)
285    }
286
287    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
288        let TlsServerParameters {
289            client_cert_verify,
290            min_protocol_version,
291            max_protocol_version,
292            server_certificate,
293            // Handled elsewhere
294            alpn: _alpn,
295        } = params;
296
297        let mut ssl = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server())?;
298        let cert = openssl::x509::X509::from_der(server_certificate.cert.as_ref())?;
299        let key = openssl::pkey::PKey::private_key_from_der(server_certificate.key.secret_der())?;
300        ssl.set_certificate(&cert)?;
301        ssl.set_private_key(&key)?;
302        ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?;
303        ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?;
304        match client_cert_verify {
305            TlsClientCertVerify::Ignore => ssl.set_verify(SslVerifyMode::NONE),
306            TlsClientCertVerify::Optional(root) => {
307                ssl.set_verify(SslVerifyMode::PEER);
308                for root in root {
309                    let root = openssl::x509::X509::from_der(root.as_ref())?;
310                    ssl.cert_store_mut().add_cert(root)?;
311                }
312            }
313            TlsClientCertVerify::Validate(root) => {
314                ssl.set_verify(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT);
315                for root in root {
316                    let root = openssl::x509::X509::from_der(root.as_ref())?;
317                    ssl.cert_store_mut().add_cert(root)?;
318                }
319            }
320        }
321        create_alpn_callback(&mut ssl);
322
323        Ok(ssl.build().into_context())
324    }
325
326    async fn upgrade_client<S: Stream>(
327        params: Self::ClientParams,
328        stream: S,
329    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
330        let stream = stream
331            .downcast::<TokioStream>()
332            .map_err(|_| crate::SslError::SslUnsupportedByClient)?;
333        let TokioStream::Tcp(stream) = stream else {
334            return Err(crate::SslError::SslUnsupportedByClient);
335        };
336
337        let mut stream = tokio_openssl::SslStream::new(params, stream)?;
338        let res = Pin::new(&mut stream).do_handshake().await;
339        if res.is_err() && stream.ssl().verify_result() != X509VerifyResult::OK {
340            return Err(SslError::OpenSslErrorVerify(stream.ssl().verify_result()));
341        }
342
343        let alpn = stream
344            .ssl()
345            .selected_alpn_protocol()
346            .map(|p| Cow::Owned(p.to_vec()));
347
348        res.map_err(SslError::OpenSslError)?;
349        let cert = stream
350            .ssl()
351            .peer_certificate()
352            .map(|cert| cert.to_der())
353            .transpose()?;
354        let cert = cert.map(CertificateDer::from);
355        let version = match stream.ssl().version2() {
356            Some(openssl::ssl::SslVersion::TLS1) => Some(SslVersion::Tls1),
357            Some(openssl::ssl::SslVersion::TLS1_1) => Some(SslVersion::Tls1_1),
358            Some(openssl::ssl::SslVersion::TLS1_2) => Some(SslVersion::Tls1_2),
359            Some(openssl::ssl::SslVersion::TLS1_3) => Some(SslVersion::Tls1_3),
360            _ => None,
361        };
362        Ok((
363            TlsStream(stream),
364            TlsHandshake {
365                alpn,
366                sni: None,
367                cert,
368                version,
369            },
370        ))
371    }
372
373    async fn upgrade_server<S: Stream>(
374        params: TlsServerParameterProvider,
375        stream: S,
376    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
377        let stream = stream
378            .downcast::<RewindStream<TokioStream>>()
379            .map_err(|_| crate::SslError::SslUnsupportedByClient)?;
380        let (stream, buffer) = stream.into_inner();
381        if !buffer.is_empty() {
382            // TODO: We should also be able to support rewinding
383            return Err(crate::SslError::SslUnsupportedByClient);
384        }
385        let TokioStream::Tcp(stream) = stream else {
386            return Err(crate::SslError::SslUnsupportedByClient);
387        };
388
389        let handshake = Arc::new(Mutex::new(HandshakeData::default()));
390
391        let mut ssl = SslContextBuilder::new(SslMethod::tls_server())?;
392        create_alpn_callback(&mut ssl);
393        create_sni_callback(&mut ssl, params);
394        ssl.set_client_hello_callback(move |ssl_ref, _alert| {
395            // TODO: We need to check the clienthello for the SNI and determine
396            // if we should verify the certificate or not. For now, just always
397            // request a certificate. Note that if we return RETRY, we'll have
398            // another chance to respond later (ie: when we implement async lookup
399            // for TLS parameters).
400            ssl_ref.set_verify(SslVerifyMode::PEER);
401            Ok(ClientHelloResponse::SUCCESS)
402        });
403
404        let mut ssl = Ssl::new(&ssl.build())?;
405        ssl.set_accept_state();
406        ssl.set_ex_data(get_ssl_ex_data_index(), handshake.clone());
407
408        let mut stream = tokio_openssl::SslStream::new(ssl, stream)?;
409        let res = Pin::new(&mut stream).do_handshake().await;
410        res.map_err(SslError::OpenSslError)?;
411
412        let mut handshake = std::mem::take(&mut handshake.lock().unwrap().handshake);
413        let cert = stream
414            .ssl()
415            .peer_certificate()
416            .and_then(|c| c.to_der().ok());
417        if let Some(cert) = cert {
418            handshake.cert = Some(CertificateDer::from(cert));
419        }
420        let version = match stream.ssl().version2() {
421            Some(openssl::ssl::SslVersion::TLS1) => Some(SslVersion::Tls1),
422            Some(openssl::ssl::SslVersion::TLS1_1) => Some(SslVersion::Tls1_1),
423            Some(openssl::ssl::SslVersion::TLS1_2) => Some(SslVersion::Tls1_2),
424            Some(openssl::ssl::SslVersion::TLS1_3) => Some(SslVersion::Tls1_3),
425            _ => None,
426        };
427        handshake.version = version;
428        Ok((TlsStream(stream), handshake))
429    }
430
431    fn unclean_shutdown(_this: Self::Stream) -> Result<(), Self::Stream> {
432        // Do nothing
433        Ok(())
434    }
435}
436
437fn ssl_select_next_proto<'b>(server: &[u8], client: &'b [u8]) -> Option<&'b [u8]> {
438    let mut server_packet = server;
439    while !server_packet.is_empty() {
440        let server_proto_len = *server_packet.first()? as usize;
441        let server_proto = server_packet.get(1..1 + server_proto_len)?;
442        let mut client_packet = client;
443        while !client_packet.is_empty() {
444            let client_proto_len = *client_packet.first()? as usize;
445            let client_proto = client_packet.get(1..1 + client_proto_len)?;
446            if client_proto == server_proto {
447                return Some(client_proto);
448            }
449            client_packet = client_packet.get(1 + client_proto_len..)?;
450        }
451        server_packet = server_packet.get(1 + server_proto_len..)?;
452    }
453    None
454}
455
456/// Create an ALPN callback for the [`SslContextBuilder`].
457fn create_alpn_callback(ssl: &mut SslContextBuilder) {
458    ssl.set_alpn_select_callback(|ssl_ref, alpn| {
459        let Some(mut handshake) = HandshakeData::from_ssl(ssl_ref) else {
460            return Err(AlpnError::ALERT_FATAL);
461        };
462
463        if let Some(server) = handshake.server_alpn.take() {
464            eprintln!("server: {:?} alpn: {:?}", server, alpn);
465            let Some(selected) = ssl_select_next_proto(&server, alpn) else {
466                return Err(AlpnError::NOACK);
467            };
468            handshake.handshake.alpn = Some(Cow::Owned(selected.to_vec()));
469
470            Ok(selected)
471        } else {
472            Err(AlpnError::NOACK)
473        }
474    })
475}
476
477/// Create an SNI callback for the [`SslContextBuilder`].
478fn create_sni_callback(ssl: &mut SslContextBuilder, params: TlsServerParameterProvider) {
479    ssl.set_servername_callback(move |ssl_ref, _alert| {
480        let Some(mut handshake) = HandshakeData::from_ssl(ssl_ref) else {
481            return Ok(());
482        };
483
484        if let Some(servername) = ssl_ref.servername_raw(NameType::HOST_NAME) {
485            handshake.handshake.sni =
486                Some(Cow::Owned(String::from_utf8_lossy(servername).to_string()));
487        }
488
489        let params = params.lookup(None);
490        if !params.alpn.is_empty() {
491            handshake.server_alpn = Some(params.alpn.as_bytes().to_vec());
492        }
493        drop(handshake);
494
495        let Ok(ssl) = OpensslDriver::init_server(&params) else {
496            return Err(SniError::ALERT_FATAL);
497        };
498        let Ok(_) = ssl_ref.set_ssl_context(&ssl) else {
499            return Err(SniError::ALERT_FATAL);
500        };
501        Ok(())
502    });
503}
504
505impl From<SslVersion> for openssl::ssl::SslVersion {
506    fn from(val: SslVersion) -> Self {
507        match val {
508            SslVersion::Tls1 => openssl::ssl::SslVersion::TLS1,
509            SslVersion::Tls1_1 => openssl::ssl::SslVersion::TLS1_1,
510            SslVersion::Tls1_2 => openssl::ssl::SslVersion::TLS1_2,
511            SslVersion::Tls1_3 => openssl::ssl::SslVersion::TLS1_3,
512        }
513    }
514}
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_ssl_select_next_proto() {
521        let server = b"\x02h2\x08http/1.1";
522        let client = b"\x08http/1.1";
523        let selected = ssl_select_next_proto(server, client);
524        assert_eq!(selected, Some(b"http/1.1".as_slice()));
525    }
526
527    #[test]
528    fn test_ssl_select_next_proto_empty() {
529        let server = b"";
530        let client = b"";
531        let selected = ssl_select_next_proto(server, client);
532        assert_eq!(selected, None);
533    }
534
535    #[test]
536    fn test_ssl_select_next_proto_invalid_length() {
537        let server = b"\x08h2"; // Claims 8 bytes but only has 2
538        let client = b"\x08http/1.1";
539        let selected = ssl_select_next_proto(server, client);
540        assert_eq!(selected, None);
541    }
542
543    #[test]
544    fn test_ssl_select_next_proto_zero_length() {
545        let server = b"\x00h2"; // Zero length but has data
546        let client = b"\x08http/1.1";
547        let selected = ssl_select_next_proto(server, client);
548        assert_eq!(selected, None);
549    }
550
551    #[test]
552    fn test_ssl_select_next_proto_truncated() {
553        let server = b"\x02h2\x08http/1"; // Second protocol truncated
554        let client = b"\x08http/1.1";
555        let selected = ssl_select_next_proto(server, client);
556        assert_eq!(selected, None);
557    }
558
559    #[test]
560    fn test_ssl_select_next_proto_overflow() {
561        let server = b"\xFFh2"; // Length that would overflow buffer
562        let client = b"\x08http/1.1";
563        let selected = ssl_select_next_proto(server, client);
564        assert_eq!(selected, None);
565    }
566
567    #[test]
568    fn test_ssl_select_next_proto_no_match() {
569        let server = b"\x02h2";
570        let client = b"\x08http/1.1";
571        let selected = ssl_select_next_proto(server, client);
572        assert_eq!(selected, None);
573    }
574
575    #[test]
576    fn test_ssl_select_next_proto_multiple_server() {
577        let server = b"\x02h2\x06spdy/2\x08http/1.1";
578        let client = b"\x08http/1.1";
579        let selected = ssl_select_next_proto(server, client);
580        assert_eq!(selected, Some(b"http/1.1".as_slice()));
581    }
582
583    #[test]
584    fn test_ssl_select_next_proto_multiple_client() {
585        let server = b"\x08http/1.1";
586        let client = b"\x02h2\x06spdy/2\x08http/1.1";
587        let selected = ssl_select_next_proto(server, client);
588        assert_eq!(selected, Some(b"http/1.1".as_slice()));
589    }
590
591    #[test]
592    fn test_ssl_select_next_proto_first_match() {
593        let server = b"\x02h2\x06spdy/2\x08http/1.1";
594        let client = b"\x06spdy/2\x02h2\x08http/1.1";
595        let selected = ssl_select_next_proto(server, client);
596        assert_eq!(selected, Some(b"h2".as_slice()));
597    }
598
599    #[test]
600    fn test_ssl_select_next_proto_first_match_2() {
601        let server = b"\x06spdy/2\x02h2\x08http/1.1";
602        let client = b"\x02h2\x06spdy/2\x08http/1.1";
603        let selected = ssl_select_next_proto(server, client);
604        assert_eq!(selected, Some(b"spdy/2".as_slice()));
605    }
606}