1use std::any::TypeId;
43use std::error::Error as StdError;
44use std::fmt;
45use std::future::Future;
46use std::io;
47use std::marker::Unpin;
48use std::pin::Pin;
49use std::sync::{Arc, Mutex};
50use std::task::{Context, Poll};
51
52use crate::rt::{Read, ReadBufCursor, Write};
53use bytes::Bytes;
54use tokio::sync::oneshot;
55
56use crate::common::io::Rewind;
57
58pub struct Upgraded {
67 io: Rewind<Box<dyn Io + Send>>,
68}
69
70#[derive(Clone)]
74pub struct OnUpgrade {
75 rx: Option<Arc<Mutex<oneshot::Receiver<crate::Result<Upgraded>>>>>,
76}
77
78#[derive(Debug)]
83pub struct Parts<T> {
84 pub io: T,
86 pub read_buf: Bytes,
95 _inner: (),
96}
97
98pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
107 msg.on_upgrade()
108}
109
110#[cfg(all(
111 any(feature = "client", feature = "server"),
112 any(feature = "http1", feature = "http2"),
113))]
114pub(super) struct Pending {
115 tx: oneshot::Sender<crate::Result<Upgraded>>,
116}
117
118#[cfg(all(
119 any(feature = "client", feature = "server"),
120 any(feature = "http1", feature = "http2"),
121))]
122pub(super) fn pending() -> (Pending, OnUpgrade) {
123 let (tx, rx) = oneshot::channel();
124 (
125 Pending { tx },
126 OnUpgrade {
127 rx: Some(Arc::new(Mutex::new(rx))),
128 },
129 )
130}
131
132impl Upgraded {
135 #[cfg(all(
136 any(feature = "client", feature = "server"),
137 any(feature = "http1", feature = "http2")
138 ))]
139 pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
140 where
141 T: Read + Write + Unpin + Send + 'static,
142 {
143 Upgraded {
144 io: Rewind::new_buffered(Box::new(io), read_buf),
145 }
146 }
147
148 pub fn downcast<T: Read + Write + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
153 let (io, buf) = self.io.into_inner();
154 match io.__hyper_downcast() {
155 Ok(t) => Ok(Parts {
156 io: *t,
157 read_buf: buf,
158 _inner: (),
159 }),
160 Err(io) => Err(Upgraded {
161 io: Rewind::new_buffered(io, buf),
162 }),
163 }
164 }
165}
166
167impl Read for Upgraded {
168 fn poll_read(
169 mut self: Pin<&mut Self>,
170 cx: &mut Context<'_>,
171 buf: ReadBufCursor<'_>,
172 ) -> Poll<io::Result<()>> {
173 Pin::new(&mut self.io).poll_read(cx, buf)
174 }
175}
176
177impl Write for Upgraded {
178 fn poll_write(
179 mut self: Pin<&mut Self>,
180 cx: &mut Context<'_>,
181 buf: &[u8],
182 ) -> Poll<io::Result<usize>> {
183 Pin::new(&mut self.io).poll_write(cx, buf)
184 }
185
186 fn poll_write_vectored(
187 mut self: Pin<&mut Self>,
188 cx: &mut Context<'_>,
189 bufs: &[io::IoSlice<'_>],
190 ) -> Poll<io::Result<usize>> {
191 Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
192 }
193
194 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
195 Pin::new(&mut self.io).poll_flush(cx)
196 }
197
198 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
199 Pin::new(&mut self.io).poll_shutdown(cx)
200 }
201
202 fn is_write_vectored(&self) -> bool {
203 self.io.is_write_vectored()
204 }
205}
206
207impl fmt::Debug for Upgraded {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 f.debug_struct("Upgraded").finish()
210 }
211}
212
213impl OnUpgrade {
216 pub(super) fn none() -> Self {
217 OnUpgrade { rx: None }
218 }
219
220 #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))]
221 pub(super) fn is_none(&self) -> bool {
222 self.rx.is_none()
223 }
224}
225
226impl Future for OnUpgrade {
227 type Output = Result<Upgraded, crate::Error>;
228
229 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
230 match self.rx {
231 Some(ref rx) => Pin::new(&mut *rx.lock().unwrap())
232 .poll(cx)
233 .map(|res| match res {
234 Ok(Ok(upgraded)) => Ok(upgraded),
235 Ok(Err(err)) => Err(err),
236 Err(_oneshot_canceled) => {
237 Err(crate::Error::new_canceled().with(UpgradeExpected))
238 }
239 }),
240 None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
241 }
242 }
243}
244
245impl fmt::Debug for OnUpgrade {
246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 f.debug_struct("OnUpgrade").finish()
248 }
249}
250
251#[cfg(all(
254 any(feature = "client", feature = "server"),
255 any(feature = "http1", feature = "http2")
256))]
257impl Pending {
258 pub(super) fn fulfill(self, upgraded: Upgraded) {
259 trace!("pending upgrade fulfill");
260 let _ = self.tx.send(Ok(upgraded));
261 }
262
263 #[cfg(feature = "http1")]
264 pub(super) fn manual(self) {
267 #[cfg(any(feature = "http1", feature = "http2"))]
268 trace!("pending upgrade handled manually");
269 let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
270 }
271}
272
273#[derive(Debug)]
280struct UpgradeExpected;
281
282impl fmt::Display for UpgradeExpected {
283 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284 f.write_str("upgrade expected but not completed")
285 }
286}
287
288impl StdError for UpgradeExpected {}
289
290pub(super) trait Io: Read + Write + Unpin + 'static {
293 fn __hyper_type_id(&self) -> TypeId {
294 TypeId::of::<Self>()
295 }
296}
297
298impl<T: Read + Write + Unpin + 'static> Io for T {}
299
300impl dyn Io + Send {
301 fn __hyper_is<T: Io>(&self) -> bool {
302 let t = TypeId::of::<T>();
303 self.__hyper_type_id() == t
304 }
305
306 fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
307 if self.__hyper_is::<T>() {
308 unsafe {
310 let raw: *mut dyn Io = Box::into_raw(self);
311 Ok(Box::from_raw(raw as *mut T))
312 }
313 } else {
314 Err(self)
315 }
316 }
317}
318
319mod sealed {
320 use super::OnUpgrade;
321
322 pub trait CanUpgrade {
323 fn on_upgrade(self) -> OnUpgrade;
324 }
325
326 impl<B> CanUpgrade for http::Request<B> {
327 fn on_upgrade(mut self) -> OnUpgrade {
328 self.extensions_mut()
329 .remove::<OnUpgrade>()
330 .unwrap_or_else(OnUpgrade::none)
331 }
332 }
333
334 impl<B> CanUpgrade for &'_ mut http::Request<B> {
335 fn on_upgrade(self) -> OnUpgrade {
336 self.extensions_mut()
337 .remove::<OnUpgrade>()
338 .unwrap_or_else(OnUpgrade::none)
339 }
340 }
341
342 impl<B> CanUpgrade for http::Response<B> {
343 fn on_upgrade(mut self) -> OnUpgrade {
344 self.extensions_mut()
345 .remove::<OnUpgrade>()
346 .unwrap_or_else(OnUpgrade::none)
347 }
348 }
349
350 impl<B> CanUpgrade for &'_ mut http::Response<B> {
351 fn on_upgrade(self) -> OnUpgrade {
352 self.extensions_mut()
353 .remove::<OnUpgrade>()
354 .unwrap_or_else(OnUpgrade::none)
355 }
356 }
357}
358
359#[cfg(all(
360 any(feature = "client", feature = "server"),
361 any(feature = "http1", feature = "http2"),
362))]
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn upgraded_downcast() {
369 let upgraded = Upgraded::new(Mock, Bytes::new());
370
371 let upgraded = upgraded
372 .downcast::<crate::common::io::Compat<std::io::Cursor<Vec<u8>>>>()
373 .unwrap_err();
374
375 upgraded.downcast::<Mock>().unwrap();
376 }
377
378 struct Mock;
380
381 impl Read for Mock {
382 fn poll_read(
383 self: Pin<&mut Self>,
384 _cx: &mut Context<'_>,
385 _buf: ReadBufCursor<'_>,
386 ) -> Poll<io::Result<()>> {
387 unreachable!("Mock::poll_read")
388 }
389 }
390
391 impl Write for Mock {
392 fn poll_write(
393 self: Pin<&mut Self>,
394 _: &mut Context<'_>,
395 buf: &[u8],
396 ) -> Poll<io::Result<usize>> {
397 Poll::Ready(Ok(buf.len()))
399 }
400
401 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
402 unreachable!("Mock::poll_flush")
403 }
404
405 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
406 unreachable!("Mock::poll_shutdown")
407 }
408 }
409}