1use crate::common::tls_state::TlsState;
4use crate::rusttls::stream::Stream;
5
6use futures_core::ready;
7use futures_io::{AsyncRead, AsyncWrite};
8use rustls::ServerConnection;
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::{io, mem};
13
14#[derive(Debug)]
17pub struct TlsStream<IO> {
18 pub(crate) io: IO,
19 pub(crate) conn: ServerConnection,
20 pub(crate) state: TlsState,
21}
22
23pub(crate) enum MidHandshake<IO> {
24 Handshaking(TlsStream<IO>),
25 End,
26}
27
28impl<IO> Future for MidHandshake<IO>
29where
30 IO: AsyncRead + AsyncWrite + Unpin,
31{
32 type Output = io::Result<TlsStream<IO>>;
33
34 #[inline]
35 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36 let this = self.get_mut();
37
38 if let MidHandshake::Handshaking(stream) = this {
39 let eof = !stream.state.readable();
40 let (io, session) = (&mut stream.io, &mut stream.conn);
41 let mut stream = Stream::new(io, session).set_eof(eof);
42
43 if stream.conn.is_handshaking() {
44 ready!(stream.complete_io(cx))?;
45 }
46
47 if stream.conn.wants_write() {
48 ready!(stream.complete_io(cx))?;
49 }
50 }
51
52 match mem::replace(this, MidHandshake::End) {
53 MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
54 MidHandshake::End => panic!(),
55 }
56 }
57}
58
59impl<IO> AsyncRead for TlsStream<IO>
60where
61 IO: AsyncRead + AsyncWrite + Unpin,
62{
63 fn poll_read(
64 self: Pin<&mut Self>,
65 cx: &mut Context<'_>,
66 buf: &mut [u8],
67 ) -> Poll<io::Result<usize>> {
68 let this = self.get_mut();
69 let mut stream = Stream::new(&mut this.io, &mut this.conn).set_eof(!this.state.readable());
70
71 match this.state {
72 TlsState::Stream | TlsState::WriteShutdown => {
73 match stream.as_mut_pin().poll_read(cx, buf) {
74 Poll::Ready(Ok(0)) => {
75 this.state.shutdown_read();
76 Poll::Ready(Ok(0))
77 }
78 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
79 Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
80 this.state.shutdown_read();
81 if this.state.writeable() {
82 stream.conn.send_close_notify();
83 this.state.shutdown_write();
84 }
85 Poll::Ready(Ok(0))
86 }
87 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
88 Poll::Pending => Poll::Pending,
89 }
90 }
91 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
92 #[cfg(feature = "early-data")]
93 s => unreachable!("server TLS can not hit this state: {:?}", s),
94 }
95 }
96}
97
98impl<IO> AsyncWrite for TlsStream<IO>
99where
100 IO: AsyncRead + AsyncWrite + Unpin,
101{
102 fn poll_write(
103 self: Pin<&mut Self>,
104 cx: &mut Context<'_>,
105 buf: &[u8],
106 ) -> Poll<io::Result<usize>> {
107 let this = self.get_mut();
108 let mut stream = Stream::new(&mut this.io, &mut this.conn).set_eof(!this.state.readable());
109 stream.as_mut_pin().poll_write(cx, buf)
110 }
111
112 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113 let this = self.get_mut();
114 let mut stream = Stream::new(&mut this.io, &mut this.conn).set_eof(!this.state.readable());
115 stream.as_mut_pin().poll_flush(cx)
116 }
117
118 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119 if self.state.writeable() {
120 self.conn.send_close_notify();
121 self.state.shutdown_write();
122 }
123
124 let this = self.get_mut();
125 let mut stream = Stream::new(&mut this.io, &mut this.conn).set_eof(!this.state.readable());
126 stream.as_mut_pin().poll_close(cx)
127 }
128}