1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
3
4use std::future::Future;
5#[cfg(feature = "tokio")]
6use std::{
7 any::Any,
8 io::IoSlice,
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13use crate::{Ssl, SslError, TlsDriver, TlsHandshake, TlsServerParameterProvider};
14
15#[cfg(feature = "tokio")]
16pub trait Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
17 fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
18 where
19 Self: Sized + 'static,
20 {
21 let mut holder = Some(self);
23 let stream = &mut holder as &mut dyn Any;
24 let Some(stream) = stream.downcast_mut::<Option<S>>() else {
25 return Err(holder.take().unwrap());
26 };
27 let stream = stream.take().unwrap();
28 Ok(stream)
29 }
30}
31
32#[cfg(feature = "tokio")]
33impl<T> Stream for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {}
34
35#[cfg(not(feature = "tokio"))]
36pub trait Stream: 'static {}
37#[cfg(not(feature = "tokio"))]
38impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
39#[cfg(not(feature = "tokio"))]
40impl Stream for () {}
41
42pub trait StreamUpgrade: Stream {
43 fn secure_upgrade(&mut self) -> impl Future<Output = Result<(), SslError>> + Send;
44 fn handshake(&self) -> Option<&TlsHandshake>;
45}
46
47#[allow(private_bounds)]
48#[derive(derive_more::Debug)]
49pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
50 inner: UpgradableStreamInner<S, D>,
51}
52
53#[allow(private_bounds)]
54impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
55 #[inline(always)]
56 pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
57 UpgradableStream {
58 inner: UpgradableStreamInner::BaseClient(base, config),
59 }
60 }
61
62 #[inline(always)]
63 pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
64 UpgradableStream {
65 inner: UpgradableStreamInner::BaseServer(base, config),
66 }
67 }
68
69 pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
71 match self.inner {
72 UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
73 UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
74 UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
75 UpgradableStreamInner::Upgrading => Err(self),
76 }
77 }
78
79 pub fn handshake(&self) -> Option<&TlsHandshake> {
80 match &self.inner {
81 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
82 _ => None,
83 }
84 }
85}
86
87impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
88 async fn secure_upgrade(&mut self) -> Result<(), SslError> {
89 match std::mem::replace(&mut self.inner, UpgradableStreamInner::Upgrading) {
90 UpgradableStreamInner::BaseClient(base, config) => {
91 let Some(config) = config else {
92 return Err(SslError::SslUnsupportedByClient);
93 };
94 let (upgraded, handshake) = D::upgrade_client(config, base).await?;
95 self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
96 Ok(())
97 }
98 UpgradableStreamInner::BaseServer(base, config) => {
99 let Some(config) = config else {
100 return Err(SslError::SslUnsupportedByClient);
101 };
102 let (upgraded, handshake) = D::upgrade_server(config, base).await?;
103 self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
104 Ok(())
105 }
106 UpgradableStreamInner::Upgraded(..) => Err(SslError::SslAlreadyUpgraded),
107 UpgradableStreamInner::Upgrading => Err(SslError::SslAlreadyUpgraded),
108 }
109 }
110
111 fn handshake(&self) -> Option<&TlsHandshake> {
112 match &self.inner {
113 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
114 _ => None,
115 }
116 }
117}
118
119#[cfg(feature = "tokio")]
120impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
121 #[inline(always)]
122 fn poll_read(
123 self: Pin<&mut Self>,
124 cx: &mut std::task::Context<'_>,
125 buf: &mut tokio::io::ReadBuf<'_>,
126 ) -> std::task::Poll<std::io::Result<()>> {
127 let inner = &mut self.get_mut().inner;
128 match inner {
129 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
130 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
131 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
132 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
133 std::io::ErrorKind::InvalidInput,
134 "Cannot read while upgrading",
135 ))),
136 }
137 }
138}
139
140#[cfg(feature = "tokio")]
141impl<S: Stream, D: TlsDriver> tokio::io::AsyncWrite for UpgradableStream<S, D> {
142 #[inline(always)]
143 fn poll_write(
144 self: Pin<&mut Self>,
145 cx: &mut std::task::Context<'_>,
146 buf: &[u8],
147 ) -> std::task::Poll<Result<usize, std::io::Error>> {
148 let inner = &mut self.get_mut().inner;
149 match inner {
150 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_write(cx, buf),
151 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_write(cx, buf),
152 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_write(cx, buf),
153 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
154 std::io::ErrorKind::InvalidInput,
155 "Cannot write while upgrading",
156 ))),
157 }
158 }
159
160 #[inline(always)]
161 fn poll_flush(
162 self: Pin<&mut Self>,
163 cx: &mut std::task::Context<'_>,
164 ) -> std::task::Poll<Result<(), std::io::Error>> {
165 let inner = &mut self.get_mut().inner;
166 match inner {
167 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_flush(cx),
168 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_flush(cx),
169 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_flush(cx),
170 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
171 std::io::ErrorKind::InvalidInput,
172 "Cannot flush while upgrading",
173 ))),
174 }
175 }
176
177 #[inline(always)]
178 fn poll_shutdown(
179 self: Pin<&mut Self>,
180 cx: &mut std::task::Context<'_>,
181 ) -> std::task::Poll<Result<(), std::io::Error>> {
182 let inner = &mut self.get_mut().inner;
183 match inner {
184 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_shutdown(cx),
185 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_shutdown(cx),
186 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_shutdown(cx),
187 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
188 std::io::ErrorKind::InvalidInput,
189 "Cannot shutdown while upgrading",
190 ))),
191 }
192 }
193
194 #[inline(always)]
195 fn is_write_vectored(&self) -> bool {
196 match &self.inner {
197 UpgradableStreamInner::BaseClient(base, _) => base.is_write_vectored(),
198 UpgradableStreamInner::BaseServer(base, _) => base.is_write_vectored(),
199 UpgradableStreamInner::Upgraded(upgraded, _) => upgraded.is_write_vectored(),
200 UpgradableStreamInner::Upgrading => false,
201 }
202 }
203
204 #[inline(always)]
205 fn poll_write_vectored(
206 self: Pin<&mut Self>,
207 cx: &mut std::task::Context<'_>,
208 bufs: &[std::io::IoSlice<'_>],
209 ) -> std::task::Poll<Result<usize, std::io::Error>> {
210 let inner = &mut self.get_mut().inner;
211 match inner {
212 UpgradableStreamInner::BaseClient(base, _) => {
213 Pin::new(base).poll_write_vectored(cx, bufs)
214 }
215 UpgradableStreamInner::BaseServer(base, _) => {
216 Pin::new(base).poll_write_vectored(cx, bufs)
217 }
218 UpgradableStreamInner::Upgraded(upgraded, _) => {
219 Pin::new(upgraded).poll_write_vectored(cx, bufs)
220 }
221 UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
222 std::io::ErrorKind::InvalidInput,
223 "Cannot write vectored while upgrading",
224 ))),
225 }
226 }
227}
228
229#[derive(derive_more::Debug)]
230enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
231 #[debug("BaseClient(..)")]
232 BaseClient(S, Option<D::ClientParams>),
233 #[debug("BaseServer(..)")]
234 BaseServer(S, Option<TlsServerParameterProvider>),
235 #[debug("Upgraded(..)")]
236 Upgraded(D::Stream, TlsHandshake),
237 #[debug("Upgrading")]
238 Upgrading,
239}
240
241pub trait Rewindable {
242 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
243}
244
245pub struct RewindStream<S> {
246 buffer: Vec<u8>,
247 inner: S,
248}
249
250impl<S> RewindStream<S> {
251 pub fn new(inner: S) -> Self {
252 RewindStream {
253 buffer: Vec::new(),
254 inner,
255 }
256 }
257
258 pub fn rewind(&mut self, data: &[u8]) {
259 self.buffer.extend_from_slice(data);
260 }
261
262 pub fn into_inner(self) -> (S, Vec<u8>) {
263 (self.inner, self.buffer)
264 }
265}
266
267#[cfg(feature = "tokio")]
268impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
269 #[inline(always)]
270 fn poll_read(
271 mut self: Pin<&mut Self>,
272 cx: &mut Context<'_>,
273 buf: &mut ReadBuf<'_>,
274 ) -> Poll<std::io::Result<()>> {
275 if !self.buffer.is_empty() {
276 let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
277 let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
278 buf.put_slice(&data);
279 Poll::Ready(Ok(()))
280 } else {
281 Pin::new(&mut self.inner).poll_read(cx, buf)
282 }
283 }
284}
285
286#[cfg(feature = "tokio")]
287impl<S: AsyncWrite + Unpin> AsyncWrite for RewindStream<S> {
288 #[inline(always)]
289 fn poll_write(
290 mut self: Pin<&mut Self>,
291 cx: &mut Context<'_>,
292 buf: &[u8],
293 ) -> Poll<Result<usize, std::io::Error>> {
294 Pin::new(&mut self.inner).poll_write(cx, buf)
295 }
296
297 #[inline(always)]
298 fn poll_flush(
299 mut self: Pin<&mut Self>,
300 cx: &mut Context<'_>,
301 ) -> Poll<Result<(), std::io::Error>> {
302 Pin::new(&mut self.inner).poll_flush(cx)
303 }
304
305 #[inline(always)]
306 fn poll_shutdown(
307 mut self: Pin<&mut Self>,
308 cx: &mut Context<'_>,
309 ) -> Poll<Result<(), std::io::Error>> {
310 Pin::new(&mut self.inner).poll_shutdown(cx)
311 }
312
313 #[inline(always)]
314 fn is_write_vectored(&self) -> bool {
315 self.inner.is_write_vectored()
316 }
317
318 #[inline(always)]
319 fn poll_write_vectored(
320 mut self: Pin<&mut Self>,
321 cx: &mut Context<'_>,
322 bufs: &[IoSlice<'_>],
323 ) -> Poll<Result<usize, std::io::Error>> {
324 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
325 }
326}
327
328impl<S: Stream> Rewindable for RewindStream<S> {
329 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
330 self.rewind(bytes);
331 Ok(())
332 }
333}
334
335impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
336where
337 D::Stream: Rewindable,
338{
339 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
340 match &mut self.inner {
341 UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
342 UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
343 UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
344 UpgradableStreamInner::Upgrading => Err(std::io::Error::new(
345 std::io::ErrorKind::Unsupported,
346 "Cannot rewind a stream that is upgrading",
347 )),
348 }
349 }
350}