fluvio_async_tls/
client.rs1use crate::common::tls_state::TlsState;
4use crate::rusttls::stream::Stream;
5use futures_core::ready;
6use futures_io::{AsyncRead, AsyncWrite};
7use rustls::ClientConnection;
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::{io, mem};
12
13pub struct TlsStream<IO> {
16 pub(crate) io: IO,
17 pub(crate) session: ClientConnection,
18 pub(crate) state: TlsState,
19
20 #[cfg(feature = "early-data")]
21 pub(crate) early_data: (usize, Vec<u8>),
22}
23
24pub(crate) enum MidHandshake<IO> {
25 Handshaking(TlsStream<IO>),
26 #[cfg(feature = "early-data")]
27 EarlyData(TlsStream<IO>),
28 End,
29}
30
31impl<IO> TlsStream<IO> {
32 pub fn get_ref(&self) -> &IO {
34 &self.io
35 }
36
37 pub fn get_mut(&mut self) -> &mut IO {
39 &mut self.io
40 }
41}
42
43impl<IO> Future for MidHandshake<IO>
44where
45 IO: AsyncRead + AsyncWrite + Unpin,
46{
47 type Output = io::Result<TlsStream<IO>>;
48
49 #[inline]
50 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
51 let this = self.get_mut();
52
53 if let MidHandshake::Handshaking(stream) = this {
54 let eof = !stream.state.readable();
55 let (io, session) = (&mut stream.io, &mut stream.session);
56 let mut stream = Stream::new(io, session).set_eof(eof);
57
58 if stream.session.is_handshaking() {
59 ready!(stream.complete_io(cx))?;
60 }
61
62 if stream.session.wants_write() {
63 ready!(stream.complete_io(cx))?;
64 }
65 }
66
67 match mem::replace(this, MidHandshake::End) {
68 MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
69 #[cfg(feature = "early-data")]
70 MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
71 MidHandshake::End => panic!(),
72 }
73 }
74}
75
76impl<IO> AsyncRead for TlsStream<IO>
77where
78 IO: AsyncRead + AsyncWrite + Unpin,
79{
80 fn poll_read(
81 self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 buf: &mut [u8],
84 ) -> Poll<io::Result<usize>> {
85 match self.state {
86 #[cfg(feature = "early-data")]
87 TlsState::EarlyData => {
88 let this = self.get_mut();
89
90 let is_early_data_accepted = this.session.is_early_data_accepted();
91 let is_handshaking = this.session.is_handshaking();
92
93 let mut stream =
94 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
95 let (pos, data) = &mut this.early_data;
96
97 if is_handshaking {
99 ready!(stream.complete_io(cx))?;
100 }
101
102 if !is_early_data_accepted {
104 while *pos < data.len() {
105 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
106 *pos += len;
107 }
108 }
109
110 this.state = TlsState::Stream;
112 data.clear();
113
114 Pin::new(this).poll_read(cx, buf)
115 }
116 TlsState::Stream | TlsState::WriteShutdown => {
117 let this = self.get_mut();
118 let mut stream =
119 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120
121 match stream.as_mut_pin().poll_read(cx, buf) {
122 Poll::Ready(Ok(0)) => {
123 this.state.shutdown_read();
124 Poll::Ready(Ok(0))
125 }
126 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
127 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
128 this.state.shutdown_read();
129 if this.state.writeable() {
130 stream.session.send_close_notify();
131 this.state.shutdown_write();
132 }
133 Poll::Ready(Ok(0))
134 }
135 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
136 Poll::Pending => Poll::Pending,
137 }
138 }
139 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
140 }
141 }
142}
143
144impl<IO> AsyncWrite for TlsStream<IO>
145where
146 IO: AsyncRead + AsyncWrite + Unpin,
147{
148 fn poll_write(
149 self: Pin<&mut Self>,
150 cx: &mut Context<'_>,
151 buf: &[u8],
152 ) -> Poll<io::Result<usize>> {
153 let this = self.get_mut();
154
155 match this.state {
156 #[cfg(feature = "early-data")]
157 TlsState::EarlyData => {
158 use std::io::Write;
159 let (pos, data) = &mut this.early_data;
160
161 if let Some(mut early_data) = this.session.early_data() {
163 let len = match early_data.write(buf) {
164 Ok(n) => n,
165 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
166 return Poll::Pending
167 }
168 Err(err) => return Poll::Ready(Err(err)),
169 };
170 data.extend_from_slice(&buf[..len]);
171 return Poll::Ready(Ok(len));
172 }
173
174 let is_early_data_accepted = this.session.is_early_data_accepted();
175
176 let mut stream =
177 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
178
179 if stream.session.is_handshaking() {
181 ready!(stream.complete_io(cx))?;
182 }
183
184 if !is_early_data_accepted {
186 while *pos < data.len() {
187 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
188 *pos += len;
189 }
190 }
191
192 this.state = TlsState::Stream;
194 data.clear();
195 stream.as_mut_pin().poll_write(cx, buf)
196 }
197
198 _ => Stream::new(&mut this.io, &mut this.session)
199 .set_eof(!this.state.readable())
200 .as_mut_pin()
201 .poll_write(cx, buf),
202 }
203 }
204
205 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
206 let this = self.get_mut();
207 let mut stream =
208 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
209 stream.as_mut_pin().poll_flush(cx)
210 }
211
212 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
213 if self.state.writeable() {
214 self.session.send_close_notify();
215 self.state.shutdown_write();
216 }
217
218 let this = self.get_mut();
219 let mut stream =
220 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
221 stream.as_mut_pin().poll_close(cx)
222 }
223}