rama_http_core/
upgrade.rs1use std::any::TypeId;
43use std::error::Error as StdError;
44use std::fmt;
45use std::io;
46use std::pin::Pin;
47use std::sync::{Arc, Mutex};
48use std::task::{Context, Poll};
49
50use bytes::Bytes;
51use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
52use tokio::sync::oneshot;
53use tracing::trace;
54
55use crate::common::io::Rewind;
56
57pub struct Upgraded {
66 io: Rewind<Box<dyn Io + Send>>,
67}
68
69#[derive(Clone)]
73pub struct OnUpgrade {
74 rx: Option<Arc<Mutex<oneshot::Receiver<crate::Result<Upgraded>>>>>,
75}
76
77#[derive(Debug)]
82#[non_exhaustive]
83pub struct Parts<T> {
84 pub io: T,
86 pub read_buf: Bytes,
95}
96
97pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
106 msg.on_upgrade()
107}
108
109pub(super) struct Pending {
110 tx: oneshot::Sender<crate::Result<Upgraded>>,
111}
112
113pub(super) fn pending() -> (Pending, OnUpgrade) {
114 let (tx, rx) = oneshot::channel();
115 (
116 Pending { tx },
117 OnUpgrade {
118 rx: Some(Arc::new(Mutex::new(rx))),
119 },
120 )
121}
122
123impl Upgraded {
126 pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
127 where
128 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
129 {
130 Upgraded {
131 io: Rewind::new_buffered(Box::new(io), read_buf),
132 }
133 }
134
135 pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
140 let (io, buf) = self.io.into_inner();
141 match io.__rama_downcast() {
142 Ok(t) => Ok(Parts {
143 io: *t,
144 read_buf: buf,
145 }),
146 Err(io) => Err(Upgraded {
147 io: Rewind::new_buffered(io, buf),
148 }),
149 }
150 }
151}
152
153impl AsyncRead for Upgraded {
154 fn poll_read(
155 mut self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 buf: &mut ReadBuf<'_>,
158 ) -> Poll<io::Result<()>> {
159 Pin::new(&mut self.io).poll_read(cx, buf)
160 }
161}
162
163impl AsyncWrite for Upgraded {
164 fn poll_write(
165 mut self: Pin<&mut Self>,
166 cx: &mut Context<'_>,
167 buf: &[u8],
168 ) -> Poll<io::Result<usize>> {
169 Pin::new(&mut self.io).poll_write(cx, buf)
170 }
171
172 fn poll_write_vectored(
173 mut self: Pin<&mut Self>,
174 cx: &mut Context<'_>,
175 bufs: &[io::IoSlice<'_>],
176 ) -> Poll<io::Result<usize>> {
177 Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
178 }
179
180 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
181 Pin::new(&mut self.io).poll_flush(cx)
182 }
183
184 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185 Pin::new(&mut self.io).poll_shutdown(cx)
186 }
187
188 fn is_write_vectored(&self) -> bool {
189 self.io.is_write_vectored()
190 }
191}
192
193impl fmt::Debug for Upgraded {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("Upgraded").finish()
196 }
197}
198
199impl OnUpgrade {
202 pub(super) fn none() -> Self {
203 OnUpgrade { rx: None }
204 }
205
206 pub(super) fn is_none(&self) -> bool {
207 self.rx.is_none()
208 }
209}
210
211impl Future for OnUpgrade {
212 type Output = Result<Upgraded, crate::Error>;
213
214 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
215 match self.rx {
216 Some(ref rx) => Pin::new(&mut *rx.lock().unwrap())
217 .poll(cx)
218 .map(|res| match res {
219 Ok(Ok(upgraded)) => Ok(upgraded),
220 Ok(Err(err)) => Err(err),
221 Err(_oneshot_canceled) => {
222 Err(crate::Error::new_canceled().with(UpgradeExpected))
223 }
224 }),
225 None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
226 }
227 }
228}
229
230impl fmt::Debug for OnUpgrade {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 f.debug_struct("OnUpgrade").finish()
233 }
234}
235
236impl Pending {
239 pub(super) fn fulfill(self, upgraded: Upgraded) {
240 trace!("pending upgrade fulfill");
241 let _ = self.tx.send(Ok(upgraded));
242 }
243
244 pub(super) fn manual(self) {
247 trace!("pending upgrade handled manually");
248 let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
249 }
250}
251
252#[derive(Debug)]
259struct UpgradeExpected;
260
261impl fmt::Display for UpgradeExpected {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 f.write_str("upgrade expected but not completed")
264 }
265}
266
267impl StdError for UpgradeExpected {}
268
269pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
272 fn __rama_type_id(&self) -> TypeId {
273 TypeId::of::<Self>()
274 }
275}
276
277impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
278
279impl dyn Io + Send {
280 fn __rama_is<T: Io>(&self) -> bool {
281 let t = TypeId::of::<T>();
282 self.__rama_type_id() == t
283 }
284
285 fn __rama_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
286 if self.__rama_is::<T>() {
287 unsafe {
289 let raw: *mut dyn Io = Box::into_raw(self);
290 Ok(Box::from_raw(raw as *mut T))
291 }
292 } else {
293 Err(self)
294 }
295 }
296}
297
298mod sealed {
299 use rama_http_types::{Request, Response};
300
301 use super::OnUpgrade;
302
303 pub trait CanUpgrade {
304 fn on_upgrade(self) -> OnUpgrade;
305 }
306
307 impl<B> CanUpgrade for Request<B> {
308 fn on_upgrade(mut self) -> OnUpgrade {
309 self.extensions_mut()
310 .remove::<OnUpgrade>()
311 .unwrap_or_else(OnUpgrade::none)
312 }
313 }
314
315 impl<B> CanUpgrade for &'_ mut Request<B> {
316 fn on_upgrade(self) -> OnUpgrade {
317 self.extensions_mut()
318 .remove::<OnUpgrade>()
319 .unwrap_or_else(OnUpgrade::none)
320 }
321 }
322
323 impl<B> CanUpgrade for Response<B> {
324 fn on_upgrade(mut self) -> OnUpgrade {
325 self.extensions_mut()
326 .remove::<OnUpgrade>()
327 .unwrap_or_else(OnUpgrade::none)
328 }
329 }
330
331 impl<B> CanUpgrade for &'_ mut Response<B> {
332 fn on_upgrade(self) -> OnUpgrade {
333 self.extensions_mut()
334 .remove::<OnUpgrade>()
335 .unwrap_or_else(OnUpgrade::none)
336 }
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use tokio_test::io::{Builder, Mock};
343
344 use super::*;
345
346 #[test]
347 fn upgraded_downcast() {
348 let upgraded = Upgraded::new(Builder::default().build(), Bytes::new());
349 let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err();
350 upgraded.downcast::<Mock>().unwrap();
351 }
352}