1use crate::adapter::clone_error;
4use crate::adapter::clone_result;
5use crate::adapter::read_acceptor;
6use crate::adapter::rustls_to_io_error;
7use crate::adapter::write_acceptor_alert;
8use crate::connection_stream::ConnectionStream;
9use crate::handshake::handshake_task;
10use crate::handshake::HandshakeResult;
11use crate::trace;
12use crate::TestOptions;
13use derive_io::AsyncRead;
14use derive_io::AsyncWrite;
15use futures::task::AtomicWaker;
16use futures::FutureExt;
17use rustls::server::Acceptor;
18use rustls::server::ClientHello;
19use rustls::version::TLS13;
20use rustls::ClientConnection;
21use rustls::Connection;
22use rustls::ServerConfig;
23use rustls::ServerConnection;
24use socket2::SockRef;
25use std::any::Any;
26use std::fmt::Debug;
27use std::future::poll_fn;
28use std::future::Future;
29use std::io;
30use std::io::ErrorKind;
31use std::io::Write;
32use std::num::NonZeroUsize;
33use std::pin::Pin;
34use std::sync::Arc;
35use std::sync::Mutex;
36
37use std::task::ready;
38use std::task::Context;
39use std::task::Poll;
40use std::task::Waker;
41use std::thread::sleep;
42use std::time::Duration;
43use tokio::io::AsyncRead;
44use tokio::io::AsyncWrite;
45use tokio::io::ReadBuf;
46use tokio::net::TcpStream;
47use tokio::spawn;
48use tokio::task::spawn_blocking;
49use tokio::task::JoinError;
50use tokio::task::JoinHandle;
51
52#[derive(Clone)]
56struct DeferredWakers {
57 wakers: Arc<Mutex<DeferredWakersInner>>,
58}
59
60#[derive(Default)]
61enum DeferredWakersInner {
62 #[default]
66 Woke,
67 Pending(Option<Waker>, Option<Waker>),
69}
70
71impl DeferredWakers {
72 pub fn wake(&self) {
73 match std::mem::take(&mut *self.wakers.lock().unwrap()) {
74 DeferredWakersInner::Pending(mut read, mut write) => {
75 if let Some(read) = read.take() {
76 read.wake();
77 }
78 if let Some(write) = write.take() {
79 write.wake();
80 }
81 }
82 DeferredWakersInner::Woke => {}
83 }
84 }
85
86 pub fn set_read_waker(&self, waker: &Waker) {
88 let mut lock = self.wakers.lock().unwrap();
89 match &mut *lock {
90 DeferredWakersInner::Pending(read, _write) => *read = Some(waker.clone()),
91 DeferredWakersInner::Woke => waker.wake_by_ref(),
92 }
93 }
94
95 pub fn set_write_waker(&self, waker: &Waker) {
97 let mut lock = self.wakers.lock().unwrap();
98 match &mut *lock {
99 DeferredWakersInner::Pending(_read, write) => {
100 *write = Some(waker.clone())
101 }
102 DeferredWakersInner::Woke => waker.wake_by_ref(),
103 }
104 }
105}
106
107impl Default for DeferredWakers {
108 fn default() -> Self {
109 Self {
110 wakers: Arc::new(Mutex::new(DeferredWakersInner::Pending(None, None))),
111 }
112 }
113}
114
115#[derive(Default)]
116struct HandshakeWatch {
117 handshake: Mutex<Option<io::Result<TlsHandshake>>>,
118 rx_waker: AtomicWaker,
119 tx_waker: AtomicWaker,
120}
121
122#[allow(clippy::large_enum_variant)]
123enum TlsStreamState<S: UnderlyingStream> {
124 Handshaking {
127 handle: JoinHandle<io::Result<HandshakeResult<S>>>,
128 wakers: DeferredWakers,
129 write_buf: Vec<u8>,
130 underlying: Arc<S>,
131 },
132 Open(ConnectionStream<S>),
134 Closed,
136 ClosedError(io::Error),
138}
139
140pub type ServerConfigProvider = Arc<
141 dyn Fn(
142 ClientHello<'_>,
143 ) -> Pin<
144 Box<dyn Future<Output = Result<Arc<ServerConfig>, io::Error>> + Send>,
145 > + Send
146 + Sync,
147>;
148
149pub trait UnderlyingStream: Debug + Send + Sync + Sized + 'static {
150 type StdType: Send;
151 fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
152 fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
153 fn try_read(&self, buf: &mut [u8]) -> io::Result<usize>;
154 fn try_write(&self, buf: &[u8]) -> io::Result<usize>;
155 fn readable(&self) -> impl Future<Output = io::Result<()>> + Send;
156 fn writable(&self) -> impl Future<Output = io::Result<()>> + Send;
157
158 fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()>;
159
160 fn into_std(self) -> Option<std::io::Result<Self::StdType>> {
161 None
162 }
163
164 fn downcast<S: UnderlyingStream>(self) -> Result<S, Self> {
165 let mut holder = Some(self);
166 let stream = &mut holder as &mut dyn Any;
167 if let Some(stream) = stream.downcast_mut::<Option<S>>() {
168 Ok(stream.take().unwrap())
169 } else {
170 Err(holder.take().unwrap())
171 }
172 }
173}
174
175impl UnderlyingStream for TcpStream {
176 type StdType = std::net::TcpStream;
177 #[inline(always)]
178 fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
179 self.poll_read_ready(cx)
180 }
181 #[inline(always)]
182 fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
183 self.poll_write_ready(cx)
184 }
185 #[inline(always)]
186 fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
187 self.try_read(buf)
188 }
189 #[inline(always)]
190 fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
191 self.try_write(buf)
192 }
193 #[inline(always)]
194 fn readable(&self) -> impl Future<Output = io::Result<()>> + Send {
195 self.readable()
196 }
197 #[inline(always)]
198 fn writable(&self) -> impl Future<Output = io::Result<()>> + Send {
199 self.writable()
200 }
201 #[inline(always)]
202 fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
203 SockRef::from(&self).shutdown(how)
204 }
205 #[inline(always)]
206 fn into_std(self) -> Option<std::io::Result<std::net::TcpStream>> {
207 Some(self.into_std())
208 }
209}
210
211#[cfg(unix)]
212impl UnderlyingStream for tokio::net::UnixStream {
213 type StdType = std::os::unix::net::UnixStream;
214 #[inline(always)]
215 fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
216 self.poll_read_ready(cx)
217 }
218 #[inline(always)]
219 fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
220 self.poll_write_ready(cx)
221 }
222 #[inline(always)]
223 fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
224 self.try_read(buf)
225 }
226 #[inline(always)]
227 fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
228 self.try_write(buf)
229 }
230 #[inline(always)]
231 fn readable(&self) -> impl Future<Output = io::Result<()>> + Send {
232 self.readable()
233 }
234 #[inline(always)]
235 fn writable(&self) -> impl Future<Output = io::Result<()>> + Send {
236 self.writable()
237 }
238 #[inline(always)]
239 fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
240 SockRef::from(&self).shutdown(how)
241 }
242 #[inline(always)]
243 fn into_std(self) -> Option<std::io::Result<std::os::unix::net::UnixStream>> {
244 Some(self.into_std())
245 }
246}
247
248pub struct TlsStream<S: UnderlyingStream> {
250 state: TlsStreamState<S>,
251
252 handshake: Arc<HandshakeWatch>,
253 buffer_size: Option<NonZeroUsize>,
254}
255
256impl<S: UnderlyingStream> Debug for TlsStream<S> {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 match &self.state {
259 TlsStreamState::Handshaking { .. } => {
260 f.write_str("TlsStream { Handshaking }")
261 }
262 TlsStreamState::Open(..) => f.write_fmt(format_args!(
263 "TlsStream {{ Open, handshake: {:?} }}",
264 self.handshake.handshake.lock().unwrap()
265 )),
266 TlsStreamState::Closed => f.write_str("TlsStream { Closed }"),
267 TlsStreamState::ClosedError(err) => {
268 f.write_fmt(format_args!("TlsStream {{ Closed, error: {:?} }}", err))
269 }
270 }
271 }
272}
273
274#[derive(Clone, Debug)]
276pub struct TlsHandshake {
277 pub alpn: Option<Vec<u8>>,
278 pub sni: Option<String>,
279 pub has_peer_certificates: bool,
282 pub peer_certificates:
284 Option<Vec<rustls::pki_types::CertificateDer<'static>>>,
285}
286
287impl TlsStream<TcpStream> {
288 pub fn linger(&self) -> Result<Option<Duration>, io::Error> {
289 match &self.state {
290 TlsStreamState::Open(stm) => stm.underlying_stream().linger(),
291 TlsStreamState::Handshaking {
292 underlying: tcp, ..
293 } => tcp.linger(),
294 TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
295 Err(std::io::ErrorKind::NotConnected.into())
296 }
297 }
298 }
299
300 pub fn set_linger(&self, dur: Option<Duration>) -> Result<(), io::Error> {
301 match &self.state {
302 TlsStreamState::Open(stm) => stm.underlying_stream().set_linger(dur),
303 TlsStreamState::Handshaking {
304 underlying: tcp, ..
305 } => tcp.set_linger(dur),
306 TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
307 Err(std::io::ErrorKind::NotConnected.into())
308 }
309 }
310 }
311
312 pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
314 match &self.state {
315 TlsStreamState::Open(stm) => stm.underlying_stream().peer_addr(),
316 TlsStreamState::Handshaking {
317 underlying: tcp, ..
318 } => tcp.peer_addr(),
319 TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
320 Err(std::io::ErrorKind::NotConnected.into())
321 }
322 }
323 }
324
325 pub fn local_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
327 match &self.state {
328 TlsStreamState::Open(stm) => stm.underlying_stream().local_addr(),
329 TlsStreamState::Handshaking {
330 underlying: tcp, ..
331 } => tcp.local_addr(),
332 TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
333 Err(std::io::ErrorKind::NotConnected.into())
334 }
335 }
336 }
337}
338
339#[cfg(unix)]
340impl TlsStream<tokio::net::UnixStream> {
341 pub fn peer_addr(&self) -> Result<tokio::net::unix::SocketAddr, io::Error> {
342 match &self.state {
343 TlsStreamState::Open(stm) => stm.underlying_stream().peer_addr(),
344 TlsStreamState::Handshaking {
345 underlying: tcp, ..
346 } => tcp.peer_addr(),
347 TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
348 Err(std::io::ErrorKind::NotConnected.into())
349 }
350 }
351 }
352
353 pub fn local_addr(&self) -> Result<tokio::net::unix::SocketAddr, io::Error> {
354 match &self.state {
355 TlsStreamState::Open(stm) => stm.underlying_stream().local_addr(),
356 TlsStreamState::Handshaking {
357 underlying: tcp, ..
358 } => tcp.local_addr(),
359 TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
360 Err(std::io::ErrorKind::NotConnected.into())
361 }
362 }
363 }
364}
365
366impl<S: UnderlyingStream + 'static> TlsStream<S> {
367 fn new(
368 tcp: S,
369 mut tls: Connection,
370 buffer_size: Option<NonZeroUsize>,
371 test_options: TestOptions,
372 ) -> Self {
373 tls.set_buffer_limit(buffer_size.map(|s| s.get()));
374 let handshake = Arc::new(HandshakeWatch::default());
375 let wakers = DeferredWakers::default();
376 let wakers_clone = wakers.clone();
377 let tcp = Arc::new(tcp);
378 let tcp_handshake = tcp.clone();
379
380 let handshake_send = handshake.clone();
381 let handle = spawn(async move {
382 let res =
383 send_handshake(tcp_handshake, Ok(tls), test_options, handshake_send)
384 .await;
385
386 wakers_clone.wake();
388
389 res
390 });
391
392 Self {
393 state: TlsStreamState::Handshaking {
394 handle,
395 wakers,
396 write_buf: vec![],
397 underlying: tcp,
398 },
399 handshake,
400 buffer_size,
401 }
402 }
403
404 async fn accept(
405 mut acceptor: Acceptor,
406 tcp_handshake: &S,
407 server_config_provider: ServerConfigProvider,
408 ) -> Result<ServerConnection, io::Error> {
409 loop {
410 tcp_handshake.readable().await?;
411 if read_acceptor(tcp_handshake, &mut acceptor)? < 1 {
413 return Err(io::ErrorKind::ConnectionReset.into());
414 }
415
416 let accepted = match acceptor.accept() {
417 Ok(Some(accepted)) => accepted,
418 Ok(None) => continue,
419 Err((e, alert)) => {
420 tcp_handshake.writable().await?;
421 write_acceptor_alert(tcp_handshake, alert)?;
422 return Err(rustls_to_io_error(e));
423 }
424 };
425
426 let config = match server_config_provider(accepted.client_hello()).await {
427 Ok(config) => config,
428 Err(err) => {
429 const FATAL_ALERT: &[u8] = b"\x15\x03\x03\x00\x02\x02\x00";
443 for c in FATAL_ALERT {
444 tcp_handshake.writable().await?;
445 tcp_handshake.try_write(&[*c])?;
446 }
447 return Err(err);
448 }
449 };
450 match accepted.into_connection(config) {
451 Ok(tls) => {
452 return Ok(tls);
453 }
454 Err((e, alert)) => {
455 tcp_handshake.writable().await?;
456 write_acceptor_alert(tcp_handshake, alert)?;
457 return Err(rustls_to_io_error(e));
458 }
459 }
460 }
461 }
462
463 fn new_server_acceptor(
464 acceptor: Acceptor,
465 tcp: S,
466 server_config_provider: ServerConfigProvider,
467 buffer_size: Option<NonZeroUsize>,
468 test_options: TestOptions,
469 ) -> Self {
470 let handshake = Arc::new(HandshakeWatch::default());
471 let wakers = DeferredWakers::default();
472 let wakers_clone = wakers.clone();
473 let tcp = Arc::new(tcp);
474 let tcp_handshake = tcp.clone();
475
476 let handshake_send = handshake.clone();
477
478 let handle = spawn(async move {
479 let tls =
480 Self::accept(acceptor, &tcp_handshake, server_config_provider).await;
481 let res = send_handshake(
482 tcp_handshake,
483 tls.map(rustls::Connection::Server),
484 test_options,
485 handshake_send,
486 )
487 .await;
488
489 wakers_clone.wake();
491
492 res
493 });
494
495 Self {
496 state: TlsStreamState::Handshaking {
497 handle,
498 wakers,
499 write_buf: vec![],
500 underlying: tcp,
501 },
502 handshake,
503 buffer_size,
504 }
505 }
506
507 pub fn new_client_side(
508 tcp: S,
509 tls: ClientConnection,
510 buffer_size: Option<NonZeroUsize>,
511 ) -> Self {
512 Self::new(
513 tcp,
514 Connection::Client(tls),
515 buffer_size,
516 TestOptions::default(),
517 )
518 }
519
520 #[cfg(test)]
521 pub(crate) fn new_client_side_test_options(
522 tcp: S,
523 tls_config: Arc<rustls::ClientConfig>,
524 server_name: rustls::pki_types::ServerName<'_>,
525 buffer_size: Option<NonZeroUsize>,
526 test_options: TestOptions,
527 ) -> Self {
528 let tls =
529 ClientConnection::new(tls_config, server_name.to_owned()).unwrap();
530 Self::new(tcp, Connection::Client(tls), buffer_size, test_options)
531 }
532
533 pub fn new_client_side_from(
534 tcp: S,
535 connection: ClientConnection,
536 buffer_size: Option<NonZeroUsize>,
537 ) -> Self {
538 Self::new(
539 tcp,
540 Connection::Client(connection),
541 buffer_size,
542 TestOptions::default(),
543 )
544 }
545
546 #[cfg(test)]
547 pub(crate) fn new_server_side_test_options(
548 tcp: S,
549 tls_config: Arc<ServerConfig>,
550 buffer_size: Option<NonZeroUsize>,
551 test_options: TestOptions,
552 ) -> Self {
553 let tls = ServerConnection::new(tls_config).unwrap();
554 Self::new(tcp, Connection::Server(tls), buffer_size, test_options)
555 }
556
557 pub fn new_server_side(
558 tcp: S,
559 tls_config: Arc<ServerConfig>,
560 buffer_size: Option<NonZeroUsize>,
561 ) -> Self {
562 let tls = ServerConnection::new(tls_config).unwrap();
563 Self::new(
564 tcp,
565 Connection::Server(tls),
566 buffer_size,
567 TestOptions::default(),
568 )
569 }
570
571 pub fn new_server_side_acceptor(
575 tcp: S,
576 server_config_provider: ServerConfigProvider,
577 buffer_size: Option<NonZeroUsize>,
578 ) -> Self {
579 Self::new_server_acceptor(
580 Acceptor::default(),
581 tcp,
582 server_config_provider,
583 buffer_size,
584 TestOptions::default(),
585 )
586 }
587
588 pub fn new_server_side_from_acceptor(
595 acceptor: Acceptor,
596 tcp: S,
597 server_config_provider: ServerConfigProvider,
598 buffer_size: Option<NonZeroUsize>,
599 ) -> Self {
600 Self::new_server_acceptor(
601 acceptor,
602 tcp,
603 server_config_provider,
604 buffer_size,
605 TestOptions::default(),
606 )
607 }
608
609 pub fn new_server_side_from(
610 tcp: S,
611 connection: ServerConnection,
612 buffer_size: Option<NonZeroUsize>,
613 ) -> Self {
614 Self::new(
615 tcp,
616 Connection::Server(connection),
617 buffer_size,
618 TestOptions::default(),
619 )
620 }
621
622 pub fn try_into_inner(mut self) -> Result<(S, Connection), Self> {
624 match self.state {
625 TlsStreamState::Open(_) => {
626 let TlsStreamState::Open(stm) =
627 std::mem::replace(&mut self.state, TlsStreamState::Closed)
628 else {
629 unreachable!()
630 };
631 Ok(stm.into_inner())
632 }
633 _ => Err(self),
634 }
635 }
636
637 pub fn into_split(self) -> (TlsStreamRead<S>, TlsStreamWrite<S>) {
638 let handshake1 = self.handshake.clone();
639 let handshake2 = self.handshake.clone();
640 let tcp = match &self.state {
641 TlsStreamState::Handshaking {
642 underlying: tcp, ..
643 } => Some(tcp.clone()),
644 TlsStreamState::Open(conn) => Some(conn.underlying_stream().clone()),
645 _ => None,
646 };
647 let (r, w) = tokio::io::split(self);
648 let read = TlsStreamRead {
649 r,
650 handshake: handshake1,
651 tcp: tcp.clone(),
652 };
653 let write = TlsStreamWrite {
654 w,
655 handshake: handshake2,
656 tcp,
657 };
658 (read, write)
659 }
660
661 pub fn connection(&self) -> Option<&rustls::Connection> {
663 match &self.state {
664 TlsStreamState::Open(stm) => Some(stm.connection()),
665 _ => None,
666 }
667 }
668
669 pub async fn into_inner(mut self) -> io::Result<(S, Connection)> {
670 poll_fn(|cx| self.poll_pending_handshake(cx)).await?;
671 match std::mem::replace(&mut self.state, TlsStreamState::Closed) {
672 TlsStreamState::Open(stm) => Ok(stm.into_inner()),
673 TlsStreamState::Closed => Err(ErrorKind::NotConnected.into()),
674 TlsStreamState::ClosedError(err) => Err(err),
675 TlsStreamState::Handshaking { .. } => unreachable!(),
676 }
677 }
678
679 pub fn poll_handshake(
680 &mut self,
681 cx: &mut Context,
682 ) -> Poll<io::Result<TlsHandshake>> {
683 ready!(self.poll_pending_handshake(cx)?);
685
686 match &*self.handshake.handshake.lock().unwrap() {
688 None => {
689 self.handshake.rx_waker.register(cx.waker());
691 self.handshake.tx_waker.register(cx.waker());
692 Poll::Pending
693 }
694 Some(handshake) => Poll::Ready(clone_result(handshake)),
695 }
696 }
697
698 pub async fn handshake(&mut self) -> io::Result<TlsHandshake> {
699 poll_fn(|cx| self.poll_handshake(cx)).await
700 }
701
702 pub fn try_handshake(&self) -> io::Result<Option<TlsHandshake>> {
704 match &*self.handshake.handshake.lock().unwrap() {
705 None => Ok(None),
706 Some(r) => clone_result(r).map(Some),
707 }
708 }
709
710 fn finalize_handshake(
711 &mut self,
712 join_result: Result<io::Result<HandshakeResult<S>>, JoinError>,
713 ) -> io::Result<()> {
714 trace!("finalize handshake");
715 match std::mem::replace(&mut self.state, TlsStreamState::Closed) {
716 TlsStreamState::Handshaking {
717 wakers,
718 write_buf: buf,
719 ..
720 } => {
721 trace!("join={join_result:?}");
722 match join_result {
723 Err(err) => {
724 self.state = TlsStreamState::ClosedError(ErrorKind::Other.into());
726 if err.is_panic() {
727 std::panic::resume_unwind(err.into_panic());
729 } else {
730 unreachable!("Task should not have been cancelled");
731 }
732 }
733 Ok(Err(err)) => {
734 self.state = TlsStreamState::ClosedError(clone_error(&err));
735 Err(err)
736 }
737 Ok(Ok(result)) => {
738 let (tcp, tls) = result.into_inner();
740 let mut stm = ConnectionStream::new(tcp, tls);
741 trace!("hs buf={}", buf.len());
742 stm.write_buf_fully(&buf);
745
746 wakers.wake();
747 self.state = TlsStreamState::Open(stm);
748 Ok(())
749 }
750 }
751 }
752 _ => unreachable!(),
753 }
754 }
755
756 fn poll_pending_handshake(
758 &mut self,
759 cx: &mut Context<'_>,
760 ) -> Poll<io::Result<()>> {
761 match &mut self.state {
762 TlsStreamState::Handshaking { handle, .. } => {
763 let res = ready!(handle.poll_unpin(cx));
764 Poll::Ready(self.finalize_handshake(res))
765 }
766 _ => Poll::Ready(Ok(())),
767 }
768 }
769
770 fn poll_shutdown_or_abort(
772 mut self: Pin<&mut Self>,
773 cx: &mut Context<'_>,
774 abort: bool,
775 ) -> Poll<io::Result<()>> {
776 let res = if abort {
777 match self.poll_pending_handshake(cx) {
779 Poll::Pending => {
780 self.state = TlsStreamState::Closed;
781 return Poll::Ready(Ok(()));
782 }
783 Poll::Ready(res) => res,
784 }
785 } else {
786 ready!(self.poll_pending_handshake(cx))
787 };
788
789 if let Err(err) = res {
790 self.state = TlsStreamState::ClosedError(err);
791 }
792
793 match &mut self.state {
794 TlsStreamState::Handshaking { .. } => {
796 unreachable!()
797 }
798 TlsStreamState::Open(stm) => {
799 let _res = ready!(stm.poll_shutdown(cx));
800 Poll::Ready(Ok(()))
803 }
804 TlsStreamState::Closed => Poll::Ready(Ok(())),
806 TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
808 }
809 }
810
811 pub async fn close(mut self) -> io::Result<()> {
812 trace!("closing {self:?}");
813 let state = std::mem::replace(&mut self.state, TlsStreamState::Closed);
814 match state {
815 TlsStreamState::Handshaking {
816 handle,
817 wakers,
818 write_buf: buf,
819 ..
820 } => {
821 wakers.wake();
822 match handle.await {
823 Ok(Ok(result)) => {
824 let (tcp, tls) = result.into_inner();
826 let mut stm = ConnectionStream::new(tcp, tls);
827 poll_fn(|cx| stm.poll_write(cx, &buf)).await?;
828 poll_fn(|cx| stm.poll_shutdown(cx)).await?;
829 nonblocking_tcp_drop(stm);
830 }
831 Err(err) => {
832 if err.is_panic() {
833 std::panic::resume_unwind(err.into_panic());
835 } else {
836 unreachable!("Task should not have been cancelled");
837 }
838 }
839 Ok(Err(err)) => {
840 return Err(err);
841 }
842 }
843 }
844 TlsStreamState::Open(mut stm) => {
845 poll_fn(|cx| stm.poll_shutdown(cx)).await?;
846 nonblocking_tcp_drop(stm);
847 }
848 TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
849 }
851 }
852
853 Ok(())
854 }
855}
856
857impl<S: UnderlyingStream> TlsStream<S> {
858 pub fn underlying_stream(&self) -> Option<&S> {
860 match &self.state {
861 TlsStreamState::Open(stm) => Some(stm.underlying_stream()),
862 TlsStreamState::Handshaking {
863 underlying: tcp, ..
864 } => Some(tcp),
865 _ => None,
866 }
867 }
868}
869
870async fn send_handshake<S: UnderlyingStream>(
871 tcp: Arc<S>,
872 tls: Result<Connection, io::Error>,
873 test_options: TestOptions,
874 handshake: Arc<HandshakeWatch>,
875) -> Result<HandshakeResult<S>, io::Error> {
876 let tls = match tls {
877 Ok(tls) => tls,
878 Err(err) => {
879 *handshake.handshake.lock().unwrap() = Some(Err(clone_error(&err)));
880 handshake.rx_waker.wake();
881 handshake.tx_waker.wake();
882 return Err(err);
883 }
884 };
885
886 #[cfg(test)]
887 if test_options.delay_handshake {
888 tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
889 }
890 let res = handshake_task(tcp, tls, test_options).await;
891 match &res {
892 Ok(res) => {
893 let peer_certificates = res
894 .1
895 .peer_certificates()
896 .map(|certs| certs.iter().map(|cert| cert.clone()).collect());
897 let has_peer_certificates = peer_certificates
898 .as_ref()
899 .map(|c: &Vec<rustls::pki_types::CertificateDer<'static>>| {
900 !c.is_empty()
901 })
902 .unwrap_or_default();
903 let alpn = res.1.alpn_protocol().map(|v| v.to_owned());
904 let sni = match &res.1 {
905 Connection::Server(server) => {
906 server.server_name().map(|s| s.to_owned())
907 }
908 _ => None,
909 };
910 *handshake.handshake.lock().unwrap() = Some(Ok(TlsHandshake {
911 alpn,
912 sni,
913 has_peer_certificates,
914 peer_certificates,
915 }));
916 }
917 Err(err) => {
918 *handshake.handshake.lock().unwrap() = Some(Err(clone_error(err)));
919 }
920 }
921 handshake.rx_waker.wake();
922 handshake.tx_waker.wake();
923 res
924}
925
926fn nonblocking_tcp_drop<S: UnderlyingStream>(stm: ConnectionStream<S>) {
932 let (inner, tls) = stm.into_inner();
935 if matches!(tls, Connection::Client(_))
936 && tls.protocol_version() == Some(TLS13.version)
937 {
938 if let Ok(tcp) = inner.downcast::<TcpStream>() {
939 if let Ok(tcp) = tcp.into_std() {
940 spawn_blocking(move || {
941 trace!("in drop tcp task");
942 sleep(Duration::from_millis(100));
943 drop(tcp);
944 trace!("done drop tcp task");
945 });
946 }
947 }
948 }
949}
950
951impl<S: UnderlyingStream> AsyncRead for TlsStream<S> {
952 fn poll_read(
953 mut self: Pin<&mut Self>,
954 cx: &mut Context<'_>,
955 buf: &mut ReadBuf<'_>,
956 ) -> Poll<io::Result<()>> {
957 loop {
958 break match &mut self.state {
959 TlsStreamState::Handshaking { handle, wakers, .. } => {
960 if handle.is_finished() {
962 let res = ready!(handle.poll_unpin(cx));
964 self.finalize_handshake(res)?;
965 continue;
966 }
967
968 wakers.set_read_waker(cx.waker());
970
971 Poll::Pending
972 }
973 TlsStreamState::Open(ref mut stm) => {
974 match std::task::ready!(stm.poll_read(cx, buf)) {
975 Ok(_n) => {
976 Poll::Ready(Ok(()))
978 }
979 Err(err) => Poll::Ready(Err(err)),
980 }
981 }
982 TlsStreamState::Closed => Poll::Ready(Ok(())),
983 TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
984 };
985 }
986 }
987}
988
989impl<S: UnderlyingStream> AsyncWrite for TlsStream<S> {
990 fn poll_write(
991 mut self: Pin<&mut Self>,
992 cx: &mut Context<'_>,
993 buf: &[u8],
994 ) -> Poll<io::Result<usize>> {
995 let buffer_size = self.buffer_size;
997 loop {
998 break match &mut self.state {
999 TlsStreamState::Handshaking {
1000 handle,
1001 wakers,
1002 write_buf,
1003 ..
1004 } => {
1005 if handle.is_finished() {
1007 let res = ready!(handle.poll_unpin(cx));
1009 self.finalize_handshake(res)?;
1010 continue;
1011 }
1012
1013 if let Some(buffer_size) = buffer_size {
1014 let remaining = buffer_size.get() - write_buf.len();
1015 if remaining == 0 {
1016 wakers.set_write_waker(cx.waker());
1018 trace!("write limit");
1019 Poll::Pending
1020 } else {
1021 trace!("write buf");
1022 if buf.len() <= remaining {
1023 write_buf.extend_from_slice(buf);
1024 Poll::Ready(Ok(buf.len()))
1025 } else {
1026 write_buf.extend_from_slice(&buf[0..remaining]);
1027 Poll::Ready(Ok(remaining))
1028 }
1029 }
1030 } else {
1031 trace!("write buf");
1032 write_buf.extend_from_slice(buf);
1033 Poll::Ready(Ok(buf.len()))
1034 }
1035 }
1036 TlsStreamState::Open(ref mut stm) => stm.poll_write(cx, buf),
1037 TlsStreamState::Closed => {
1038 Poll::Ready(Err(ErrorKind::NotConnected.into()))
1039 }
1040 TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
1041 };
1042 }
1043 }
1044
1045 fn poll_write_vectored(
1046 mut self: Pin<&mut Self>,
1047 cx: &mut Context<'_>,
1048 bufs: &[std::io::IoSlice<'_>],
1049 ) -> Poll<Result<usize, io::Error>> {
1050 let buffer_size = self.buffer_size;
1052 loop {
1053 break match &mut self.state {
1054 TlsStreamState::Handshaking {
1055 handle,
1056 wakers,
1057 write_buf,
1058 ..
1059 } => {
1060 if handle.is_finished() {
1062 let res = ready!(handle.poll_unpin(cx));
1064 self.finalize_handshake(res)?;
1065 continue;
1066 }
1067 if let Some(buffer_size) = buffer_size {
1068 let mut remaining = buffer_size.get() - write_buf.len();
1069 if remaining == 0 {
1070 wakers.set_write_waker(cx.waker());
1072 trace!("write limit");
1073 Poll::Pending
1074 } else {
1075 trace!("write buf");
1076 let mut wrote = 0;
1077 for buf in bufs {
1078 if buf.len() <= remaining {
1079 write_buf.extend_from_slice(buf);
1080 wrote += buf.len();
1081 remaining -= buf.len();
1082 } else {
1083 write_buf.extend_from_slice(&buf[0..remaining]);
1084 wrote += remaining;
1085 break;
1086 }
1087 }
1088
1089 Poll::Ready(Ok(wrote))
1090 }
1091 } else {
1092 trace!("write buf");
1093 Poll::Ready(Ok(write_buf.write_vectored(bufs).unwrap()))
1094 }
1095 }
1096 TlsStreamState::Open(ref mut stm) => stm.poll_write_vectored(cx, bufs),
1097 TlsStreamState::Closed => {
1098 Poll::Ready(Err(ErrorKind::NotConnected.into()))
1099 }
1100 TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
1101 };
1102 }
1103 }
1104
1105 fn poll_flush(
1106 mut self: Pin<&mut Self>,
1107 cx: &mut Context<'_>,
1108 ) -> Poll<io::Result<()>> {
1109 loop {
1110 break match &mut self.state {
1111 TlsStreamState::Handshaking { wakers, handle, .. } => {
1112 if handle.is_finished() {
1114 let res = ready!(handle.poll_unpin(cx));
1116 self.finalize_handshake(res)?;
1117 continue;
1118 }
1119
1120 wakers.set_write_waker(cx.waker());
1121 Poll::Pending
1122 }
1123 TlsStreamState::Open(stm) => stm.poll_flush(cx),
1124 TlsStreamState::Closed => {
1125 Poll::Ready(Err(ErrorKind::NotConnected.into()))
1126 }
1127 TlsStreamState::ClosedError(err) => Poll::Ready(Err(clone_error(err))),
1128 };
1129 }
1130 }
1131
1132 fn is_write_vectored(&self) -> bool {
1133 false
1136 }
1137
1138 fn poll_shutdown(
1139 self: Pin<&mut Self>,
1140 cx: &mut Context<'_>,
1141 ) -> Poll<Result<(), io::Error>> {
1142 self.poll_shutdown_or_abort(cx, false)
1143 }
1144}
1145
1146impl<S: UnderlyingStream> Drop for TlsStream<S> {
1147 fn drop(&mut self) {
1148 trace!("dropping {self:?}");
1149 let state = std::mem::replace(&mut self.state, TlsStreamState::Closed);
1150 match state {
1151 TlsStreamState::Handshaking {
1152 handle,
1153 write_buf,
1154 underlying: tcp,
1155 ..
1156 } => {
1157 spawn(async move {
1158 trace!("in drop task");
1159 match handle.await {
1160 Ok(Ok(result)) => {
1161 drop(tcp);
1162 let (tcp, tls) = result.into_inner();
1164 let mut stm = ConnectionStream::new(tcp, tls);
1165 stm.write_buf_fully(&write_buf);
1166 let res = poll_fn(|cx| stm.poll_shutdown(cx)).await;
1167 trace!("shutdown handshake {:?}", res);
1168 nonblocking_tcp_drop(stm);
1169 }
1170 x @ Err(_) => {
1171 trace!("{x:?}");
1172 }
1173 x @ Ok(Err(_)) => {
1174 trace!("{x:?}");
1175 }
1176 }
1177 trace!("done drop task");
1178 });
1179 }
1180 TlsStreamState::Open(mut stm) => {
1181 spawn(async move {
1182 trace!("in drop task");
1183 let res = poll_fn(|cx| stm.poll_shutdown(cx)).await;
1184 trace!("shutdown open {:?}", res);
1185 nonblocking_tcp_drop(stm);
1186 trace!("done drop task");
1187 });
1188 }
1189 TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
1190 }
1192 }
1193 }
1194}
1195
1196#[derive(AsyncRead)]
1198pub struct TlsStreamRead<S: UnderlyingStream> {
1199 #[read]
1200 r: tokio::io::ReadHalf<TlsStream<S>>,
1201 handshake: Arc<HandshakeWatch>,
1202 tcp: Option<Arc<S>>,
1203}
1204
1205impl TlsStreamRead<TcpStream> {
1206 pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
1208 let Some(tcp) = &self.tcp else {
1209 return Err(std::io::ErrorKind::NotConnected.into());
1210 };
1211 tcp.peer_addr()
1212 }
1213
1214 pub fn local_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
1216 let Some(tcp) = &self.tcp else {
1217 return Err(std::io::ErrorKind::NotConnected.into());
1218 };
1219 tcp.local_addr()
1220 }
1221}
1222
1223impl<S: UnderlyingStream> TlsStreamRead<S> {
1224 pub fn unsplit(self, other: TlsStreamWrite<S>) -> TlsStream<S> {
1226 self.r.unsplit(other.w)
1227 }
1228
1229 pub fn poll_handshake(
1230 &mut self,
1231 cx: &mut Context,
1232 ) -> Poll<io::Result<TlsHandshake>> {
1233 match &*self.handshake.handshake.lock().unwrap() {
1235 None => {
1236 self.handshake.rx_waker.register(cx.waker());
1237 Poll::Pending
1238 }
1239 Some(handshake) => Poll::Ready(clone_result(handshake)),
1240 }
1241 }
1242
1243 pub async fn handshake(&mut self) -> io::Result<TlsHandshake> {
1244 poll_fn(|cx| self.poll_handshake(cx)).await
1245 }
1246
1247 pub fn try_handshake(&self) -> io::Result<Option<TlsHandshake>> {
1249 match &*self.handshake.handshake.lock().unwrap() {
1250 None => Ok(None),
1251 Some(r) => clone_result(r).map(Some),
1252 }
1253 }
1254}
1255
1256#[derive(AsyncWrite)]
1258pub struct TlsStreamWrite<S: UnderlyingStream> {
1259 #[write]
1260 w: tokio::io::WriteHalf<TlsStream<S>>,
1261 handshake: Arc<HandshakeWatch>,
1262 tcp: Option<Arc<S>>,
1263}
1264
1265impl TlsStreamWrite<TcpStream> {
1266 pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
1268 let Some(tcp) = &self.tcp else {
1269 return Err(std::io::ErrorKind::NotConnected.into());
1270 };
1271 tcp.peer_addr()
1272 }
1273
1274 pub fn local_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
1276 let Some(tcp) = &self.tcp else {
1277 return Err(std::io::ErrorKind::NotConnected.into());
1278 };
1279 tcp.local_addr()
1280 }
1281}
1282
1283impl<S: UnderlyingStream> TlsStreamWrite<S> {
1284 pub fn poll_handshake(
1285 &mut self,
1286 cx: &mut Context,
1287 ) -> Poll<io::Result<TlsHandshake>> {
1288 match &*self.handshake.handshake.lock().unwrap() {
1290 None => {
1291 self.handshake.tx_waker.register(cx.waker());
1292 Poll::Pending
1293 }
1294 Some(handshake) => Poll::Ready(clone_result(handshake)),
1295 }
1296 }
1297
1298 pub async fn handshake(&mut self) -> io::Result<TlsHandshake> {
1299 poll_fn(|cx| self.poll_handshake(cx)).await
1300 }
1301
1302 pub fn try_handshake(&self) -> io::Result<Option<TlsHandshake>> {
1304 match &*self.handshake.handshake.lock().unwrap() {
1305 None => Ok(None),
1306 Some(r) => clone_result(r).map(Some),
1307 }
1308 }
1309}
1310
1311#[cfg(test)]
1312pub(super) mod tests {
1313 use super::*;
1314 use crate::tests::certificate;
1315 use crate::tests::expect_io_error;
1316 use crate::tests::private_key;
1317 use crate::tests::UnsafeVerifier;
1318 use futures::stream::FuturesUnordered;
1319 use futures::FutureExt;
1320 use futures::StreamExt;
1321 use rstest::rstest;
1322 use rustls::version::TLS12;
1323 use rustls::ClientConfig;
1324 use rustls::SupportedProtocolVersion;
1325 use std::io::ErrorKind;
1326 use std::io::IoSlice;
1327 use std::net::Ipv4Addr;
1328 use std::net::SocketAddr;
1329 use std::net::SocketAddrV4;
1330 use std::time::Duration;
1331 use tokio::io::AsyncReadExt;
1332 use tokio::io::AsyncWriteExt;
1333 use tokio::net::TcpListener;
1334 use tokio::net::TcpSocket;
1335 use tokio::spawn;
1336 use tokio::sync::Barrier;
1337
1338 type TestResult = Result<(), std::io::Error>;
1339
1340 type TlsStream = super::TlsStream<TcpStream>;
1341
1342 fn server_config(alpn: &[&str]) -> ServerConfig {
1343 let mut config = ServerConfig::builder()
1344 .with_no_client_auth()
1345 .with_single_cert(vec![certificate()], private_key())
1346 .expect("Failed to build server config");
1347 config.alpn_protocols =
1348 alpn.iter().map(|v| v.as_bytes().to_owned()).collect();
1349 config
1350 }
1351
1352 fn server_config_protocol(
1353 protocol: &'static SupportedProtocolVersion,
1354 ) -> ServerConfig {
1355 let config = ServerConfig::builder_with_protocol_versions(&[protocol])
1356 .with_no_client_auth()
1357 .with_single_cert(vec![certificate()], private_key())
1358 .expect("Failed to build server config");
1359 config
1360 }
1361
1362 fn client_config(alpn: &[&str]) -> ClientConfig {
1363 let mut config = ClientConfig::builder()
1364 .dangerous()
1365 .with_custom_certificate_verifier(Arc::new(UnsafeVerifier {}))
1366 .with_no_client_auth();
1367 config.alpn_protocols =
1368 alpn.iter().map(|v| v.as_bytes().to_owned()).collect();
1369 config.enable_sni = true;
1370 config
1371 }
1372
1373 async fn tcp_pair() -> (TcpStream, TcpStream) {
1374 let listener = TcpListener::bind(SocketAddr::V4(SocketAddrV4::new(
1375 Ipv4Addr::LOCALHOST,
1376 0,
1377 )))
1378 .await
1379 .unwrap();
1380 let port = listener.local_addr().unwrap().port();
1381 let server = spawn(async move { listener.accept().await.unwrap().0 });
1382 let client = spawn(async move {
1383 TcpSocket::new_v4()
1384 .unwrap()
1385 .connect(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)))
1386 .await
1387 .unwrap()
1388 });
1389
1390 let (server, client) = (server.await.unwrap(), client.await.unwrap());
1391 (server, client)
1392 }
1393
1394 pub async fn tls_pair() -> (TlsStream, TlsStream) {
1395 tls_pair_buffer_size(None).await
1396 }
1397
1398 pub async fn tls_pair_protocol(
1399 buffer_size: Option<NonZeroUsize>,
1400 protocol: &'static SupportedProtocolVersion,
1401 ) -> (TlsStream, TlsStream) {
1402 let (server, client) = tcp_pair().await;
1403 let server = TlsStream::new_server_side(
1404 server,
1405 server_config_protocol(protocol).into(),
1406 None,
1407 );
1408 let client = TlsStream::new_client_side_test_options(
1409 client,
1410 client_config(&[]).into(),
1411 "example.com".try_into().unwrap(),
1412 buffer_size,
1413 TestOptions::default(),
1414 );
1415
1416 (server, client)
1417 }
1418
1419 pub async fn tls_pair_buffer_size(
1420 buffer_size: Option<NonZeroUsize>,
1421 ) -> (TlsStream, TlsStream) {
1422 let (server, client) = tcp_pair().await;
1423 let server =
1424 TlsStream::new_server_side(server, server_config(&[]).into(), None);
1425 let client = TlsStream::new_client_side_test_options(
1426 client,
1427 client_config(&[]).into(),
1428 "example.com".try_into().unwrap(),
1429 buffer_size,
1430 TestOptions::default(),
1431 );
1432
1433 (server, client)
1434 }
1435
1436 async fn tls_with_tcp_server(
1437 delay_handshake: bool,
1438 ) -> (TcpStream, TlsStream) {
1439 let (server, client) = tcp_pair().await;
1440 let client_test_options = TestOptions {
1441 delay_handshake,
1442 ..Default::default()
1443 };
1444 let client = TlsStream::new_client_side_test_options(
1445 client,
1446 client_config(&[]).into(),
1447 "example.com".try_into().unwrap(),
1448 None,
1449 client_test_options,
1450 );
1451 (server, client)
1452 }
1453
1454 async fn tls_pair_slow_handshake(
1455 delay_handshake: bool,
1456 slow_server: bool,
1457 slow_client: bool,
1458 buffer: bool,
1459 ) -> (TlsStream, TlsStream) {
1460 let (server, client) = tcp_pair().await;
1461 let server_test_options = TestOptions {
1462 delay_handshake,
1463 slow_handshake_read: slow_server,
1464 slow_handshake_write: slow_server,
1465 };
1466 let client_test_options = TestOptions {
1467 delay_handshake,
1468 slow_handshake_read: slow_client,
1469 slow_handshake_write: slow_client,
1470 };
1471 let buffer_size = if buffer {
1472 NonZeroUsize::new(1024)
1473 } else {
1474 None
1475 };
1476
1477 let server = TlsStream::new_server_side_test_options(
1478 server,
1479 server_config(&[]).into(),
1480 buffer_size,
1481 server_test_options,
1482 );
1483 let client = TlsStream::new_client_side_test_options(
1484 client,
1485 client_config(&[]).into(),
1486 "example.com".try_into().unwrap(),
1487 buffer_size,
1488 client_test_options,
1489 );
1490
1491 (server, client)
1492 }
1493
1494 async fn tls_pair_alpn(
1495 server_alpn: &[&str],
1496 server_buffer_size: Option<NonZeroUsize>,
1497 client_alpn: &[&str],
1498 client_buffer_size: Option<NonZeroUsize>,
1499 ) -> (TlsStream, TlsStream) {
1500 let (server, client) = tcp_pair().await;
1501 let server = TlsStream::new_server_side(
1502 server,
1503 server_config(server_alpn).into(),
1504 server_buffer_size,
1505 );
1506 let client = TlsStream::new_client_side_test_options(
1507 client,
1508 client_config(client_alpn).into(),
1509 "example.com".try_into().unwrap(),
1510 client_buffer_size,
1511 TestOptions::default(),
1512 );
1513
1514 (server, client)
1515 }
1516
1517 async fn make_config(
1518 alpn: Result<&'static [&'static str], &'static str>,
1519 ) -> Result<Arc<ServerConfig>, io::Error> {
1520 Ok(
1521 server_config(
1522 alpn.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?,
1523 )
1524 .into(),
1525 )
1526 }
1527
1528 async fn tls_pair_alpn_acceptor(
1529 server_alpn: fn(
1530 ClientHello,
1531 ) -> Result<&'static [&'static str], &'static str>,
1532 server_buffer_size: Option<NonZeroUsize>,
1533 client_alpn: &[&str],
1534 client_buffer_size: Option<NonZeroUsize>,
1535 ) -> (TlsStream, TlsStream) {
1536 let (server, client) = tcp_pair().await;
1537 let server = TlsStream::new_server_side_acceptor(
1538 server,
1539 Arc::new(move |client_hello| {
1540 Box::pin(make_config(server_alpn(client_hello)))
1541 }),
1542 server_buffer_size,
1543 );
1544 let client = TlsStream::new_client_side_test_options(
1545 client,
1546 client_config(client_alpn).into(),
1547 "example.com".try_into().unwrap(),
1548 client_buffer_size,
1549 TestOptions::default(),
1550 );
1551
1552 (server, client)
1553 }
1554
1555 async fn tls_pair_alpn_from_acceptor(
1556 server_alpn: fn(
1557 ClientHello,
1558 ) -> Result<&'static [&'static str], &'static str>,
1559 server_buffer_size: Option<NonZeroUsize>,
1560 client_alpn: &[&str],
1561 client_buffer_size: Option<NonZeroUsize>,
1562 ) -> (TlsStream, TlsStream) {
1563 let (mut server, client) = tcp_pair().await;
1564
1565 let client = TlsStream::new_client_side_test_options(
1568 client,
1569 client_config(client_alpn).into(),
1570 "example.com".try_into().unwrap(),
1571 client_buffer_size,
1572 TestOptions::default(),
1573 );
1574
1575 let mut prefix = [0; 8];
1580 server
1581 .read_exact(&mut prefix)
1582 .await
1583 .expect("Failed to read prefix");
1584 let mut acceptor = Acceptor::default();
1585 assert_eq!(
1586 acceptor.read_tls(&mut prefix.as_slice()).unwrap(),
1587 prefix.len()
1588 );
1589
1590 let server = TlsStream::new_server_side_from_acceptor(
1591 acceptor,
1592 server,
1593 Arc::new(move |client_hello| {
1594 Box::pin(make_config(server_alpn(client_hello)))
1595 }),
1596 server_buffer_size,
1597 );
1598
1599 (server, client)
1600 }
1601
1602 async fn tls_pair_handshake_buffer_size(
1603 server_buffer_size: Option<NonZeroUsize>,
1604 client_buffer_size: Option<NonZeroUsize>,
1605 ) -> (TlsStream, TlsStream) {
1606 let (mut server, mut client) =
1607 tls_pair_alpn(&[], server_buffer_size, &[], client_buffer_size).await;
1608 let a = spawn(async move {
1609 server.handshake().await.unwrap();
1610 server
1611 });
1612 let b = spawn(async move {
1613 client.handshake().await.unwrap();
1614 client
1615 });
1616 (a.await.unwrap(), b.await.unwrap())
1617 }
1618
1619 async fn tls_pair_handshake() -> (TlsStream, TlsStream) {
1620 tls_pair_handshake_buffer_size(None, None).await
1621 }
1622
1623 async fn expect_eof_read(stm: &mut (impl AsyncReadExt + Unpin)) {
1624 let mut buf = [0_u8; 1];
1625 let e = stm.read(&mut buf).await.expect("Expected no error");
1626 assert_eq!(e, 0, "expected eof");
1627 }
1628
1629 async fn expect_io_error_read(
1630 stm: &mut (impl AsyncReadExt + Unpin),
1631 kind: io::ErrorKind,
1632 ) {
1633 let mut buf = [0_u8; 1];
1634 let e = stm.read(&mut buf).await.expect_err("Expected error");
1635 assert_eq!(e.kind(), kind);
1636 }
1637
1638 #[rstest]
1641 #[tokio::test]
1642 async fn test_client_server(
1643 #[values(true, false)] server_slow: bool,
1644 #[values(true, false)] client_slow: bool,
1645 #[values(true, false)] buffer: bool,
1646 ) -> TestResult {
1647 let (mut server, mut client) =
1648 tls_pair_slow_handshake(false, server_slow, client_slow, buffer).await;
1649 let a = spawn(async move {
1650 server.write_all(b"hello?").await.unwrap();
1651 let mut buf = [0; 6];
1652 server.read_exact(&mut buf).await.unwrap();
1653 assert_eq!(buf.as_slice(), b"hello!");
1654 });
1655 let b = spawn(async move {
1656 client.write_all(b"hello!").await.unwrap();
1657 let mut buf = [0; 6];
1658 client.read_exact(&mut buf).await.unwrap();
1659 });
1660 a.await?;
1661 b.await?;
1662
1663 Ok(())
1664 }
1665
1666 #[tokio::test]
1668 #[ntest::timeout(60000)]
1669 async fn test_flush_before_handshake() -> TestResult {
1670 let (mut server, mut client) = tls_pair().await;
1671 server.write_all(b"hello?").await.unwrap();
1672 server.flush().await.unwrap();
1673 let mut buf = [0; 6];
1674 assert_eq!(6, client.read_exact(&mut buf).await.unwrap());
1675 Ok(())
1676 }
1677
1678 #[rstest]
1679 #[tokio::test(flavor = "multi_thread")]
1680 #[ntest::timeout(60000)]
1681 async fn test_read_with_buffered_write(
1682 #[values(true, false)] delay_handshake: bool,
1683 #[values(true, false)] slow_server: bool,
1684 #[values(true, false)] slow_client: bool,
1685 #[values(true, false)] buffer: bool,
1686 ) -> TestResult {
1687 let (mut server, mut client) = tls_pair_slow_handshake(
1688 delay_handshake,
1689 slow_server,
1690 slow_client,
1691 buffer,
1692 )
1693 .await;
1694
1695 let a = tokio::task::spawn(async move {
1696 server.read_u8().await.unwrap();
1697 server.write_u8(1).await.unwrap();
1698 });
1699
1700 let b = tokio::task::spawn(async move {
1701 let buf = [0; 1024];
1702 client.write_all(&buf).await.unwrap();
1703 client.read_u8().await.unwrap();
1704 });
1705
1706 a.await.unwrap();
1707 b.await.unwrap();
1708
1709 Ok(())
1710 }
1711
1712 #[tokio::test]
1714 #[ntest::timeout(60000)]
1715 async fn test_client_server_alpn() -> TestResult {
1716 let (mut server, mut client) =
1717 tls_pair_alpn(&["a", "b", "c"], None, &["b"], None).await;
1718 let a = spawn(async move {
1719 let handshake = server.handshake().await.unwrap();
1720 assert_eq!(handshake.alpn, Some("b".as_bytes().to_vec()));
1721 assert_eq!(handshake.sni, Some("example.com".into()));
1722 server.write_all(b"hello?").await.unwrap();
1723 let mut buf = [0; 6];
1724 server.read_exact(&mut buf).await.unwrap();
1725 assert_eq!(buf.as_slice(), b"hello!");
1726 });
1727 let b = spawn(async move {
1728 let handshake = client.handshake().await.unwrap();
1729 assert_eq!(handshake.alpn, Some("b".as_bytes().to_vec()));
1730 client.write_all(b"hello!").await.unwrap();
1731 let mut buf = [0; 6];
1732 client.read_exact(&mut buf).await.unwrap();
1733 });
1734 a.await?;
1735 b.await?;
1736
1737 Ok(())
1738 }
1739
1740 fn alpn_handler(
1741 client_hello: ClientHello,
1742 ) -> Result<&'static [&'static str], &'static str> {
1743 if let Some(alpn) = client_hello.alpn() {
1744 for alpn in alpn {
1745 if alpn == b"a" {
1746 return Ok(&["a"]);
1747 }
1748 if alpn == b"b" {
1749 return Ok(&["b"]);
1750 }
1751 }
1752 }
1753 Err("bad server")
1754 }
1755
1756 #[rstest]
1758 #[case("a")]
1759 #[case("b")]
1760 #[case("c")]
1761 #[tokio::test]
1762 #[ntest::timeout(60000)]
1763 async fn test_client_server_alpn_acceptor(
1764 #[case] alpn: &'static str,
1765 #[values(true, false)] use_from: bool,
1766 ) -> TestResult {
1767 let (mut server, mut client) = if use_from {
1768 tls_pair_alpn_from_acceptor(alpn_handler, None, &[alpn], None).await
1769 } else {
1770 tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await
1771 };
1772 let a = spawn(async move {
1773 if alpn == "c" {
1774 server.handshake().await.expect_err("expected failure");
1775 return;
1776 }
1777 let handshake = server.handshake().await.unwrap();
1778 assert_eq!(handshake.alpn, Some(alpn.as_bytes().to_vec()));
1779 assert_eq!(handshake.sni, Some("example.com".into()));
1780 server.write_all(b"hello?").await.unwrap();
1781 let mut buf = [0; 6];
1782 server.read_exact(&mut buf).await.unwrap();
1783 assert_eq!(buf.as_slice(), b"hello!");
1784 });
1785 let b = spawn(async move {
1786 if alpn == "c" {
1787 client.handshake().await.expect_err("expected failure");
1788 return;
1789 }
1790 let handshake = client.handshake().await.unwrap();
1791 assert_eq!(handshake.alpn, Some(alpn.as_bytes().to_vec()));
1792 client.write_all(b"hello!").await.unwrap();
1793 let mut buf = [0; 6];
1794 client.read_exact(&mut buf).await.unwrap();
1795 });
1796 a.await?;
1797 b.await?;
1798
1799 Ok(())
1800 }
1801
1802 #[tokio::test]
1804 #[ntest::timeout(60000)]
1805 async fn test_client_server_alpn_mismatch() -> TestResult {
1806 let (mut server, mut client) =
1807 tls_pair_alpn(&["a"], None, &["b"], None).await;
1808 let a = spawn(async move {
1809 let e = server.handshake().await.expect_err("Expected a failure");
1810 assert_eq!(e.kind(), ErrorKind::InvalidData);
1811 assert_eq!(e.to_string(), "peer doesn't support any known protocol");
1812 let e = server.flush().await.expect_err("Expected a failure");
1813 assert_eq!(e.kind(), ErrorKind::InvalidData);
1814 });
1815 let b = spawn(async move {
1816 let e = client.handshake().await.expect_err("Expected a failure");
1817 assert_eq!(e.kind(), ErrorKind::InvalidData);
1818 assert_eq!(e.to_string(), "received fatal alert: NoApplicationProtocol");
1819 let e = client.flush().await.expect_err("Expected a failure");
1820 assert_eq!(e.kind(), ErrorKind::InvalidData);
1821 });
1822 a.await?;
1823 b.await?;
1824
1825 Ok(())
1826 }
1827
1828 #[tokio::test]
1830 #[ntest::timeout(60000)]
1831 async fn test_client_server_raw_connection() -> TestResult {
1832 let (mut server, mut client) =
1833 tls_pair_alpn(&["a"], None, &["a"], None).await;
1834
1835 assert!(server.connection().is_none());
1836 assert!(client.connection().is_none());
1837
1838 server.handshake().await?;
1839 client.handshake().await?;
1840
1841 assert!(server.connection().is_some());
1842 assert!(client.connection().is_some());
1843
1844 Ok(())
1845 }
1846
1847 #[tokio::test]
1848 async fn test_peer_and_local_addresses() {
1849 let (server, client) =
1850 tls_pair_slow_handshake(true, true, true, false).await;
1851 let barrier = Arc::new(Barrier::new(2));
1853 let barrier_clone = barrier.clone();
1854 let a = spawn(async move {
1855 loop {
1856 tokio::time::sleep(Duration::from_millis(10)).await;
1857 server.local_addr().unwrap();
1858 server.peer_addr().unwrap();
1859 if server.try_handshake().unwrap().is_some() {
1860 server.local_addr().unwrap();
1861 server.peer_addr().unwrap();
1862 break;
1863 }
1864 }
1865 barrier.wait().await;
1866 });
1867 let b = spawn(async move {
1868 loop {
1869 tokio::time::sleep(Duration::from_millis(10)).await;
1870 client.local_addr().unwrap();
1871 client.peer_addr().unwrap();
1872 if client.try_handshake().unwrap().is_some() {
1873 client.local_addr().unwrap();
1874 client.peer_addr().unwrap();
1875 break;
1876 }
1877 }
1878 barrier_clone.wait().await;
1879 });
1880 a.await.unwrap();
1881 b.await.unwrap();
1882 }
1883
1884 #[rstest]
1885 #[case(false, false)]
1886 #[case(false, true)]
1887 #[case(true, false)]
1888 #[case(true, true)]
1889 #[tokio::test]
1890 async fn test_client_immediate_close(
1891 #[case] server_slow: bool,
1892 #[case] client_slow: bool,
1893 ) -> TestResult {
1894 let (mut server, client) =
1895 tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
1896 let a = spawn(async move {
1897 server.shutdown().await.unwrap();
1898 expect_eof_read(&mut server).await;
1901 drop(server);
1902 });
1903 let b = spawn(async move {
1904 drop(client);
1905 });
1906 a.await?;
1907 b.await?;
1908
1909 Ok(())
1910 }
1911
1912 #[rstest]
1942 #[case(false, false)]
1943 #[case(false, true)]
1944 #[case(true, false)]
1945 #[case(true, true)]
1946 #[tokio::test]
1947 async fn test_server_immediate_close(
1948 #[case] server_slow: bool,
1949 #[case] client_slow: bool,
1950 ) -> TestResult {
1951 let (server, mut client) =
1952 tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
1953 let a = spawn(async move {
1954 drop(server);
1955 });
1956 let b = spawn(async move {
1957 client.shutdown().await.unwrap();
1958 expect_eof_read(&mut client).await;
1961 drop(client);
1962 });
1963 a.await?;
1964 b.await?;
1965
1966 Ok(())
1967 }
1968
1969 #[rstest]
1970 #[case(false, false)]
1971 #[case(false, true)]
1972 #[case(true, false)]
1973 #[case(true, true)]
1974 #[tokio::test]
1975 async fn test_orderly_shutdown(
1976 #[case] server_slow: bool,
1977 #[case] client_slow: bool,
1978 ) -> TestResult {
1979 let (mut server, mut client) =
1980 tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
1981 let (tx, rx) = tokio::sync::oneshot::channel();
1982 let a = spawn(async move {
1983 server.write_all(b"hello?").await.unwrap();
1984 let mut buf = [0; 6];
1985 server.read_exact(&mut buf).await.unwrap();
1986 assert_eq!(buf.as_slice(), b"hello!");
1987 server.shutdown().await.unwrap();
1989 server.read_exact(&mut buf).await.unwrap();
1990 assert_eq!(buf.as_slice(), b"hello*");
1991 drop(server);
1993 tokio::time::sleep(Duration::from_millis(10)).await;
1994 tx.send(()).unwrap();
1995 });
1996 let b = spawn(async move {
1997 client.write_all(b"hello!").await.unwrap();
1998 let mut buf = [0; 6];
1999 client.read_exact(&mut buf).await.unwrap();
2000 assert_eq!(client.read(&mut buf).await.unwrap(), 0);
2001 client.write_all(b"hello*").await.unwrap();
2002 rx.await.unwrap();
2004 client.shutdown().await.unwrap();
2005 drop(client);
2006 });
2007 a.await?;
2008 b.await?;
2009
2010 Ok(())
2011 }
2012
2013 #[rstest]
2014 #[case(false, false)]
2015 #[case(false, true)]
2016 #[case(true, false)]
2017 #[case(true, true)]
2018 #[tokio::test]
2019 async fn test_server_shutdown_after_handshake(
2020 #[case] server_slow: bool,
2021 #[case] client_slow: bool,
2022 ) -> TestResult {
2023 let (mut server, mut client) =
2024 tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
2025 let (tx, rx) = tokio::sync::oneshot::channel();
2026 let a = spawn(async move {
2027 server.handshake().await.unwrap();
2029 server.shutdown().await.unwrap();
2030 tx.send(()).unwrap();
2031 expect_io_error(
2032 server.write_all(b"hello?").await,
2033 io::ErrorKind::NotConnected,
2034 );
2035 });
2036 let b = spawn(async move {
2037 client.handshake().await.unwrap();
2039 rx.await.unwrap();
2040 expect_eof_read(&mut client).await;
2042 });
2043 a.await?;
2044 b.await?;
2045
2046 Ok(())
2047 }
2048
2049 #[rstest]
2050 #[case(false, false)]
2051 #[case(false, true)]
2052 #[case(true, false)]
2053 #[case(true, true)]
2054 #[tokio::test]
2055 async fn test_server_shutdown_before_handshake(
2056 #[case] server_slow: bool,
2057 #[case] client_slow: bool,
2058 ) -> TestResult {
2059 let (mut server, mut client) =
2060 tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
2061 let a = spawn(async move {
2062 let mut futures = FuturesUnordered::new();
2063
2064 futures.push(server.shutdown().map(|_| 1).boxed());
2066 futures.push(client.handshake().map(|_| 2).boxed());
2067
2068 assert_eq!(poll_fn(|cx| futures.poll_next_unpin(cx)).await.unwrap(), 2);
2069 assert_eq!(poll_fn(|cx| futures.poll_next_unpin(cx)).await.unwrap(), 1);
2070 drop(futures);
2071
2072 expect_eof_read(&mut client).await;
2074 });
2075 a.await?;
2076
2077 Ok(())
2078 }
2079
2080 #[rstest]
2081 #[case(false, false)]
2082 #[case(false, true)]
2083 #[case(true, false)]
2084 #[case(true, true)]
2085 #[tokio::test]
2086 async fn test_server_dropped(
2087 #[case] server_slow: bool,
2088 #[case] client_slow: bool,
2089 ) -> TestResult {
2090 let (server, mut client) =
2091 tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
2092 drop(server);
2094 client.handshake().await?;
2095 expect_eof_read(&mut client).await;
2097 Ok(())
2098 }
2099
2100 #[tokio::test]
2101 #[ntest::timeout(60000)]
2102 async fn test_server_dropped_after_handshake() -> TestResult {
2103 let (server, mut client) = tls_pair_handshake().await;
2104 drop(server);
2105 expect_eof_read(&mut client).await;
2107 Ok(())
2108 }
2109
2110 #[tokio::test]
2111 #[ntest::timeout(60000)]
2112 async fn test_server_dropped_after_handshake_with_write() -> TestResult {
2113 let (mut server, mut client) = tls_pair_handshake().await;
2114 server.write_all(b"XYZ").await.unwrap();
2115 drop(server);
2116 let mut buf: [u8; 10] = [0; 10];
2118 assert_eq!(client.read(&mut buf).await.unwrap(), 3);
2119 Ok(())
2120 }
2121
2122 #[rstest]
2123 #[case(false, false)]
2124 #[case(false, true)]
2125 #[case(true, false)]
2126 #[case(true, true)]
2127 #[tokio::test]
2128 async fn test_client_dropped(
2129 #[case] server_slow: bool,
2130 #[case] client_slow: bool,
2131 ) -> TestResult {
2132 let (mut server, client) =
2133 tls_pair_slow_handshake(false, server_slow, client_slow, false).await;
2134 drop(client);
2135 server.handshake().await?;
2137 expect_eof_read(&mut server).await;
2139 Ok(())
2140 }
2141
2142 #[tokio::test]
2143 async fn test_server_half_crash_before_handshake() -> TestResult {
2144 let (mut server, mut client) = tls_with_tcp_server(false).await;
2145 tokio::time::sleep(Duration::from_millis(100)).await;
2148 <TcpStream as AsyncWriteExt>::shutdown(&mut server).await?;
2149
2150 let expected = ErrorKind::UnexpectedEof;
2151
2152 expect_io_error(client.handshake().await, expected);
2153 expect_io_error_read(&mut client, expected).await;
2155 Ok(())
2156 }
2157
2158 #[tokio::test]
2159 async fn test_server_crash_before_handshake() -> TestResult {
2160 let (mut server, mut client) = tls_with_tcp_server(false).await;
2161 <TcpStream as AsyncWriteExt>::shutdown(&mut server).await?;
2162 drop(server);
2163
2164 let expected = ErrorKind::UnexpectedEof;
2165
2166 expect_io_error(client.handshake().await, expected);
2167 expect_io_error_read(&mut client, expected).await;
2169 Ok(())
2170 }
2171
2172 #[tokio::test]
2173 async fn test_server_crash_after_handshake() -> TestResult {
2174 let (server, mut client) = tls_pair_handshake().await;
2175
2176 let (mut tcp, _tls) = server.into_inner().await.unwrap();
2177 <TcpStream as AsyncWriteExt>::shutdown(&mut tcp).await?;
2178 drop(tcp);
2179
2180 expect_io_error_read(&mut client, ErrorKind::UnexpectedEof).await;
2182 Ok(())
2183 }
2184
2185 #[rstest]
2186 #[case(true)]
2187 #[case(false)]
2188 #[tokio::test]
2189 async fn large_transfer_no_buffer_limit_or_handshake(
2190 #[case] swap: bool,
2191 ) -> TestResult {
2192 const BUF_SIZE: usize = 64 * 1024;
2193 const BUF_COUNT: usize = 1024;
2194
2195 let (server, client) = tls_pair().await;
2196
2197 let (mut server, mut client) = if swap {
2198 (client, server)
2199 } else {
2200 (server, client)
2201 };
2202
2203 let a = spawn(async move {
2204 let buf = vec![42; BUF_COUNT * BUF_SIZE];
2206 server.write_all(&buf).await.unwrap();
2207 assert_eq!(server.read_u8().await.unwrap(), 0xff);
2208 server.shutdown().await.unwrap();
2209 server.close().await.unwrap();
2210 });
2211 let b = spawn(async move {
2212 for _ in 0..BUF_COUNT {
2213 tokio::time::sleep(Duration::from_millis(1)).await;
2214 let mut buf = [0; BUF_SIZE];
2215 assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap());
2216 }
2217 client.write_u8(0xff).await.unwrap();
2218 expect_eof_read(&mut client).await;
2219 });
2220 a.await?;
2221 b.await?;
2222 Ok(())
2223 }
2224
2225 #[rstest]
2226 #[case(true)]
2227 #[case(false)]
2228 #[tokio::test]
2229 async fn large_transfer_with_buffer_limit(#[case] swap: bool) -> TestResult {
2230 const BUF_SIZE: usize = 10 * 1024;
2231 const BUF_COUNT: usize = 1024;
2232
2233 let (server, client) = tls_pair_handshake_buffer_size(
2234 BUF_SIZE.try_into().ok(),
2235 BUF_SIZE.try_into().ok(),
2236 )
2237 .await;
2238
2239 let (mut server, mut client) = if swap {
2240 (client, server)
2241 } else {
2242 (server, client)
2243 };
2244
2245 let a = spawn(async move {
2246 let buf = vec![42; BUF_COUNT * BUF_SIZE];
2248 server.write_all(&buf).await.unwrap();
2249 server.shutdown().await.unwrap();
2250 server.close().await.unwrap();
2251 });
2252 let b = spawn(async move {
2253 for _ in 0..BUF_COUNT {
2254 tokio::time::sleep(Duration::from_millis(1)).await;
2255 let mut buf = [0; BUF_SIZE];
2256 assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap());
2257 }
2258 expect_eof_read(&mut client).await;
2259 });
2260 a.await?;
2261 b.await?;
2262 Ok(())
2263 }
2264
2265 #[rstest]
2266 #[case(true, &TLS12)]
2267 #[case(false, &TLS12)]
2268 #[case(true, &TLS13)]
2269 #[case(false, &TLS13)]
2270 #[tokio::test]
2271 async fn large_transfer_with_aggressive_close_split(
2272 #[case] swap: bool,
2273 #[case] protocol: &'static SupportedProtocolVersion,
2274 ) -> TestResult {
2275 const BUF_SIZE: usize = 1024;
2276 const BUF_COUNT: usize = 1 * 1024;
2277
2278 let (server, client) =
2279 tls_pair_protocol(NonZeroUsize::new(65536), protocol).await;
2280 let (server, client) = if swap {
2281 (client, server)
2282 } else {
2283 (server, client)
2284 };
2285
2286 let a = spawn(async move {
2287 let (mut r, mut w) = server.into_split();
2288 let barrier = Arc::new(Barrier::new(2));
2289 let barrier2 = barrier.clone();
2290 let a = spawn(async move {
2291 tokio::select! {
2294 x = r.read_u8() => { _ = x.expect_err("should have failed") },
2295 _ = barrier.wait() => {}
2296 };
2297 r
2298 });
2299 let b = spawn(async move {
2300 let mut buf = vec![42; BUF_COUNT * BUF_SIZE];
2302 let mut buf: &mut [u8] = &mut buf;
2303 w.handshake().await.unwrap();
2304 while !buf.is_empty() {
2305 let n = w.write(&buf).await.unwrap();
2306 w.flush().await.unwrap();
2307 buf = &mut buf[n..];
2308 trace!("[TEST] wrote {n}");
2309 }
2310 w.shutdown().await.unwrap();
2311 barrier2.wait().await;
2312 w
2313 });
2314
2315 let r = a.await.unwrap();
2316 let w = b.await.unwrap();
2317 r.unsplit(w).close().await.unwrap();
2321 });
2322 let b = spawn(async move {
2323 let (mut r, _w) = client.into_split();
2324 let mut buf = vec![0; BUF_SIZE];
2325 for i in 0..BUF_COUNT {
2326 let r = r.read_exact(&mut buf).await;
2327 if let Err(e) = &r {
2328 panic!("Failed to read after {i} of {BUF_COUNT} reads: {e:?}");
2329 };
2330 assert_eq!(BUF_SIZE, r.unwrap());
2331 }
2332 expect_eof_read(&mut r).await;
2333 });
2334 a.await?;
2335 b.await?;
2336 Ok(())
2337 }
2338
2339 #[rstest]
2340 #[case(true)]
2341 #[case(false)]
2342 #[tokio::test(flavor = "current_thread")]
2343 async fn large_transfer_with_shutdown(#[case] swap: bool) -> TestResult {
2344 const BUF_SIZE: usize = 10 * 1024;
2345 const BUF_COUNT: usize = 1024;
2346
2347 let (server, client) = tls_pair_handshake().await;
2348 let (mut server, mut client) = if swap {
2349 (client, server)
2350 } else {
2351 (server, client)
2352 };
2353
2354 let a = spawn(async move {
2355 let buf = vec![42; BUF_COUNT * BUF_SIZE];
2357 server.write_all(&buf).await.unwrap();
2358 server.shutdown().await.unwrap();
2359 server.close().await.unwrap();
2360 });
2361 let b = spawn(async move {
2362 for _ in 0..BUF_COUNT {
2363 tokio::time::sleep(Duration::from_millis(1)).await;
2364 let mut buf = [0; BUF_SIZE];
2365 assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap());
2366 }
2367 expect_eof_read(&mut client).await;
2368 });
2369 a.await?;
2370 b.await?;
2371 Ok(())
2372 }
2373
2374 #[rstest]
2375 #[case(true)]
2376 #[case(false)]
2377 #[tokio::test(flavor = "current_thread")]
2378 #[ntest::timeout(60000)]
2379 async fn large_transfer_no_shutdown(#[case] swap: bool) -> TestResult {
2380 const BUF_SIZE: usize = 10 * 1024;
2381 const BUF_COUNT: usize = 1024;
2382
2383 let (server, client) = tls_pair_handshake().await;
2384 let (mut server, mut client) = if swap {
2385 (client, server)
2386 } else {
2387 (server, client)
2388 };
2389
2390 let a = spawn(async move {
2391 let buf = vec![42; BUF_COUNT * BUF_SIZE];
2393 server.write_all(&buf).await.unwrap();
2394 server.close().await.unwrap();
2395 });
2396 let b = spawn(async move {
2397 for _ in 0..BUF_COUNT {
2398 tokio::time::sleep(Duration::from_millis(1)).await;
2399 let mut buf = [0; BUF_SIZE];
2400 assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap());
2401 }
2402 expect_eof_read(&mut client).await;
2403 });
2404 a.await?;
2405 b.await?;
2406 Ok(())
2407 }
2408
2409 #[rstest]
2411 #[case(true, 1024, 1024, 1024)]
2412 #[case(false, 1024, 1024, 1024)]
2413 #[case(true, 1024, 16, 1024)]
2414 #[case(false, 1024, 16, 1024)]
2415 #[case(true, 1024, 10000, 1)]
2416 #[case(false, 1024, 10000, 1)]
2417 #[case(true, 32, 16, 16)]
2418 #[case(false, 32, 16, 16)]
2419 #[tokio::test]
2420 async fn vectored_stream_write(
2421 #[case] handshake_first: bool,
2422 #[case] expected: usize,
2423 #[case] first: usize,
2424 #[case] second: usize,
2425 ) -> TestResult {
2426 let (mut server, mut client) =
2427 tls_pair_buffer_size(Some(NonZeroUsize::try_from(1024).unwrap())).await;
2428 if handshake_first {
2429 server.handshake().await.unwrap();
2430 server.flush().await.unwrap();
2431 client.handshake().await.unwrap();
2432 client.flush().await.unwrap();
2433 }
2434 let n = client
2435 .write_vectored(&[
2436 IoSlice::new(&vec![1; first]),
2437 IoSlice::new(&vec![2; second]),
2438 ])
2439 .await
2440 .expect("failed to write");
2441 assert_eq!(n, expected);
2442 let mut buf = [0; 2048];
2443 client.flush().await.expect("failed to flush");
2445 tokio::time::sleep(Duration::from_millis(1)).await;
2447 let n = server.read(&mut buf).await.expect("failed to read");
2448 assert_eq!(n, expected);
2449 Ok(())
2450 }
2451
2452 #[tokio::test]
2454 async fn test_split_peer_certificates_before_handshake() -> TestResult {
2455 let (server, client) = tls_pair().await;
2456
2457 let (server_read, server_write) = server.into_split();
2458 let (client_read, client_write) = client.into_split();
2459
2460 assert!(
2462 server_read.try_handshake()?.is_none(),
2463 "Server handshake should be None before completion"
2464 );
2465 assert!(
2466 server_write.try_handshake()?.is_none(),
2467 "Server handshake should be None before completion"
2468 );
2469 assert!(
2470 client_read.try_handshake()?.is_none(),
2471 "Client handshake should be None before completion"
2472 );
2473 assert!(
2474 client_write.try_handshake()?.is_none(),
2475 "Client handshake should be None before completion"
2476 );
2477
2478 Ok(())
2479 }
2480
2481 #[tokio::test]
2483 async fn test_split_peer_certificates_access() -> TestResult {
2484 let (server, client) = tls_pair_handshake().await;
2485
2486 let (server_read, server_write) = server.into_split();
2487 let (client_read, client_write) = client.into_split();
2488
2489 let server_read_handshake = server_read.try_handshake()?.unwrap();
2491 let server_write_handshake = server_write.try_handshake()?.unwrap();
2492 let client_read_handshake = client_read.try_handshake()?.unwrap();
2493 let client_write_handshake = client_write.try_handshake()?.unwrap();
2494
2495 assert_eq!(
2497 server_read_handshake.peer_certificates.is_some(),
2498 server_write_handshake.peer_certificates.is_some()
2499 );
2500 assert_eq!(
2501 client_read_handshake.peer_certificates.is_some(),
2502 client_write_handshake.peer_certificates.is_some()
2503 );
2504
2505 if let (Some(read_certs), Some(write_certs)) = (
2506 &server_read_handshake.peer_certificates,
2507 &server_write_handshake.peer_certificates,
2508 ) {
2509 assert_eq!(read_certs.len(), write_certs.len());
2510 }
2511
2512 if let (Some(read_certs), Some(write_certs)) = (
2513 &client_read_handshake.peer_certificates,
2514 &client_write_handshake.peer_certificates,
2515 ) {
2516 assert_eq!(read_certs.len(), write_certs.len());
2517 }
2518
2519 Ok(())
2520 }
2521}