1use std::{
2 borrow::Cow,
3 io,
4 ops::DerefMut,
5 pin::Pin,
6 task::{Context, Poll, ready},
7};
8
9#[cfg(any(feature = "native-tls", feature = "rustls"))]
10use compio::tls::TlsStream;
11use compio::{
12 buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut},
13 io::{AsyncRead, AsyncWrite, compat::AsyncStream},
14 net::TcpStream,
15};
16use hyper::Uri;
17#[cfg(feature = "client")]
18use hyper_util::client::legacy::connect::{Connected, Connection};
19use send_wrapper::SendWrapper;
20
21use crate::TlsBackend;
22
23#[allow(clippy::large_enum_variant)]
24enum HttpStreamInner {
25 Tcp(TcpStream),
26 #[cfg(any(feature = "native-tls", feature = "rustls"))]
27 Tls(TlsStream<TcpStream>),
28}
29
30impl HttpStreamInner {
31 pub async fn connect(uri: Uri, tls: TlsBackend) -> io::Result<Self> {
32 let scheme = uri.scheme_str().unwrap_or("http");
33 let host = uri.host().expect("there should be host");
34 let port = uri.port_u16();
35 match scheme {
36 "http" => {
37 let stream = TcpStream::connect((host, port.unwrap_or(80))).await?;
38 let _tls = tls;
40 Ok(Self::Tcp(stream))
41 }
42 #[cfg(any(feature = "native-tls", feature = "rustls"))]
43 "https" => {
44 let stream = TcpStream::connect((host, port.unwrap_or(443))).await?;
45 let connector = tls.create_connector()?;
46 Ok(Self::Tls(connector.connect(host, stream).await?))
47 }
48 _ => Err(io::Error::new(
49 io::ErrorKind::InvalidInput,
50 "unsupported scheme",
51 )),
52 }
53 }
54
55 pub fn from_tcp(s: TcpStream) -> Self {
56 Self::Tcp(s)
57 }
58
59 #[cfg(any(feature = "native-tls", feature = "rustls"))]
60 pub fn from_tls(s: TlsStream<TcpStream>) -> Self {
61 Self::Tls(s)
62 }
63
64 fn negotiated_alpn(&self) -> Option<Cow<[u8]>> {
65 match self {
66 Self::Tcp(_) => None,
67 #[cfg(any(feature = "native-tls", feature = "rustls"))]
68 Self::Tls(s) => s.negotiated_alpn(),
69 }
70 }
71}
72
73impl AsyncRead for HttpStreamInner {
74 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
75 match self {
76 Self::Tcp(s) => s.read(buf).await,
77 #[cfg(any(feature = "native-tls", feature = "rustls"))]
78 Self::Tls(s) => s.read(buf).await,
79 }
80 }
81
82 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
83 match self {
84 Self::Tcp(s) => s.read_vectored(buf).await,
85 #[cfg(any(feature = "native-tls", feature = "rustls"))]
86 Self::Tls(s) => s.read_vectored(buf).await,
87 }
88 }
89}
90
91impl AsyncWrite for HttpStreamInner {
92 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
93 match self {
94 Self::Tcp(s) => s.write(buf).await,
95 #[cfg(any(feature = "native-tls", feature = "rustls"))]
96 Self::Tls(s) => s.write(buf).await,
97 }
98 }
99
100 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
101 match self {
102 Self::Tcp(s) => s.write_vectored(buf).await,
103 #[cfg(any(feature = "native-tls", feature = "rustls"))]
104 Self::Tls(s) => s.write_vectored(buf).await,
105 }
106 }
107
108 async fn flush(&mut self) -> io::Result<()> {
109 match self {
110 Self::Tcp(s) => s.flush().await,
111 #[cfg(any(feature = "native-tls", feature = "rustls"))]
112 Self::Tls(s) => s.flush().await,
113 }
114 }
115
116 async fn shutdown(&mut self) -> io::Result<()> {
117 match self {
118 Self::Tcp(s) => s.shutdown().await,
119 #[cfg(any(feature = "native-tls", feature = "rustls"))]
120 Self::Tls(s) => s.shutdown().await,
121 }
122 }
123}
124
125pub struct HttpStream(HyperStream<HttpStreamInner>);
128
129impl HttpStream {
130 pub async fn connect(uri: Uri, tls: TlsBackend) -> io::Result<Self> {
132 Ok(Self::from_inner(HttpStreamInner::connect(uri, tls).await?))
133 }
134
135 pub fn from_tcp(s: TcpStream) -> Self {
137 Self::from_inner(HttpStreamInner::from_tcp(s))
138 }
139
140 #[cfg(any(feature = "native-tls", feature = "rustls"))]
142 pub fn from_tls(s: TlsStream<TcpStream>) -> Self {
143 Self::from_inner(HttpStreamInner::from_tls(s))
144 }
145
146 fn from_inner(s: HttpStreamInner) -> Self {
147 Self(HyperStream::new(s))
148 }
149}
150
151impl hyper::rt::Read for HttpStream {
152 fn poll_read(
153 mut self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: hyper::rt::ReadBufCursor<'_>,
156 ) -> Poll<io::Result<()>> {
157 let inner = std::pin::pin!(&mut self.0);
158 inner.poll_read(cx, buf)
159 }
160}
161
162impl hyper::rt::Write for HttpStream {
163 fn poll_write(
164 mut self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 buf: &[u8],
167 ) -> Poll<io::Result<usize>> {
168 let inner = std::pin::pin!(&mut self.0);
169 inner.poll_write(cx, buf)
170 }
171
172 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173 let inner = std::pin::pin!(&mut self.0);
174 inner.poll_flush(cx)
175 }
176
177 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178 let inner = std::pin::pin!(&mut self.0);
179 inner.poll_shutdown(cx)
180 }
181}
182
183#[cfg(feature = "client")]
184impl Connection for HttpStream {
185 fn connected(&self) -> Connected {
186 let conn = Connected::new();
187 let is_h2 = self
188 .0
189 .0
190 .get_ref()
191 .negotiated_alpn()
192 .map(|alpn| alpn.as_slice() == b"h2")
193 .unwrap_or_default();
194 if is_h2 { conn.negotiated_h2() } else { conn }
195 }
196}
197
198pub struct HyperStream<S>(SendWrapper<AsyncStream<S>>);
200
201impl<S> HyperStream<S> {
202 pub fn new(s: S) -> Self {
204 Self(SendWrapper::new(AsyncStream::new(s)))
205 }
206
207 pub fn get_ref(&self) -> &S {
209 self.0.get_ref()
210 }
211}
212
213impl<S> std::fmt::Debug for HyperStream<S> {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 f.debug_struct("HyperStream").finish_non_exhaustive()
216 }
217}
218
219impl<S: AsyncRead + Unpin + 'static> hyper::rt::Read for HyperStream<S> {
220 fn poll_read(
221 self: Pin<&mut Self>,
222 cx: &mut Context<'_>,
223 mut buf: hyper::rt::ReadBufCursor<'_>,
224 ) -> Poll<io::Result<()>> {
225 let stream = unsafe { self.map_unchecked_mut(|this| this.0.deref_mut()) };
226 let slice = unsafe { buf.as_mut() };
227 let len = ready!(stream.poll_read_uninit(cx, slice))?;
228 unsafe { buf.advance(len) };
229 Poll::Ready(Ok(()))
230 }
231}
232
233impl<S: AsyncWrite + Unpin + 'static> hyper::rt::Write for HyperStream<S> {
234 fn poll_write(
235 self: Pin<&mut Self>,
236 cx: &mut Context<'_>,
237 buf: &[u8],
238 ) -> Poll<io::Result<usize>> {
239 let stream = unsafe { self.map_unchecked_mut(|this| this.0.deref_mut()) };
240 futures_util::AsyncWrite::poll_write(stream, cx, buf)
241 }
242
243 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
244 let stream = unsafe { self.map_unchecked_mut(|this| this.0.deref_mut()) };
245 futures_util::AsyncWrite::poll_flush(stream, cx)
246 }
247
248 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
249 let stream = unsafe { self.map_unchecked_mut(|this| this.0.deref_mut()) };
250 futures_util::AsyncWrite::poll_close(stream, cx)
251 }
252}