1use std::any::Any;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7
8use futures_core::future::BoxFuture;
9use futures_io::{AsyncRead, AsyncWrite};
10
11use crate::BoxError;
12
13#[derive(Clone)]
15pub struct OnUpgrade {
16 fut: Arc<Mutex<BoxFuture<'static, Result<Upgraded, BoxError>>>>,
17}
18
19impl OnUpgrade {
20 pub fn new<T>(fut: T) -> Self
22 where
23 T: Future<Output = Result<Upgraded, BoxError>> + Send + 'static,
24 {
25 Self {
26 fut: Arc::new(Mutex::new(Box::pin(fut))),
27 }
28 }
29}
30
31impl Future for OnUpgrade {
32 type Output = Result<Upgraded, BoxError>;
33
34 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
35 self.fut.lock().unwrap().as_mut().poll(cx)
36 }
37}
38
39impl std::fmt::Debug for OnUpgrade {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("OnUpgrade").finish()
42 }
43}
44
45pub struct Upgraded {
47 io: Box<dyn IO + Send>,
48}
49
50impl Upgraded {
51 pub fn new<T>(io: T) -> Self
53 where
54 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
55 {
56 Self { io: Box::new(io) }
57 }
58
59 pub fn downcast<T: 'static>(self) -> Result<T, Self> {
61 if self.io.as_ref().as_any().is::<T>() {
62 Ok(*self.io.into_any().downcast::<T>().unwrap())
63 } else {
64 Err(self)
65 }
66 }
67}
68
69impl AsyncRead for Upgraded {
70 fn poll_read(
71 mut self: Pin<&mut Self>,
72 cx: &mut Context<'_>,
73 buf: &mut [u8],
74 ) -> Poll<std::io::Result<usize>> {
75 Pin::new(&mut self.io).poll_read(cx, buf)
76 }
77}
78
79impl AsyncWrite for Upgraded {
80 fn poll_write(
81 mut self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 buf: &[u8],
84 ) -> Poll<std::io::Result<usize>> {
85 Pin::new(&mut self.io).poll_write(cx, buf)
86 }
87
88 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
89 Pin::new(&mut self.io).poll_flush(cx)
90 }
91
92 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
93 Pin::new(&mut self.io).poll_close(cx)
94 }
95}
96
97impl std::fmt::Debug for Upgraded {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("Upgraded").finish()
100 }
101}
102
103trait IO: AsyncRead + AsyncWrite + Unpin + 'static {
104 fn as_any(&self) -> &dyn Any;
105
106 fn into_any(self: Box<Self>) -> Box<dyn Any>;
107}
108
109impl<T: AsyncRead + AsyncWrite + Unpin + 'static> IO for T {
110 fn as_any(&self) -> &dyn Any {
111 self
112 }
113
114 fn into_any(self: Box<Self>) -> Box<dyn Any> {
115 self
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use std::pin::Pin;
122 use std::task::{Context, Poll};
123
124 use futures_io::{AsyncRead, AsyncWrite};
125
126 use super::Upgraded;
127
128 struct FuturesIo;
129
130 impl AsyncRead for FuturesIo {
131 fn poll_read(
132 self: Pin<&mut Self>,
133 _cx: &mut Context<'_>,
134 _buf: &mut [u8],
135 ) -> Poll<std::io::Result<usize>> {
136 todo!()
137 }
138 }
139
140 impl AsyncWrite for FuturesIo {
141 fn poll_write(
142 self: Pin<&mut Self>,
143 _cx: &mut Context<'_>,
144 _buf: &[u8],
145 ) -> Poll<std::io::Result<usize>> {
146 todo!()
147 }
148
149 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
150 todo!()
151 }
152
153 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
154 todo!()
155 }
156 }
157
158 #[test]
159 fn upgraded_downcast() {
160 assert!(Upgraded::new(FuturesIo).downcast::<()>().is_err());
161 assert!(Upgraded::new(FuturesIo).downcast::<FuturesIo>().is_ok());
162 }
163}