1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
3
4use std::future::Future;
5#[cfg(feature = "tokio")]
6use std::{
7 any::Any,
8 io::IoSlice,
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13use crate::{Ssl, SslError, TlsDriver, TlsHandshake, TlsServerParameterProvider};
14
15#[cfg(feature = "tokio")]
17pub trait Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
18 fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
19 where
20 Self: Sized + 'static,
21 {
22 let mut holder = Some(self);
24 let stream = &mut holder as &mut dyn Any;
25 let Some(stream) = stream.downcast_mut::<Option<S>>() else {
26 return Err(holder.take().unwrap());
27 };
28 let stream = stream.take().unwrap();
29 Ok(stream)
30 }
31}
32
33#[cfg(feature = "tokio")]
34impl<T> Stream for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {}
35
36#[cfg(not(feature = "tokio"))]
37pub trait Stream: 'static {}
38#[cfg(not(feature = "tokio"))]
39impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
40#[cfg(not(feature = "tokio"))]
41impl Stream for () {}
42
43pub trait StreamUpgrade: Stream {
45 fn secure_upgrade(&mut self) -> impl Future<Output = Result<(), SslError>> + Send;
46 fn handshake(&self) -> Option<&TlsHandshake>;
47}
48
49#[derive(Default, Debug)]
50struct UpgradableStreamOptions {
51 ignore_missing_close_notify: bool,
52}
53
54#[allow(private_bounds)]
55#[derive(derive_more::Debug)]
56pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
57 inner: UpgradableStreamInner<S, D>,
58 options: UpgradableStreamOptions,
59}
60
61#[allow(private_bounds)]
62impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
63 #[inline(always)]
64 pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
65 UpgradableStream {
66 inner: UpgradableStreamInner::BaseClient(base, config),
67 options: Default::default(),
68 }
69 }
70
71 #[inline(always)]
72 pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
73 UpgradableStream {
74 inner: UpgradableStreamInner::BaseServer(base, config),
75 options: Default::default(),
76 }
77 }
78
79 pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
81 match self.inner {
82 UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
83 UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
84 UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
85 UpgradableStreamInner::Upgrading => Err(self),
86 }
87 }
88
89 pub fn handshake(&self) -> Option<&TlsHandshake> {
90 match &self.inner {
91 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
92 _ => None,
93 }
94 }
95
96 pub fn ignore_missing_close_notify(&mut self) {
97 self.options.ignore_missing_close_notify = true;
98 }
99
100 pub fn unclean_shutdown(self) -> Result<(), Self> {
103 match self.inner {
104 UpgradableStreamInner::BaseClient(..) => Ok(()),
105 UpgradableStreamInner::BaseServer(..) => Ok(()),
106 UpgradableStreamInner::Upgraded(upgraded, cfg) => {
107 if let Err(e) = D::unclean_shutdown(upgraded) {
108 Err(Self {
109 inner: UpgradableStreamInner::Upgraded(e, cfg),
110 options: self.options,
111 })
112 } else {
113 Ok(())
114 }
115 }
116 UpgradableStreamInner::Upgrading => Err(self),
117 }
118 }
119}
120
121impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
122 async fn secure_upgrade(&mut self) -> Result<(), SslError> {
123 match std::mem::replace(&mut self.inner, UpgradableStreamInner::Upgrading) {
124 UpgradableStreamInner::BaseClient(base, config) => {
125 let Some(config) = config else {
126 return Err(SslError::SslUnsupportedByClient);
127 };
128 let (upgraded, handshake) = D::upgrade_client(config, base).await?;
129 self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
130 Ok(())
131 }
132 UpgradableStreamInner::BaseServer(base, config) => {
133 let Some(config) = config else {
134 return Err(SslError::SslUnsupportedByClient);
135 };
136 let (upgraded, handshake) = D::upgrade_server(config, base).await?;
137 self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
138 Ok(())
139 }
140 UpgradableStreamInner::Upgraded(..) => Err(SslError::SslAlreadyUpgraded),
141 UpgradableStreamInner::Upgrading => Err(SslError::SslAlreadyUpgraded),
142 }
143 }
144
145 fn handshake(&self) -> Option<&TlsHandshake> {
146 match &self.inner {
147 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
148 _ => None,
149 }
150 }
151}
152
153#[cfg(feature = "tokio")]
154impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
155 #[inline(always)]
156 fn poll_read(
157 self: Pin<&mut Self>,
158 cx: &mut std::task::Context<'_>,
159 buf: &mut tokio::io::ReadBuf<'_>,
160 ) -> std::task::Poll<std::io::Result<()>> {
161 let ignore_missing_close_notify = self.options.ignore_missing_close_notify;
162 let inner = &mut self.get_mut().inner;
163 let res = match inner {
164 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
165 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
166 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
167 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
168 std::io::ErrorKind::InvalidInput,
169 "Cannot read while upgrading",
170 ))),
171 };
172 if ignore_missing_close_notify {
173 if matches!(res, std::task::Poll::Ready(Err(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof)
174 {
175 return std::task::Poll::Ready(Ok(()));
176 }
177 }
178 res
179 }
180}
181
182#[cfg(feature = "tokio")]
183impl<S: Stream, D: TlsDriver> tokio::io::AsyncWrite for UpgradableStream<S, D> {
184 #[inline(always)]
185 fn poll_write(
186 self: Pin<&mut Self>,
187 cx: &mut std::task::Context<'_>,
188 buf: &[u8],
189 ) -> std::task::Poll<Result<usize, std::io::Error>> {
190 let inner = &mut self.get_mut().inner;
191 match inner {
192 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_write(cx, buf),
193 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_write(cx, buf),
194 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_write(cx, buf),
195 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
196 std::io::ErrorKind::InvalidInput,
197 "Cannot write while upgrading",
198 ))),
199 }
200 }
201
202 #[inline(always)]
203 fn poll_flush(
204 self: Pin<&mut Self>,
205 cx: &mut std::task::Context<'_>,
206 ) -> std::task::Poll<Result<(), std::io::Error>> {
207 let inner = &mut self.get_mut().inner;
208 match inner {
209 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_flush(cx),
210 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_flush(cx),
211 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_flush(cx),
212 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
213 std::io::ErrorKind::InvalidInput,
214 "Cannot flush while upgrading",
215 ))),
216 }
217 }
218
219 #[inline(always)]
220 fn poll_shutdown(
221 self: Pin<&mut Self>,
222 cx: &mut std::task::Context<'_>,
223 ) -> std::task::Poll<Result<(), std::io::Error>> {
224 let inner = &mut self.get_mut().inner;
225 match inner {
226 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_shutdown(cx),
227 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_shutdown(cx),
228 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_shutdown(cx),
229 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
230 std::io::ErrorKind::InvalidInput,
231 "Cannot shutdown while upgrading",
232 ))),
233 }
234 }
235
236 #[inline(always)]
237 fn is_write_vectored(&self) -> bool {
238 match &self.inner {
239 UpgradableStreamInner::BaseClient(base, _) => base.is_write_vectored(),
240 UpgradableStreamInner::BaseServer(base, _) => base.is_write_vectored(),
241 UpgradableStreamInner::Upgraded(upgraded, _) => upgraded.is_write_vectored(),
242 UpgradableStreamInner::Upgrading => false,
243 }
244 }
245
246 #[inline(always)]
247 fn poll_write_vectored(
248 self: Pin<&mut Self>,
249 cx: &mut std::task::Context<'_>,
250 bufs: &[std::io::IoSlice<'_>],
251 ) -> std::task::Poll<Result<usize, std::io::Error>> {
252 let inner = &mut self.get_mut().inner;
253 match inner {
254 UpgradableStreamInner::BaseClient(base, _) => {
255 Pin::new(base).poll_write_vectored(cx, bufs)
256 }
257 UpgradableStreamInner::BaseServer(base, _) => {
258 Pin::new(base).poll_write_vectored(cx, bufs)
259 }
260 UpgradableStreamInner::Upgraded(upgraded, _) => {
261 Pin::new(upgraded).poll_write_vectored(cx, bufs)
262 }
263 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
264 std::io::ErrorKind::InvalidInput,
265 "Cannot write vectored while upgrading",
266 ))),
267 }
268 }
269}
270
271#[derive(derive_more::Debug)]
272enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
273 #[debug("BaseClient(..)")]
274 BaseClient(S, Option<D::ClientParams>),
275 #[debug("BaseServer(..)")]
276 BaseServer(S, Option<TlsServerParameterProvider>),
277 #[debug("Upgraded(..)")]
278 Upgraded(D::Stream, TlsHandshake),
279 #[debug("Upgrading")]
280 Upgrading,
281}
282
283pub trait Rewindable {
284 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
285}
286
287pub struct RewindStream<S> {
288 buffer: Vec<u8>,
289 inner: S,
290}
291
292impl<S> RewindStream<S> {
293 pub fn new(inner: S) -> Self {
294 RewindStream {
295 buffer: Vec::new(),
296 inner,
297 }
298 }
299
300 pub fn rewind(&mut self, data: &[u8]) {
301 self.buffer.extend_from_slice(data);
302 }
303
304 pub fn into_inner(self) -> (S, Vec<u8>) {
305 (self.inner, self.buffer)
306 }
307}
308
309#[cfg(feature = "tokio")]
310impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
311 #[inline(always)]
312 fn poll_read(
313 mut self: Pin<&mut Self>,
314 cx: &mut Context<'_>,
315 buf: &mut ReadBuf<'_>,
316 ) -> Poll<std::io::Result<()>> {
317 if !self.buffer.is_empty() {
318 let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
319 let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
320 buf.put_slice(&data);
321 Poll::Ready(Ok(()))
322 } else {
323 Pin::new(&mut self.inner).poll_read(cx, buf)
324 }
325 }
326}
327
328#[cfg(feature = "tokio")]
329impl<S: AsyncWrite + Unpin> AsyncWrite for RewindStream<S> {
330 #[inline(always)]
331 fn poll_write(
332 mut self: Pin<&mut Self>,
333 cx: &mut Context<'_>,
334 buf: &[u8],
335 ) -> Poll<Result<usize, std::io::Error>> {
336 Pin::new(&mut self.inner).poll_write(cx, buf)
337 }
338
339 #[inline(always)]
340 fn poll_flush(
341 mut self: Pin<&mut Self>,
342 cx: &mut Context<'_>,
343 ) -> Poll<Result<(), std::io::Error>> {
344 Pin::new(&mut self.inner).poll_flush(cx)
345 }
346
347 #[inline(always)]
348 fn poll_shutdown(
349 mut self: Pin<&mut Self>,
350 cx: &mut Context<'_>,
351 ) -> Poll<Result<(), std::io::Error>> {
352 Pin::new(&mut self.inner).poll_shutdown(cx)
353 }
354
355 #[inline(always)]
356 fn is_write_vectored(&self) -> bool {
357 self.inner.is_write_vectored()
358 }
359
360 #[inline(always)]
361 fn poll_write_vectored(
362 mut self: Pin<&mut Self>,
363 cx: &mut Context<'_>,
364 bufs: &[IoSlice<'_>],
365 ) -> Poll<Result<usize, std::io::Error>> {
366 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
367 }
368}
369
370impl<S: Stream> Rewindable for RewindStream<S> {
371 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
372 self.rewind(bytes);
373 Ok(())
374 }
375}
376
377impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
378where
379 D::Stream: Rewindable,
380{
381 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
382 match &mut self.inner {
383 UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
384 UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
385 UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
386 UpgradableStreamInner::Upgrading => Err(std::io::Error::new(
387 std::io::ErrorKind::Unsupported,
388 "Cannot rewind a stream that is upgrading",
389 )),
390 }
391 }
392}