1use openssl::{
2 ssl::{
3 AlpnError, ClientHelloResponse, NameType, SniError, Ssl, SslAcceptor, SslContextBuilder,
4 SslMethod, SslOptions, SslRef, SslVerifyMode,
5 },
6 x509::{verify::X509VerifyFlags, X509VerifyResult},
7};
8use rustls_pki_types::{CertificateDer, DnsName, ServerName};
9use std::{
10 borrow::Cow,
11 pin::Pin,
12 sync::{Arc, Mutex, MutexGuard, OnceLock},
13 task::{ready, Poll},
14};
15
16use crate::{
17 AsHandle, LocalAddress, PeekableStream, PeerCred, RemoteAddress, ResolvedTarget, SslError,
18 SslVersion, Stream, StreamMetadata, TlsCert, TlsClientCertVerify, TlsDriver, TlsHandshake,
19 TlsParameters, TlsServerCertVerify, TlsServerParameterProvider, TlsServerParameters, Transport,
20};
21
22use super::tokio_stream::TokioStream;
23
24#[derive(Debug, Clone)]
25struct HandshakeData {
26 server_alpn: Option<Vec<u8>>,
27 handshake: TlsHandshake,
28 stream: *const Box<dyn Stream + Send>,
29}
30
31unsafe impl Send for HandshakeData {}
32
33impl HandshakeData {
34 fn from_ssl(ssl: &SslRef) -> Option<MutexGuard<Self>> {
35 let mutex = ssl.ex_data(get_ssl_ex_data_index())?;
36 mutex.lock().ok()
37 }
38}
39
40static SSL_EX_DATA_INDEX: OnceLock<openssl::ex_data::Index<Ssl, Arc<Mutex<HandshakeData>>>> =
41 OnceLock::new();
42
43fn get_ssl_ex_data_index() -> openssl::ex_data::Index<Ssl, Arc<Mutex<HandshakeData>>> {
44 *SSL_EX_DATA_INDEX
45 .get_or_init(|| Ssl::new_ex_index().expect("Failed to create SSL ex_data index"))
46}
47
48#[derive(Default)]
49
50pub struct OpensslDriver;
51
52#[derive(derive_io::AsyncRead, derive_io::AsyncWrite)]
57pub struct TlsStream(
58 #[read]
59 #[write(poll_shutdown=poll_shutdown)]
60 tokio_openssl::SslStream<Box<dyn Stream + Send>>,
61);
62
63fn poll_shutdown(
64 this: Pin<&mut tokio_openssl::SslStream<Box<dyn Stream + Send>>>,
65 cx: &mut std::task::Context<'_>,
66) -> std::task::Poll<std::io::Result<()>> {
67 use tokio::io::AsyncWrite;
68 let res = ready!(this.poll_shutdown(cx));
69 if let Err(e) = &res {
70 if e.kind() == std::io::ErrorKind::NotConnected {
72 return Poll::Ready(Ok(()));
73 }
74
75 if let Some(ssl_err) = e
77 .get_ref()
78 .and_then(|e| e.downcast_ref::<openssl::ssl::Error>())
79 {
80 if ssl_err.code() == openssl::ssl::ErrorCode::SYSCALL {
81 return Poll::Ready(Ok(()));
82 }
83 }
84 }
85 Poll::Ready(res)
86}
87
88static WEBPKI_ROOTS: OnceLock<Vec<openssl::x509::X509>> = OnceLock::new();
90
91impl TlsDriver for OpensslDriver {
92 type Stream = TlsStream;
93 type ClientParams = openssl::ssl::Ssl;
94 type ServerParams = openssl::ssl::SslContext;
95 const DRIVER_NAME: &'static str = "openssl";
96
97 fn init_client(
98 params: &TlsParameters,
99 name: Option<ServerName>,
100 ) -> Result<Self::ClientParams, SslError> {
101 let TlsParameters {
102 server_cert_verify,
103 root_cert,
104 cert,
105 key,
106 crl,
107 min_protocol_version,
108 max_protocol_version,
109 alpn,
110 sni_override,
111 enable_keylog,
112 } = params;
113
114 let mut ssl = SslContextBuilder::new(SslMethod::tls_client())?;
116
117 ssl.clear_options(SslOptions::from_bits_retain(1 << 7));
119
120 match root_cert {
122 TlsCert::Custom(root) | TlsCert::SystemPlus(root) | TlsCert::WebpkiPlus(root) => {
123 for root in root {
124 let root = openssl::x509::X509::from_der(root.as_ref())?;
125 ssl.cert_store_mut().add_cert(root)?;
126 }
127 }
128 _ => {}
129 }
130
131 match root_cert {
132 TlsCert::Webpki | TlsCert::WebpkiPlus(_) => {
133 let webpki_roots = WEBPKI_ROOTS.get_or_init(|| {
134 let webpki_roots = webpki_root_certs::TLS_SERVER_ROOT_CERTS;
135 let mut roots = Vec::new();
136 for root in webpki_roots {
137 if let Ok(root) = openssl::x509::X509::from_der(root.as_ref()) {
139 roots.push(root);
140 }
141 }
142 roots
143 });
144 for root in webpki_roots {
145 ssl.cert_store_mut().add_cert(root.clone())?;
146 }
147 }
148 _ => {}
149 }
150
151 if matches!(root_cert, TlsCert::SystemPlus(_) | TlsCert::System) {
153 let probe = openssl_probe::probe();
155 ssl.load_verify_locations(probe.cert_file.as_deref(), probe.cert_dir.as_deref())?;
156 }
157
158 match server_cert_verify {
160 TlsServerCertVerify::Insecure => {
161 ssl.set_verify(SslVerifyMode::NONE);
162 }
163 TlsServerCertVerify::IgnoreHostname => {
164 ssl.set_verify(SslVerifyMode::PEER);
165 }
166 TlsServerCertVerify::VerifyFull => {
167 ssl.set_verify(SslVerifyMode::PEER);
168 if let Some(hostname) = sni_override {
169 ssl.verify_param_mut().set_host(hostname)?;
170 } else if let Some(ServerName::DnsName(hostname)) = &name {
171 ssl.verify_param_mut().set_host(hostname.as_ref())?;
172 } else if let Some(ServerName::IpAddress(ip)) = &name {
173 ssl.verify_param_mut().set_ip((*ip).into())?;
174 }
175 }
176 }
177
178 if !crl.is_empty() {
180 use foreign_types::ForeignTypeRef;
182 let ptr = ssl.cert_store_mut().as_ptr();
183
184 extern "C" {
185 pub fn X509_STORE_add_crl(
186 store: *mut openssl_sys::X509_STORE,
187 x: *mut openssl_sys::X509_CRL,
188 ) -> openssl_sys::c_int;
189 }
190
191 for crl in crl {
192 let crl = openssl::x509::X509Crl::from_der(crl.as_ref())?;
193 let crl_ptr = crl.as_ptr();
194 let res = unsafe { X509_STORE_add_crl(ptr, crl_ptr) };
195 if res != 1 {
196 return Err(std::io::Error::new(
197 std::io::ErrorKind::Other,
198 "Failed to add CRL to store",
199 )
200 .into());
201 }
202 }
203
204 ssl.verify_param_mut()
205 .set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
206 ssl.cert_store_mut()
207 .set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
208 }
209
210 if let (Some(cert), Some(key)) = (cert.as_ref(), key.as_ref()) {
212 let builder = openssl::x509::X509::from_der(cert.as_ref())?;
213 ssl.set_certificate(&builder)?;
214 let builder = openssl::pkey::PKey::private_key_from_der(key.secret_der())?;
215 ssl.set_private_key(&builder)?;
216 }
217
218 ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?;
219 ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?;
220
221 if *enable_keylog {
223 if let Ok(path) = std::env::var("SSLKEYLOGFILE") {
224 ssl.set_keylog_callback(move |_ssl, msg| {
225 let Ok(mut file) = std::fs::OpenOptions::new().append(true).open(&path) else {
226 return;
227 };
228 let _ = std::io::Write::write_all(&mut file, msg.as_bytes());
229 });
230 }
231 }
232
233 let mut ssl = openssl::ssl::Ssl::new(&ssl.build())?;
234 ssl.set_connect_state();
235
236 if let Some(hostname) = sni_override {
238 ssl.set_hostname(hostname)?;
239 } else if let Some(ServerName::DnsName(hostname)) = &name {
240 ssl.set_hostname(hostname.as_ref())?;
241 }
242
243 if !alpn.is_empty() {
244 ssl.set_alpn_protos(&alpn.as_bytes())?;
245 }
246
247 Ok(ssl)
248 }
249
250 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
251 let TlsServerParameters {
252 client_cert_verify,
253 min_protocol_version,
254 max_protocol_version,
255 server_certificate,
256 alpn: _alpn,
258 } = params;
259
260 let mut ssl = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server())?;
261 let cert = openssl::x509::X509::from_der(server_certificate.cert.as_ref())?;
262 let key = openssl::pkey::PKey::private_key_from_der(server_certificate.key.secret_der())?;
263 ssl.set_certificate(&cert)?;
264 ssl.set_private_key(&key)?;
265 ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?;
266 ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?;
267 match client_cert_verify {
268 TlsClientCertVerify::Ignore => ssl.set_verify(SslVerifyMode::NONE),
269 TlsClientCertVerify::Optional(root) => {
270 ssl.set_verify(SslVerifyMode::PEER);
271 for root in root {
272 let root = openssl::x509::X509::from_der(root.as_ref())?;
273 ssl.cert_store_mut().add_cert(root)?;
274 }
275 }
276 TlsClientCertVerify::Validate(root) => {
277 ssl.set_verify(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT);
278 for root in root {
279 let root = openssl::x509::X509::from_der(root.as_ref())?;
280 ssl.cert_store_mut().add_cert(root)?;
281 }
282 }
283 }
284 create_alpn_callback(&mut ssl);
285
286 Ok(ssl.build().into_context())
287 }
288
289 async fn upgrade_client<S: Stream>(
290 params: Self::ClientParams,
291 stream: S,
292 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
293 let stream = stream
294 .downcast::<TokioStream>()
295 .map_err(|_| crate::SslError::SslUnsupported)?;
296 let TokioStream::Tcp(stream) = stream else {
297 return Err(crate::SslError::SslUnsupported);
298 };
299
300 let mut stream =
301 tokio_openssl::SslStream::new(params, Box::new(stream) as Box<dyn Stream + Send>)?;
302 let res = Pin::new(&mut stream).do_handshake().await;
303 if res.is_err() && stream.ssl().verify_result() != X509VerifyResult::OK {
304 return Err(SslError::OpenSslErrorVerify(stream.ssl().verify_result()));
305 }
306
307 let alpn = stream
308 .ssl()
309 .selected_alpn_protocol()
310 .map(|p| Cow::Owned(p.to_vec()));
311
312 res.map_err(SslError::OpenSslError)?;
313 let cert = stream
314 .ssl()
315 .peer_certificate()
316 .map(|cert| cert.to_der())
317 .transpose()?;
318 let cert = cert.map(CertificateDer::from);
319 let version = match stream.ssl().version2() {
320 Some(openssl::ssl::SslVersion::TLS1) => Some(SslVersion::Tls1),
321 Some(openssl::ssl::SslVersion::TLS1_1) => Some(SslVersion::Tls1_1),
322 Some(openssl::ssl::SslVersion::TLS1_2) => Some(SslVersion::Tls1_2),
323 Some(openssl::ssl::SslVersion::TLS1_3) => Some(SslVersion::Tls1_3),
324 _ => None,
325 };
326 Ok((
327 TlsStream(stream),
328 TlsHandshake {
329 alpn,
330 sni: None,
331 cert,
332 version,
333 },
334 ))
335 }
336
337 async fn upgrade_server<S: Stream>(
338 params: TlsServerParameterProvider,
339 stream: S,
340 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
341 let stream = stream.boxed();
342
343 let mut ssl = SslContextBuilder::new(SslMethod::tls_server())?;
344 create_alpn_callback(&mut ssl);
345 create_sni_callback(&mut ssl, params);
346 ssl.set_client_hello_callback(move |ssl_ref, _alert| {
347 ssl_ref.set_verify(SslVerifyMode::PEER);
353 Ok(ClientHelloResponse::SUCCESS)
354 });
355
356 let mut ssl = Ssl::new(&ssl.build())?;
357 ssl.set_accept_state();
358 let handshake = Arc::new(Mutex::new(HandshakeData {
359 server_alpn: None,
360 handshake: TlsHandshake::default(),
361 stream: &stream as *const _,
362 }));
363 ssl.set_ex_data(get_ssl_ex_data_index(), handshake.clone());
364
365 let mut stream = tokio_openssl::SslStream::new(ssl, stream)?;
366
367 let res = Pin::new(&mut stream).do_handshake().await;
368 res.map_err(SslError::OpenSslError)?;
369
370 let mut handshake = std::mem::take(&mut handshake.lock().unwrap().handshake);
371 let cert = stream
372 .ssl()
373 .peer_certificate()
374 .and_then(|c| c.to_der().ok());
375 if let Some(cert) = cert {
376 handshake.cert = Some(CertificateDer::from(cert));
377 }
378 let version = match stream.ssl().version2() {
379 Some(openssl::ssl::SslVersion::TLS1) => Some(SslVersion::Tls1),
380 Some(openssl::ssl::SslVersion::TLS1_1) => Some(SslVersion::Tls1_1),
381 Some(openssl::ssl::SslVersion::TLS1_2) => Some(SslVersion::Tls1_2),
382 Some(openssl::ssl::SslVersion::TLS1_3) => Some(SslVersion::Tls1_3),
383 _ => None,
384 };
385 handshake.version = version;
386 Ok((TlsStream(stream), handshake))
387 }
388
389 fn unclean_shutdown(_this: Self::Stream) -> Result<(), Self::Stream> {
390 Ok(())
392 }
393}
394
395fn ssl_select_next_proto<'b>(server: &[u8], client: &'b [u8]) -> Option<&'b [u8]> {
396 let mut server_packet = server;
397 while !server_packet.is_empty() {
398 let server_proto_len = *server_packet.first()? as usize;
399 let server_proto = server_packet.get(1..1 + server_proto_len)?;
400 let mut client_packet = client;
401 while !client_packet.is_empty() {
402 let client_proto_len = *client_packet.first()? as usize;
403 let client_proto = client_packet.get(1..1 + client_proto_len)?;
404 if client_proto == server_proto {
405 return Some(client_proto);
406 }
407 client_packet = client_packet.get(1 + client_proto_len..)?;
408 }
409 server_packet = server_packet.get(1 + server_proto_len..)?;
410 }
411 None
412}
413
414fn create_alpn_callback(ssl: &mut SslContextBuilder) {
416 ssl.set_alpn_select_callback(|ssl_ref, alpn| {
417 let Some(mut handshake) = HandshakeData::from_ssl(ssl_ref) else {
418 return Err(AlpnError::ALERT_FATAL);
419 };
420
421 if let Some(server) = handshake.server_alpn.take() {
422 eprintln!("server: {:?} alpn: {:?}", server, alpn);
423 let Some(selected) = ssl_select_next_proto(&server, alpn) else {
424 return Err(AlpnError::NOACK);
425 };
426 handshake.handshake.alpn = Some(Cow::Owned(selected.to_vec()));
427
428 Ok(selected)
429 } else {
430 Err(AlpnError::NOACK)
431 }
432 })
433}
434
435fn create_sni_callback(ssl: &mut SslContextBuilder, params: TlsServerParameterProvider) {
437 ssl.set_servername_callback(move |ssl_ref, _alert| {
438 let Some(mut handshake) = HandshakeData::from_ssl(ssl_ref) else {
439 return Ok(());
440 };
441
442 if let Some(servername) = ssl_ref.servername_raw(NameType::HOST_NAME) {
443 handshake.handshake.sni = DnsName::try_from(servername).ok().map(|s| s.to_owned());
444 }
445 let name = handshake.handshake.sni.as_ref().map(|s| s.borrow());
446
447 let params = unsafe {
451 let stream = handshake.stream.as_ref().unwrap();
452 params.lookup(name, stream)
457 };
458
459 if !params.alpn.is_empty() {
460 handshake.server_alpn = Some(params.alpn.as_bytes().to_vec());
461 }
462 drop(handshake);
463
464 let Ok(ssl) = OpensslDriver::init_server(¶ms) else {
465 return Err(SniError::ALERT_FATAL);
466 };
467 let Ok(_) = ssl_ref.set_ssl_context(&ssl) else {
468 return Err(SniError::ALERT_FATAL);
469 };
470 Ok(())
471 });
472}
473
474impl From<SslVersion> for openssl::ssl::SslVersion {
475 fn from(val: SslVersion) -> Self {
476 match val {
477 SslVersion::Tls1 => openssl::ssl::SslVersion::TLS1,
478 SslVersion::Tls1_1 => openssl::ssl::SslVersion::TLS1_1,
479 SslVersion::Tls1_2 => openssl::ssl::SslVersion::TLS1_2,
480 SslVersion::Tls1_3 => openssl::ssl::SslVersion::TLS1_3,
481 }
482 }
483}
484
485impl AsHandle for TlsStream {
486 #[cfg(windows)]
487 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
488 self.0.get_ref().as_handle()
489 }
490
491 #[cfg(unix)]
492 fn as_fd(&self) -> std::os::fd::BorrowedFd {
493 self.0.get_ref().as_fd()
494 }
495}
496
497impl PeekableStream for TlsStream {
498 #[cfg(feature = "tokio")]
499 fn poll_peek(
500 mut self: Pin<&mut Self>,
501 cx: &mut std::task::Context<'_>,
502 buf: &mut tokio::io::ReadBuf<'_>,
503 ) -> std::task::Poll<std::io::Result<usize>> {
504 let buf = unsafe { &mut *(buf.unfilled_mut() as *mut _ as *mut [u8]) };
510 Pin::new(&mut self.0)
511 .poll_peek(cx, buf)
512 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
513 }
514}
515
516impl StreamMetadata for TlsStream {
517 fn transport(&self) -> Transport {
518 self.0.get_ref().transport()
519 }
520}
521
522impl PeerCred for TlsStream {
523 #[cfg(all(unix, feature = "tokio"))]
524 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
525 self.0.get_ref().peer_cred()
526 }
527}
528
529impl LocalAddress for TlsStream {
530 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
531 self.0.get_ref().local_address()
532 }
533}
534
535impl RemoteAddress for TlsStream {
536 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
537 self.0.get_ref().remote_address()
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544
545 #[test]
546 fn test_ssl_select_next_proto() {
547 let server = b"\x02h2\x08http/1.1";
548 let client = b"\x08http/1.1";
549 let selected = ssl_select_next_proto(server, client);
550 assert_eq!(selected, Some(b"http/1.1".as_slice()));
551 }
552
553 #[test]
554 fn test_ssl_select_next_proto_empty() {
555 let server = b"";
556 let client = b"";
557 let selected = ssl_select_next_proto(server, client);
558 assert_eq!(selected, None);
559 }
560
561 #[test]
562 fn test_ssl_select_next_proto_invalid_length() {
563 let server = b"\x08h2"; let client = b"\x08http/1.1";
565 let selected = ssl_select_next_proto(server, client);
566 assert_eq!(selected, None);
567 }
568
569 #[test]
570 fn test_ssl_select_next_proto_zero_length() {
571 let server = b"\x00h2"; let client = b"\x08http/1.1";
573 let selected = ssl_select_next_proto(server, client);
574 assert_eq!(selected, None);
575 }
576
577 #[test]
578 fn test_ssl_select_next_proto_truncated() {
579 let server = b"\x02h2\x08http/1"; let client = b"\x08http/1.1";
581 let selected = ssl_select_next_proto(server, client);
582 assert_eq!(selected, None);
583 }
584
585 #[test]
586 fn test_ssl_select_next_proto_overflow() {
587 let server = b"\xFFh2"; let client = b"\x08http/1.1";
589 let selected = ssl_select_next_proto(server, client);
590 assert_eq!(selected, None);
591 }
592
593 #[test]
594 fn test_ssl_select_next_proto_no_match() {
595 let server = b"\x02h2";
596 let client = b"\x08http/1.1";
597 let selected = ssl_select_next_proto(server, client);
598 assert_eq!(selected, None);
599 }
600
601 #[test]
602 fn test_ssl_select_next_proto_multiple_server() {
603 let server = b"\x02h2\x06spdy/2\x08http/1.1";
604 let client = b"\x08http/1.1";
605 let selected = ssl_select_next_proto(server, client);
606 assert_eq!(selected, Some(b"http/1.1".as_slice()));
607 }
608
609 #[test]
610 fn test_ssl_select_next_proto_multiple_client() {
611 let server = b"\x08http/1.1";
612 let client = b"\x02h2\x06spdy/2\x08http/1.1";
613 let selected = ssl_select_next_proto(server, client);
614 assert_eq!(selected, Some(b"http/1.1".as_slice()));
615 }
616
617 #[test]
618 fn test_ssl_select_next_proto_first_match() {
619 let server = b"\x02h2\x06spdy/2\x08http/1.1";
620 let client = b"\x06spdy/2\x02h2\x08http/1.1";
621 let selected = ssl_select_next_proto(server, client);
622 assert_eq!(selected, Some(b"h2".as_slice()));
623 }
624
625 #[test]
626 fn test_ssl_select_next_proto_first_match_2() {
627 let server = b"\x06spdy/2\x02h2\x08http/1.1";
628 let client = b"\x02h2\x06spdy/2\x08http/1.1";
629 let selected = ssl_select_next_proto(server, client);
630 assert_eq!(selected, Some(b"spdy/2".as_slice()));
631 }
632}