1use openssl::{
2 ssl::{
3 AlpnError, ClientHelloResponse, NameType, SniError, Ssl, SslAcceptor, SslContextBuilder,
4 SslMethod, SslOptions, SslRef, SslVerifyMode,
5 },
6 x509::{verify::X509VerifyFlags, X509VerifyResult},
7};
8use rustls_pki_types::{CertificateDer, ServerName};
9use std::{
10 borrow::Cow,
11 io::IoSlice,
12 pin::Pin,
13 sync::{Arc, Mutex, MutexGuard, OnceLock},
14 task::{ready, Poll},
15};
16use tokio::{
17 io::{AsyncRead, AsyncWrite, ReadBuf},
18 net::TcpStream,
19};
20
21use crate::{
22 RewindStream, SslError, SslVersion, Stream, TlsCert, TlsClientCertVerify, TlsDriver,
23 TlsHandshake, TlsParameters, TlsServerCertVerify, TlsServerParameterProvider,
24 TlsServerParameters,
25};
26
27use super::tokio_stream::TokioStream;
28
29#[derive(Debug, Clone, Default)]
30struct HandshakeData {
31 server_alpn: Option<Vec<u8>>,
32 handshake: TlsHandshake,
33}
34
35impl HandshakeData {
36 fn from_ssl(ssl: &SslRef) -> Option<MutexGuard<Self>> {
37 let mutex = ssl.ex_data(get_ssl_ex_data_index())?;
38 mutex.lock().ok()
39 }
40}
41
42static SSL_EX_DATA_INDEX: OnceLock<openssl::ex_data::Index<Ssl, Arc<Mutex<HandshakeData>>>> =
43 OnceLock::new();
44
45fn get_ssl_ex_data_index() -> openssl::ex_data::Index<Ssl, Arc<Mutex<HandshakeData>>> {
46 *SSL_EX_DATA_INDEX
47 .get_or_init(|| Ssl::new_ex_index().expect("Failed to create SSL ex_data index"))
48}
49
50#[derive(Default)]
51
52pub struct OpensslDriver;
53
54pub struct TlsStream(tokio_openssl::SslStream<TcpStream>);
55
56impl AsyncRead for TlsStream {
57 #[inline(always)]
58 fn poll_read(
59 mut self: Pin<&mut Self>,
60 cx: &mut std::task::Context<'_>,
61 buf: &mut ReadBuf<'_>,
62 ) -> std::task::Poll<std::io::Result<()>> {
63 Pin::new(&mut self.0).poll_read(cx, buf)
64 }
65}
66
67impl AsyncWrite for TlsStream {
68 #[inline(always)]
69 fn poll_write(
70 mut self: Pin<&mut Self>,
71 cx: &mut std::task::Context<'_>,
72 buf: &[u8],
73 ) -> std::task::Poll<std::io::Result<usize>> {
74 Pin::new(&mut self.0).poll_write(cx, buf)
75 }
76
77 #[inline(always)]
78 fn poll_write_vectored(
79 mut self: Pin<&mut Self>,
80 cx: &mut std::task::Context<'_>,
81 bufs: &[IoSlice<'_>],
82 ) -> std::task::Poll<std::io::Result<usize>> {
83 Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
84 }
85
86 #[inline(always)]
87 fn is_write_vectored(&self) -> bool {
88 self.0.is_write_vectored()
89 }
90
91 #[inline(always)]
92 fn poll_flush(
93 mut self: Pin<&mut Self>,
94 cx: &mut std::task::Context<'_>,
95 ) -> std::task::Poll<std::io::Result<()>> {
96 Pin::new(&mut self.0).poll_flush(cx)
97 }
98
99 #[inline(always)]
100 fn poll_shutdown(
101 mut self: Pin<&mut Self>,
102 cx: &mut std::task::Context<'_>,
103 ) -> std::task::Poll<std::io::Result<()>> {
104 let res = ready!(Pin::new(&mut self.0).poll_shutdown(cx));
105 if let Err(e) = &res {
106 if e.kind() == std::io::ErrorKind::NotConnected {
108 return Poll::Ready(Ok(()));
109 }
110
111 if let Some(ssl_err) = e
113 .get_ref()
114 .and_then(|e| e.downcast_ref::<openssl::ssl::Error>())
115 {
116 if ssl_err.code() == openssl::ssl::ErrorCode::SYSCALL {
117 return Poll::Ready(Ok(()));
118 }
119 }
120 }
121 Poll::Ready(res)
122 }
123}
124
125static WEBPKI_ROOTS: OnceLock<Vec<openssl::x509::X509>> = OnceLock::new();
127
128impl TlsDriver for OpensslDriver {
129 type Stream = TlsStream;
130 type ClientParams = openssl::ssl::Ssl;
131 type ServerParams = openssl::ssl::SslContext;
132 const DRIVER_NAME: &'static str = "openssl";
133
134 fn init_client(
135 params: &TlsParameters,
136 name: Option<ServerName>,
137 ) -> Result<Self::ClientParams, SslError> {
138 let TlsParameters {
139 server_cert_verify,
140 root_cert,
141 cert,
142 key,
143 crl,
144 min_protocol_version,
145 max_protocol_version,
146 alpn,
147 sni_override,
148 enable_keylog,
149 } = params;
150
151 let mut ssl = SslContextBuilder::new(SslMethod::tls_client())?;
153
154 ssl.clear_options(SslOptions::from_bits_retain(1 << 7));
156
157 match root_cert {
159 TlsCert::Custom(root) | TlsCert::SystemPlus(root) | TlsCert::WebpkiPlus(root) => {
160 for root in root {
161 let root = openssl::x509::X509::from_der(root.as_ref())?;
162 ssl.cert_store_mut().add_cert(root)?;
163 }
164 }
165 _ => {}
166 }
167
168 match root_cert {
169 TlsCert::Webpki | TlsCert::WebpkiPlus(_) => {
170 let webpki_roots = WEBPKI_ROOTS.get_or_init(|| {
171 let webpki_roots = webpki_root_certs::TLS_SERVER_ROOT_CERTS;
172 let mut roots = Vec::new();
173 for root in webpki_roots {
174 if let Ok(root) = openssl::x509::X509::from_der(root.as_ref()) {
176 roots.push(root);
177 }
178 }
179 roots
180 });
181 for root in webpki_roots {
182 ssl.cert_store_mut().add_cert(root.clone())?;
183 }
184 }
185 _ => {}
186 }
187
188 if matches!(root_cert, TlsCert::SystemPlus(_) | TlsCert::System) {
190 let probe = openssl_probe::probe();
192 ssl.load_verify_locations(probe.cert_file.as_deref(), probe.cert_dir.as_deref())?;
193 }
194
195 match server_cert_verify {
197 TlsServerCertVerify::Insecure => {
198 ssl.set_verify(SslVerifyMode::NONE);
199 }
200 TlsServerCertVerify::IgnoreHostname => {
201 ssl.set_verify(SslVerifyMode::PEER);
202 }
203 TlsServerCertVerify::VerifyFull => {
204 ssl.set_verify(SslVerifyMode::PEER);
205 if let Some(hostname) = sni_override {
206 ssl.verify_param_mut().set_host(hostname)?;
207 } else if let Some(ServerName::DnsName(hostname)) = &name {
208 ssl.verify_param_mut().set_host(hostname.as_ref())?;
209 } else if let Some(ServerName::IpAddress(ip)) = &name {
210 ssl.verify_param_mut().set_ip((*ip).into())?;
211 }
212 }
213 }
214
215 if !crl.is_empty() {
217 use foreign_types::ForeignTypeRef;
219 let ptr = ssl.cert_store_mut().as_ptr();
220
221 extern "C" {
222 pub fn X509_STORE_add_crl(
223 store: *mut openssl_sys::X509_STORE,
224 x: *mut openssl_sys::X509_CRL,
225 ) -> openssl_sys::c_int;
226 }
227
228 for crl in crl {
229 let crl = openssl::x509::X509Crl::from_der(crl.as_ref())?;
230 let crl_ptr = crl.as_ptr();
231 let res = unsafe { X509_STORE_add_crl(ptr, crl_ptr) };
232 if res != 1 {
233 return Err(std::io::Error::new(
234 std::io::ErrorKind::Other,
235 "Failed to add CRL to store",
236 )
237 .into());
238 }
239 }
240
241 ssl.verify_param_mut()
242 .set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
243 ssl.cert_store_mut()
244 .set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
245 }
246
247 if let (Some(cert), Some(key)) = (cert.as_ref(), key.as_ref()) {
249 let builder = openssl::x509::X509::from_der(cert.as_ref())?;
250 ssl.set_certificate(&builder)?;
251 let builder = openssl::pkey::PKey::private_key_from_der(key.secret_der())?;
252 ssl.set_private_key(&builder)?;
253 }
254
255 ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?;
256 ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?;
257
258 if *enable_keylog {
260 if let Ok(path) = std::env::var("SSLKEYLOGFILE") {
261 ssl.set_keylog_callback(move |_ssl, msg| {
262 let Ok(mut file) = std::fs::OpenOptions::new().append(true).open(&path) else {
263 return;
264 };
265 let _ = std::io::Write::write_all(&mut file, msg.as_bytes());
266 });
267 }
268 }
269
270 let mut ssl = openssl::ssl::Ssl::new(&ssl.build())?;
271 ssl.set_connect_state();
272
273 if let Some(hostname) = sni_override {
275 ssl.set_hostname(hostname)?;
276 } else if let Some(ServerName::DnsName(hostname)) = &name {
277 ssl.set_hostname(hostname.as_ref())?;
278 }
279
280 if !alpn.is_empty() {
281 ssl.set_alpn_protos(&alpn.as_bytes())?;
282 }
283
284 Ok(ssl)
285 }
286
287 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
288 let TlsServerParameters {
289 client_cert_verify,
290 min_protocol_version,
291 max_protocol_version,
292 server_certificate,
293 alpn: _alpn,
295 } = params;
296
297 let mut ssl = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server())?;
298 let cert = openssl::x509::X509::from_der(server_certificate.cert.as_ref())?;
299 let key = openssl::pkey::PKey::private_key_from_der(server_certificate.key.secret_der())?;
300 ssl.set_certificate(&cert)?;
301 ssl.set_private_key(&key)?;
302 ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?;
303 ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?;
304 match client_cert_verify {
305 TlsClientCertVerify::Ignore => ssl.set_verify(SslVerifyMode::NONE),
306 TlsClientCertVerify::Optional(root) => {
307 ssl.set_verify(SslVerifyMode::PEER);
308 for root in root {
309 let root = openssl::x509::X509::from_der(root.as_ref())?;
310 ssl.cert_store_mut().add_cert(root)?;
311 }
312 }
313 TlsClientCertVerify::Validate(root) => {
314 ssl.set_verify(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT);
315 for root in root {
316 let root = openssl::x509::X509::from_der(root.as_ref())?;
317 ssl.cert_store_mut().add_cert(root)?;
318 }
319 }
320 }
321 create_alpn_callback(&mut ssl);
322
323 Ok(ssl.build().into_context())
324 }
325
326 async fn upgrade_client<S: Stream>(
327 params: Self::ClientParams,
328 stream: S,
329 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
330 let stream = stream
331 .downcast::<TokioStream>()
332 .map_err(|_| crate::SslError::SslUnsupportedByClient)?;
333 let TokioStream::Tcp(stream) = stream else {
334 return Err(crate::SslError::SslUnsupportedByClient);
335 };
336
337 let mut stream = tokio_openssl::SslStream::new(params, stream)?;
338 let res = Pin::new(&mut stream).do_handshake().await;
339 if res.is_err() && stream.ssl().verify_result() != X509VerifyResult::OK {
340 return Err(SslError::OpenSslErrorVerify(stream.ssl().verify_result()));
341 }
342
343 let alpn = stream
344 .ssl()
345 .selected_alpn_protocol()
346 .map(|p| Cow::Owned(p.to_vec()));
347
348 res.map_err(SslError::OpenSslError)?;
349 let cert = stream
350 .ssl()
351 .peer_certificate()
352 .map(|cert| cert.to_der())
353 .transpose()?;
354 let cert = cert.map(CertificateDer::from);
355 let version = match stream.ssl().version2() {
356 Some(openssl::ssl::SslVersion::TLS1) => Some(SslVersion::Tls1),
357 Some(openssl::ssl::SslVersion::TLS1_1) => Some(SslVersion::Tls1_1),
358 Some(openssl::ssl::SslVersion::TLS1_2) => Some(SslVersion::Tls1_2),
359 Some(openssl::ssl::SslVersion::TLS1_3) => Some(SslVersion::Tls1_3),
360 _ => None,
361 };
362 Ok((
363 TlsStream(stream),
364 TlsHandshake {
365 alpn,
366 sni: None,
367 cert,
368 version,
369 },
370 ))
371 }
372
373 async fn upgrade_server<S: Stream>(
374 params: TlsServerParameterProvider,
375 stream: S,
376 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
377 let stream = stream
378 .downcast::<RewindStream<TokioStream>>()
379 .map_err(|_| crate::SslError::SslUnsupportedByClient)?;
380 let (stream, buffer) = stream.into_inner();
381 if !buffer.is_empty() {
382 return Err(crate::SslError::SslUnsupportedByClient);
384 }
385 let TokioStream::Tcp(stream) = stream else {
386 return Err(crate::SslError::SslUnsupportedByClient);
387 };
388
389 let handshake = Arc::new(Mutex::new(HandshakeData::default()));
390
391 let mut ssl = SslContextBuilder::new(SslMethod::tls_server())?;
392 create_alpn_callback(&mut ssl);
393 create_sni_callback(&mut ssl, params);
394 ssl.set_client_hello_callback(move |ssl_ref, _alert| {
395 ssl_ref.set_verify(SslVerifyMode::PEER);
401 Ok(ClientHelloResponse::SUCCESS)
402 });
403
404 let mut ssl = Ssl::new(&ssl.build())?;
405 ssl.set_accept_state();
406 ssl.set_ex_data(get_ssl_ex_data_index(), handshake.clone());
407
408 let mut stream = tokio_openssl::SslStream::new(ssl, stream)?;
409 let res = Pin::new(&mut stream).do_handshake().await;
410 res.map_err(SslError::OpenSslError)?;
411
412 let mut handshake = std::mem::take(&mut handshake.lock().unwrap().handshake);
413 let cert = stream
414 .ssl()
415 .peer_certificate()
416 .and_then(|c| c.to_der().ok());
417 if let Some(cert) = cert {
418 handshake.cert = Some(CertificateDer::from(cert));
419 }
420 let version = match stream.ssl().version2() {
421 Some(openssl::ssl::SslVersion::TLS1) => Some(SslVersion::Tls1),
422 Some(openssl::ssl::SslVersion::TLS1_1) => Some(SslVersion::Tls1_1),
423 Some(openssl::ssl::SslVersion::TLS1_2) => Some(SslVersion::Tls1_2),
424 Some(openssl::ssl::SslVersion::TLS1_3) => Some(SslVersion::Tls1_3),
425 _ => None,
426 };
427 handshake.version = version;
428 Ok((TlsStream(stream), handshake))
429 }
430
431 fn unclean_shutdown(_this: Self::Stream) -> Result<(), Self::Stream> {
432 Ok(())
434 }
435}
436
437fn ssl_select_next_proto<'b>(server: &[u8], client: &'b [u8]) -> Option<&'b [u8]> {
438 let mut server_packet = server;
439 while !server_packet.is_empty() {
440 let server_proto_len = *server_packet.first()? as usize;
441 let server_proto = server_packet.get(1..1 + server_proto_len)?;
442 let mut client_packet = client;
443 while !client_packet.is_empty() {
444 let client_proto_len = *client_packet.first()? as usize;
445 let client_proto = client_packet.get(1..1 + client_proto_len)?;
446 if client_proto == server_proto {
447 return Some(client_proto);
448 }
449 client_packet = client_packet.get(1 + client_proto_len..)?;
450 }
451 server_packet = server_packet.get(1 + server_proto_len..)?;
452 }
453 None
454}
455
456fn create_alpn_callback(ssl: &mut SslContextBuilder) {
458 ssl.set_alpn_select_callback(|ssl_ref, alpn| {
459 let Some(mut handshake) = HandshakeData::from_ssl(ssl_ref) else {
460 return Err(AlpnError::ALERT_FATAL);
461 };
462
463 if let Some(server) = handshake.server_alpn.take() {
464 eprintln!("server: {:?} alpn: {:?}", server, alpn);
465 let Some(selected) = ssl_select_next_proto(&server, alpn) else {
466 return Err(AlpnError::NOACK);
467 };
468 handshake.handshake.alpn = Some(Cow::Owned(selected.to_vec()));
469
470 Ok(selected)
471 } else {
472 Err(AlpnError::NOACK)
473 }
474 })
475}
476
477fn create_sni_callback(ssl: &mut SslContextBuilder, params: TlsServerParameterProvider) {
479 ssl.set_servername_callback(move |ssl_ref, _alert| {
480 let Some(mut handshake) = HandshakeData::from_ssl(ssl_ref) else {
481 return Ok(());
482 };
483
484 if let Some(servername) = ssl_ref.servername_raw(NameType::HOST_NAME) {
485 handshake.handshake.sni =
486 Some(Cow::Owned(String::from_utf8_lossy(servername).to_string()));
487 }
488
489 let params = params.lookup(None);
490 if !params.alpn.is_empty() {
491 handshake.server_alpn = Some(params.alpn.as_bytes().to_vec());
492 }
493 drop(handshake);
494
495 let Ok(ssl) = OpensslDriver::init_server(¶ms) else {
496 return Err(SniError::ALERT_FATAL);
497 };
498 let Ok(_) = ssl_ref.set_ssl_context(&ssl) else {
499 return Err(SniError::ALERT_FATAL);
500 };
501 Ok(())
502 });
503}
504
505impl From<SslVersion> for openssl::ssl::SslVersion {
506 fn from(val: SslVersion) -> Self {
507 match val {
508 SslVersion::Tls1 => openssl::ssl::SslVersion::TLS1,
509 SslVersion::Tls1_1 => openssl::ssl::SslVersion::TLS1_1,
510 SslVersion::Tls1_2 => openssl::ssl::SslVersion::TLS1_2,
511 SslVersion::Tls1_3 => openssl::ssl::SslVersion::TLS1_3,
512 }
513 }
514}
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_ssl_select_next_proto() {
521 let server = b"\x02h2\x08http/1.1";
522 let client = b"\x08http/1.1";
523 let selected = ssl_select_next_proto(server, client);
524 assert_eq!(selected, Some(b"http/1.1".as_slice()));
525 }
526
527 #[test]
528 fn test_ssl_select_next_proto_empty() {
529 let server = b"";
530 let client = b"";
531 let selected = ssl_select_next_proto(server, client);
532 assert_eq!(selected, None);
533 }
534
535 #[test]
536 fn test_ssl_select_next_proto_invalid_length() {
537 let server = b"\x08h2"; let client = b"\x08http/1.1";
539 let selected = ssl_select_next_proto(server, client);
540 assert_eq!(selected, None);
541 }
542
543 #[test]
544 fn test_ssl_select_next_proto_zero_length() {
545 let server = b"\x00h2"; let client = b"\x08http/1.1";
547 let selected = ssl_select_next_proto(server, client);
548 assert_eq!(selected, None);
549 }
550
551 #[test]
552 fn test_ssl_select_next_proto_truncated() {
553 let server = b"\x02h2\x08http/1"; let client = b"\x08http/1.1";
555 let selected = ssl_select_next_proto(server, client);
556 assert_eq!(selected, None);
557 }
558
559 #[test]
560 fn test_ssl_select_next_proto_overflow() {
561 let server = b"\xFFh2"; let client = b"\x08http/1.1";
563 let selected = ssl_select_next_proto(server, client);
564 assert_eq!(selected, None);
565 }
566
567 #[test]
568 fn test_ssl_select_next_proto_no_match() {
569 let server = b"\x02h2";
570 let client = b"\x08http/1.1";
571 let selected = ssl_select_next_proto(server, client);
572 assert_eq!(selected, None);
573 }
574
575 #[test]
576 fn test_ssl_select_next_proto_multiple_server() {
577 let server = b"\x02h2\x06spdy/2\x08http/1.1";
578 let client = b"\x08http/1.1";
579 let selected = ssl_select_next_proto(server, client);
580 assert_eq!(selected, Some(b"http/1.1".as_slice()));
581 }
582
583 #[test]
584 fn test_ssl_select_next_proto_multiple_client() {
585 let server = b"\x08http/1.1";
586 let client = b"\x02h2\x06spdy/2\x08http/1.1";
587 let selected = ssl_select_next_proto(server, client);
588 assert_eq!(selected, Some(b"http/1.1".as_slice()));
589 }
590
591 #[test]
592 fn test_ssl_select_next_proto_first_match() {
593 let server = b"\x02h2\x06spdy/2\x08http/1.1";
594 let client = b"\x06spdy/2\x02h2\x08http/1.1";
595 let selected = ssl_select_next_proto(server, client);
596 assert_eq!(selected, Some(b"h2".as_slice()));
597 }
598
599 #[test]
600 fn test_ssl_select_next_proto_first_match_2() {
601 let server = b"\x06spdy/2\x02h2\x08http/1.1";
602 let client = b"\x02h2\x06spdy/2\x08http/1.1";
603 let selected = ssl_select_next_proto(server, client);
604 assert_eq!(selected, Some(b"spdy/2".as_slice()));
605 }
606}