1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
3
4use std::future::Future;
5#[cfg(feature = "tokio")]
6use std::{
7 any::Any,
8 io::IoSlice,
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13use crate::{Ssl, SslError, TlsDriver, TlsHandshake, TlsServerParameterProvider};
14
15#[cfg(feature = "tokio")]
17pub trait Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
18 fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
19 where
20 Self: Sized + 'static,
21 {
22 let mut holder = Some(self);
24 let stream = &mut holder as &mut dyn Any;
25 let Some(stream) = stream.downcast_mut::<Option<S>>() else {
26 return Err(holder.take().unwrap());
27 };
28 let stream = stream.take().unwrap();
29 Ok(stream)
30 }
31}
32
33#[cfg(feature = "tokio")]
34impl<T> Stream for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {}
35
36#[cfg(not(feature = "tokio"))]
37pub trait Stream: 'static {}
38#[cfg(not(feature = "tokio"))]
39impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
40#[cfg(not(feature = "tokio"))]
41impl Stream for () {}
42
43pub trait StreamUpgrade: Stream {
45 fn secure_upgrade(&mut self) -> impl Future<Output = Result<(), SslError>> + Send;
46 fn handshake(&self) -> Option<&TlsHandshake>;
47}
48
49#[derive(Default, Debug)]
50struct UpgradableStreamOptions {
51 ignore_missing_close_notify: bool,
52}
53
54#[allow(private_bounds)]
55#[derive(derive_more::Debug)]
56pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
57 inner: UpgradableStreamInner<S, D>,
58 options: UpgradableStreamOptions,
59}
60
61#[allow(private_bounds)]
62impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
63 #[inline(always)]
64 pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
65 UpgradableStream {
66 inner: UpgradableStreamInner::BaseClient(base, config),
67 options: Default::default(),
68 }
69 }
70
71 #[inline(always)]
72 pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
73 UpgradableStream {
74 inner: UpgradableStreamInner::BaseServer(base, config),
75 options: Default::default(),
76 }
77 }
78
79 pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
81 match self.inner {
82 UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
83 UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
84 UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
85 UpgradableStreamInner::Upgrading => Err(self),
86 }
87 }
88
89 pub fn handshake(&self) -> Option<&TlsHandshake> {
90 match &self.inner {
91 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
92 _ => None,
93 }
94 }
95
96 pub fn ignore_missing_close_notify(&mut self) {
97 self.options.ignore_missing_close_notify = true;
98 }
99
100 pub fn unclean_shutdown(self) -> Result<(), Self> {
103 match self.inner {
104 UpgradableStreamInner::BaseClient(..) => Ok(()),
105 UpgradableStreamInner::BaseServer(..) => Ok(()),
106 UpgradableStreamInner::Upgraded(upgraded, cfg) => {
107 if let Err(e) = D::unclean_shutdown(upgraded) {
108 Err(Self {
109 inner: UpgradableStreamInner::Upgraded(e, cfg),
110 options: self.options,
111 })
112 } else {
113 Ok(())
114 }
115 }
116 UpgradableStreamInner::Upgrading => Err(self),
117 }
118 }
119}
120
121impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
122 async fn secure_upgrade(&mut self) -> Result<(), SslError> {
123 match std::mem::replace(&mut self.inner, UpgradableStreamInner::Upgrading) {
124 UpgradableStreamInner::BaseClient(base, config) => {
125 let Some(config) = config else {
126 return Err(SslError::SslUnsupportedByClient);
127 };
128 let (upgraded, handshake) = D::upgrade_client(config, base).await?;
129 self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
130 Ok(())
131 }
132 UpgradableStreamInner::BaseServer(base, config) => {
133 let Some(config) = config else {
134 return Err(SslError::SslUnsupportedByClient);
135 };
136 let (upgraded, handshake) = D::upgrade_server(config, base).await?;
137 self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
138 Ok(())
139 }
140 UpgradableStreamInner::Upgraded(..) => Err(SslError::SslAlreadyUpgraded),
141 UpgradableStreamInner::Upgrading => Err(SslError::SslAlreadyUpgraded),
142 }
143 }
144
145 fn handshake(&self) -> Option<&TlsHandshake> {
146 match &self.inner {
147 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
148 _ => None,
149 }
150 }
151}
152
153#[cfg(feature = "tokio")]
154impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
155 #[inline(always)]
156 fn poll_read(
157 self: Pin<&mut Self>,
158 cx: &mut std::task::Context<'_>,
159 buf: &mut tokio::io::ReadBuf<'_>,
160 ) -> std::task::Poll<std::io::Result<()>> {
161 let ignore_missing_close_notify = self.options.ignore_missing_close_notify;
162 let inner = &mut self.get_mut().inner;
163 let res = match inner {
164 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
165 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
166 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
167 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
168 std::io::ErrorKind::InvalidInput,
169 "Cannot read while upgrading",
170 ))),
171 };
172 if ignore_missing_close_notify {
173 if matches!(res, std::task::Poll::Ready(Err(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof)
174 {
175 eprintln!("Unexpected EOF");
176 return std::task::Poll::Ready(Ok(()));
177 }
178 }
179 res
180 }
181}
182
183#[cfg(feature = "tokio")]
184impl<S: Stream, D: TlsDriver> tokio::io::AsyncWrite for UpgradableStream<S, D> {
185 #[inline(always)]
186 fn poll_write(
187 self: Pin<&mut Self>,
188 cx: &mut std::task::Context<'_>,
189 buf: &[u8],
190 ) -> std::task::Poll<Result<usize, std::io::Error>> {
191 let inner = &mut self.get_mut().inner;
192 match inner {
193 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_write(cx, buf),
194 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_write(cx, buf),
195 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_write(cx, buf),
196 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
197 std::io::ErrorKind::InvalidInput,
198 "Cannot write while upgrading",
199 ))),
200 }
201 }
202
203 #[inline(always)]
204 fn poll_flush(
205 self: Pin<&mut Self>,
206 cx: &mut std::task::Context<'_>,
207 ) -> std::task::Poll<Result<(), std::io::Error>> {
208 let inner = &mut self.get_mut().inner;
209 match inner {
210 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_flush(cx),
211 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_flush(cx),
212 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_flush(cx),
213 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
214 std::io::ErrorKind::InvalidInput,
215 "Cannot flush while upgrading",
216 ))),
217 }
218 }
219
220 #[inline(always)]
221 fn poll_shutdown(
222 self: Pin<&mut Self>,
223 cx: &mut std::task::Context<'_>,
224 ) -> std::task::Poll<Result<(), std::io::Error>> {
225 let inner = &mut self.get_mut().inner;
226 match inner {
227 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_shutdown(cx),
228 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_shutdown(cx),
229 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_shutdown(cx),
230 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
231 std::io::ErrorKind::InvalidInput,
232 "Cannot shutdown while upgrading",
233 ))),
234 }
235 }
236
237 #[inline(always)]
238 fn is_write_vectored(&self) -> bool {
239 match &self.inner {
240 UpgradableStreamInner::BaseClient(base, _) => base.is_write_vectored(),
241 UpgradableStreamInner::BaseServer(base, _) => base.is_write_vectored(),
242 UpgradableStreamInner::Upgraded(upgraded, _) => upgraded.is_write_vectored(),
243 UpgradableStreamInner::Upgrading => false,
244 }
245 }
246
247 #[inline(always)]
248 fn poll_write_vectored(
249 self: Pin<&mut Self>,
250 cx: &mut std::task::Context<'_>,
251 bufs: &[std::io::IoSlice<'_>],
252 ) -> std::task::Poll<Result<usize, std::io::Error>> {
253 let inner = &mut self.get_mut().inner;
254 match inner {
255 UpgradableStreamInner::BaseClient(base, _) => {
256 Pin::new(base).poll_write_vectored(cx, bufs)
257 }
258 UpgradableStreamInner::BaseServer(base, _) => {
259 Pin::new(base).poll_write_vectored(cx, bufs)
260 }
261 UpgradableStreamInner::Upgraded(upgraded, _) => {
262 Pin::new(upgraded).poll_write_vectored(cx, bufs)
263 }
264 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
265 std::io::ErrorKind::InvalidInput,
266 "Cannot write vectored while upgrading",
267 ))),
268 }
269 }
270}
271
272#[derive(derive_more::Debug)]
273enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
274 #[debug("BaseClient(..)")]
275 BaseClient(S, Option<D::ClientParams>),
276 #[debug("BaseServer(..)")]
277 BaseServer(S, Option<TlsServerParameterProvider>),
278 #[debug("Upgraded(..)")]
279 Upgraded(D::Stream, TlsHandshake),
280 #[debug("Upgrading")]
281 Upgrading,
282}
283
284pub trait Rewindable {
285 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
286}
287
288pub struct RewindStream<S> {
289 buffer: Vec<u8>,
290 inner: S,
291}
292
293impl<S> RewindStream<S> {
294 pub fn new(inner: S) -> Self {
295 RewindStream {
296 buffer: Vec::new(),
297 inner,
298 }
299 }
300
301 pub fn rewind(&mut self, data: &[u8]) {
302 self.buffer.extend_from_slice(data);
303 }
304
305 pub fn into_inner(self) -> (S, Vec<u8>) {
306 (self.inner, self.buffer)
307 }
308}
309
310#[cfg(feature = "tokio")]
311impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
312 #[inline(always)]
313 fn poll_read(
314 mut self: Pin<&mut Self>,
315 cx: &mut Context<'_>,
316 buf: &mut ReadBuf<'_>,
317 ) -> Poll<std::io::Result<()>> {
318 if !self.buffer.is_empty() {
319 let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
320 let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
321 buf.put_slice(&data);
322 Poll::Ready(Ok(()))
323 } else {
324 Pin::new(&mut self.inner).poll_read(cx, buf)
325 }
326 }
327}
328
329#[cfg(feature = "tokio")]
330impl<S: AsyncWrite + Unpin> AsyncWrite for RewindStream<S> {
331 #[inline(always)]
332 fn poll_write(
333 mut self: Pin<&mut Self>,
334 cx: &mut Context<'_>,
335 buf: &[u8],
336 ) -> Poll<Result<usize, std::io::Error>> {
337 Pin::new(&mut self.inner).poll_write(cx, buf)
338 }
339
340 #[inline(always)]
341 fn poll_flush(
342 mut self: Pin<&mut Self>,
343 cx: &mut Context<'_>,
344 ) -> Poll<Result<(), std::io::Error>> {
345 Pin::new(&mut self.inner).poll_flush(cx)
346 }
347
348 #[inline(always)]
349 fn poll_shutdown(
350 mut self: Pin<&mut Self>,
351 cx: &mut Context<'_>,
352 ) -> Poll<Result<(), std::io::Error>> {
353 Pin::new(&mut self.inner).poll_shutdown(cx)
354 }
355
356 #[inline(always)]
357 fn is_write_vectored(&self) -> bool {
358 self.inner.is_write_vectored()
359 }
360
361 #[inline(always)]
362 fn poll_write_vectored(
363 mut self: Pin<&mut Self>,
364 cx: &mut Context<'_>,
365 bufs: &[IoSlice<'_>],
366 ) -> Poll<Result<usize, std::io::Error>> {
367 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
368 }
369}
370
371impl<S: Stream> Rewindable for RewindStream<S> {
372 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
373 self.rewind(bytes);
374 Ok(())
375 }
376}
377
378impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
379where
380 D::Stream: Rewindable,
381{
382 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
383 match &mut self.inner {
384 UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
385 UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
386 UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
387 UpgradableStreamInner::Upgrading => Err(std::io::Error::new(
388 std::io::ErrorKind::Unsupported,
389 "Cannot rewind a stream that is upgrading",
390 )),
391 }
392 }
393}