rustls_tokio_stream/
stream.rs

1// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
2
3use crate::adapter::clone_error;
4use crate::adapter::clone_result;
5use crate::adapter::read_acceptor;
6use crate::adapter::rustls_to_io_error;
7use crate::adapter::write_acceptor_alert;
8use crate::connection_stream::ConnectionStream;
9use crate::handshake::handshake_task;
10use crate::handshake::HandshakeResult;
11use crate::trace;
12use crate::TestOptions;
13use derive_io::AsyncRead;
14use derive_io::AsyncWrite;
15use futures::task::AtomicWaker;
16use futures::FutureExt;
17use rustls::server::Acceptor;
18use rustls::server::ClientHello;
19use rustls::version::TLS13;
20use rustls::ClientConnection;
21use rustls::Connection;
22use rustls::ServerConfig;
23use rustls::ServerConnection;
24use socket2::SockRef;
25use std::any::Any;
26use std::fmt::Debug;
27use std::future::poll_fn;
28use std::future::Future;
29use std::io;
30use std::io::ErrorKind;
31use std::io::Write;
32use std::num::NonZeroUsize;
33use std::pin::Pin;
34use std::sync::Arc;
35use std::sync::Mutex;
36
37use std::task::ready;
38use std::task::Context;
39use std::task::Poll;
40use std::task::Waker;
41use std::thread::sleep;
42use std::time::Duration;
43use tokio::io::AsyncRead;
44use tokio::io::AsyncWrite;
45use tokio::io::ReadBuf;
46use tokio::net::TcpStream;
47use tokio::spawn;
48use tokio::task::spawn_blocking;
49use tokio::task::JoinError;
50use tokio::task::JoinHandle;
51
52/// The handshake may block read and write operations and requires us to track
53/// which wakers are pending so that we can wake them to re-poll their
54/// operations after the handshake completes.
55#[derive(Clone)]
56struct DeferredWakers {
57  wakers: Arc<Mutex<DeferredWakersInner>>,
58}
59
60#[derive(Default)]
61enum DeferredWakersInner {
62  /// If the deferred wakers have been woken already, we don't want
63  /// to re-register them and instead just wake them in place to
64  /// prevent races.
65  #[default]
66  Woke,
67  /// No deferred wakers have been woken.
68  Pending(Option<Waker>, Option<Waker>),
69}
70
71impl DeferredWakers {
72  pub fn wake(&self) {
73    match std::mem::take(&mut *self.wakers.lock().unwrap()) {
74      DeferredWakersInner::Pending(mut read, mut write) => {
75        if let Some(read) = read.take() {
76          read.wake();
77        }
78        if let Some(write) = write.take() {
79          write.wake();
80        }
81      }
82      DeferredWakersInner::Woke => {}
83    }
84  }
85
86  /// Register the read waker if pending, or wake immediately if the deferred wakers have been woken.
87  pub fn set_read_waker(&self, waker: &Waker) {
88    let mut lock = self.wakers.lock().unwrap();
89    match &mut *lock {
90      DeferredWakersInner::Pending(read, _write) => *read = Some(waker.clone()),
91      DeferredWakersInner::Woke => waker.wake_by_ref(),
92    }
93  }
94
95  /// Register the write waker if pending, or wake immediately if the deferred wakers have been woken.
96  pub fn set_write_waker(&self, waker: &Waker) {
97    let mut lock = self.wakers.lock().unwrap();
98    match &mut *lock {
99      DeferredWakersInner::Pending(_read, write) => {
100        *write = Some(waker.clone())
101      }
102      DeferredWakersInner::Woke => waker.wake_by_ref(),
103    }
104  }
105}
106
107impl Default for DeferredWakers {
108  fn default() -> Self {
109    Self {
110      wakers: Arc::new(Mutex::new(DeferredWakersInner::Pending(None, None))),
111    }
112  }
113}
114
115#[derive(Default)]
116struct HandshakeWatch {
117  handshake: Mutex<Option<io::Result<TlsHandshake>>>,
118  rx_waker: AtomicWaker,
119  tx_waker: AtomicWaker,
120}
121
122#[allow(clippy::large_enum_variant)]
123enum TlsStreamState<S: UnderlyingStream> {
124  /// If we are handshaking, writes are buffered and reads block.
125  // TODO(mmastrac): We should be buffered in the Connection, not the Vec, as this results in a double-copy.
126  Handshaking {
127    handle: JoinHandle<io::Result<HandshakeResult<S>>>,
128    wakers: DeferredWakers,
129    write_buf: Vec<u8>,
130    underlying: Arc<S>,
131  },
132  /// The connection is open.
133  Open(ConnectionStream<S>),
134  /// The connection is closed.
135  Closed,
136  /// The connection is closed because of an error.
137  ClosedError(io::Error),
138}
139
140pub type ServerConfigProvider = Arc<
141  dyn Fn(
142      ClientHello<'_>,
143    ) -> Pin<
144      Box<dyn Future<Output = Result<Arc<ServerConfig>, io::Error>> + Send>,
145    > + Send
146    + Sync,
147>;
148
149pub trait UnderlyingStream: Debug + Send + Sync + Sized + 'static {
150  type StdType: Send;
151  fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
152  fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
153  fn try_read(&self, buf: &mut [u8]) -> io::Result<usize>;
154  fn try_write(&self, buf: &[u8]) -> io::Result<usize>;
155  fn readable(&self) -> impl Future<Output = io::Result<()>> + Send;
156  fn writable(&self) -> impl Future<Output = io::Result<()>> + Send;
157
158  fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()>;
159
160  fn into_std(self) -> Option<std::io::Result<Self::StdType>> {
161    None
162  }
163
164  fn downcast<S: UnderlyingStream>(self) -> Result<S, Self> {
165    let mut holder = Some(self);
166    let stream = &mut holder as &mut dyn Any;
167    if let Some(stream) = stream.downcast_mut::<Option<S>>() {
168      Ok(stream.take().unwrap())
169    } else {
170      Err(holder.take().unwrap())
171    }
172  }
173}
174
175impl UnderlyingStream for TcpStream {
176  type StdType = std::net::TcpStream;
177  #[inline(always)]
178  fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
179    self.poll_read_ready(cx)
180  }
181  #[inline(always)]
182  fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
183    self.poll_write_ready(cx)
184  }
185  #[inline(always)]
186  fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
187    self.try_read(buf)
188  }
189  #[inline(always)]
190  fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
191    self.try_write(buf)
192  }
193  #[inline(always)]
194  fn readable(&self) -> impl Future<Output = io::Result<()>> + Send {
195    self.readable()
196  }
197  #[inline(always)]
198  fn writable(&self) -> impl Future<Output = io::Result<()>> + Send {
199    self.writable()
200  }
201  #[inline(always)]
202  fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
203    SockRef::from(&self).shutdown(how)
204  }
205  #[inline(always)]
206  fn into_std(self) -> Option<std::io::Result<std::net::TcpStream>> {
207    Some(self.into_std())
208  }
209}
210
211#[cfg(unix)]
212impl UnderlyingStream for tokio::net::UnixStream {
213  type StdType = std::os::unix::net::UnixStream;
214  #[inline(always)]
215  fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
216    self.poll_read_ready(cx)
217  }
218  #[inline(always)]
219  fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
220    self.poll_write_ready(cx)
221  }
222  #[inline(always)]
223  fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
224    self.try_read(buf)
225  }
226  #[inline(always)]
227  fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
228    self.try_write(buf)
229  }
230  #[inline(always)]
231  fn readable(&self) -> impl Future<Output = io::Result<()>> + Send {
232    self.readable()
233  }
234  #[inline(always)]
235  fn writable(&self) -> impl Future<Output = io::Result<()>> + Send {
236    self.writable()
237  }
238  #[inline(always)]
239  fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
240    SockRef::from(&self).shutdown(how)
241  }
242  #[inline(always)]
243  fn into_std(self) -> Option<std::io::Result<std::os::unix::net::UnixStream>> {
244    Some(self.into_std())
245  }
246}
247
248/// An `async` stream that wraps a `rustls` connection and a TCP socket.
249pub struct TlsStream<S: UnderlyingStream> {
250  state: TlsStreamState<S>,
251
252  handshake: Arc<HandshakeWatch>,
253  buffer_size: Option<NonZeroUsize>,
254}
255
256impl<S: UnderlyingStream> Debug for TlsStream<S> {
257  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258    match &self.state {
259      TlsStreamState::Handshaking { .. } => {
260        f.write_str("TlsStream { Handshaking }")
261      }
262      TlsStreamState::Open(..) => f.write_fmt(format_args!(
263        "TlsStream {{ Open, handshake: {:?} }}",
264        self.handshake.handshake.lock().unwrap()
265      )),
266      TlsStreamState::Closed => f.write_str("TlsStream { Closed }"),
267      TlsStreamState::ClosedError(err) => {
268        f.write_fmt(format_args!("TlsStream {{ Closed, error: {:?} }}", err))
269      }
270    }
271  }
272}
273
274/// The handshake results from a TLS connection.
275#[derive(Clone, Debug)]
276pub struct TlsHandshake {
277  pub alpn: Option<Vec<u8>>,
278  pub sni: Option<String>,
279  /// For client-to-server connections, will always return true. For server-to-client connections, returns
280  /// true if the client provided a valid certificate.
281  pub has_peer_certificates: bool,
282  /// The peer certificates from the TLS handshake, if available.
283  pub peer_certificates:
284    Option<Vec<rustls::pki_types::CertificateDer<'static>>>,
285}
286
287impl TlsStream<TcpStream> {
288  pub fn linger(&self) -> Result<Option<Duration>, io::Error> {
289    match &self.state {
290      TlsStreamState::Open(stm) => stm.underlying_stream().linger(),
291      TlsStreamState::Handshaking {
292        underlying: tcp, ..
293      } => tcp.linger(),
294      TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
295        Err(std::io::ErrorKind::NotConnected.into())
296      }
297    }
298  }
299
300  pub fn set_linger(&self, dur: Option<Duration>) -> Result<(), io::Error> {
301    match &self.state {
302      TlsStreamState::Open(stm) => stm.underlying_stream().set_linger(dur),
303      TlsStreamState::Handshaking {
304        underlying: tcp, ..
305      } => tcp.set_linger(dur),
306      TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
307        Err(std::io::ErrorKind::NotConnected.into())
308      }
309    }
310  }
311
312  /// Returns the peer address of this socket.
313  pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
314    match &self.state {
315      TlsStreamState::Open(stm) => stm.underlying_stream().peer_addr(),
316      TlsStreamState::Handshaking {
317        underlying: tcp, ..
318      } => tcp.peer_addr(),
319      TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
320        Err(std::io::ErrorKind::NotConnected.into())
321      }
322    }
323  }
324
325  /// Returns the local address of this socket.
326  pub fn local_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
327    match &self.state {
328      TlsStreamState::Open(stm) => stm.underlying_stream().local_addr(),
329      TlsStreamState::Handshaking {
330        underlying: tcp, ..
331      } => tcp.local_addr(),
332      TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
333        Err(std::io::ErrorKind::NotConnected.into())
334      }
335    }
336  }
337}
338
339#[cfg(unix)]
340impl TlsStream<tokio::net::UnixStream> {
341  pub fn peer_addr(&self) -> Result<tokio::net::unix::SocketAddr, io::Error> {
342    match &self.state {
343      TlsStreamState::Open(stm) => stm.underlying_stream().peer_addr(),
344      TlsStreamState::Handshaking {
345        underlying: tcp, ..
346      } => tcp.peer_addr(),
347      TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
348        Err(std::io::ErrorKind::NotConnected.into())
349      }
350    }
351  }
352
353  pub fn local_addr(&self) -> Result<tokio::net::unix::SocketAddr, io::Error> {
354    match &self.state {
355      TlsStreamState::Open(stm) => stm.underlying_stream().local_addr(),
356      TlsStreamState::Handshaking {
357        underlying: tcp, ..
358      } => tcp.local_addr(),
359      TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
360        Err(std::io::ErrorKind::NotConnected.into())
361      }
362    }
363  }
364}
365
366impl<S: UnderlyingStream + 'static> TlsStream<S> {
367  fn new(
368    tcp: S,
369    mut tls: Connection,
370    buffer_size: Option<NonZeroUsize>,
371    test_options: TestOptions,
372  ) -> Self {
373    tls.set_buffer_limit(buffer_size.map(|s| s.get()));
374    let handshake = Arc::new(HandshakeWatch::default());
375    let wakers = DeferredWakers::default();
376    let wakers_clone = wakers.clone();
377    let tcp = Arc::new(tcp);
378    let tcp_handshake = tcp.clone();
379
380    let handshake_send = handshake.clone();
381    let handle = spawn(async move {
382      let res =
383        send_handshake(tcp_handshake, Ok(tls), test_options, handshake_send)
384          .await;
385
386      // We may have read/writes blocked on the handshake, so wake them all up
387      wakers_clone.wake();
388
389      res
390    });
391
392    Self {
393      state: TlsStreamState::Handshaking {
394        handle,
395        wakers,
396        write_buf: vec![],
397        underlying: tcp,
398      },
399      handshake,
400      buffer_size,
401    }
402  }
403
404  async fn accept(
405    mut acceptor: Acceptor,
406    tcp_handshake: &S,
407    server_config_provider: ServerConfigProvider,
408  ) -> Result<ServerConnection, io::Error> {
409    loop {
410      tcp_handshake.readable().await?;
411      // Stop if connection was closed by client
412      if read_acceptor(tcp_handshake, &mut acceptor)? < 1 {
413        return Err(io::ErrorKind::ConnectionReset.into());
414      }
415
416      let accepted = match acceptor.accept() {
417        Ok(Some(accepted)) => accepted,
418        Ok(None) => continue,
419        Err((e, alert)) => {
420          tcp_handshake.writable().await?;
421          write_acceptor_alert(tcp_handshake, alert)?;
422          return Err(rustls_to_io_error(e));
423        }
424      };
425
426      let config = match server_config_provider(accepted.client_hello()).await {
427        Ok(config) => config,
428        Err(err) => {
429          // This is a bad case. The provider was supposed to give us a config, but instead it failed.
430          //
431          // There's no easy way to reject an acceptor, and we only have an Arc for the stream so we can't close
432          // it. Instead we send a fatal alert manually which is effectively going to close the stream.
433          //
434          // Wireshark packet decode:
435          //     TLSv1.2 Record Layer: Alert (Level: Fatal, Description: Close Notify)
436          //         Content Type: Alert (21)
437          //         Version: TLS 1.2 (0x0303)
438          //         Length: 2
439          //         Alert Message
440          //             Level: Fatal (2)
441          //             Description: Close Notify (0)
442          const FATAL_ALERT: &[u8] = b"\x15\x03\x03\x00\x02\x02\x00";
443          for c in FATAL_ALERT {
444            tcp_handshake.writable().await?;
445            tcp_handshake.try_write(&[*c])?;
446          }
447          return Err(err);
448        }
449      };
450      match accepted.into_connection(config) {
451        Ok(tls) => {
452          return Ok(tls);
453        }
454        Err((e, alert)) => {
455          tcp_handshake.writable().await?;
456          write_acceptor_alert(tcp_handshake, alert)?;
457          return Err(rustls_to_io_error(e));
458        }
459      }
460    }
461  }
462
463  fn new_server_acceptor(
464    acceptor: Acceptor,
465    tcp: S,
466    server_config_provider: ServerConfigProvider,
467    buffer_size: Option<NonZeroUsize>,
468    test_options: TestOptions,
469  ) -> Self {
470    let handshake = Arc::new(HandshakeWatch::default());
471    let wakers = DeferredWakers::default();
472    let wakers_clone = wakers.clone();
473    let tcp = Arc::new(tcp);
474    let tcp_handshake = tcp.clone();
475
476    let handshake_send = handshake.clone();
477
478    let handle = spawn(async move {
479      let tls =
480        Self::accept(acceptor, &tcp_handshake, server_config_provider).await;
481      let res = send_handshake(
482        tcp_handshake,
483        tls.map(rustls::Connection::Server),
484        test_options,
485        handshake_send,
486      )
487      .await;
488
489      // We may have read/writes blocked on the handshake, so wake them all up
490      wakers_clone.wake();
491
492      res
493    });
494
495    Self {
496      state: TlsStreamState::Handshaking {
497        handle,
498        wakers,
499        write_buf: vec![],
500        underlying: tcp,
501      },
502      handshake,
503      buffer_size,
504    }
505  }
506
507  pub fn new_client_side(
508    tcp: S,
509    tls: ClientConnection,
510    buffer_size: Option<NonZeroUsize>,
511  ) -> Self {
512    Self::new(
513      tcp,
514      Connection::Client(tls),
515      buffer_size,
516      TestOptions::default(),
517    )
518  }
519
520  #[cfg(test)]
521  pub(crate) fn new_client_side_test_options(
522    tcp: S,
523    tls_config: Arc<rustls::ClientConfig>,
524    server_name: rustls::pki_types::ServerName<'_>,
525    buffer_size: Option<NonZeroUsize>,
526    test_options: TestOptions,
527  ) -> Self {
528    let tls =
529      ClientConnection::new(tls_config, server_name.to_owned()).unwrap();
530    Self::new(tcp, Connection::Client(tls), buffer_size, test_options)
531  }
532
533  pub fn new_client_side_from(
534    tcp: S,
535    connection: ClientConnection,
536    buffer_size: Option<NonZeroUsize>,
537  ) -> Self {
538    Self::new(
539      tcp,
540      Connection::Client(connection),
541      buffer_size,
542      TestOptions::default(),
543    )
544  }
545
546  #[cfg(test)]
547  pub(crate) fn new_server_side_test_options(
548    tcp: S,
549    tls_config: Arc<ServerConfig>,
550    buffer_size: Option<NonZeroUsize>,
551    test_options: TestOptions,
552  ) -> Self {
553    let tls = ServerConnection::new(tls_config).unwrap();
554    Self::new(tcp, Connection::Server(tls), buffer_size, test_options)
555  }
556
557  pub fn new_server_side(
558    tcp: S,
559    tls_config: Arc<ServerConfig>,
560    buffer_size: Option<NonZeroUsize>,
561  ) -> Self {
562    let tls = ServerConnection::new(tls_config).unwrap();
563    Self::new(
564      tcp,
565      Connection::Server(tls),
566      buffer_size,
567      TestOptions::default(),
568    )
569  }
570
571  /// Create a server-side TLS connection that provides the [`ServerConfig`] dynamically
572  /// based on the [`ClientHello`] message. This may be used to provide a different server
573  /// certificate or ALPN configuration depending on the requested hostname.
574  pub fn new_server_side_acceptor(
575    tcp: S,
576    server_config_provider: ServerConfigProvider,
577    buffer_size: Option<NonZeroUsize>,
578  ) -> Self {
579    Self::new_server_acceptor(
580      Acceptor::default(),
581      tcp,
582      server_config_provider,
583      buffer_size,
584      TestOptions::default(),
585    )
586  }
587
588  /// Create a server-side TLS connection that provides the [`ServerConfig`] dynamically
589  /// based on the [`ClientHello`] message. This may be used to provide a different server
590  /// certificate or ALPN configuration depending on the requested hostname.
591  ///
592  /// This allows the caller to provide an [`Acceptor`] which may be non-default in some
593  /// way, perhaps stuffed with prefix bytes or a full handshake to emulate.
594  pub fn new_server_side_from_acceptor(
595    acceptor: Acceptor,
596    tcp: S,
597    server_config_provider: ServerConfigProvider,
598    buffer_size: Option<NonZeroUsize>,
599  ) -> Self {
600    Self::new_server_acceptor(
601      acceptor,
602      tcp,
603      server_config_provider,
604      buffer_size,
605      TestOptions::default(),
606    )
607  }
608
609  pub fn new_server_side_from(
610    tcp: S,
611    connection: ServerConnection,
612    buffer_size: Option<NonZeroUsize>,
613  ) -> Self {
614    Self::new(
615      tcp,
616      Connection::Server(connection),
617      buffer_size,
618      TestOptions::default(),
619    )
620  }
621
622  /// Attempt to retrieve the inner stream and connection.
623  pub fn try_into_inner(mut self) -> Result<(S, Connection), Self> {
624    match self.state {
625      TlsStreamState::Open(_) => {
626        let TlsStreamState::Open(stm) =
627          std::mem::replace(&mut self.state, TlsStreamState::Closed)
628        else {
629          unreachable!()
630        };
631        Ok(stm.into_inner())
632      }
633      _ => Err(self),
634    }
635  }
636
637  pub fn into_split(self) -> (TlsStreamRead<S>, TlsStreamWrite<S>) {
638    let handshake1 = self.handshake.clone();
639    let handshake2 = self.handshake.clone();
640    let tcp = match &self.state {
641      TlsStreamState::Handshaking {
642        underlying: tcp, ..
643      } => Some(tcp.clone()),
644      TlsStreamState::Open(conn) => Some(conn.underlying_stream().clone()),
645      _ => None,
646    };
647    let (r, w) = tokio::io::split(self);
648    let read = TlsStreamRead {
649      r,
650      handshake: handshake1,
651      tcp: tcp.clone(),
652    };
653    let write = TlsStreamWrite {
654      w,
655      handshake: handshake2,
656      tcp,
657    };
658    (read, write)
659  }
660
661  /// If the stream is open, returns the underlying rustls connection.
662  pub fn connection(&self) -> Option<&rustls::Connection> {
663    match &self.state {
664      TlsStreamState::Open(stm) => Some(stm.connection()),
665      _ => None,
666    }
667  }
668
669  pub async fn into_inner(mut self) -> io::Result<(S, Connection)> {
670    poll_fn(|cx| self.poll_pending_handshake(cx)).await?;
671    match std::mem::replace(&mut self.state, TlsStreamState::Closed) {
672      TlsStreamState::Open(stm) => Ok(stm.into_inner()),
673      TlsStreamState::Closed => Err(ErrorKind::NotConnected.into()),
674      TlsStreamState::ClosedError(err) => Err(err),
675      TlsStreamState::Handshaking { .. } => unreachable!(),
676    }
677  }
678
679  pub fn poll_handshake(
680    &mut self,
681    cx: &mut Context,
682  ) -> Poll<io::Result<TlsHandshake>> {
683    // Transition to the open state if necessary
684    ready!(self.poll_pending_handshake(cx)?);
685
686    // TODO(mmastrac): Handshake shouldn't need to be cloned
687    match &*self.handshake.handshake.lock().unwrap() {
688      None => {
689        // Register both wakers just in case we get split
690        self.handshake.rx_waker.register(cx.waker());
691        self.handshake.tx_waker.register(cx.waker());
692        Poll::Pending
693      }
694      Some(handshake) => Poll::Ready(clone_result(handshake)),
695    }
696  }
697
698  pub async fn handshake(&mut self) -> io::Result<TlsHandshake> {
699    poll_fn(|cx| self.poll_handshake(cx)).await
700  }
701
702  /// Try to get the handshake, if one exists.
703  pub fn try_handshake(&self) -> io::Result<Option<TlsHandshake>> {
704    match &*self.handshake.handshake.lock().unwrap() {
705      None => Ok(None),
706      Some(r) => clone_result(r).map(Some),
707    }
708  }
709
710  fn finalize_handshake(
711    &mut self,
712    join_result: Result<io::Result<HandshakeResult<S>>, JoinError>,
713  ) -> io::Result<()> {
714    trace!("finalize handshake");
715    match std::mem::replace(&mut self.state, TlsStreamState::Closed) {
716      TlsStreamState::Handshaking {
717        wakers,
718        write_buf: buf,
719        ..
720      } => {
721        trace!("join={join_result:?}");
722        match join_result {
723          Err(err) => {
724            // We polled the handle, so we need to update the state to something
725            self.state = TlsStreamState::ClosedError(ErrorKind::Other.into());
726            if err.is_panic() {
727              // Resume the panic on the main task
728              std::panic::resume_unwind(err.into_panic());
729            } else {
730              unreachable!("Task should not have been cancelled");
731            }
732          }
733          Ok(Err(err)) => {
734            self.state = TlsStreamState::ClosedError(clone_error(&err));
735            Err(err)
736          }
737          Ok(Ok(result)) => {
738            // TODO(mmastrac): if we split ConnectionStream we can remove this Arc and use reclaim2
739            let (tcp, tls) = result.into_inner();
740            let mut stm = ConnectionStream::new(tcp, tls);
741            trace!("hs buf={}", buf.len());
742            // We need to save all the data we wrote before the connection. The stream has an internal buffer
743            // that matches our buffer, so it can accept it all.
744            stm.write_buf_fully(&buf);
745
746            wakers.wake();
747            self.state = TlsStreamState::Open(stm);
748            Ok(())
749          }
750        }
751      }
752      _ => unreachable!(),
753    }
754  }
755
756  /// If the handshake is complete, migrate from a pending handshake to the open state.
757  fn poll_pending_handshake(
758    &mut self,
759    cx: &mut Context<'_>,
760  ) -> Poll<io::Result<()>> {
761    match &mut self.state {
762      TlsStreamState::Handshaking { handle, .. } => {
763        let res = ready!(handle.poll_unpin(cx));
764        Poll::Ready(self.finalize_handshake(res))
765      }
766      _ => Poll::Ready(Ok(())),
767    }
768  }
769
770  /// Shuts the connection down, optionally waiting for the handshake to complete.
771  fn poll_shutdown_or_abort(
772    mut self: Pin<&mut Self>,
773    cx: &mut Context<'_>,
774    abort: bool,
775  ) -> Poll<io::Result<()>> {
776    let res = if abort {
777      // If we're still handshaking, abort
778      match self.poll_pending_handshake(cx) {
779        Poll::Pending => {
780          self.state = TlsStreamState::Closed;
781          return Poll::Ready(Ok(()));
782        }
783        Poll::Ready(res) => res,
784      }
785    } else {
786      ready!(self.poll_pending_handshake(cx))
787    };
788
789    if let Err(err) = res {
790      self.state = TlsStreamState::ClosedError(err);
791    }
792
793    match &mut self.state {
794      // Handshaking: drop the handshake and return ready.
795      TlsStreamState::Handshaking { .. } => {
796        unreachable!()
797      }
798      TlsStreamState::Open(stm) => {
799        let _res = ready!(stm.poll_shutdown(cx));
800        // Because we're in shutdown, we will eat errors
801        // TODO: error
802        Poll::Ready(Ok(()))
803      }
804      // Closed: return ready.
805      TlsStreamState::Closed => Poll::Ready(Ok(())),
806      // Closed: return error.
807      TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
808    }
809  }
810
811  pub async fn close(mut self) -> io::Result<()> {
812    trace!("closing {self:?}");
813    let state = std::mem::replace(&mut self.state, TlsStreamState::Closed);
814    match state {
815      TlsStreamState::Handshaking {
816        handle,
817        wakers,
818        write_buf: buf,
819        ..
820      } => {
821        wakers.wake();
822        match handle.await {
823          Ok(Ok(result)) => {
824            // TODO(mmastrac): if we split ConnectionStream we can remove this Arc and use reclaim2
825            let (tcp, tls) = result.into_inner();
826            let mut stm = ConnectionStream::new(tcp, tls);
827            poll_fn(|cx| stm.poll_write(cx, &buf)).await?;
828            poll_fn(|cx| stm.poll_shutdown(cx)).await?;
829            nonblocking_tcp_drop(stm);
830          }
831          Err(err) => {
832            if err.is_panic() {
833              // Resume the panic on the main task
834              std::panic::resume_unwind(err.into_panic());
835            } else {
836              unreachable!("Task should not have been cancelled");
837            }
838          }
839          Ok(Err(err)) => {
840            return Err(err);
841          }
842        }
843      }
844      TlsStreamState::Open(mut stm) => {
845        poll_fn(|cx| stm.poll_shutdown(cx)).await?;
846        nonblocking_tcp_drop(stm);
847      }
848      TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
849        // Nothing
850      }
851    }
852
853    Ok(())
854  }
855}
856
857impl<S: UnderlyingStream> TlsStream<S> {
858  /// If the stream is open or handshaking, returns the underlying TCP stream.
859  pub fn underlying_stream(&self) -> Option<&S> {
860    match &self.state {
861      TlsStreamState::Open(stm) => Some(stm.underlying_stream()),
862      TlsStreamState::Handshaking {
863        underlying: tcp, ..
864      } => Some(tcp),
865      _ => None,
866    }
867  }
868}
869
870async fn send_handshake<S: UnderlyingStream>(
871  tcp: Arc<S>,
872  tls: Result<Connection, io::Error>,
873  test_options: TestOptions,
874  handshake: Arc<HandshakeWatch>,
875) -> Result<HandshakeResult<S>, io::Error> {
876  let tls = match tls {
877    Ok(tls) => tls,
878    Err(err) => {
879      *handshake.handshake.lock().unwrap() = Some(Err(clone_error(&err)));
880      handshake.rx_waker.wake();
881      handshake.tx_waker.wake();
882      return Err(err);
883    }
884  };
885
886  #[cfg(test)]
887  if test_options.delay_handshake {
888    tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
889  }
890  let res = handshake_task(tcp, tls, test_options).await;
891  match &res {
892    Ok(res) => {
893      let peer_certificates = res
894        .1
895        .peer_certificates()
896        .map(|certs| certs.iter().map(|cert| cert.clone()).collect());
897      let has_peer_certificates = peer_certificates
898        .as_ref()
899        .map(|c: &Vec<rustls::pki_types::CertificateDer<'static>>| {
900          !c.is_empty()
901        })
902        .unwrap_or_default();
903      let alpn = res.1.alpn_protocol().map(|v| v.to_owned());
904      let sni = match &res.1 {
905        Connection::Server(server) => {
906          server.server_name().map(|s| s.to_owned())
907        }
908        _ => None,
909      };
910      *handshake.handshake.lock().unwrap() = Some(Ok(TlsHandshake {
911        alpn,
912        sni,
913        has_peer_certificates,
914        peer_certificates,
915      }));
916    }
917    Err(err) => {
918      *handshake.handshake.lock().unwrap() = Some(Err(clone_error(err)));
919    }
920  }
921  handshake.rx_waker.wake();
922  handshake.tx_waker.wake();
923  res
924}
925
926/// TLS 1.3 may yield a state where the client has sent a large stream of data and closed
927/// the connection before receiving anything from the server. The server may attempt to
928/// send the final part of its handshake to the client's closed socket, which yields a TCP
929/// reset and then causes the server to throw away its received buffer. This holds a TCP
930/// socket open for a shortly extended period of time if we have a TLS 1.3 client.
931fn nonblocking_tcp_drop<S: UnderlyingStream>(stm: ConnectionStream<S>) {
932  // TODO(mmastrac) A better fix would be detecting that the server has sent at least one post-handshake packet,
933  // which would indicate that it's safe to close at this point.
934  let (inner, tls) = stm.into_inner();
935  if matches!(tls, Connection::Client(_))
936    && tls.protocol_version() == Some(TLS13.version)
937  {
938    if let Ok(tcp) = inner.downcast::<TcpStream>() {
939      if let Ok(tcp) = tcp.into_std() {
940        spawn_blocking(move || {
941          trace!("in drop tcp task");
942          sleep(Duration::from_millis(100));
943          drop(tcp);
944          trace!("done drop tcp task");
945        });
946      }
947    }
948  }
949}
950
951impl<S: UnderlyingStream> AsyncRead for TlsStream<S> {
952  fn poll_read(
953    mut self: Pin<&mut Self>,
954    cx: &mut Context<'_>,
955    buf: &mut ReadBuf<'_>,
956  ) -> Poll<io::Result<()>> {
957    loop {
958      break match &mut self.state {
959        TlsStreamState::Handshaking { handle, wakers, .. } => {
960          // If the handshake completed, we want to finalize it and then continue
961          if handle.is_finished() {
962            // This may return Pending if we've exhausted the co-op budget
963            let res = ready!(handle.poll_unpin(cx));
964            self.finalize_handshake(res)?;
965            continue;
966          }
967
968          // Handshake is still blocking us
969          wakers.set_read_waker(cx.waker());
970
971          Poll::Pending
972        }
973        TlsStreamState::Open(ref mut stm) => {
974          match std::task::ready!(stm.poll_read(cx, buf)) {
975            Ok(_n) => {
976              // TODO: n?
977              Poll::Ready(Ok(()))
978            }
979            Err(err) => Poll::Ready(Err(err)),
980          }
981        }
982        TlsStreamState::Closed => Poll::Ready(Ok(())),
983        TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
984      };
985    }
986  }
987}
988
989impl<S: UnderlyingStream> AsyncWrite for TlsStream<S> {
990  fn poll_write(
991    mut self: Pin<&mut Self>,
992    cx: &mut Context<'_>,
993    buf: &[u8],
994  ) -> Poll<io::Result<usize>> {
995    // NOTE: Changes to this method may need to be reflected in `poll_write_vectored`
996    let buffer_size = self.buffer_size;
997    loop {
998      break match &mut self.state {
999        TlsStreamState::Handshaking {
1000          handle,
1001          wakers,
1002          write_buf,
1003          ..
1004        } => {
1005          // If the handshake completed, we want to finalize it and then continue
1006          if handle.is_finished() {
1007            // This may return Pending if we've exhausted the co-op budget
1008            let res = ready!(handle.poll_unpin(cx));
1009            self.finalize_handshake(res)?;
1010            continue;
1011          }
1012
1013          if let Some(buffer_size) = buffer_size {
1014            let remaining = buffer_size.get() - write_buf.len();
1015            if remaining == 0 {
1016              // No room to write, so store the waker for whenever the handshake is done
1017              wakers.set_write_waker(cx.waker());
1018              trace!("write limit");
1019              Poll::Pending
1020            } else {
1021              trace!("write buf");
1022              if buf.len() <= remaining {
1023                write_buf.extend_from_slice(buf);
1024                Poll::Ready(Ok(buf.len()))
1025              } else {
1026                write_buf.extend_from_slice(&buf[0..remaining]);
1027                Poll::Ready(Ok(remaining))
1028              }
1029            }
1030          } else {
1031            trace!("write buf");
1032            write_buf.extend_from_slice(buf);
1033            Poll::Ready(Ok(buf.len()))
1034          }
1035        }
1036        TlsStreamState::Open(ref mut stm) => stm.poll_write(cx, buf),
1037        TlsStreamState::Closed => {
1038          Poll::Ready(Err(ErrorKind::NotConnected.into()))
1039        }
1040        TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
1041      };
1042    }
1043  }
1044
1045  fn poll_write_vectored(
1046    mut self: Pin<&mut Self>,
1047    cx: &mut Context<'_>,
1048    bufs: &[std::io::IoSlice<'_>],
1049  ) -> Poll<Result<usize, io::Error>> {
1050    // NOTE: Changes to this method may need to be reflected in `poll_write`
1051    let buffer_size = self.buffer_size;
1052    loop {
1053      break match &mut self.state {
1054        TlsStreamState::Handshaking {
1055          handle,
1056          wakers,
1057          write_buf,
1058          ..
1059        } => {
1060          // If the handshake completed, we want to finalize it and then continue
1061          if handle.is_finished() {
1062            // This may return Pending if we've exhausted the co-op budget
1063            let res = ready!(handle.poll_unpin(cx));
1064            self.finalize_handshake(res)?;
1065            continue;
1066          }
1067          if let Some(buffer_size) = buffer_size {
1068            let mut remaining = buffer_size.get() - write_buf.len();
1069            if remaining == 0 {
1070              // No room to write, so store the waker for whenever the handshake is done
1071              wakers.set_write_waker(cx.waker());
1072              trace!("write limit");
1073              Poll::Pending
1074            } else {
1075              trace!("write buf");
1076              let mut wrote = 0;
1077              for buf in bufs {
1078                if buf.len() <= remaining {
1079                  write_buf.extend_from_slice(buf);
1080                  wrote += buf.len();
1081                  remaining -= buf.len();
1082                } else {
1083                  write_buf.extend_from_slice(&buf[0..remaining]);
1084                  wrote += remaining;
1085                  break;
1086                }
1087              }
1088
1089              Poll::Ready(Ok(wrote))
1090            }
1091          } else {
1092            trace!("write buf");
1093            Poll::Ready(Ok(write_buf.write_vectored(bufs).unwrap()))
1094          }
1095        }
1096        TlsStreamState::Open(ref mut stm) => stm.poll_write_vectored(cx, bufs),
1097        TlsStreamState::Closed => {
1098          Poll::Ready(Err(ErrorKind::NotConnected.into()))
1099        }
1100        TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
1101      };
1102    }
1103  }
1104
1105  fn poll_flush(
1106    mut self: Pin<&mut Self>,
1107    cx: &mut Context<'_>,
1108  ) -> Poll<io::Result<()>> {
1109    loop {
1110      break match &mut self.state {
1111        TlsStreamState::Handshaking { wakers, handle, .. } => {
1112          // If the handshake completed, we want to finalize it and then continue
1113          if handle.is_finished() {
1114            // This may return Pending if we've exhausted the co-op budget
1115            let res = ready!(handle.poll_unpin(cx));
1116            self.finalize_handshake(res)?;
1117            continue;
1118          }
1119
1120          wakers.set_write_waker(cx.waker());
1121          Poll::Pending
1122        }
1123        TlsStreamState::Open(stm) => stm.poll_flush(cx),
1124        TlsStreamState::Closed => {
1125          Poll::Ready(Err(ErrorKind::NotConnected.into()))
1126        }
1127        TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
1128      };
1129    }
1130  }
1131
1132  fn is_write_vectored(&self) -> bool {
1133    // While rustls supports vectored writes, they act more like buffered writes so
1134    // we should prefer upstream producers to pre-aggregate when possible.
1135    false
1136  }
1137
1138  fn poll_shutdown(
1139    self: Pin<&mut Self>,
1140    cx: &mut Context<'_>,
1141  ) -> Poll<Result<(), io::Error>> {
1142    self.poll_shutdown_or_abort(cx, false)
1143  }
1144}
1145
1146impl<S: UnderlyingStream> Drop for TlsStream<S> {
1147  fn drop(&mut self) {
1148    trace!("dropping {self:?}");
1149    let state = std::mem::replace(&mut self.state, TlsStreamState::Closed);
1150    match state {
1151      TlsStreamState::Handshaking {
1152        handle,
1153        write_buf,
1154        underlying: tcp,
1155        ..
1156      } => {
1157        spawn(async move {
1158          trace!("in drop task");
1159          match handle.await {
1160            Ok(Ok(result)) => {
1161              drop(tcp);
1162              // TODO(mmastrac): if we split ConnectionStream we can remove this Arc and use reclaim2
1163              let (tcp, tls) = result.into_inner();
1164              let mut stm = ConnectionStream::new(tcp, tls);
1165              stm.write_buf_fully(&write_buf);
1166              let res = poll_fn(|cx| stm.poll_shutdown(cx)).await;
1167              trace!("shutdown handshake {:?}", res);
1168              nonblocking_tcp_drop(stm);
1169            }
1170            x @ Err(_) => {
1171              trace!("{x:?}");
1172            }
1173            x @ Ok(Err(_)) => {
1174              trace!("{x:?}");
1175            }
1176          }
1177          trace!("done drop task");
1178        });
1179      }
1180      TlsStreamState::Open(mut stm) => {
1181        spawn(async move {
1182          trace!("in drop task");
1183          let res = poll_fn(|cx| stm.poll_shutdown(cx)).await;
1184          trace!("shutdown open {:?}", res);
1185          nonblocking_tcp_drop(stm);
1186          trace!("done drop task");
1187        });
1188      }
1189      TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
1190        // Nothing
1191      }
1192    }
1193  }
1194}
1195
1196/// An `async` read half of stream that wraps a `rustls` connection and a TCP socket.
1197#[derive(AsyncRead)]
1198pub struct TlsStreamRead<S: UnderlyingStream> {
1199  #[read]
1200  r: tokio::io::ReadHalf<TlsStream<S>>,
1201  handshake: Arc<HandshakeWatch>,
1202  tcp: Option<Arc<S>>,
1203}
1204
1205impl TlsStreamRead<TcpStream> {
1206  /// Returns the peer address of this socket.
1207  pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
1208    let Some(tcp) = &self.tcp else {
1209      return Err(std::io::ErrorKind::NotConnected.into());
1210    };
1211    tcp.peer_addr()
1212  }
1213
1214  /// Returns the local address of this socket.
1215  pub fn local_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
1216    let Some(tcp) = &self.tcp else {
1217      return Err(std::io::ErrorKind::NotConnected.into());
1218    };
1219    tcp.local_addr()
1220  }
1221}
1222
1223impl<S: UnderlyingStream> TlsStreamRead<S> {
1224  /// Reunites with a previously split `TlsStreamWrite`.
1225  pub fn unsplit(self, other: TlsStreamWrite<S>) -> TlsStream<S> {
1226    self.r.unsplit(other.w)
1227  }
1228
1229  pub fn poll_handshake(
1230    &mut self,
1231    cx: &mut Context,
1232  ) -> Poll<io::Result<TlsHandshake>> {
1233    // TODO(mmastrac): Handshake shouldn't need to be cloned
1234    match &*self.handshake.handshake.lock().unwrap() {
1235      None => {
1236        self.handshake.rx_waker.register(cx.waker());
1237        Poll::Pending
1238      }
1239      Some(handshake) => Poll::Ready(clone_result(handshake)),
1240    }
1241  }
1242
1243  pub async fn handshake(&mut self) -> io::Result<TlsHandshake> {
1244    poll_fn(|cx| self.poll_handshake(cx)).await
1245  }
1246
1247  /// Try to get the handshake, if one exists.
1248  pub fn try_handshake(&self) -> io::Result<Option<TlsHandshake>> {
1249    match &*self.handshake.handshake.lock().unwrap() {
1250      None => Ok(None),
1251      Some(r) => clone_result(r).map(Some),
1252    }
1253  }
1254}
1255
1256/// An `async` write half of stream that wraps a `rustls` connection and a TCP socket.
1257#[derive(AsyncWrite)]
1258pub struct TlsStreamWrite<S: UnderlyingStream> {
1259  #[write]
1260  w: tokio::io::WriteHalf<TlsStream<S>>,
1261  handshake: Arc<HandshakeWatch>,
1262  tcp: Option<Arc<S>>,
1263}
1264
1265impl TlsStreamWrite<TcpStream> {
1266  /// Returns the peer address of this socket.
1267  pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
1268    let Some(tcp) = &self.tcp else {
1269      return Err(std::io::ErrorKind::NotConnected.into());
1270    };
1271    tcp.peer_addr()
1272  }
1273
1274  /// Returns the local address of this socket.
1275  pub fn local_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
1276    let Some(tcp) = &self.tcp else {
1277      return Err(std::io::ErrorKind::NotConnected.into());
1278    };
1279    tcp.local_addr()
1280  }
1281}
1282
1283impl<S: UnderlyingStream> TlsStreamWrite<S> {
1284  pub fn poll_handshake(
1285    &mut self,
1286    cx: &mut Context,
1287  ) -> Poll<io::Result<TlsHandshake>> {
1288    // TODO(mmastrac): Handshake shouldn't need to be cloned
1289    match &*self.handshake.handshake.lock().unwrap() {
1290      None => {
1291        self.handshake.tx_waker.register(cx.waker());
1292        Poll::Pending
1293      }
1294      Some(handshake) => Poll::Ready(clone_result(handshake)),
1295    }
1296  }
1297
1298  pub async fn handshake(&mut self) -> io::Result<TlsHandshake> {
1299    poll_fn(|cx| self.poll_handshake(cx)).await
1300  }
1301
1302  /// Try to get the handshake, if one exists.
1303  pub fn try_handshake(&self) -> io::Result<Option<TlsHandshake>> {
1304    match &*self.handshake.handshake.lock().unwrap() {
1305      None => Ok(None),
1306      Some(r) => clone_result(r).map(Some),
1307    }
1308  }
1309}
1310
1311#[cfg(test)]
1312pub(super) mod tests {
1313  use super::*;
1314  use crate::tests::certificate;
1315  use crate::tests::expect_io_error;
1316  use crate::tests::private_key;
1317  use crate::tests::UnsafeVerifier;
1318  use futures::stream::FuturesUnordered;
1319  use futures::FutureExt;
1320  use futures::StreamExt;
1321  use rstest::rstest;
1322  use rustls::version::TLS12;
1323  use rustls::ClientConfig;
1324  use rustls::SupportedProtocolVersion;
1325  use std::io::ErrorKind;
1326  use std::io::IoSlice;
1327  use std::net::Ipv4Addr;
1328  use std::net::SocketAddr;
1329  use std::net::SocketAddrV4;
1330  use std::time::Duration;
1331  use tokio::io::AsyncReadExt;
1332  use tokio::io::AsyncWriteExt;
1333  use tokio::net::TcpListener;
1334  use tokio::net::TcpSocket;
1335  use tokio::spawn;
1336  use tokio::sync::Barrier;
1337
1338  type TestResult = Result<(), std::io::Error>;
1339
1340  type TlsStream = super::TlsStream<TcpStream>;
1341
1342  fn server_config(alpn: &[&str]) -> ServerConfig {
1343    let mut config = ServerConfig::builder()
1344      .with_no_client_auth()
1345      .with_single_cert(vec![certificate()], private_key())
1346      .expect("Failed to build server config");
1347    config.alpn_protocols =
1348      alpn.iter().map(|v| v.as_bytes().to_owned()).collect();
1349    config
1350  }
1351
1352  fn server_config_protocol(
1353    protocol: &'static SupportedProtocolVersion,
1354  ) -> ServerConfig {
1355    let config = ServerConfig::builder_with_protocol_versions(&[protocol])
1356      .with_no_client_auth()
1357      .with_single_cert(vec![certificate()], private_key())
1358      .expect("Failed to build server config");
1359    config
1360  }
1361
1362  fn client_config(alpn: &[&str]) -> ClientConfig {
1363    let mut config = ClientConfig::builder()
1364      .dangerous()
1365      .with_custom_certificate_verifier(Arc::new(UnsafeVerifier {}))
1366      .with_no_client_auth();
1367    config.alpn_protocols =
1368      alpn.iter().map(|v| v.as_bytes().to_owned()).collect();
1369    config.enable_sni = true;
1370    config
1371  }
1372
1373  async fn tcp_pair() -> (TcpStream, TcpStream) {
1374    let listener = TcpListener::bind(SocketAddr::V4(SocketAddrV4::new(
1375      Ipv4Addr::LOCALHOST,
1376      0,
1377    )))
1378    .await
1379    .unwrap();
1380    let port = listener.local_addr().unwrap().port();
1381    let server = spawn(async move { listener.accept().await.unwrap().0 });
1382    let client = spawn(async move {
1383      TcpSocket::new_v4()
1384        .unwrap()
1385        .connect(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)))
1386        .await
1387        .unwrap()
1388    });
1389
1390    let (server, client) = (server.await.unwrap(), client.await.unwrap());
1391    (server, client)
1392  }
1393
1394  pub async fn tls_pair() -> (TlsStream, TlsStream) {
1395    tls_pair_buffer_size(None).await
1396  }
1397
1398  pub async fn tls_pair_protocol(
1399    buffer_size: Option<NonZeroUsize>,
1400    protocol: &'static SupportedProtocolVersion,
1401  ) -> (TlsStream, TlsStream) {
1402    let (server, client) = tcp_pair().await;
1403    let server = TlsStream::new_server_side(
1404      server,
1405      server_config_protocol(protocol).into(),
1406      None,
1407    );
1408    let client = TlsStream::new_client_side_test_options(
1409      client,
1410      client_config(&[]).into(),
1411      "example.com".try_into().unwrap(),
1412      buffer_size,
1413      TestOptions::default(),
1414    );
1415
1416    (server, client)
1417  }
1418
1419  pub async fn tls_pair_buffer_size(
1420    buffer_size: Option<NonZeroUsize>,
1421  ) -> (TlsStream, TlsStream) {
1422    let (server, client) = tcp_pair().await;
1423    let server =
1424      TlsStream::new_server_side(server, server_config(&[]).into(), None);
1425    let client = TlsStream::new_client_side_test_options(
1426      client,
1427      client_config(&[]).into(),
1428      "example.com".try_into().unwrap(),
1429      buffer_size,
1430      TestOptions::default(),
1431    );
1432
1433    (server, client)
1434  }
1435
1436  async fn tls_with_tcp_server(
1437    delay_handshake: bool,
1438  ) -> (TcpStream, TlsStream) {
1439    let (server, client) = tcp_pair().await;
1440    let client_test_options = TestOptions {
1441      delay_handshake,
1442      ..Default::default()
1443    };
1444    let client = TlsStream::new_client_side_test_options(
1445      client,
1446      client_config(&[]).into(),
1447      "example.com".try_into().unwrap(),
1448      None,
1449      client_test_options,
1450    );
1451    (server, client)
1452  }
1453
1454  async fn tls_pair_slow_handshake(
1455    delay_handshake: bool,
1456    slow_server: bool,
1457    slow_client: bool,
1458    buffer: bool,
1459  ) -> (TlsStream, TlsStream) {
1460    let (server, client) = tcp_pair().await;
1461    let server_test_options = TestOptions {
1462      delay_handshake,
1463      slow_handshake_read: slow_server,
1464      slow_handshake_write: slow_server,
1465    };
1466    let client_test_options = TestOptions {
1467      delay_handshake,
1468      slow_handshake_read: slow_client,
1469      slow_handshake_write: slow_client,
1470    };
1471    let buffer_size = if buffer {
1472      NonZeroUsize::new(1024)
1473    } else {
1474      None
1475    };
1476
1477    let server = TlsStream::new_server_side_test_options(
1478      server,
1479      server_config(&[]).into(),
1480      buffer_size,
1481      server_test_options,
1482    );
1483    let client = TlsStream::new_client_side_test_options(
1484      client,
1485      client_config(&[]).into(),
1486      "example.com".try_into().unwrap(),
1487      buffer_size,
1488      client_test_options,
1489    );
1490
1491    (server, client)
1492  }
1493
1494  async fn tls_pair_alpn(
1495    server_alpn: &[&str],
1496    server_buffer_size: Option<NonZeroUsize>,
1497    client_alpn: &[&str],
1498    client_buffer_size: Option<NonZeroUsize>,
1499  ) -> (TlsStream, TlsStream) {
1500    let (server, client) = tcp_pair().await;
1501    let server = TlsStream::new_server_side(
1502      server,
1503      server_config(server_alpn).into(),
1504      server_buffer_size,
1505    );
1506    let client = TlsStream::new_client_side_test_options(
1507      client,
1508      client_config(client_alpn).into(),
1509      "example.com".try_into().unwrap(),
1510      client_buffer_size,
1511      TestOptions::default(),
1512    );
1513
1514    (server, client)
1515  }
1516
1517  async fn make_config(
1518    alpn: Result<&'static [&'static str], &'static str>,
1519  ) -> Result<Arc<ServerConfig>, io::Error> {
1520    Ok(
1521      server_config(
1522        alpn.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?,
1523      )
1524      .into(),
1525    )
1526  }
1527
1528  async fn tls_pair_alpn_acceptor(
1529    server_alpn: fn(
1530      ClientHello,
1531    ) -> Result<&'static [&'static str], &'static str>,
1532    server_buffer_size: Option<NonZeroUsize>,
1533    client_alpn: &[&str],
1534    client_buffer_size: Option<NonZeroUsize>,
1535  ) -> (TlsStream, TlsStream) {
1536    let (server, client) = tcp_pair().await;
1537    let server = TlsStream::new_server_side_acceptor(
1538      server,
1539      Arc::new(move |client_hello| {
1540        Box::pin(make_config(server_alpn(client_hello)))
1541      }),
1542      server_buffer_size,
1543    );
1544    let client = TlsStream::new_client_side_test_options(
1545      client,
1546      client_config(client_alpn).into(),
1547      "example.com".try_into().unwrap(),
1548      client_buffer_size,
1549      TestOptions::default(),
1550    );
1551
1552    (server, client)
1553  }
1554
1555  async fn tls_pair_alpn_from_acceptor(
1556    server_alpn: fn(
1557      ClientHello,
1558    ) -> Result<&'static [&'static str], &'static str>,
1559    server_buffer_size: Option<NonZeroUsize>,
1560    client_alpn: &[&str],
1561    client_buffer_size: Option<NonZeroUsize>,
1562  ) -> (TlsStream, TlsStream) {
1563    let (mut server, client) = tcp_pair().await;
1564
1565    // Create the client first because we need the ClientHello. This will
1566    // boot the client's handshake task and write to the socket.
1567    let client = TlsStream::new_client_side_test_options(
1568      client,
1569      client_config(client_alpn).into(),
1570      "example.com".try_into().unwrap(),
1571      client_buffer_size,
1572      TestOptions::default(),
1573    );
1574
1575    // Read 8 bytes from the start of the server connection and then
1576    // feed them to an Acceptor. Pass that acceptor when we create the
1577    // TlsStream which will populate the rest of the ClientHello and
1578    // properly handshake.
1579    let mut prefix = [0; 8];
1580    server
1581      .read_exact(&mut prefix)
1582      .await
1583      .expect("Failed to read prefix");
1584    let mut acceptor = Acceptor::default();
1585    assert_eq!(
1586      acceptor.read_tls(&mut prefix.as_slice()).unwrap(),
1587      prefix.len()
1588    );
1589
1590    let server = TlsStream::new_server_side_from_acceptor(
1591      acceptor,
1592      server,
1593      Arc::new(move |client_hello| {
1594        Box::pin(make_config(server_alpn(client_hello)))
1595      }),
1596      server_buffer_size,
1597    );
1598
1599    (server, client)
1600  }
1601
1602  async fn tls_pair_handshake_buffer_size(
1603    server_buffer_size: Option<NonZeroUsize>,
1604    client_buffer_size: Option<NonZeroUsize>,
1605  ) -> (TlsStream, TlsStream) {
1606    let (mut server, mut client) =
1607      tls_pair_alpn(&[], server_buffer_size, &[], client_buffer_size).await;
1608    let a = spawn(async move {
1609      server.handshake().await.unwrap();
1610      server
1611    });
1612    let b = spawn(async move {
1613      client.handshake().await.unwrap();
1614      client
1615    });
1616    (a.await.unwrap(), b.await.unwrap())
1617  }
1618
1619  async fn tls_pair_handshake() -> (TlsStream, TlsStream) {
1620    tls_pair_handshake_buffer_size(None, None).await
1621  }
1622
1623  async fn expect_eof_read(stm: &mut (impl AsyncReadExt + Unpin)) {
1624    let mut buf = [0_u8; 1];
1625    let e = stm.read(&mut buf).await.expect("Expected no error");
1626    assert_eq!(e, 0, "expected eof");
1627  }
1628
1629  async fn expect_io_error_read(
1630    stm: &mut (impl AsyncReadExt + Unpin),
1631    kind: io::ErrorKind,
1632  ) {
1633    let mut buf = [0_u8; 1];
1634    let e = stm.read(&mut buf).await.expect_err("Expected error");
1635    assert_eq!(e.kind(), kind);
1636  }
1637
1638  /// Test that automatic state transition works: send and receive work as expected without waiting
1639  /// for the handshake
1640  #[rstest]
1641  #[tokio::test]
1642  async fn test_client_server(
1643    #[values(true, false)] server_slow: bool,
1644    #[values(true, false)] client_slow: bool,
1645    #[values(true, false)] buffer: bool,
1646  ) -> TestResult {
1647    let (mut server, mut client) =
1648      tls_pair_slow_handshake(false, server_slow, client_slow, buffer).await;
1649    let a = spawn(async move {
1650      server.write_all(b"hello?").await.unwrap();
1651      let mut buf = [0; 6];
1652      server.read_exact(&mut buf).await.unwrap();
1653      assert_eq!(buf.as_slice(), b"hello!");
1654    });
1655    let b = spawn(async move {
1656      client.write_all(b"hello!").await.unwrap();
1657      let mut buf = [0; 6];
1658      client.read_exact(&mut buf).await.unwrap();
1659    });
1660    a.await?;
1661    b.await?;
1662
1663    Ok(())
1664  }
1665
1666  /// Test that a flush before a handshake completes works.
1667  #[tokio::test]
1668  #[ntest::timeout(60000)]
1669  async fn test_flush_before_handshake() -> TestResult {
1670    let (mut server, mut client) = tls_pair().await;
1671    server.write_all(b"hello?").await.unwrap();
1672    server.flush().await.unwrap();
1673    let mut buf = [0; 6];
1674    assert_eq!(6, client.read_exact(&mut buf).await.unwrap());
1675    Ok(())
1676  }
1677
1678  #[rstest]
1679  #[tokio::test(flavor = "multi_thread")]
1680  #[ntest::timeout(60000)]
1681  async fn test_read_with_buffered_write(
1682    #[values(true, false)] delay_handshake: bool,
1683    #[values(true, false)] slow_server: bool,
1684    #[values(true, false)] slow_client: bool,
1685    #[values(true, false)] buffer: bool,
1686  ) -> TestResult {
1687    let (mut server, mut client) = tls_pair_slow_handshake(
1688      delay_handshake,
1689      slow_server,
1690      slow_client,
1691      buffer,
1692    )
1693    .await;
1694
1695    let a = tokio::task::spawn(async move {
1696      server.read_u8().await.unwrap();
1697      server.write_u8(1).await.unwrap();
1698    });
1699
1700    let b = tokio::task::spawn(async move {
1701      let buf = [0; 1024];
1702      client.write_all(&buf).await.unwrap();
1703      client.read_u8().await.unwrap();
1704    });
1705
1706    a.await.unwrap();
1707    b.await.unwrap();
1708
1709    Ok(())
1710  }
1711
1712  /// Test that the handshake works, and we get the correct ALPN negotiated values.
1713  #[tokio::test]
1714  #[ntest::timeout(60000)]
1715  async fn test_client_server_alpn() -> TestResult {
1716    let (mut server, mut client) =
1717      tls_pair_alpn(&["a", "b", "c"], None, &["b"], None).await;
1718    let a = spawn(async move {
1719      let handshake = server.handshake().await.unwrap();
1720      assert_eq!(handshake.alpn, Some("b".as_bytes().to_vec()));
1721      assert_eq!(handshake.sni, Some("example.com".into()));
1722      server.write_all(b"hello?").await.unwrap();
1723      let mut buf = [0; 6];
1724      server.read_exact(&mut buf).await.unwrap();
1725      assert_eq!(buf.as_slice(), b"hello!");
1726    });
1727    let b = spawn(async move {
1728      let handshake = client.handshake().await.unwrap();
1729      assert_eq!(handshake.alpn, Some("b".as_bytes().to_vec()));
1730      client.write_all(b"hello!").await.unwrap();
1731      let mut buf = [0; 6];
1732      client.read_exact(&mut buf).await.unwrap();
1733    });
1734    a.await?;
1735    b.await?;
1736
1737    Ok(())
1738  }
1739
1740  fn alpn_handler(
1741    client_hello: ClientHello,
1742  ) -> Result<&'static [&'static str], &'static str> {
1743    if let Some(alpn) = client_hello.alpn() {
1744      for alpn in alpn {
1745        if alpn == b"a" {
1746          return Ok(&["a"]);
1747        }
1748        if alpn == b"b" {
1749          return Ok(&["b"]);
1750        }
1751      }
1752    }
1753    Err("bad server")
1754  }
1755
1756  /// Test that the handshake works, and we get the correct ALPN negotiated values.
1757  #[rstest]
1758  #[case("a")]
1759  #[case("b")]
1760  #[case("c")]
1761  #[tokio::test]
1762  #[ntest::timeout(60000)]
1763  async fn test_client_server_alpn_acceptor(
1764    #[case] alpn: &'static str,
1765    #[values(true, false)] use_from: bool,
1766  ) -> TestResult {
1767    let (mut server, mut client) = if use_from {
1768      tls_pair_alpn_from_acceptor(alpn_handler, None, &[alpn], None).await
1769    } else {
1770      tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await
1771    };
1772    let a = spawn(async move {
1773      if alpn == "c" {
1774        server.handshake().await.expect_err("expected failure");
1775        return;
1776      }
1777      let handshake = server.handshake().await.unwrap();
1778      assert_eq!(handshake.alpn, Some(alpn.as_bytes().to_vec()));
1779      assert_eq!(handshake.sni, Some("example.com".into()));
1780      server.write_all(b"hello?").await.unwrap();
1781      let mut buf = [0; 6];
1782      server.read_exact(&mut buf).await.unwrap();
1783      assert_eq!(buf.as_slice(), b"hello!");
1784    });
1785    let b = spawn(async move {
1786      if alpn == "c" {
1787        client.handshake().await.expect_err("expected failure");
1788        return;
1789      }
1790      let handshake = client.handshake().await.unwrap();
1791      assert_eq!(handshake.alpn, Some(alpn.as_bytes().to_vec()));
1792      client.write_all(b"hello!").await.unwrap();
1793      let mut buf = [0; 6];
1794      client.read_exact(&mut buf).await.unwrap();
1795    });
1796    a.await?;
1797    b.await?;
1798
1799    Ok(())
1800  }
1801
1802  /// Test that the handshake fails, and we get the correct errors on both ends.
1803  #[tokio::test]
1804  #[ntest::timeout(60000)]
1805  async fn test_client_server_alpn_mismatch() -> TestResult {
1806    let (mut server, mut client) =
1807      tls_pair_alpn(&["a"], None, &["b"], None).await;
1808    let a = spawn(async move {
1809      let e = server.handshake().await.expect_err("Expected a failure");
1810      assert_eq!(e.kind(), ErrorKind::InvalidData);
1811      assert_eq!(e.to_string(), "peer doesn't support any known protocol");
1812      let e = server.flush().await.expect_err("Expected a failure");
1813      assert_eq!(e.kind(), ErrorKind::InvalidData);
1814    });
1815    let b = spawn(async move {
1816      let e = client.handshake().await.expect_err("Expected a failure");
1817      assert_eq!(e.kind(), ErrorKind::InvalidData);
1818      assert_eq!(e.to_string(), "received fatal alert: NoApplicationProtocol");
1819      let e = client.flush().await.expect_err("Expected a failure");
1820      assert_eq!(e.kind(), ErrorKind::InvalidData);
1821    });
1822    a.await?;
1823    b.await?;
1824
1825    Ok(())
1826  }
1827
1828  /// Test that the handshake fails, and we get the correct errors on both ends.
1829  #[tokio::test]
1830  #[ntest::timeout(60000)]
1831  async fn test_client_server_raw_connection() -> TestResult {
1832    let (mut server, mut client) =
1833      tls_pair_alpn(&["a"], None, &["a"], None).await;
1834
1835    assert!(server.connection().is_none());
1836    assert!(client.connection().is_none());
1837
1838    server.handshake().await?;
1839    client.handshake().await?;
1840
1841    assert!(server.connection().is_some());
1842    assert!(client.connection().is_some());
1843
1844    Ok(())
1845  }
1846
1847  #[tokio::test]
1848  async fn test_peer_and_local_addresses() {
1849    let (server, client) =
1850      tls_pair_slow_handshake(true, true, true, false).await;
1851    // Use a barrier to keep the client and server sockets alive until the end
1852    let barrier = Arc::new(Barrier::new(2));
1853    let barrier_clone = barrier.clone();
1854    let a = spawn(async move {
1855      loop {
1856        tokio::time::sleep(Duration::from_millis(10)).await;
1857        server.local_addr().unwrap();
1858        server.peer_addr().unwrap();
1859        if server.try_handshake().unwrap().is_some() {
1860          server.local_addr().unwrap();
1861          server.peer_addr().unwrap();
1862          break;
1863        }
1864      }
1865      barrier.wait().await;
1866    });
1867    let b = spawn(async move {
1868      loop {
1869        tokio::time::sleep(Duration::from_millis(10)).await;
1870        client.local_addr().unwrap();
1871        client.peer_addr().unwrap();
1872        if client.try_handshake().unwrap().is_some() {
1873          client.local_addr().unwrap();
1874          client.peer_addr().unwrap();
1875          break;
1876        }
1877      }
1878      barrier_clone.wait().await;
1879    });
1880    a.await.unwrap();
1881    b.await.unwrap();
1882  }
1883
1884  #[rstest]
1885  #[case(false, false)]
1886  #[case(false, true)]
1887  #[case(true, false)]
1888  #[case(true, true)]
1889  #[tokio::test]
1890  async fn test_client_immediate_close(
1891    #[case] server_slow: bool,
1892    #[case] client_slow: bool,
1893  ) -> TestResult {
1894    let (mut server, client) =
1895      tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
1896    let a = spawn(async move {
1897      server.shutdown().await.unwrap();
1898      // While this races the handshake, we are not going to expose a handshake EOF to the stream in a
1899      // regular read.
1900      expect_eof_read(&mut server).await;
1901      drop(server);
1902    });
1903    let b = spawn(async move {
1904      drop(client);
1905    });
1906    a.await?;
1907    b.await?;
1908
1909    Ok(())
1910  }
1911
1912  // ---- stream::tests::test_server_immediate_close stdout ----
1913  // w=Ok(242)
1914  // r(4096)=Ok(242)
1915  // w=Ok(127)
1916  // w=Ok(6)
1917  // w=Ok(32)
1918  // w=Ok(913)
1919  // w=Ok(286)
1920  // w=Ok(74)
1921  // r(4096)=Err(Kind(WouldBlock))
1922  // r(4096)=Ok(1438)
1923  // w=Ok(6)
1924  // w=Ok(74)
1925  // w=Ok(24)
1926  // r(4096)=Err(Kind(WouldBlock))
1927  // r(4096)=Ok(80)
1928  // w=Ok(103)
1929  // w=Ok(103)
1930  // w=Ok(103)
1931  // w=Ok(103)
1932  // w=Ok(24)
1933  // r(4096)=Ok(103)
1934  // r*=Kind(WouldBlock)
1935  // r(4096)=Err(Os { code: 54, kind: ConnectionReset, message: "Connection reset by peer" })
1936  // r*=Kind(WouldBlock)
1937  // thread 'stream::tests::test_server_immediate_close' panicked at 'Expected no error: Kind(ConnectionReset)', src/stream.rs:548:38
1938  // note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
1939  // Error: Custom { kind: Other, error: "task panicked" }
1940
1941  #[rstest]
1942  #[case(false, false)]
1943  #[case(false, true)]
1944  #[case(true, false)]
1945  #[case(true, true)]
1946  #[tokio::test]
1947  async fn test_server_immediate_close(
1948    #[case] server_slow: bool,
1949    #[case] client_slow: bool,
1950  ) -> TestResult {
1951    let (server, mut client) =
1952      tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
1953    let a = spawn(async move {
1954      drop(server);
1955    });
1956    let b = spawn(async move {
1957      client.shutdown().await.unwrap();
1958      // While this races the handshake, we are not going to expose a handshake EOF to the stream in a
1959      // regular read.
1960      expect_eof_read(&mut client).await;
1961      drop(client);
1962    });
1963    a.await?;
1964    b.await?;
1965
1966    Ok(())
1967  }
1968
1969  #[rstest]
1970  #[case(false, false)]
1971  #[case(false, true)]
1972  #[case(true, false)]
1973  #[case(true, true)]
1974  #[tokio::test]
1975  async fn test_orderly_shutdown(
1976    #[case] server_slow: bool,
1977    #[case] client_slow: bool,
1978  ) -> TestResult {
1979    let (mut server, mut client) =
1980      tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
1981    let (tx, rx) = tokio::sync::oneshot::channel();
1982    let a = spawn(async move {
1983      server.write_all(b"hello?").await.unwrap();
1984      let mut buf = [0; 6];
1985      server.read_exact(&mut buf).await.unwrap();
1986      assert_eq!(buf.as_slice(), b"hello!");
1987      // Shut down write, but reads are still open
1988      server.shutdown().await.unwrap();
1989      server.read_exact(&mut buf).await.unwrap();
1990      assert_eq!(buf.as_slice(), b"hello*");
1991      // Tell the client to shut down at some point after we've closed the server TCP socket.
1992      drop(server);
1993      tokio::time::sleep(Duration::from_millis(10)).await;
1994      tx.send(()).unwrap();
1995    });
1996    let b = spawn(async move {
1997      client.write_all(b"hello!").await.unwrap();
1998      let mut buf = [0; 6];
1999      client.read_exact(&mut buf).await.unwrap();
2000      assert_eq!(client.read(&mut buf).await.unwrap(), 0);
2001      client.write_all(b"hello*").await.unwrap();
2002      // The server is long gone by the point we get the message, but it's a clean shutdown
2003      rx.await.unwrap();
2004      client.shutdown().await.unwrap();
2005      drop(client);
2006    });
2007    a.await?;
2008    b.await?;
2009
2010    Ok(())
2011  }
2012
2013  #[rstest]
2014  #[case(false, false)]
2015  #[case(false, true)]
2016  #[case(true, false)]
2017  #[case(true, true)]
2018  #[tokio::test]
2019  async fn test_server_shutdown_after_handshake(
2020    #[case] server_slow: bool,
2021    #[case] client_slow: bool,
2022  ) -> TestResult {
2023    let (mut server, mut client) =
2024      tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
2025    let (tx, rx) = tokio::sync::oneshot::channel();
2026    let a = spawn(async move {
2027      // Shut down after the handshake
2028      server.handshake().await.unwrap();
2029      server.shutdown().await.unwrap();
2030      tx.send(()).unwrap();
2031      expect_io_error(
2032        server.write_all(b"hello?").await,
2033        io::ErrorKind::NotConnected,
2034      );
2035    });
2036    let b = spawn(async move {
2037      // assert!(client.get_ref().1.is_handshaking());
2038      client.handshake().await.unwrap();
2039      rx.await.unwrap();
2040      // Can't read -- server shut down
2041      expect_eof_read(&mut client).await;
2042    });
2043    a.await?;
2044    b.await?;
2045
2046    Ok(())
2047  }
2048
2049  #[rstest]
2050  #[case(false, false)]
2051  #[case(false, true)]
2052  #[case(true, false)]
2053  #[case(true, true)]
2054  #[tokio::test]
2055  async fn test_server_shutdown_before_handshake(
2056    #[case] server_slow: bool,
2057    #[case] client_slow: bool,
2058  ) -> TestResult {
2059    let (mut server, mut client) =
2060      tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
2061    let a = spawn(async move {
2062      let mut futures = FuturesUnordered::new();
2063
2064      // The client handshake must complete before the server shutdown is resolved
2065      futures.push(server.shutdown().map(|_| 1).boxed());
2066      futures.push(client.handshake().map(|_| 2).boxed());
2067
2068      assert_eq!(poll_fn(|cx| futures.poll_next_unpin(cx)).await.unwrap(), 2);
2069      assert_eq!(poll_fn(|cx| futures.poll_next_unpin(cx)).await.unwrap(), 1);
2070      drop(futures);
2071
2072      // Can't read -- server shut down
2073      expect_eof_read(&mut client).await;
2074    });
2075    a.await?;
2076
2077    Ok(())
2078  }
2079
2080  #[rstest]
2081  #[case(false, false)]
2082  #[case(false, true)]
2083  #[case(true, false)]
2084  #[case(true, true)]
2085  #[tokio::test]
2086  async fn test_server_dropped(
2087    #[case] server_slow: bool,
2088    #[case] client_slow: bool,
2089  ) -> TestResult {
2090    let (server, mut client) =
2091      tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
2092    // The server will spawn a task to complete the handshake and then go away
2093    drop(server);
2094    client.handshake().await?;
2095    // Can't read -- server shut down (but it was graceful)
2096    expect_eof_read(&mut client).await;
2097    Ok(())
2098  }
2099
2100  #[tokio::test]
2101  #[ntest::timeout(60000)]
2102  async fn test_server_dropped_after_handshake() -> TestResult {
2103    let (server, mut client) = tls_pair_handshake().await;
2104    drop(server);
2105    // Can't read -- server shut down (but it was graceful)
2106    expect_eof_read(&mut client).await;
2107    Ok(())
2108  }
2109
2110  #[tokio::test]
2111  #[ntest::timeout(60000)]
2112  async fn test_server_dropped_after_handshake_with_write() -> TestResult {
2113    let (mut server, mut client) = tls_pair_handshake().await;
2114    server.write_all(b"XYZ").await.unwrap();
2115    drop(server);
2116    // Can't read -- server shut down (but it was graceful)
2117    let mut buf: [u8; 10] = [0; 10];
2118    assert_eq!(client.read(&mut buf).await.unwrap(), 3);
2119    Ok(())
2120  }
2121
2122  #[rstest]
2123  #[case(false, false)]
2124  #[case(false, true)]
2125  #[case(true, false)]
2126  #[case(true, true)]
2127  #[tokio::test]
2128  async fn test_client_dropped(
2129    #[case] server_slow: bool,
2130    #[case] client_slow: bool,
2131  ) -> TestResult {
2132    let (mut server, client) =
2133      tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
2134    drop(client);
2135    // The client will spawn a task to complete the handshake and then go away
2136    server.handshake().await?;
2137    // Can't read -- server shut down (but it was graceful)
2138    expect_eof_read(&mut server).await;
2139    Ok(())
2140  }
2141
2142  #[tokio::test]
2143  async fn test_server_half_crash_before_handshake() -> TestResult {
2144    let (mut server, mut client) = tls_with_tcp_server(false).await;
2145    // This test occasionally shows up as ConnectionReset on Mac -- the delay ensures we wait long enough
2146    // for the handshake to settle.
2147    tokio::time::sleep(Duration::from_millis(100)).await;
2148    <TcpStream as AsyncWriteExt>::shutdown(&mut server).await?;
2149
2150    let expected = ErrorKind::UnexpectedEof;
2151
2152    expect_io_error(client.handshake().await, expected);
2153    // Can't read -- server shut down. Because this happened before the handshake, it's an unexpected EOF.
2154    expect_io_error_read(&mut client, expected).await;
2155    Ok(())
2156  }
2157
2158  #[tokio::test]
2159  async fn test_server_crash_before_handshake() -> TestResult {
2160    let (mut server, mut client) = tls_with_tcp_server(false).await;
2161    <TcpStream as AsyncWriteExt>::shutdown(&mut server).await?;
2162    drop(server);
2163
2164    let expected = ErrorKind::UnexpectedEof;
2165
2166    expect_io_error(client.handshake().await, expected);
2167    // Can't read -- server shut down. Because this happened before the handshake, it's an unexpected EOF.
2168    expect_io_error_read(&mut client, expected).await;
2169    Ok(())
2170  }
2171
2172  #[tokio::test]
2173  async fn test_server_crash_after_handshake() -> TestResult {
2174    let (server, mut client) = tls_pair_handshake().await;
2175
2176    let (mut tcp, _tls) = server.into_inner().await.unwrap();
2177    <TcpStream as AsyncWriteExt>::shutdown(&mut tcp).await?;
2178    drop(tcp);
2179
2180    // Can't read -- server shut down. This is an unexpected EOF.
2181    expect_io_error_read(&mut client, ErrorKind::UnexpectedEof).await;
2182    Ok(())
2183  }
2184
2185  #[rstest]
2186  #[case(true)]
2187  #[case(false)]
2188  #[tokio::test]
2189  async fn large_transfer_no_buffer_limit_or_handshake(
2190    #[case] swap: bool,
2191  ) -> TestResult {
2192    const BUF_SIZE: usize = 64 * 1024;
2193    const BUF_COUNT: usize = 1024;
2194
2195    let (server, client) = tls_pair().await;
2196
2197    let (mut server, mut client) = if swap {
2198      (client, server)
2199    } else {
2200      (server, client)
2201    };
2202
2203    let a = spawn(async move {
2204      // Heap allocate a large buffer and send it
2205      let buf = vec![42; BUF_COUNT * BUF_SIZE];
2206      server.write_all(&buf).await.unwrap();
2207      assert_eq!(server.read_u8().await.unwrap(), 0xff);
2208      server.shutdown().await.unwrap();
2209      server.close().await.unwrap();
2210    });
2211    let b = spawn(async move {
2212      for _ in 0..BUF_COUNT {
2213        tokio::time::sleep(Duration::from_millis(1)).await;
2214        let mut buf = [0; BUF_SIZE];
2215        assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap());
2216      }
2217      client.write_u8(0xff).await.unwrap();
2218      expect_eof_read(&mut client).await;
2219    });
2220    a.await?;
2221    b.await?;
2222    Ok(())
2223  }
2224
2225  #[rstest]
2226  #[case(true)]
2227  #[case(false)]
2228  #[tokio::test]
2229  async fn large_transfer_with_buffer_limit(#[case] swap: bool) -> TestResult {
2230    const BUF_SIZE: usize = 10 * 1024;
2231    const BUF_COUNT: usize = 1024;
2232
2233    let (server, client) = tls_pair_handshake_buffer_size(
2234      BUF_SIZE.try_into().ok(),
2235      BUF_SIZE.try_into().ok(),
2236    )
2237    .await;
2238
2239    let (mut server, mut client) = if swap {
2240      (client, server)
2241    } else {
2242      (server, client)
2243    };
2244
2245    let a = spawn(async move {
2246      // Heap allocate a large buffer and send it
2247      let buf = vec![42; BUF_COUNT * BUF_SIZE];
2248      server.write_all(&buf).await.unwrap();
2249      server.shutdown().await.unwrap();
2250      server.close().await.unwrap();
2251    });
2252    let b = spawn(async move {
2253      for _ in 0..BUF_COUNT {
2254        tokio::time::sleep(Duration::from_millis(1)).await;
2255        let mut buf = [0; BUF_SIZE];
2256        assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap());
2257      }
2258      expect_eof_read(&mut client).await;
2259    });
2260    a.await?;
2261    b.await?;
2262    Ok(())
2263  }
2264
2265  #[rstest]
2266  #[case(true, &TLS12)]
2267  #[case(false, &TLS12)]
2268  #[case(true, &TLS13)]
2269  #[case(false, &TLS13)]
2270  #[tokio::test]
2271  async fn large_transfer_with_aggressive_close_split(
2272    #[case] swap: bool,
2273    #[case] protocol: &'static SupportedProtocolVersion,
2274  ) -> TestResult {
2275    const BUF_SIZE: usize = 1024;
2276    const BUF_COUNT: usize = 1 * 1024;
2277
2278    let (server, client) =
2279      tls_pair_protocol(NonZeroUsize::new(65536), protocol).await;
2280    let (server, client) = if swap {
2281      (client, server)
2282    } else {
2283      (server, client)
2284    };
2285
2286    let a = spawn(async move {
2287      let (mut r, mut w) = server.into_split();
2288      let barrier = Arc::new(Barrier::new(2));
2289      let barrier2 = barrier.clone();
2290      let a = spawn(async move {
2291        // We want to register a read here to test whether the split read stomps over a write on
2292        // the other half.
2293        tokio::select! {
2294          x = r.read_u8() => { _ = x.expect_err("should have failed") },
2295          _ = barrier.wait() => {}
2296        };
2297        r
2298      });
2299      let b = spawn(async move {
2300        // Heap allocate a large buffer and send it
2301        let mut buf = vec![42; BUF_COUNT * BUF_SIZE];
2302        let mut buf: &mut [u8] = &mut buf;
2303        w.handshake().await.unwrap();
2304        while !buf.is_empty() {
2305          let n = w.write(&buf).await.unwrap();
2306          w.flush().await.unwrap();
2307          buf = &mut buf[n..];
2308          trace!("[TEST] wrote {n}");
2309        }
2310        w.shutdown().await.unwrap();
2311        barrier2.wait().await;
2312        w
2313      });
2314
2315      let r = a.await.unwrap();
2316      let w = b.await.unwrap();
2317      // In TLS1.3, this aggressive close can cause the other side to lose its buffer
2318      // if the handshake is not fully completed because we send a TCP RST if we receive
2319      // anything further.
2320      r.unsplit(w).close().await.unwrap();
2321    });
2322    let b = spawn(async move {
2323      let (mut r, _w) = client.into_split();
2324      let mut buf = vec![0; BUF_SIZE];
2325      for i in 0..BUF_COUNT {
2326        let r = r.read_exact(&mut buf).await;
2327        if let Err(e) = &r {
2328          panic!("Failed to read after {i} of {BUF_COUNT} reads: {e:?}");
2329        };
2330        assert_eq!(BUF_SIZE, r.unwrap());
2331      }
2332      expect_eof_read(&mut r).await;
2333    });
2334    a.await?;
2335    b.await?;
2336    Ok(())
2337  }
2338
2339  #[rstest]
2340  #[case(true)]
2341  #[case(false)]
2342  #[tokio::test(flavor = "current_thread")]
2343  async fn large_transfer_with_shutdown(#[case] swap: bool) -> TestResult {
2344    const BUF_SIZE: usize = 10 * 1024;
2345    const BUF_COUNT: usize = 1024;
2346
2347    let (server, client) = tls_pair_handshake().await;
2348    let (mut server, mut client) = if swap {
2349      (client, server)
2350    } else {
2351      (server, client)
2352    };
2353
2354    let a = spawn(async move {
2355      // Heap allocate a large buffer and send it
2356      let buf = vec![42; BUF_COUNT * BUF_SIZE];
2357      server.write_all(&buf).await.unwrap();
2358      server.shutdown().await.unwrap();
2359      server.close().await.unwrap();
2360    });
2361    let b = spawn(async move {
2362      for _ in 0..BUF_COUNT {
2363        tokio::time::sleep(Duration::from_millis(1)).await;
2364        let mut buf = [0; BUF_SIZE];
2365        assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap());
2366      }
2367      expect_eof_read(&mut client).await;
2368    });
2369    a.await?;
2370    b.await?;
2371    Ok(())
2372  }
2373
2374  #[rstest]
2375  #[case(true)]
2376  #[case(false)]
2377  #[tokio::test(flavor = "current_thread")]
2378  #[ntest::timeout(60000)]
2379  async fn large_transfer_no_shutdown(#[case] swap: bool) -> TestResult {
2380    const BUF_SIZE: usize = 10 * 1024;
2381    const BUF_COUNT: usize = 1024;
2382
2383    let (server, client) = tls_pair_handshake().await;
2384    let (mut server, mut client) = if swap {
2385      (client, server)
2386    } else {
2387      (server, client)
2388    };
2389
2390    let a = spawn(async move {
2391      // Heap allocate a large buffer and send it
2392      let buf = vec![42; BUF_COUNT * BUF_SIZE];
2393      server.write_all(&buf).await.unwrap();
2394      server.close().await.unwrap();
2395    });
2396    let b = spawn(async move {
2397      for _ in 0..BUF_COUNT {
2398        tokio::time::sleep(Duration::from_millis(1)).await;
2399        let mut buf = [0; BUF_SIZE];
2400        assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap());
2401      }
2402      expect_eof_read(&mut client).await;
2403    });
2404    a.await?;
2405    b.await?;
2406    Ok(())
2407  }
2408
2409  /// One byte read/write, don't check close.
2410  #[rstest]
2411  #[case(true, 1024, 1024, 1024)]
2412  #[case(false, 1024, 1024, 1024)]
2413  #[case(true, 1024, 16, 1024)]
2414  #[case(false, 1024, 16, 1024)]
2415  #[case(true, 1024, 10000, 1)]
2416  #[case(false, 1024, 10000, 1)]
2417  #[case(true, 32, 16, 16)]
2418  #[case(false, 32, 16, 16)]
2419  #[tokio::test]
2420  async fn vectored_stream_write(
2421    #[case] handshake_first: bool,
2422    #[case] expected: usize,
2423    #[case] first: usize,
2424    #[case] second: usize,
2425  ) -> TestResult {
2426    let (mut server, mut client) =
2427      tls_pair_buffer_size(Some(NonZeroUsize::try_from(1024).unwrap())).await;
2428    if handshake_first {
2429      server.handshake().await.unwrap();
2430      server.flush().await.unwrap();
2431      client.handshake().await.unwrap();
2432      client.flush().await.unwrap();
2433    }
2434    let n = client
2435      .write_vectored(&[
2436        IoSlice::new(&vec![1; first]),
2437        IoSlice::new(&vec![2; second]),
2438      ])
2439      .await
2440      .expect("failed to write");
2441    assert_eq!(n, expected);
2442    let mut buf = [0; 2048];
2443    // Note that we need to flush to make progress on writes!
2444    client.flush().await.expect("failed to flush");
2445    // We need the TCP stack to send all the writes -- in release mode this is sometimes too fast
2446    tokio::time::sleep(Duration::from_millis(1)).await;
2447    let n = server.read(&mut buf).await.expect("failed to read");
2448    assert_eq!(n, expected);
2449    Ok(())
2450  }
2451
2452  /// Test that the peer_certificates are not available before handshake.
2453  #[tokio::test]
2454  async fn test_split_peer_certificates_before_handshake() -> TestResult {
2455    let (server, client) = tls_pair().await;
2456
2457    let (server_read, server_write) = server.into_split();
2458    let (client_read, client_write) = client.into_split();
2459
2460    // Test that handshake returns None before completion
2461    assert!(
2462      server_read.try_handshake()?.is_none(),
2463      "Server handshake should be None before completion"
2464    );
2465    assert!(
2466      server_write.try_handshake()?.is_none(),
2467      "Server handshake should be None before completion"
2468    );
2469    assert!(
2470      client_read.try_handshake()?.is_none(),
2471      "Client handshake should be None before completion"
2472    );
2473    assert!(
2474      client_write.try_handshake()?.is_none(),
2475      "Client handshake should be None before completion"
2476    );
2477
2478    Ok(())
2479  }
2480
2481  /// Test that the peer_certificates are available via handshake after completion.
2482  #[tokio::test]
2483  async fn test_split_peer_certificates_access() -> TestResult {
2484    let (server, client) = tls_pair_handshake().await;
2485
2486    let (server_read, server_write) = server.into_split();
2487    let (client_read, client_write) = client.into_split();
2488
2489    // Test that peer_certificates are available via handshake after completion
2490    let server_read_handshake = server_read.try_handshake()?.unwrap();
2491    let server_write_handshake = server_write.try_handshake()?.unwrap();
2492    let client_read_handshake = client_read.try_handshake()?.unwrap();
2493    let client_write_handshake = client_write.try_handshake()?.unwrap();
2494
2495    // Both halves should return the same peer certificates via handshake
2496    assert_eq!(
2497      server_read_handshake.peer_certificates.is_some(),
2498      server_write_handshake.peer_certificates.is_some()
2499    );
2500    assert_eq!(
2501      client_read_handshake.peer_certificates.is_some(),
2502      client_write_handshake.peer_certificates.is_some()
2503    );
2504
2505    if let (Some(read_certs), Some(write_certs)) = (
2506      &server_read_handshake.peer_certificates,
2507      &server_write_handshake.peer_certificates,
2508    ) {
2509      assert_eq!(read_certs.len(), write_certs.len());
2510    }
2511
2512    if let (Some(read_certs), Some(write_certs)) = (
2513      &client_read_handshake.peer_certificates,
2514      &client_write_handshake.peer_certificates,
2515    ) {
2516      assert_eq!(read_certs.len(), write_certs.len());
2517    }
2518
2519    Ok(())
2520  }
2521}