1use std::{
26 error::Error as StdError,
27 fmt,
28 future::Future,
29 io,
30 pin::Pin,
31 sync::{Arc, Mutex},
32 task::{Context, Poll},
33};
34
35use bytes::Bytes;
36use tokio::{
37 io::{AsyncRead, AsyncWrite, ReadBuf},
38 sync::oneshot,
39};
40
41use self::rewind::Rewind;
42use super::{Error, Result};
43
44pub struct Upgraded {
53 io: Rewind<Box<dyn Io + Send>>,
54}
55
56#[derive(Clone)]
60pub struct OnUpgrade {
61 rx: Option<Arc<Mutex<oneshot::Receiver<Result<Upgraded>>>>>,
62}
63
64#[inline]
73pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
74 msg.on_upgrade()
75}
76
77pub(crate) struct Pending {
78 tx: oneshot::Sender<Result<Upgraded>>,
79}
80
81pub(crate) fn pending() -> (Pending, OnUpgrade) {
82 let (tx, rx) = oneshot::channel();
83 (
84 Pending { tx },
85 OnUpgrade {
86 rx: Some(Arc::new(Mutex::new(rx))),
87 },
88 )
89}
90
91impl Upgraded {
94 #[inline]
95 pub(crate) fn new<T>(io: T, read_buf: Bytes) -> Self
96 where
97 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
98 {
99 Upgraded {
100 io: Rewind::new_buffered(Box::new(io), read_buf),
101 }
102 }
103}
104
105impl AsyncRead for Upgraded {
106 #[inline]
107 fn poll_read(
108 mut self: Pin<&mut Self>,
109 cx: &mut Context<'_>,
110 buf: &mut ReadBuf<'_>,
111 ) -> Poll<io::Result<()>> {
112 Pin::new(&mut self.io).poll_read(cx, buf)
113 }
114}
115
116impl AsyncWrite for Upgraded {
117 #[inline]
118 fn poll_write(
119 mut self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 buf: &[u8],
122 ) -> Poll<io::Result<usize>> {
123 Pin::new(&mut self.io).poll_write(cx, buf)
124 }
125
126 #[inline]
127 fn poll_write_vectored(
128 mut self: Pin<&mut Self>,
129 cx: &mut Context<'_>,
130 bufs: &[io::IoSlice<'_>],
131 ) -> Poll<io::Result<usize>> {
132 Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
133 }
134
135 #[inline]
136 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
137 Pin::new(&mut self.io).poll_flush(cx)
138 }
139
140 #[inline]
141 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
142 Pin::new(&mut self.io).poll_shutdown(cx)
143 }
144
145 #[inline]
146 fn is_write_vectored(&self) -> bool {
147 self.io.is_write_vectored()
148 }
149}
150
151impl fmt::Debug for Upgraded {
152 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153 f.debug_struct("Upgraded").finish()
154 }
155}
156
157impl OnUpgrade {
160 #[inline]
161 pub(super) fn none() -> Self {
162 OnUpgrade { rx: None }
163 }
164
165 #[inline]
166 pub(super) fn is_none(&self) -> bool {
167 self.rx.is_none()
168 }
169}
170
171impl Future for OnUpgrade {
172 type Output = Result<Upgraded, Error>;
173
174 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
175 match self.rx {
176 Some(ref rx) => Pin::new(&mut *rx.lock().unwrap())
177 .poll(cx)
178 .map(|res| match res {
179 Ok(Ok(upgraded)) => Ok(upgraded),
180 Ok(Err(err)) => Err(err),
181 Err(_oneshot_canceled) => Err(Error::new_canceled().with(UpgradeExpected)),
182 }),
183 None => Poll::Ready(Err(Error::new_user_no_upgrade())),
184 }
185 }
186}
187
188impl Pending {
191 #[inline]
192 pub(super) fn fulfill(self, upgraded: Upgraded) {
193 trace!("pending upgrade fulfill");
194 let _ = self.tx.send(Ok(upgraded));
195 }
196
197 #[inline]
200 pub(super) fn manual(self) {
201 trace!("pending upgrade handled manually");
202 let _ = self.tx.send(Err(Error::new_user_manual_upgrade()));
203 }
204}
205
206#[derive(Debug)]
213struct UpgradeExpected;
214
215impl fmt::Display for UpgradeExpected {
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 f.write_str("upgrade expected but not completed")
218 }
219}
220
221impl StdError for UpgradeExpected {}
222
223trait Io: AsyncRead + AsyncWrite + Unpin + 'static {}
226
227impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
228
229mod sealed {
230 use super::OnUpgrade;
231
232 pub trait CanUpgrade {
233 fn on_upgrade(self) -> OnUpgrade;
234 }
235
236 impl<B> CanUpgrade for http::Request<B> {
237 fn on_upgrade(mut self) -> OnUpgrade {
238 self.extensions_mut()
239 .remove::<OnUpgrade>()
240 .unwrap_or_else(OnUpgrade::none)
241 }
242 }
243
244 impl<B> CanUpgrade for &'_ mut http::Request<B> {
245 fn on_upgrade(self) -> OnUpgrade {
246 self.extensions_mut()
247 .remove::<OnUpgrade>()
248 .unwrap_or_else(OnUpgrade::none)
249 }
250 }
251
252 impl<B> CanUpgrade for http::Response<B> {
253 fn on_upgrade(mut self) -> OnUpgrade {
254 self.extensions_mut()
255 .remove::<OnUpgrade>()
256 .unwrap_or_else(OnUpgrade::none)
257 }
258 }
259
260 impl<B> CanUpgrade for &'_ mut http::Response<B> {
261 fn on_upgrade(self) -> OnUpgrade {
262 self.extensions_mut()
263 .remove::<OnUpgrade>()
264 .unwrap_or_else(OnUpgrade::none)
265 }
266 }
267}
268
269mod rewind {
270 use std::{
271 cmp, io,
272 pin::Pin,
273 task::{Context, Poll},
274 };
275
276 use bytes::{Buf, Bytes};
277 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
278
279 #[derive(Debug)]
281 pub(crate) struct Rewind<T> {
282 pre: Option<Bytes>,
283 inner: T,
284 }
285
286 impl<T> Rewind<T> {
287 #[inline]
288 pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
289 Rewind {
290 pre: Some(buf),
291 inner: io,
292 }
293 }
294
295 #[cfg(test)]
296 pub(crate) fn rewind(&mut self, bs: Bytes) {
297 debug_assert!(self.pre.is_none());
298 self.pre = Some(bs);
299 }
300 }
301
302 impl<T> AsyncRead for Rewind<T>
303 where
304 T: AsyncRead + Unpin,
305 {
306 fn poll_read(
307 mut self: Pin<&mut Self>,
308 cx: &mut Context<'_>,
309 buf: &mut ReadBuf<'_>,
310 ) -> Poll<io::Result<()>> {
311 if let Some(mut prefix) = self.pre.take() {
312 if !prefix.is_empty() {
314 let copy_len = cmp::min(prefix.len(), buf.remaining());
315 buf.put_slice(&prefix[..copy_len]);
317 prefix.advance(copy_len);
318 if !prefix.is_empty() {
320 self.pre = Some(prefix);
321 }
322
323 return Poll::Ready(Ok(()));
324 }
325 }
326 Pin::new(&mut self.inner).poll_read(cx, buf)
327 }
328 }
329
330 impl<T> AsyncWrite for Rewind<T>
331 where
332 T: AsyncWrite + Unpin,
333 {
334 #[inline]
335 fn poll_write(
336 mut self: Pin<&mut Self>,
337 cx: &mut Context<'_>,
338 buf: &[u8],
339 ) -> Poll<io::Result<usize>> {
340 Pin::new(&mut self.inner).poll_write(cx, buf)
341 }
342
343 #[inline]
344 fn poll_write_vectored(
345 mut self: Pin<&mut Self>,
346 cx: &mut Context<'_>,
347 bufs: &[io::IoSlice<'_>],
348 ) -> Poll<io::Result<usize>> {
349 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
350 }
351
352 #[inline]
353 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
354 Pin::new(&mut self.inner).poll_flush(cx)
355 }
356
357 #[inline]
358 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
359 Pin::new(&mut self.inner).poll_shutdown(cx)
360 }
361
362 #[inline]
363 fn is_write_vectored(&self) -> bool {
364 self.inner.is_write_vectored()
365 }
366 }
367
368 #[cfg(test)]
369 mod tests {
370 use bytes::Bytes;
371 use tokio::io::AsyncReadExt;
372
373 use super::Rewind;
374
375 #[tokio::test]
376 async fn partial_rewind() {
377 let underlying = [104, 101, 108, 108, 111];
378
379 let mock = tokio_test::io::Builder::new().read(&underlying).build();
380
381 let mut stream = Rewind::new_buffered(mock, Bytes::new());
382
383 let mut buf = [0; 2];
385 stream.read_exact(&mut buf).await.expect("read1");
386
387 stream.rewind(Bytes::copy_from_slice(&buf[..]));
389
390 let mut buf = [0; 5];
391 stream.read_exact(&mut buf).await.expect("read1");
392
393 assert_eq!(&buf, &underlying);
395 }
396
397 #[tokio::test]
398 async fn full_rewind() {
399 let underlying = [104, 101, 108, 108, 111];
400
401 let mock = tokio_test::io::Builder::new().read(&underlying).build();
402
403 let mut stream = Rewind::new_buffered(mock, Bytes::new());
404
405 let mut buf = [0; 5];
406 stream.read_exact(&mut buf).await.expect("read1");
407
408 stream.rewind(Bytes::copy_from_slice(&buf[..]));
410
411 let mut buf = [0; 5];
412 stream.read_exact(&mut buf).await.expect("read1");
413
414 assert_eq!(&buf, &underlying);
415 }
416 }
417}