1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{atomic::AtomicBool, Arc},
5 task::ready,
6};
7
8use bytes::BufMut;
9use futures_util::FutureExt;
10use http::Request;
11use http_body::Body;
12use send_wrapper::SendWrapper;
13use tokio::io::{AsyncRead, AsyncWrite};
14
15pub struct Upgraded {
26 reader: SendWrapper<Pin<Box<dyn AsyncRead + Unpin>>>,
27 writer: SendWrapper<Pin<Box<dyn AsyncWrite + Unpin>>>,
28 leftover: Option<bytes::Bytes>,
29}
30
31impl Upgraded {
32 #[inline]
33 pub(super) fn new(
34 io: impl AsyncRead + AsyncWrite + Unpin + 'static,
35 leftover: Option<bytes::Bytes>,
36 ) -> Self {
37 let (reader, writer) = tokio::io::split(io);
38 Self {
39 reader: SendWrapper::new(Box::pin(reader)),
40 writer: SendWrapper::new(Box::pin(writer)),
41 leftover,
42 }
43 }
44}
45
46impl AsyncRead for Upgraded {
47 #[inline]
48 fn poll_read(
49 mut self: std::pin::Pin<&mut Self>,
50 cx: &mut std::task::Context<'_>,
51 buf: &mut tokio::io::ReadBuf<'_>,
52 ) -> std::task::Poll<std::io::Result<()>> {
53 if let Some(leftover) = &mut self.leftover {
54 let slice_len = leftover.len().min(buf.remaining());
55 let leftover_to_write = leftover.split_to(slice_len);
56 buf.put(leftover_to_write);
57 if leftover.is_empty() {
58 self.leftover = None;
59 }
60 return std::task::Poll::Ready(Ok(()));
61 }
62 (*self.reader).as_mut().poll_read(cx, buf)
63 }
64}
65
66impl AsyncWrite for Upgraded {
67 #[inline]
68 fn poll_write(
69 mut self: std::pin::Pin<&mut Self>,
70 cx: &mut std::task::Context<'_>,
71 buf: &[u8],
72 ) -> std::task::Poll<std::io::Result<usize>> {
73 (*self.writer).as_mut().poll_write(cx, buf)
74 }
75
76 #[inline]
77 fn poll_flush(
78 mut self: std::pin::Pin<&mut Self>,
79 cx: &mut std::task::Context<'_>,
80 ) -> std::task::Poll<std::io::Result<()>> {
81 (*self.writer).as_mut().poll_flush(cx)
82 }
83
84 #[inline]
85 fn poll_shutdown(
86 mut self: std::pin::Pin<&mut Self>,
87 cx: &mut std::task::Context<'_>,
88 ) -> std::task::Poll<std::io::Result<()>> {
89 (*self.writer).as_mut().poll_shutdown(cx)
90 }
91
92 #[inline]
93 fn is_write_vectored(&self) -> bool {
94 self.writer.is_write_vectored()
95 }
96
97 #[inline]
98 fn poll_write_vectored(
99 mut self: Pin<&mut Self>,
100 cx: &mut std::task::Context<'_>,
101 bufs: &[std::io::IoSlice<'_>],
102 ) -> std::task::Poll<std::io::Result<usize>> {
103 (*self.writer).as_mut().poll_write_vectored(cx, bufs)
104 }
105}
106
107#[derive(Clone)]
108pub(super) struct Upgrade {
109 inner: Arc<futures_util::lock::Mutex<oneshot::AsyncReceiver<Upgraded>>>,
110 pub(super) upgraded: Arc<AtomicBool>,
111}
112
113impl Upgrade {
114 #[inline]
115 pub(super) fn new(inner: oneshot::AsyncReceiver<Upgraded>) -> Self {
116 Self {
117 inner: Arc::new(futures_util::lock::Mutex::new(inner)),
118 upgraded: Arc::new(AtomicBool::new(false)),
119 }
120 }
121}
122
123impl Future for Upgrade {
124 type Output = Option<Upgraded>;
125
126 #[inline]
127 fn poll(
128 self: Pin<&mut Self>,
129 cx: &mut std::task::Context<'_>,
130 ) -> std::task::Poll<Self::Output> {
131 let mut inner = ready!(self.inner.lock().poll_unpin(cx));
132 match inner.poll_unpin(cx) {
133 std::task::Poll::Ready(result) => std::task::Poll::Ready(result.ok()),
134 std::task::Poll::Pending => std::task::Poll::Pending,
135 }
136 }
137}
138
139#[derive(Clone)]
159pub struct OnUpgrade {
160 inner: Upgrade,
161}
162
163impl Future for OnUpgrade {
164 type Output = Option<Upgraded>;
165
166 #[inline]
167 fn poll(
168 mut self: Pin<&mut Self>,
169 cx: &mut std::task::Context<'_>,
170 ) -> std::task::Poll<Self::Output> {
171 self.inner.poll_unpin(cx)
172 }
173}
174
175#[inline]
190pub fn prepare_upgrade(req: &mut Request<impl Body>) -> Option<OnUpgrade> {
191 req.extensions_mut().remove::<Upgrade>().map(|inner| {
192 inner
193 .upgraded
194 .store(true, std::sync::atomic::Ordering::Relaxed);
195 OnUpgrade { inner }
196 })
197}