fluvio_async_tls/
client.rs

1//! The client end of a TLS connection.
2
3use crate::common::tls_state::TlsState;
4use crate::rusttls::stream::Stream;
5use futures_core::ready;
6use futures_io::{AsyncRead, AsyncWrite};
7use rustls::ClientConnection;
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::{io, mem};
12
13/// The client end of a TLS connection. Can be used like any other bidirectional IO stream.
14/// Wraps the underlying TCP stream.
15pub struct TlsStream<IO> {
16    pub(crate) io: IO,
17    pub(crate) session: ClientConnection,
18    pub(crate) state: TlsState,
19
20    #[cfg(feature = "early-data")]
21    pub(crate) early_data: (usize, Vec<u8>),
22}
23
24pub(crate) enum MidHandshake<IO> {
25    Handshaking(TlsStream<IO>),
26    #[cfg(feature = "early-data")]
27    EarlyData(TlsStream<IO>),
28    End,
29}
30
31impl<IO> TlsStream<IO> {
32    /// Returns a reference to the underlying IO stream.
33    pub fn get_ref(&self) -> &IO {
34        &self.io
35    }
36
37    /// Returns a mutuable reference to the underlying IO stream.
38    pub fn get_mut(&mut self) -> &mut IO {
39        &mut self.io
40    }
41}
42
43impl<IO> Future for MidHandshake<IO>
44where
45    IO: AsyncRead + AsyncWrite + Unpin,
46{
47    type Output = io::Result<TlsStream<IO>>;
48
49    #[inline]
50    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
51        let this = self.get_mut();
52
53        if let MidHandshake::Handshaking(stream) = this {
54            let eof = !stream.state.readable();
55            let (io, session) = (&mut stream.io, &mut stream.session);
56            let mut stream = Stream::new(io, session).set_eof(eof);
57
58            if stream.session.is_handshaking() {
59                ready!(stream.complete_io(cx))?;
60            }
61
62            if stream.session.wants_write() {
63                ready!(stream.complete_io(cx))?;
64            }
65        }
66
67        match mem::replace(this, MidHandshake::End) {
68            MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
69            #[cfg(feature = "early-data")]
70            MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
71            MidHandshake::End => panic!(),
72        }
73    }
74}
75
76impl<IO> AsyncRead for TlsStream<IO>
77where
78    IO: AsyncRead + AsyncWrite + Unpin,
79{
80    fn poll_read(
81        self: Pin<&mut Self>,
82        cx: &mut Context<'_>,
83        buf: &mut [u8],
84    ) -> Poll<io::Result<usize>> {
85        match self.state {
86            #[cfg(feature = "early-data")]
87            TlsState::EarlyData => {
88                let this = self.get_mut();
89
90                let is_early_data_accepted = this.session.is_early_data_accepted();
91                let is_handshaking = this.session.is_handshaking();
92
93                let mut stream =
94                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
95                let (pos, data) = &mut this.early_data;
96
97                // complete handshake
98                if is_handshaking {
99                    ready!(stream.complete_io(cx))?;
100                }
101
102                // write early data (fallback)
103                if !is_early_data_accepted {
104                    while *pos < data.len() {
105                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
106                        *pos += len;
107                    }
108                }
109
110                // end
111                this.state = TlsState::Stream;
112                data.clear();
113
114                Pin::new(this).poll_read(cx, buf)
115            }
116            TlsState::Stream | TlsState::WriteShutdown => {
117                let this = self.get_mut();
118                let mut stream =
119                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120
121                match stream.as_mut_pin().poll_read(cx, buf) {
122                    Poll::Ready(Ok(0)) => {
123                        this.state.shutdown_read();
124                        Poll::Ready(Ok(0))
125                    }
126                    Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
127                    Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
128                        this.state.shutdown_read();
129                        if this.state.writeable() {
130                            stream.session.send_close_notify();
131                            this.state.shutdown_write();
132                        }
133                        Poll::Ready(Ok(0))
134                    }
135                    Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
136                    Poll::Pending => Poll::Pending,
137                }
138            }
139            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
140        }
141    }
142}
143
144impl<IO> AsyncWrite for TlsStream<IO>
145where
146    IO: AsyncRead + AsyncWrite + Unpin,
147{
148    fn poll_write(
149        self: Pin<&mut Self>,
150        cx: &mut Context<'_>,
151        buf: &[u8],
152    ) -> Poll<io::Result<usize>> {
153        let this = self.get_mut();
154
155        match this.state {
156            #[cfg(feature = "early-data")]
157            TlsState::EarlyData => {
158                use std::io::Write;
159                let (pos, data) = &mut this.early_data;
160
161                // write early data
162                if let Some(mut early_data) = this.session.early_data() {
163                    let len = match early_data.write(buf) {
164                        Ok(n) => n,
165                        Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
166                            return Poll::Pending
167                        }
168                        Err(err) => return Poll::Ready(Err(err)),
169                    };
170                    data.extend_from_slice(&buf[..len]);
171                    return Poll::Ready(Ok(len));
172                }
173
174                let is_early_data_accepted = this.session.is_early_data_accepted();
175
176                let mut stream =
177                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
178
179                // complete handshake
180                if stream.session.is_handshaking() {
181                    ready!(stream.complete_io(cx))?;
182                }
183
184                // write early data (fallback)
185                if !is_early_data_accepted {
186                    while *pos < data.len() {
187                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
188                        *pos += len;
189                    }
190                }
191
192                // end
193                this.state = TlsState::Stream;
194                data.clear();
195                stream.as_mut_pin().poll_write(cx, buf)
196            }
197
198            _ => Stream::new(&mut this.io, &mut this.session)
199                .set_eof(!this.state.readable())
200                .as_mut_pin()
201                .poll_write(cx, buf),
202        }
203    }
204
205    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
206        let this = self.get_mut();
207        let mut stream =
208            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
209        stream.as_mut_pin().poll_flush(cx)
210    }
211
212    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
213        if self.state.writeable() {
214            self.session.send_close_notify();
215            self.state.shutdown_write();
216        }
217
218        let this = self.get_mut();
219        let mut stream =
220            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
221        stream.as_mut_pin().poll_close(cx)
222    }
223}