boluo_core/
upgrade.rs

1//! HTTP 升级。
2
3use std::any::Any;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7
8use futures_core::future::BoxFuture;
9use futures_io::{AsyncRead, AsyncWrite};
10
11use crate::BoxError;
12
13/// 用于处理 HTTP 升级请求,获取升级后的连接。
14#[derive(Clone)]
15pub struct OnUpgrade {
16    fut: Arc<Mutex<BoxFuture<'static, Result<Upgraded, BoxError>>>>,
17}
18
19impl OnUpgrade {
20    /// 创建一个 `OnUpgrade` 实例。
21    pub fn new<T>(fut: T) -> Self
22    where
23        T: Future<Output = Result<Upgraded, BoxError>> + Send + 'static,
24    {
25        Self {
26            fut: Arc::new(Mutex::new(Box::pin(fut))),
27        }
28    }
29}
30
31impl Future for OnUpgrade {
32    type Output = Result<Upgraded, BoxError>;
33
34    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
35        self.fut.lock().unwrap().as_mut().poll(cx)
36    }
37}
38
39impl std::fmt::Debug for OnUpgrade {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("OnUpgrade").finish()
42    }
43}
44
45/// HTTP 升级后的连接。
46pub struct Upgraded {
47    io: Box<dyn IO + Send>,
48}
49
50impl Upgraded {
51    /// 创建一个 `Upgraded` 实例。
52    pub fn new<T>(io: T) -> Self
53    where
54        T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
55    {
56        Self { io: Box::new(io) }
57    }
58
59    /// 尝试将 `Upgraded` 实例转换为指定类型。
60    pub fn downcast<T: 'static>(self) -> Result<T, Self> {
61        if self.io.as_ref().as_any().is::<T>() {
62            Ok(*self.io.into_any().downcast::<T>().unwrap())
63        } else {
64            Err(self)
65        }
66    }
67}
68
69impl AsyncRead for Upgraded {
70    fn poll_read(
71        mut self: Pin<&mut Self>,
72        cx: &mut Context<'_>,
73        buf: &mut [u8],
74    ) -> Poll<std::io::Result<usize>> {
75        Pin::new(&mut self.io).poll_read(cx, buf)
76    }
77}
78
79impl AsyncWrite for Upgraded {
80    fn poll_write(
81        mut self: Pin<&mut Self>,
82        cx: &mut Context<'_>,
83        buf: &[u8],
84    ) -> Poll<std::io::Result<usize>> {
85        Pin::new(&mut self.io).poll_write(cx, buf)
86    }
87
88    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
89        Pin::new(&mut self.io).poll_flush(cx)
90    }
91
92    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
93        Pin::new(&mut self.io).poll_close(cx)
94    }
95}
96
97impl std::fmt::Debug for Upgraded {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("Upgraded").finish()
100    }
101}
102
103trait IO: AsyncRead + AsyncWrite + Unpin + 'static {
104    fn as_any(&self) -> &dyn Any;
105
106    fn into_any(self: Box<Self>) -> Box<dyn Any>;
107}
108
109impl<T: AsyncRead + AsyncWrite + Unpin + 'static> IO for T {
110    fn as_any(&self) -> &dyn Any {
111        self
112    }
113
114    fn into_any(self: Box<Self>) -> Box<dyn Any> {
115        self
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use std::pin::Pin;
122    use std::task::{Context, Poll};
123
124    use futures_io::{AsyncRead, AsyncWrite};
125
126    use super::Upgraded;
127
128    struct FuturesIo;
129
130    impl AsyncRead for FuturesIo {
131        fn poll_read(
132            self: Pin<&mut Self>,
133            _cx: &mut Context<'_>,
134            _buf: &mut [u8],
135        ) -> Poll<std::io::Result<usize>> {
136            todo!()
137        }
138    }
139
140    impl AsyncWrite for FuturesIo {
141        fn poll_write(
142            self: Pin<&mut Self>,
143            _cx: &mut Context<'_>,
144            _buf: &[u8],
145        ) -> Poll<std::io::Result<usize>> {
146            todo!()
147        }
148
149        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
150            todo!()
151        }
152
153        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
154            todo!()
155        }
156    }
157
158    #[test]
159    fn upgraded_downcast() {
160        assert!(Upgraded::new(FuturesIo).downcast::<()>().is_err());
161        assert!(Upgraded::new(FuturesIo).downcast::<FuturesIo>().is_ok());
162    }
163}