1use crate::common::tls_state::TlsState;
4use crate::rusttls::stream::Stream;
5use futures_core::ready;
6use futures_io::{AsyncRead, AsyncWrite};
7use rustls::ClientConnection;
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::{io, mem};
12
13#[derive(Debug)]
16pub struct TlsStream<IO> {
17 pub(crate) io: IO,
18 pub(crate) session: ClientConnection,
19 pub(crate) state: TlsState,
20
21 #[cfg(feature = "early-data")]
22 pub(crate) early_data: (usize, Vec<u8>),
23}
24
25pub(crate) enum MidHandshake<IO> {
26 Handshaking(TlsStream<IO>),
27 #[cfg(feature = "early-data")]
28 EarlyData(TlsStream<IO>),
29 End,
30}
31
32impl<IO> TlsStream<IO> {
33 pub fn get_ref(&self) -> &IO {
35 &self.io
36 }
37
38 pub fn get_mut(&mut self) -> &mut IO {
40 &mut self.io
41 }
42}
43
44impl<IO> Future for MidHandshake<IO>
45where
46 IO: AsyncRead + AsyncWrite + Unpin,
47{
48 type Output = io::Result<TlsStream<IO>>;
49
50 #[inline]
51 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
52 let this = self.get_mut();
53
54 if let MidHandshake::Handshaking(stream) = this {
55 let eof = !stream.state.readable();
56 let (io, session) = (&mut stream.io, &mut stream.session);
57 let mut stream = Stream::new(io, session).set_eof(eof);
58
59 if stream.conn.is_handshaking() {
60 ready!(stream.complete_io(cx))?;
61 }
62
63 if stream.conn.wants_write() {
64 ready!(stream.complete_io(cx))?;
65 }
66 }
67
68 match mem::replace(this, MidHandshake::End) {
69 MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
70 #[cfg(feature = "early-data")]
71 MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
72 MidHandshake::End => panic!(),
73 }
74 }
75}
76
77impl<IO> AsyncRead for TlsStream<IO>
78where
79 IO: AsyncRead + AsyncWrite + Unpin,
80{
81 fn poll_read(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 buf: &mut [u8],
85 ) -> Poll<io::Result<usize>> {
86 match self.state {
87 #[cfg(feature = "early-data")]
88 TlsState::EarlyData => {
89 let this = self.get_mut();
90
91 let is_handshaking = this.session.is_handshaking();
92 let is_early_data_accepted = this.session.is_early_data_accepted();
93
94 let mut stream =
95 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
96 let (pos, data) = &mut this.early_data;
97
98 if is_handshaking {
100 ready!(stream.complete_io(cx))?;
101 }
102
103 if !is_early_data_accepted {
105 while *pos < data.len() {
106 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
107 *pos += len;
108 }
109 }
110
111 this.state = TlsState::Stream;
113 data.clear();
114
115 Pin::new(this).poll_read(cx, buf)
116 }
117 TlsState::Stream | TlsState::WriteShutdown => {
118 let this = self.get_mut();
119 let mut stream =
120 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
121
122 match stream.as_mut_pin().poll_read(cx, buf) {
123 Poll::Ready(Ok(0)) => {
124 this.state.shutdown_read();
125 Poll::Ready(Ok(0))
126 }
127 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
128 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
129 this.state.shutdown_read();
130 if this.state.writeable() {
131 stream.conn.send_close_notify();
132 this.state.shutdown_write();
133 }
134 Poll::Ready(Ok(0))
135 }
136 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
137 Poll::Pending => Poll::Pending,
138 }
139 }
140 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
141 }
142 }
143}
144
145impl<IO> AsyncWrite for TlsStream<IO>
146where
147 IO: AsyncRead + AsyncWrite + Unpin,
148{
149 fn poll_write(
150 self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 buf: &[u8],
153 ) -> Poll<io::Result<usize>> {
154 let this = self.get_mut();
155 let mut stream =
156 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
157
158 match this.state {
159 #[cfg(feature = "early-data")]
160 TlsState::EarlyData => {
161 use std::io::Write;
162
163 let (pos, data) = &mut this.early_data;
164
165 if let Some(mut early_data) = stream.conn.client_early_data() {
167 let len = match early_data.write(buf) {
168 Ok(n) => n,
169 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
170 return Poll::Pending
171 }
172 Err(err) => return Poll::Ready(Err(err)),
173 };
174 data.extend_from_slice(&buf[..len]);
175 return Poll::Ready(Ok(len));
176 }
177
178 if stream.conn.is_handshaking() {
180 ready!(stream.complete_io(cx))?;
181 }
182
183 if !stream.conn.is_early_data_accepted() {
185 while *pos < data.len() {
186 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
187 *pos += len;
188 }
189 }
190
191 this.state = TlsState::Stream;
193 data.clear();
194 stream.as_mut_pin().poll_write(cx, buf)
195 }
196 _ => stream.as_mut_pin().poll_write(cx, buf),
197 }
198 }
199
200 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
201 let this = self.get_mut();
202 let mut stream =
203 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
204 stream.as_mut_pin().poll_flush(cx)
205 }
206
207 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208 if self.state.writeable() {
209 self.session.send_close_notify();
210 self.state.shutdown_write();
211 }
212
213 let this = self.get_mut();
214 let mut stream =
215 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
216 stream.as_mut_pin().poll_close(cx)
217 }
218}