cyper_core/
stream.rs

1use std::{
2    borrow::Cow,
3    io,
4    ops::DerefMut,
5    pin::Pin,
6    task::{Context, Poll, ready},
7};
8
9#[cfg(any(feature = "native-tls", feature = "rustls"))]
10use compio::tls::TlsStream;
11use compio::{
12    buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut},
13    io::{AsyncRead, AsyncWrite, compat::AsyncStream},
14    net::TcpStream,
15};
16use hyper::Uri;
17#[cfg(feature = "client")]
18use hyper_util::client::legacy::connect::{Connected, Connection};
19use send_wrapper::SendWrapper;
20
21use crate::TlsBackend;
22
23#[allow(clippy::large_enum_variant)]
24enum HttpStreamInner {
25    Tcp(TcpStream),
26    #[cfg(any(feature = "native-tls", feature = "rustls"))]
27    Tls(TlsStream<TcpStream>),
28}
29
30impl HttpStreamInner {
31    pub async fn connect(uri: Uri, tls: TlsBackend) -> io::Result<Self> {
32        let scheme = uri.scheme_str().unwrap_or("http");
33        let host = uri.host().expect("there should be host");
34        let port = uri.port_u16();
35        match scheme {
36            "http" => {
37                let stream = TcpStream::connect((host, port.unwrap_or(80))).await?;
38                // Ignore it.
39                let _tls = tls;
40                Ok(Self::Tcp(stream))
41            }
42            #[cfg(any(feature = "native-tls", feature = "rustls"))]
43            "https" => {
44                let stream = TcpStream::connect((host, port.unwrap_or(443))).await?;
45                let connector = tls.create_connector()?;
46                Ok(Self::Tls(connector.connect(host, stream).await?))
47            }
48            _ => Err(io::Error::new(
49                io::ErrorKind::InvalidInput,
50                "unsupported scheme",
51            )),
52        }
53    }
54
55    pub fn from_tcp(s: TcpStream) -> Self {
56        Self::Tcp(s)
57    }
58
59    #[cfg(any(feature = "native-tls", feature = "rustls"))]
60    pub fn from_tls(s: TlsStream<TcpStream>) -> Self {
61        Self::Tls(s)
62    }
63
64    fn negotiated_alpn(&self) -> Option<Cow<[u8]>> {
65        match self {
66            Self::Tcp(_) => None,
67            #[cfg(any(feature = "native-tls", feature = "rustls"))]
68            Self::Tls(s) => s.negotiated_alpn(),
69        }
70    }
71}
72
73impl AsyncRead for HttpStreamInner {
74    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
75        match self {
76            Self::Tcp(s) => s.read(buf).await,
77            #[cfg(any(feature = "native-tls", feature = "rustls"))]
78            Self::Tls(s) => s.read(buf).await,
79        }
80    }
81
82    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
83        match self {
84            Self::Tcp(s) => s.read_vectored(buf).await,
85            #[cfg(any(feature = "native-tls", feature = "rustls"))]
86            Self::Tls(s) => s.read_vectored(buf).await,
87        }
88    }
89}
90
91impl AsyncWrite for HttpStreamInner {
92    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
93        match self {
94            Self::Tcp(s) => s.write(buf).await,
95            #[cfg(any(feature = "native-tls", feature = "rustls"))]
96            Self::Tls(s) => s.write(buf).await,
97        }
98    }
99
100    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
101        match self {
102            Self::Tcp(s) => s.write_vectored(buf).await,
103            #[cfg(any(feature = "native-tls", feature = "rustls"))]
104            Self::Tls(s) => s.write_vectored(buf).await,
105        }
106    }
107
108    async fn flush(&mut self) -> io::Result<()> {
109        match self {
110            Self::Tcp(s) => s.flush().await,
111            #[cfg(any(feature = "native-tls", feature = "rustls"))]
112            Self::Tls(s) => s.flush().await,
113        }
114    }
115
116    async fn shutdown(&mut self) -> io::Result<()> {
117        match self {
118            Self::Tcp(s) => s.shutdown().await,
119            #[cfg(any(feature = "native-tls", feature = "rustls"))]
120            Self::Tls(s) => s.shutdown().await,
121        }
122    }
123}
124
125/// A HTTP stream wrapper, based on compio, and exposes [`hyper::rt`]
126/// interfaces.
127pub struct HttpStream(HyperStream<HttpStreamInner>);
128
129impl HttpStream {
130    /// Create [`HttpStream`] with target uri and TLS backend.
131    pub async fn connect(uri: Uri, tls: TlsBackend) -> io::Result<Self> {
132        Ok(Self::from_inner(HttpStreamInner::connect(uri, tls).await?))
133    }
134
135    /// Create [`HttpStream`] with connected TCP stream.
136    pub fn from_tcp(s: TcpStream) -> Self {
137        Self::from_inner(HttpStreamInner::from_tcp(s))
138    }
139
140    /// Create [`HttpStream`] with connected TLS stream.
141    #[cfg(any(feature = "native-tls", feature = "rustls"))]
142    pub fn from_tls(s: TlsStream<TcpStream>) -> Self {
143        Self::from_inner(HttpStreamInner::from_tls(s))
144    }
145
146    fn from_inner(s: HttpStreamInner) -> Self {
147        Self(HyperStream::new(s))
148    }
149}
150
151impl hyper::rt::Read for HttpStream {
152    fn poll_read(
153        mut self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: hyper::rt::ReadBufCursor<'_>,
156    ) -> Poll<io::Result<()>> {
157        let inner = std::pin::pin!(&mut self.0);
158        inner.poll_read(cx, buf)
159    }
160}
161
162impl hyper::rt::Write for HttpStream {
163    fn poll_write(
164        mut self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166        buf: &[u8],
167    ) -> Poll<io::Result<usize>> {
168        let inner = std::pin::pin!(&mut self.0);
169        inner.poll_write(cx, buf)
170    }
171
172    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173        let inner = std::pin::pin!(&mut self.0);
174        inner.poll_flush(cx)
175    }
176
177    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178        let inner = std::pin::pin!(&mut self.0);
179        inner.poll_shutdown(cx)
180    }
181}
182
183#[cfg(feature = "client")]
184impl Connection for HttpStream {
185    fn connected(&self) -> Connected {
186        let conn = Connected::new();
187        let is_h2 = self
188            .0
189            .0
190            .get_ref()
191            .negotiated_alpn()
192            .map(|alpn| alpn.as_slice() == b"h2")
193            .unwrap_or_default();
194        if is_h2 { conn.negotiated_h2() } else { conn }
195    }
196}
197
198/// A stream wrapper for hyper.
199pub struct HyperStream<S>(SendWrapper<AsyncStream<S>>);
200
201impl<S> HyperStream<S> {
202    /// Create a hyper stream wrapper.
203    pub fn new(s: S) -> Self {
204        Self(SendWrapper::new(AsyncStream::new(s)))
205    }
206
207    /// Get the reference of the inner stream.
208    pub fn get_ref(&self) -> &S {
209        self.0.get_ref()
210    }
211}
212
213impl<S> std::fmt::Debug for HyperStream<S> {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        f.debug_struct("HyperStream").finish_non_exhaustive()
216    }
217}
218
219impl<S: AsyncRead + Unpin + 'static> hyper::rt::Read for HyperStream<S> {
220    fn poll_read(
221        self: Pin<&mut Self>,
222        cx: &mut Context<'_>,
223        mut buf: hyper::rt::ReadBufCursor<'_>,
224    ) -> Poll<io::Result<()>> {
225        let stream = unsafe { self.map_unchecked_mut(|this| this.0.deref_mut()) };
226        let slice = unsafe { buf.as_mut() };
227        let len = ready!(stream.poll_read_uninit(cx, slice))?;
228        unsafe { buf.advance(len) };
229        Poll::Ready(Ok(()))
230    }
231}
232
233impl<S: AsyncWrite + Unpin + 'static> hyper::rt::Write for HyperStream<S> {
234    fn poll_write(
235        self: Pin<&mut Self>,
236        cx: &mut Context<'_>,
237        buf: &[u8],
238    ) -> Poll<io::Result<usize>> {
239        let stream = unsafe { self.map_unchecked_mut(|this| this.0.deref_mut()) };
240        futures_util::AsyncWrite::poll_write(stream, cx, buf)
241    }
242
243    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
244        let stream = unsafe { self.map_unchecked_mut(|this| this.0.deref_mut()) };
245        futures_util::AsyncWrite::poll_flush(stream, cx)
246    }
247
248    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
249        let stream = unsafe { self.map_unchecked_mut(|this| this.0.deref_mut()) };
250        futures_util::AsyncWrite::poll_close(stream, cx)
251    }
252}