1use super::*;
2use crate::common::IoSession;
3#[cfg(unix)]
4use std::os::unix::io::{AsRawFd, RawFd};
5#[cfg(windows)]
6use std::os::windows::io::{AsRawSocket, RawSocket};
7
8#[derive(Debug)]
11pub struct TlsStream<IO> {
12 pub(crate) io: IO,
13 pub(crate) session: ClientConnection,
14 pub(crate) state: TlsState,
15
16 #[cfg(feature = "early-data")]
17 pub(crate) early_waker: Option<std::task::Waker>,
18}
19
20impl<IO> TlsStream<IO> {
21 #[inline]
22 pub fn get_ref(&self) -> (&IO, &ClientConnection) {
23 (&self.io, &self.session)
24 }
25
26 #[inline]
27 pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
28 (&mut self.io, &mut self.session)
29 }
30
31 #[inline]
32 pub fn into_inner(self) -> (IO, ClientConnection) {
33 (self.io, self.session)
34 }
35}
36
37#[cfg(unix)]
38impl<S> AsRawFd for TlsStream<S>
39where
40 S: AsRawFd,
41{
42 #[inline]
43 fn as_raw_fd(&self) -> RawFd {
44 self.get_ref().0.as_raw_fd()
45 }
46}
47
48#[cfg(windows)]
49impl<S> AsRawSocket for TlsStream<S>
50where
51 S: AsRawSocket,
52{
53 #[inline]
54 fn as_raw_socket(&self) -> RawSocket {
55 self.get_ref().0.as_raw_socket()
56 }
57}
58
59impl<IO> IoSession for TlsStream<IO> {
60 type Io = IO;
61 type Session = ClientConnection;
62
63 #[inline]
64 fn skip_handshake(&self) -> bool {
65 self.state.is_early_data()
66 }
67
68 #[inline]
69 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
70 (&mut self.state, &mut self.io, &mut self.session)
71 }
72
73 #[inline]
74 fn into_io(self) -> Self::Io {
75 self.io
76 }
77}
78
79impl<IO> AsyncRead for TlsStream<IO>
80where
81 IO: AsyncRead + AsyncWrite + Unpin,
82{
83 fn poll_read(
84 self: Pin<&mut Self>,
85 cx: &mut Context<'_>,
86 buf: &mut [u8],
87 ) -> Poll<io::Result<usize>> {
88 match self.state {
89 #[cfg(feature = "early-data")]
90 TlsState::EarlyData(..) => {
91 let this = self.get_mut();
92
93 if this
100 .early_waker
101 .as_ref()
102 .filter(|waker| cx.waker().will_wake(waker))
103 .is_none()
104 {
105 this.early_waker = Some(cx.waker().clone());
106 }
107
108 Poll::Pending
109 }
110 TlsState::Stream | TlsState::WriteShutdown => {
111 let this = self.get_mut();
112 let mut stream =
113 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
114
115 match stream.as_mut_pin().poll_read(cx, buf) {
116 Poll::Ready(Ok(n)) => {
117 if n == 0 || stream.eof {
118 this.state.shutdown_read();
119 }
120
121 Poll::Ready(Ok(n))
122 }
123 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
124 this.state.shutdown_read();
125 Poll::Ready(Err(err))
126 }
127 output => output,
128 }
129 }
130 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
131 }
132 }
133}
134
135impl<IO> AsyncWrite for TlsStream<IO>
136where
137 IO: AsyncRead + AsyncWrite + Unpin,
138{
139 fn poll_write(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 buf: &[u8],
145 ) -> Poll<io::Result<usize>> {
146 let this = self.get_mut();
147 let mut stream =
148 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
149
150 #[allow(clippy::match_single_binding)]
151 match this.state {
152 #[cfg(feature = "early-data")]
153 TlsState::EarlyData(ref mut pos, ref mut data) => {
154 use std::io::Write;
155
156 if let Some(mut early_data) = stream.session.early_data() {
158 let len = match early_data.write(buf) {
159 Ok(n) => n,
160 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
161 return Poll::Pending
162 }
163 Err(err) => return Poll::Ready(Err(err)),
164 };
165 if len != 0 {
166 data.extend_from_slice(&buf[..len]);
167 return Poll::Ready(Ok(len));
168 }
169 }
170
171 while stream.session.is_handshaking() {
173 ready!(stream.handshake(cx))?;
174 }
175
176 if !stream.session.is_early_data_accepted() {
178 while *pos < data.len() {
179 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
180 *pos += len;
181 }
182 }
183
184 this.state = TlsState::Stream;
186
187 if let Some(waker) = this.early_waker.take() {
188 waker.wake();
189 }
190
191 stream.as_mut_pin().poll_write(cx, buf)
192 }
193 _ => stream.as_mut_pin().poll_write(cx, buf),
194 }
195 }
196
197 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198 let this = self.get_mut();
199 let mut stream =
200 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
201
202 #[cfg(feature = "early-data")]
203 {
204 if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
205 while stream.session.is_handshaking() {
207 ready!(stream.handshake(cx))?;
208 }
209
210 if !stream.session.is_early_data_accepted() {
212 while *pos < data.len() {
213 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
214 *pos += len;
215 }
216 }
217
218 this.state = TlsState::Stream;
219
220 if let Some(waker) = this.early_waker.take() {
221 waker.wake();
222 }
223 }
224 }
225
226 stream.as_mut_pin().poll_flush(cx)
227 }
228
229 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
230 #[cfg(feature = "early-data")]
232 if matches!(self.state, TlsState::EarlyData(..)) {
233 ready!(self.as_mut().poll_flush(cx))?;
234 }
235
236 if self.state.writeable() {
237 self.session.send_close_notify();
238 self.state.shutdown_write();
239 }
240
241 let this = self.get_mut();
242 let mut stream =
243 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
244 stream.as_mut_pin().poll_close(cx)
245 }
246}