Skip to main content

vibeio_http/
upgrade.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{atomic::AtomicBool, Arc},
5    task::ready,
6};
7
8use bytes::BufMut;
9use futures_util::FutureExt;
10use http::Request;
11use http_body::Body;
12use send_wrapper::SendWrapper;
13use tokio::io::{AsyncRead, AsyncWrite};
14
15/// Represents a successfully upgraded HTTP connection.
16///
17/// After a successful HTTP upgrade handshake (e.g. WebSocket or HTTP/2
18/// cleartext), the original TCP stream is handed off as an [`Upgraded`] value.
19/// It implements both [`AsyncRead`] and [`AsyncWrite`], so it can be used as a
20/// plain async I/O object by the protocol taking over the connection.
21///
22/// Any bytes that were already read from the socket as part of the HTTP request
23/// but not yet consumed are prepended to the read stream via the `leftover`
24/// buffer, ensuring no data is lost during the transition.
25pub struct Upgraded {
26    reader: SendWrapper<Pin<Box<dyn AsyncRead + Unpin>>>,
27    writer: SendWrapper<Pin<Box<dyn AsyncWrite + Unpin>>>,
28    leftover: Option<bytes::Bytes>,
29}
30
31impl Upgraded {
32    #[inline]
33    pub(super) fn new(
34        io: impl AsyncRead + AsyncWrite + Unpin + 'static,
35        leftover: Option<bytes::Bytes>,
36    ) -> Self {
37        let (reader, writer) = tokio::io::split(io);
38        Self {
39            reader: SendWrapper::new(Box::pin(reader)),
40            writer: SendWrapper::new(Box::pin(writer)),
41            leftover,
42        }
43    }
44}
45
46impl AsyncRead for Upgraded {
47    #[inline]
48    fn poll_read(
49        mut self: std::pin::Pin<&mut Self>,
50        cx: &mut std::task::Context<'_>,
51        buf: &mut tokio::io::ReadBuf<'_>,
52    ) -> std::task::Poll<std::io::Result<()>> {
53        if let Some(leftover) = &mut self.leftover {
54            let slice_len = leftover.len().min(buf.remaining());
55            let leftover_to_write = leftover.split_to(slice_len);
56            buf.put(leftover_to_write);
57            if leftover.is_empty() {
58                self.leftover = None;
59            }
60            return std::task::Poll::Ready(Ok(()));
61        }
62        (*self.reader).as_mut().poll_read(cx, buf)
63    }
64}
65
66impl AsyncWrite for Upgraded {
67    #[inline]
68    fn poll_write(
69        mut self: std::pin::Pin<&mut Self>,
70        cx: &mut std::task::Context<'_>,
71        buf: &[u8],
72    ) -> std::task::Poll<std::io::Result<usize>> {
73        (*self.writer).as_mut().poll_write(cx, buf)
74    }
75
76    #[inline]
77    fn poll_flush(
78        mut self: std::pin::Pin<&mut Self>,
79        cx: &mut std::task::Context<'_>,
80    ) -> std::task::Poll<std::io::Result<()>> {
81        (*self.writer).as_mut().poll_flush(cx)
82    }
83
84    #[inline]
85    fn poll_shutdown(
86        mut self: std::pin::Pin<&mut Self>,
87        cx: &mut std::task::Context<'_>,
88    ) -> std::task::Poll<std::io::Result<()>> {
89        (*self.writer).as_mut().poll_shutdown(cx)
90    }
91
92    #[inline]
93    fn is_write_vectored(&self) -> bool {
94        self.writer.is_write_vectored()
95    }
96
97    #[inline]
98    fn poll_write_vectored(
99        mut self: Pin<&mut Self>,
100        cx: &mut std::task::Context<'_>,
101        bufs: &[std::io::IoSlice<'_>],
102    ) -> std::task::Poll<std::io::Result<usize>> {
103        (*self.writer).as_mut().poll_write_vectored(cx, bufs)
104    }
105}
106
107#[derive(Clone)]
108pub(super) struct Upgrade {
109    inner: Arc<futures_util::lock::Mutex<oneshot::AsyncReceiver<Upgraded>>>,
110    pub(super) upgraded: Arc<AtomicBool>,
111}
112
113impl Upgrade {
114    #[inline]
115    pub(super) fn new(inner: oneshot::AsyncReceiver<Upgraded>) -> Self {
116        Self {
117            inner: Arc::new(futures_util::lock::Mutex::new(inner)),
118            upgraded: Arc::new(AtomicBool::new(false)),
119        }
120    }
121}
122
123impl Future for Upgrade {
124    type Output = Option<Upgraded>;
125
126    #[inline]
127    fn poll(
128        self: Pin<&mut Self>,
129        cx: &mut std::task::Context<'_>,
130    ) -> std::task::Poll<Self::Output> {
131        let mut inner = ready!(self.inner.lock().poll_unpin(cx));
132        match inner.poll_unpin(cx) {
133            std::task::Poll::Ready(result) => std::task::Poll::Ready(result.ok()),
134            std::task::Poll::Pending => std::task::Poll::Pending,
135        }
136    }
137}
138
139/// A future that resolves to an [`Upgraded`] connection once the HTTP upgrade
140/// handshake has been completed by the server side.
141///
142/// Obtain an `OnUpgrade` by calling [`prepare_upgrade`] on an incoming
143/// request. The future will yield `Some(Upgraded)` when the server has
144/// finished writing the upgrade response, or `None` if the upgrade was
145/// cancelled or the connection was closed before the handshake completed.
146///
147/// # Example
148///
149/// ```rust,ignore
150/// if let Some(on_upgrade) = prepare_upgrade(&mut request) {
151///     tokio::spawn(async move {
152///         if let Some(upgraded) = on_upgrade.await {
153///             // `upgraded` is now a raw async I/O stream
154///         }
155///     });
156/// }
157/// ```
158#[derive(Clone)]
159pub struct OnUpgrade {
160    inner: Upgrade,
161}
162
163impl Future for OnUpgrade {
164    type Output = Option<Upgraded>;
165
166    #[inline]
167    fn poll(
168        mut self: Pin<&mut Self>,
169        cx: &mut std::task::Context<'_>,
170    ) -> std::task::Poll<Self::Output> {
171        self.inner.poll_unpin(cx)
172    }
173}
174
175/// Prepares an HTTP upgrade on the given request.
176///
177/// This function removes the internal `Upgrade` token from the request's
178/// extensions and marks the connection as "to be upgraded". The returned
179/// [`OnUpgrade`] future resolves to the raw [`Upgraded`] I/O stream after the
180/// server has sent the `101 Switching Protocols` response.
181///
182/// Returns `None` if the request does not carry an upgrade token, which
183/// happens when the connection handler was not configured to support upgrades
184/// or the upgrade extension has already been consumed.
185///
186/// # Panics
187///
188/// Does not panic; returns `None` instead of panicking on missing state.
189#[inline]
190pub fn prepare_upgrade(req: &mut Request<impl Body>) -> Option<OnUpgrade> {
191    req.extensions_mut().remove::<Upgrade>().map(|inner| {
192        inner
193            .upgraded
194            .store(true, std::sync::atomic::Ordering::Relaxed);
195        OnUpgrade { inner }
196    })
197}