1use openssl::{
2 ssl::{
3 AlpnError, ClientHelloResponse, NameType, SniError, Ssl, SslAcceptor, SslContextBuilder,
4 SslMethod, SslOptions, SslRef, SslVerifyMode,
5 },
6 x509::{verify::X509VerifyFlags, X509VerifyResult},
7};
8use rustls_pki_types::{CertificateDer, DnsName, ServerName};
9use std::{
10 borrow::Cow,
11 pin::Pin,
12 sync::{Arc, Mutex, MutexGuard, OnceLock},
13 task::{ready, Poll},
14};
15
16use crate::{
17 AsHandle, LocalAddress, PeekableStream, PeerCred, RemoteAddress, ResolvedTarget, SslError,
18 SslVersion, Stream, StreamMetadata, TlsCert, TlsClientCertVerify, TlsDriver, TlsHandshake,
19 TlsParameters, TlsServerCertVerify, TlsServerParameterProvider, TlsServerParameters, Transport,
20};
21
22use super::tokio_stream::TokioStream;
23
24#[derive(Debug, Clone)]
25struct HandshakeData {
26 server_alpn: Option<Vec<u8>>,
27 handshake: TlsHandshake,
28 stream: *const Box<dyn Stream + Send>,
29}
30
31unsafe impl Send for HandshakeData {}
32
33impl HandshakeData {
34 fn from_ssl(ssl: &SslRef) -> Option<MutexGuard<Self>> {
35 let mutex = ssl.ex_data(get_ssl_ex_data_index())?;
36 mutex.lock().ok()
37 }
38}
39
40static SSL_EX_DATA_INDEX: OnceLock<openssl::ex_data::Index<Ssl, Arc<Mutex<HandshakeData>>>> =
41 OnceLock::new();
42
43fn get_ssl_ex_data_index() -> openssl::ex_data::Index<Ssl, Arc<Mutex<HandshakeData>>> {
44 *SSL_EX_DATA_INDEX
45 .get_or_init(|| Ssl::new_ex_index().expect("Failed to create SSL ex_data index"))
46}
47
48#[derive(Default)]
49
50pub struct OpensslDriver;
51
52#[derive(derive_io::AsyncRead, derive_io::AsyncWrite)]
57pub struct TlsStream(
58 #[read]
59 #[write(poll_shutdown=poll_shutdown)]
60 tokio_openssl::SslStream<Box<dyn Stream + Send>>,
61);
62
63fn poll_shutdown(
64 this: Pin<&mut tokio_openssl::SslStream<Box<dyn Stream + Send>>>,
65 cx: &mut std::task::Context<'_>,
66) -> std::task::Poll<std::io::Result<()>> {
67 use tokio::io::AsyncWrite;
68 let res = ready!(this.poll_shutdown(cx));
69 if let Err(e) = &res {
70 if e.kind() == std::io::ErrorKind::NotConnected {
72 return Poll::Ready(Ok(()));
73 }
74
75 if let Some(ssl_err) = e
77 .get_ref()
78 .and_then(|e| e.downcast_ref::<openssl::ssl::Error>())
79 {
80 if ssl_err.code() == openssl::ssl::ErrorCode::SYSCALL {
81 return Poll::Ready(Ok(()));
82 }
83 }
84 }
85 Poll::Ready(res)
86}
87
88static WEBPKI_ROOTS: OnceLock<Vec<openssl::x509::X509>> = OnceLock::new();
90
91impl TlsDriver for OpensslDriver {
92 type Stream = TlsStream;
93 type ClientParams = openssl::ssl::Ssl;
94 type ServerParams = openssl::ssl::SslContext;
95 const DRIVER_NAME: &'static str = "openssl";
96
97 fn init_client(
98 params: &TlsParameters,
99 name: Option<ServerName>,
100 ) -> Result<Self::ClientParams, SslError> {
101 let TlsParameters {
102 server_cert_verify,
103 root_cert,
104 cert,
105 key,
106 crl,
107 min_protocol_version,
108 max_protocol_version,
109 alpn,
110 sni_override,
111 enable_keylog,
112 } = params;
113
114 let mut ssl = SslContextBuilder::new(SslMethod::tls_client())?;
116
117 ssl.clear_options(SslOptions::from_bits_retain(1 << 7));
119
120 match root_cert {
122 TlsCert::Custom(root) | TlsCert::SystemPlus(root) | TlsCert::WebpkiPlus(root) => {
123 for root in root {
124 let root = openssl::x509::X509::from_der(root.as_ref())?;
125 ssl.cert_store_mut().add_cert(root)?;
126 }
127 }
128 _ => {}
129 }
130
131 match root_cert {
132 TlsCert::Webpki | TlsCert::WebpkiPlus(_) => {
133 let webpki_roots = WEBPKI_ROOTS.get_or_init(|| {
134 let webpki_roots = webpki_root_certs::TLS_SERVER_ROOT_CERTS;
135 let mut roots = Vec::new();
136 for root in webpki_roots {
137 if let Ok(root) = openssl::x509::X509::from_der(root.as_ref()) {
139 roots.push(root);
140 }
141 }
142 roots
143 });
144 for root in webpki_roots {
145 ssl.cert_store_mut().add_cert(root.clone())?;
146 }
147 }
148 _ => {}
149 }
150
151 if matches!(root_cert, TlsCert::SystemPlus(_) | TlsCert::System) {
153 let probe = openssl_probe::probe();
155 ssl.load_verify_locations(probe.cert_file.as_deref(), probe.cert_dir.as_deref())?;
156 }
157
158 match server_cert_verify {
160 TlsServerCertVerify::Insecure => {
161 ssl.set_verify(SslVerifyMode::NONE);
162 }
163 TlsServerCertVerify::IgnoreHostname => {
164 ssl.set_verify(SslVerifyMode::PEER);
165 }
166 TlsServerCertVerify::VerifyFull => {
167 ssl.set_verify(SslVerifyMode::PEER);
168 if let Some(hostname) = sni_override {
169 ssl.verify_param_mut().set_host(hostname)?;
170 } else if let Some(ServerName::DnsName(hostname)) = &name {
171 ssl.verify_param_mut().set_host(hostname.as_ref())?;
172 } else if let Some(ServerName::IpAddress(ip)) = &name {
173 ssl.verify_param_mut().set_ip((*ip).into())?;
174 }
175 }
176 }
177
178 if !crl.is_empty() {
180 use foreign_types::ForeignTypeRef;
182 let ptr = ssl.cert_store_mut().as_ptr();
183
184 extern "C" {
185 pub fn X509_STORE_add_crl(
186 store: *mut openssl_sys::X509_STORE,
187 x: *mut openssl_sys::X509_CRL,
188 ) -> openssl_sys::c_int;
189 }
190
191 for crl in crl {
192 let crl = openssl::x509::X509Crl::from_der(crl.as_ref())?;
193 let crl_ptr = crl.as_ptr();
194 let res = unsafe { X509_STORE_add_crl(ptr, crl_ptr) };
195 if res != 1 {
196 return Err(std::io::Error::other("Failed to add CRL to store").into());
197 }
198 }
199
200 ssl.verify_param_mut()
201 .set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
202 ssl.cert_store_mut()
203 .set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
204 }
205
206 if let (Some(cert), Some(key)) = (cert.as_ref(), key.as_ref()) {
208 let builder = openssl::x509::X509::from_der(cert.as_ref())?;
209 ssl.set_certificate(&builder)?;
210 let builder = openssl::pkey::PKey::private_key_from_der(key.secret_der())?;
211 ssl.set_private_key(&builder)?;
212 }
213
214 ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?;
215 ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?;
216
217 if *enable_keylog {
219 if let Ok(path) = std::env::var("SSLKEYLOGFILE") {
220 ssl.set_keylog_callback(move |_ssl, msg| {
221 let Ok(mut file) = std::fs::OpenOptions::new().append(true).open(&path) else {
222 return;
223 };
224 let _ = std::io::Write::write_all(&mut file, msg.as_bytes());
225 });
226 }
227 }
228
229 let mut ssl = openssl::ssl::Ssl::new(&ssl.build())?;
230 ssl.set_connect_state();
231
232 if let Some(hostname) = sni_override {
234 ssl.set_hostname(hostname)?;
235 } else if let Some(ServerName::DnsName(hostname)) = &name {
236 ssl.set_hostname(hostname.as_ref())?;
237 }
238
239 if !alpn.is_empty() {
240 ssl.set_alpn_protos(&alpn.as_bytes())?;
241 }
242
243 Ok(ssl)
244 }
245
246 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
247 let TlsServerParameters {
248 client_cert_verify,
249 min_protocol_version,
250 max_protocol_version,
251 server_certificate,
252 alpn: _alpn,
254 } = params;
255
256 let mut ssl = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server())?;
257 let cert = openssl::x509::X509::from_der(server_certificate.cert.as_ref())?;
258 let key = openssl::pkey::PKey::private_key_from_der(server_certificate.key.secret_der())?;
259 ssl.set_certificate(&cert)?;
260 ssl.set_private_key(&key)?;
261 ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?;
262 ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?;
263 match client_cert_verify {
264 TlsClientCertVerify::Ignore => ssl.set_verify(SslVerifyMode::NONE),
265 TlsClientCertVerify::Optional(root) => {
266 ssl.set_verify(SslVerifyMode::PEER);
267 for root in root {
268 let root = openssl::x509::X509::from_der(root.as_ref())?;
269 ssl.cert_store_mut().add_cert(root)?;
270 }
271 }
272 TlsClientCertVerify::Validate(root) => {
273 ssl.set_verify(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT);
274 for root in root {
275 let root = openssl::x509::X509::from_der(root.as_ref())?;
276 ssl.cert_store_mut().add_cert(root)?;
277 }
278 }
279 }
280 create_alpn_callback(&mut ssl);
281
282 Ok(ssl.build().into_context())
283 }
284
285 async fn upgrade_client<S: Stream>(
286 params: Self::ClientParams,
287 stream: S,
288 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
289 let stream = stream
290 .downcast::<TokioStream>()
291 .map_err(|_| crate::SslError::SslUnsupported)?;
292 let TokioStream::Tcp(stream) = stream else {
293 return Err(crate::SslError::SslUnsupported);
294 };
295
296 let mut stream =
297 tokio_openssl::SslStream::new(params, Box::new(stream) as Box<dyn Stream + Send>)?;
298 let res = Pin::new(&mut stream).do_handshake().await;
299 if res.is_err() && stream.ssl().verify_result() != X509VerifyResult::OK {
300 return Err(SslError::OpenSslErrorVerify(stream.ssl().verify_result()));
301 }
302
303 let alpn = stream
304 .ssl()
305 .selected_alpn_protocol()
306 .map(|p| Cow::Owned(p.to_vec()));
307
308 res.map_err(SslError::OpenSslError)?;
309 let cert = stream
310 .ssl()
311 .peer_certificate()
312 .map(|cert| cert.to_der())
313 .transpose()?;
314 let cert = cert.map(CertificateDer::from);
315 let version = match stream.ssl().version2() {
316 Some(openssl::ssl::SslVersion::TLS1) => Some(SslVersion::Tls1),
317 Some(openssl::ssl::SslVersion::TLS1_1) => Some(SslVersion::Tls1_1),
318 Some(openssl::ssl::SslVersion::TLS1_2) => Some(SslVersion::Tls1_2),
319 Some(openssl::ssl::SslVersion::TLS1_3) => Some(SslVersion::Tls1_3),
320 _ => None,
321 };
322 Ok((
323 TlsStream(stream),
324 TlsHandshake {
325 alpn,
326 sni: None,
327 cert,
328 version,
329 },
330 ))
331 }
332
333 async fn upgrade_server<S: Stream>(
334 params: TlsServerParameterProvider,
335 stream: S,
336 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
337 let stream = stream.boxed();
338
339 let mut ssl = SslContextBuilder::new(SslMethod::tls_server())?;
340 create_alpn_callback(&mut ssl);
341 create_sni_callback(&mut ssl, params);
342 ssl.set_client_hello_callback(move |ssl_ref, _alert| {
343 ssl_ref.set_verify(SslVerifyMode::PEER);
349 Ok(ClientHelloResponse::SUCCESS)
350 });
351
352 let mut ssl = Ssl::new(&ssl.build())?;
353 ssl.set_accept_state();
354 let handshake = Arc::new(Mutex::new(HandshakeData {
355 server_alpn: None,
356 handshake: TlsHandshake::default(),
357 stream: &stream as *const _,
358 }));
359 ssl.set_ex_data(get_ssl_ex_data_index(), handshake.clone());
360
361 let mut stream = tokio_openssl::SslStream::new(ssl, stream)?;
362
363 let res = Pin::new(&mut stream).do_handshake().await;
364 res.map_err(SslError::OpenSslError)?;
365
366 let mut handshake = std::mem::take(&mut handshake.lock().unwrap().handshake);
367 let cert = stream
368 .ssl()
369 .peer_certificate()
370 .and_then(|c| c.to_der().ok());
371 if let Some(cert) = cert {
372 handshake.cert = Some(CertificateDer::from(cert));
373 }
374 let version = match stream.ssl().version2() {
375 Some(openssl::ssl::SslVersion::TLS1) => Some(SslVersion::Tls1),
376 Some(openssl::ssl::SslVersion::TLS1_1) => Some(SslVersion::Tls1_1),
377 Some(openssl::ssl::SslVersion::TLS1_2) => Some(SslVersion::Tls1_2),
378 Some(openssl::ssl::SslVersion::TLS1_3) => Some(SslVersion::Tls1_3),
379 _ => None,
380 };
381 handshake.version = version;
382 Ok((TlsStream(stream), handshake))
383 }
384
385 fn unclean_shutdown(_this: Self::Stream) -> Result<(), Self::Stream> {
386 Ok(())
388 }
389}
390
391fn ssl_select_next_proto<'b>(server: &[u8], client: &'b [u8]) -> Option<&'b [u8]> {
392 let mut server_packet = server;
393 while !server_packet.is_empty() {
394 let server_proto_len = *server_packet.first()? as usize;
395 let server_proto = server_packet.get(1..1 + server_proto_len)?;
396 let mut client_packet = client;
397 while !client_packet.is_empty() {
398 let client_proto_len = *client_packet.first()? as usize;
399 let client_proto = client_packet.get(1..1 + client_proto_len)?;
400 if client_proto == server_proto {
401 return Some(client_proto);
402 }
403 client_packet = client_packet.get(1 + client_proto_len..)?;
404 }
405 server_packet = server_packet.get(1 + server_proto_len..)?;
406 }
407 None
408}
409
410fn create_alpn_callback(ssl: &mut SslContextBuilder) {
412 ssl.set_alpn_select_callback(|ssl_ref, alpn| {
413 let Some(mut handshake) = HandshakeData::from_ssl(ssl_ref) else {
414 return Err(AlpnError::ALERT_FATAL);
415 };
416
417 if let Some(server) = handshake.server_alpn.take() {
418 eprintln!("server: {server:?} alpn: {alpn:?}");
419 let Some(selected) = ssl_select_next_proto(&server, alpn) else {
420 return Err(AlpnError::NOACK);
421 };
422 handshake.handshake.alpn = Some(Cow::Owned(selected.to_vec()));
423
424 Ok(selected)
425 } else {
426 Err(AlpnError::NOACK)
427 }
428 })
429}
430
431fn create_sni_callback(ssl: &mut SslContextBuilder, params: TlsServerParameterProvider) {
433 ssl.set_servername_callback(move |ssl_ref, _alert| {
434 let Some(mut handshake) = HandshakeData::from_ssl(ssl_ref) else {
435 return Ok(());
436 };
437
438 if let Some(servername) = ssl_ref.servername_raw(NameType::HOST_NAME) {
439 handshake.handshake.sni = DnsName::try_from(servername).ok().map(|s| s.to_owned());
440 }
441 let name = handshake.handshake.sni.as_ref().map(|s| s.borrow());
442
443 let params = unsafe {
447 let stream = handshake.stream.as_ref().unwrap();
448 params.lookup(name, stream)
453 };
454
455 if !params.alpn.is_empty() {
456 handshake.server_alpn = Some(params.alpn.as_bytes().to_vec());
457 }
458 drop(handshake);
459
460 let Ok(ssl) = OpensslDriver::init_server(¶ms) else {
461 return Err(SniError::ALERT_FATAL);
462 };
463 let Ok(_) = ssl_ref.set_ssl_context(&ssl) else {
464 return Err(SniError::ALERT_FATAL);
465 };
466 Ok(())
467 });
468}
469
470impl From<SslVersion> for openssl::ssl::SslVersion {
471 fn from(val: SslVersion) -> Self {
472 match val {
473 SslVersion::Tls1 => openssl::ssl::SslVersion::TLS1,
474 SslVersion::Tls1_1 => openssl::ssl::SslVersion::TLS1_1,
475 SslVersion::Tls1_2 => openssl::ssl::SslVersion::TLS1_2,
476 SslVersion::Tls1_3 => openssl::ssl::SslVersion::TLS1_3,
477 }
478 }
479}
480
481impl AsHandle for TlsStream {
482 #[cfg(windows)]
483 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
484 self.0.get_ref().as_handle()
485 }
486
487 #[cfg(unix)]
488 fn as_fd(&self) -> std::os::fd::BorrowedFd {
489 self.0.get_ref().as_fd()
490 }
491}
492
493impl PeekableStream for TlsStream {
494 #[cfg(feature = "tokio")]
495 fn poll_peek(
496 mut self: Pin<&mut Self>,
497 cx: &mut std::task::Context<'_>,
498 buf: &mut tokio::io::ReadBuf<'_>,
499 ) -> std::task::Poll<std::io::Result<usize>> {
500 let buf = unsafe { &mut *(buf.unfilled_mut() as *mut _ as *mut [u8]) };
506 Pin::new(&mut self.0)
507 .poll_peek(cx, buf)
508 .map_err(std::io::Error::other)
509 }
510}
511
512impl StreamMetadata for TlsStream {
513 fn transport(&self) -> Transport {
514 self.0.get_ref().transport()
515 }
516}
517
518impl PeerCred for TlsStream {
519 #[cfg(all(unix, feature = "tokio"))]
520 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
521 self.0.get_ref().peer_cred()
522 }
523}
524
525impl LocalAddress for TlsStream {
526 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
527 self.0.get_ref().local_address()
528 }
529}
530
531impl RemoteAddress for TlsStream {
532 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
533 self.0.get_ref().remote_address()
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540
541 #[test]
542 fn test_ssl_select_next_proto() {
543 let server = b"\x02h2\x08http/1.1";
544 let client = b"\x08http/1.1";
545 let selected = ssl_select_next_proto(server, client);
546 assert_eq!(selected, Some(b"http/1.1".as_slice()));
547 }
548
549 #[test]
550 fn test_ssl_select_next_proto_empty() {
551 let server = b"";
552 let client = b"";
553 let selected = ssl_select_next_proto(server, client);
554 assert_eq!(selected, None);
555 }
556
557 #[test]
558 fn test_ssl_select_next_proto_invalid_length() {
559 let server = b"\x08h2"; let client = b"\x08http/1.1";
561 let selected = ssl_select_next_proto(server, client);
562 assert_eq!(selected, None);
563 }
564
565 #[test]
566 fn test_ssl_select_next_proto_zero_length() {
567 let server = b"\x00h2"; let client = b"\x08http/1.1";
569 let selected = ssl_select_next_proto(server, client);
570 assert_eq!(selected, None);
571 }
572
573 #[test]
574 fn test_ssl_select_next_proto_truncated() {
575 let server = b"\x02h2\x08http/1"; let client = b"\x08http/1.1";
577 let selected = ssl_select_next_proto(server, client);
578 assert_eq!(selected, None);
579 }
580
581 #[test]
582 fn test_ssl_select_next_proto_overflow() {
583 let server = b"\xFFh2"; let client = b"\x08http/1.1";
585 let selected = ssl_select_next_proto(server, client);
586 assert_eq!(selected, None);
587 }
588
589 #[test]
590 fn test_ssl_select_next_proto_no_match() {
591 let server = b"\x02h2";
592 let client = b"\x08http/1.1";
593 let selected = ssl_select_next_proto(server, client);
594 assert_eq!(selected, None);
595 }
596
597 #[test]
598 fn test_ssl_select_next_proto_multiple_server() {
599 let server = b"\x02h2\x06spdy/2\x08http/1.1";
600 let client = b"\x08http/1.1";
601 let selected = ssl_select_next_proto(server, client);
602 assert_eq!(selected, Some(b"http/1.1".as_slice()));
603 }
604
605 #[test]
606 fn test_ssl_select_next_proto_multiple_client() {
607 let server = b"\x08http/1.1";
608 let client = b"\x02h2\x06spdy/2\x08http/1.1";
609 let selected = ssl_select_next_proto(server, client);
610 assert_eq!(selected, Some(b"http/1.1".as_slice()));
611 }
612
613 #[test]
614 fn test_ssl_select_next_proto_first_match() {
615 let server = b"\x02h2\x06spdy/2\x08http/1.1";
616 let client = b"\x06spdy/2\x02h2\x08http/1.1";
617 let selected = ssl_select_next_proto(server, client);
618 assert_eq!(selected, Some(b"h2".as_slice()));
619 }
620
621 #[test]
622 fn test_ssl_select_next_proto_first_match_2() {
623 let server = b"\x06spdy/2\x02h2\x08http/1.1";
624 let client = b"\x02h2\x06spdy/2\x08http/1.1";
625 let selected = ssl_select_next_proto(server, client);
626 assert_eq!(selected, Some(b"spdy/2".as_slice()));
627 }
628}