async_rustls/
client.rs

1use super::*;
2use crate::common::IoSession;
3#[cfg(unix)]
4use std::os::unix::io::{AsRawFd, RawFd};
5#[cfg(windows)]
6use std::os::windows::io::{AsRawSocket, RawSocket};
7
8/// A wrapper around an underlying raw stream which implements the TLS or SSL
9/// protocol.
10#[derive(Debug)]
11pub struct TlsStream<IO> {
12    pub(crate) io: IO,
13    pub(crate) session: ClientConnection,
14    pub(crate) state: TlsState,
15
16    #[cfg(feature = "early-data")]
17    pub(crate) early_waker: Option<std::task::Waker>,
18}
19
20impl<IO> TlsStream<IO> {
21    #[inline]
22    pub fn get_ref(&self) -> (&IO, &ClientConnection) {
23        (&self.io, &self.session)
24    }
25
26    #[inline]
27    pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
28        (&mut self.io, &mut self.session)
29    }
30
31    #[inline]
32    pub fn into_inner(self) -> (IO, ClientConnection) {
33        (self.io, self.session)
34    }
35}
36
37#[cfg(unix)]
38impl<S> AsRawFd for TlsStream<S>
39where
40    S: AsRawFd,
41{
42    #[inline]
43    fn as_raw_fd(&self) -> RawFd {
44        self.get_ref().0.as_raw_fd()
45    }
46}
47
48#[cfg(windows)]
49impl<S> AsRawSocket for TlsStream<S>
50where
51    S: AsRawSocket,
52{
53    #[inline]
54    fn as_raw_socket(&self) -> RawSocket {
55        self.get_ref().0.as_raw_socket()
56    }
57}
58
59impl<IO> IoSession for TlsStream<IO> {
60    type Io = IO;
61    type Session = ClientConnection;
62
63    #[inline]
64    fn skip_handshake(&self) -> bool {
65        self.state.is_early_data()
66    }
67
68    #[inline]
69    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
70        (&mut self.state, &mut self.io, &mut self.session)
71    }
72
73    #[inline]
74    fn into_io(self) -> Self::Io {
75        self.io
76    }
77}
78
79impl<IO> AsyncRead for TlsStream<IO>
80where
81    IO: AsyncRead + AsyncWrite + Unpin,
82{
83    fn poll_read(
84        self: Pin<&mut Self>,
85        cx: &mut Context<'_>,
86        buf: &mut [u8],
87    ) -> Poll<io::Result<usize>> {
88        match self.state {
89            #[cfg(feature = "early-data")]
90            TlsState::EarlyData(..) => {
91                let this = self.get_mut();
92
93                // In the EarlyData state, we have not really established a Tls connection.
94                // Before writing data through `AsyncWrite` and completing the tls handshake,
95                // we ignore read readiness and return to pending.
96                //
97                // In order to avoid event loss,
98                // we need to register a waker and wake it up after tls is connected.
99                if this
100                    .early_waker
101                    .as_ref()
102                    .filter(|waker| cx.waker().will_wake(waker))
103                    .is_none()
104                {
105                    this.early_waker = Some(cx.waker().clone());
106                }
107
108                Poll::Pending
109            }
110            TlsState::Stream | TlsState::WriteShutdown => {
111                let this = self.get_mut();
112                let mut stream =
113                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
114
115                match stream.as_mut_pin().poll_read(cx, buf) {
116                    Poll::Ready(Ok(n)) => {
117                        if n == 0 || stream.eof {
118                            this.state.shutdown_read();
119                        }
120
121                        Poll::Ready(Ok(n))
122                    }
123                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
124                        this.state.shutdown_read();
125                        Poll::Ready(Err(err))
126                    }
127                    output => output,
128                }
129            }
130            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
131        }
132    }
133}
134
135impl<IO> AsyncWrite for TlsStream<IO>
136where
137    IO: AsyncRead + AsyncWrite + Unpin,
138{
139    /// Note: that it does not guarantee the final data to be sent.
140    /// To be cautious, you must manually call `flush`.
141    fn poll_write(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144        buf: &[u8],
145    ) -> Poll<io::Result<usize>> {
146        let this = self.get_mut();
147        let mut stream =
148            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
149
150        #[allow(clippy::match_single_binding)]
151        match this.state {
152            #[cfg(feature = "early-data")]
153            TlsState::EarlyData(ref mut pos, ref mut data) => {
154                use std::io::Write;
155
156                // write early data
157                if let Some(mut early_data) = stream.session.early_data() {
158                    let len = match early_data.write(buf) {
159                        Ok(n) => n,
160                        Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
161                            return Poll::Pending
162                        }
163                        Err(err) => return Poll::Ready(Err(err)),
164                    };
165                    if len != 0 {
166                        data.extend_from_slice(&buf[..len]);
167                        return Poll::Ready(Ok(len));
168                    }
169                }
170
171                // complete handshake
172                while stream.session.is_handshaking() {
173                    ready!(stream.handshake(cx))?;
174                }
175
176                // write early data (fallback)
177                if !stream.session.is_early_data_accepted() {
178                    while *pos < data.len() {
179                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
180                        *pos += len;
181                    }
182                }
183
184                // end
185                this.state = TlsState::Stream;
186
187                if let Some(waker) = this.early_waker.take() {
188                    waker.wake();
189                }
190
191                stream.as_mut_pin().poll_write(cx, buf)
192            }
193            _ => stream.as_mut_pin().poll_write(cx, buf),
194        }
195    }
196
197    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198        let this = self.get_mut();
199        let mut stream =
200            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
201
202        #[cfg(feature = "early-data")]
203        {
204            if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
205                // complete handshake
206                while stream.session.is_handshaking() {
207                    ready!(stream.handshake(cx))?;
208                }
209
210                // write early data (fallback)
211                if !stream.session.is_early_data_accepted() {
212                    while *pos < data.len() {
213                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
214                        *pos += len;
215                    }
216                }
217
218                this.state = TlsState::Stream;
219
220                if let Some(waker) = this.early_waker.take() {
221                    waker.wake();
222                }
223            }
224        }
225
226        stream.as_mut_pin().poll_flush(cx)
227    }
228
229    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
230        // complete handshake
231        #[cfg(feature = "early-data")]
232        if matches!(self.state, TlsState::EarlyData(..)) {
233            ready!(self.as_mut().poll_flush(cx))?;
234        }
235
236        if self.state.writeable() {
237            self.session.send_close_notify();
238            self.state.shutdown_write();
239        }
240
241        let this = self.get_mut();
242        let mut stream =
243            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
244        stream.as_mut_pin().poll_close(cx)
245    }
246}