async_tls/
client.rs

1//! The client end of a TLS connection.
2
3use crate::common::tls_state::TlsState;
4use crate::rusttls::stream::Stream;
5use futures_core::ready;
6use futures_io::{AsyncRead, AsyncWrite};
7use rustls::ClientConnection;
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::{io, mem};
12
13/// The client end of a TLS connection. Can be used like any other bidirectional IO stream.
14/// Wraps the underlying TCP stream.
15#[derive(Debug)]
16pub struct TlsStream<IO> {
17    pub(crate) io: IO,
18    pub(crate) session: ClientConnection,
19    pub(crate) state: TlsState,
20
21    #[cfg(feature = "early-data")]
22    pub(crate) early_data: (usize, Vec<u8>),
23}
24
25pub(crate) enum MidHandshake<IO> {
26    Handshaking(TlsStream<IO>),
27    #[cfg(feature = "early-data")]
28    EarlyData(TlsStream<IO>),
29    End,
30}
31
32impl<IO> TlsStream<IO> {
33    /// Returns a reference to the underlying IO stream.
34    pub fn get_ref(&self) -> &IO {
35        &self.io
36    }
37
38    /// Returns a mutuable reference to the underlying IO stream.
39    pub fn get_mut(&mut self) -> &mut IO {
40        &mut self.io
41    }
42}
43
44impl<IO> Future for MidHandshake<IO>
45where
46    IO: AsyncRead + AsyncWrite + Unpin,
47{
48    type Output = io::Result<TlsStream<IO>>;
49
50    #[inline]
51    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
52        let this = self.get_mut();
53
54        if let MidHandshake::Handshaking(stream) = this {
55            let eof = !stream.state.readable();
56            let (io, session) = (&mut stream.io, &mut stream.session);
57            let mut stream = Stream::new(io, session).set_eof(eof);
58
59            if stream.conn.is_handshaking() {
60                ready!(stream.complete_io(cx))?;
61            }
62
63            if stream.conn.wants_write() {
64                ready!(stream.complete_io(cx))?;
65            }
66        }
67
68        match mem::replace(this, MidHandshake::End) {
69            MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
70            #[cfg(feature = "early-data")]
71            MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
72            MidHandshake::End => panic!(),
73        }
74    }
75}
76
77impl<IO> AsyncRead for TlsStream<IO>
78where
79    IO: AsyncRead + AsyncWrite + Unpin,
80{
81    fn poll_read(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84        buf: &mut [u8],
85    ) -> Poll<io::Result<usize>> {
86        match self.state {
87            #[cfg(feature = "early-data")]
88            TlsState::EarlyData => {
89                let this = self.get_mut();
90
91                let is_handshaking = this.session.is_handshaking();
92                let is_early_data_accepted = this.session.is_early_data_accepted();
93
94                let mut stream =
95                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
96                let (pos, data) = &mut this.early_data;
97
98                // complete handshake
99                if is_handshaking {
100                    ready!(stream.complete_io(cx))?;
101                }
102
103                // write early data (fallback)
104                if !is_early_data_accepted {
105                    while *pos < data.len() {
106                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
107                        *pos += len;
108                    }
109                }
110
111                // end
112                this.state = TlsState::Stream;
113                data.clear();
114
115                Pin::new(this).poll_read(cx, buf)
116            }
117            TlsState::Stream | TlsState::WriteShutdown => {
118                let this = self.get_mut();
119                let mut stream =
120                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
121
122                match stream.as_mut_pin().poll_read(cx, buf) {
123                    Poll::Ready(Ok(0)) => {
124                        this.state.shutdown_read();
125                        Poll::Ready(Ok(0))
126                    }
127                    Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
128                    Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
129                        this.state.shutdown_read();
130                        if this.state.writeable() {
131                            stream.conn.send_close_notify();
132                            this.state.shutdown_write();
133                        }
134                        Poll::Ready(Ok(0))
135                    }
136                    Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
137                    Poll::Pending => Poll::Pending,
138                }
139            }
140            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
141        }
142    }
143}
144
145impl<IO> AsyncWrite for TlsStream<IO>
146where
147    IO: AsyncRead + AsyncWrite + Unpin,
148{
149    fn poll_write(
150        self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        buf: &[u8],
153    ) -> Poll<io::Result<usize>> {
154        let this = self.get_mut();
155        let mut stream =
156            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
157
158        match this.state {
159            #[cfg(feature = "early-data")]
160            TlsState::EarlyData => {
161                use std::io::Write;
162
163                let (pos, data) = &mut this.early_data;
164
165                // write early data
166                if let Some(mut early_data) = stream.conn.client_early_data() {
167                    let len = match early_data.write(buf) {
168                        Ok(n) => n,
169                        Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
170                            return Poll::Pending
171                        }
172                        Err(err) => return Poll::Ready(Err(err)),
173                    };
174                    data.extend_from_slice(&buf[..len]);
175                    return Poll::Ready(Ok(len));
176                }
177
178                // complete handshake
179                if stream.conn.is_handshaking() {
180                    ready!(stream.complete_io(cx))?;
181                }
182
183                // write early data (fallback)
184                if !stream.conn.is_early_data_accepted() {
185                    while *pos < data.len() {
186                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
187                        *pos += len;
188                    }
189                }
190
191                // end
192                this.state = TlsState::Stream;
193                data.clear();
194                stream.as_mut_pin().poll_write(cx, buf)
195            }
196            _ => stream.as_mut_pin().poll_write(cx, buf),
197        }
198    }
199
200    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
201        let this = self.get_mut();
202        let mut stream =
203            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
204        stream.as_mut_pin().poll_flush(cx)
205    }
206
207    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208        if self.state.writeable() {
209            self.session.send_close_notify();
210            self.state.shutdown_write();
211        }
212
213        let this = self.get_mut();
214        let mut stream =
215            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
216        stream.as_mut_pin().poll_close(cx)
217    }
218}