1#![doc = include_str!("../README.md")]
2#![deny(unsafe_code)]
3
4use bytes::Bytes;
5use http::header::{CONNECTION, UPGRADE};
6pub use http::{
7 HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, header::HOST,
8};
9use http_body_util::{Empty, Full};
10use hyper::body::Incoming;
11use tokio::net::TcpStream;
12
13mod client;
14mod error;
15pub mod http1;
16pub mod http2;
17
18pub use client::{Client, ClientExt};
19pub use error::Error;
20pub use http1::Http1;
21pub use http2::Http2;
22pub use hyper::upgrade::on as upgrade;
23pub use sealed::ResponseExt;
24
25pub type EmptyBody = Empty<Bytes>;
28
29pub type BytesBody = Full<Bytes>;
32
33pub type Upgraded = hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>;
36
37pub async fn do_upgrade(resp: Response<Incoming>) -> hyper::Result<Upgraded> {
39 let upgraded = hyper::upgrade::on(resp).await?;
40 Ok(hyper_util::rt::TokioIo::new(upgraded))
41}
42
43mod sealed {
44 use futures::TryStreamExt;
45 use http_body_util::{BodyExt, Limited};
46 use tokio::io::AsyncRead;
47 use tokio_util::io::StreamReader;
48
49 use crate::Error;
50
51 pub trait ResponseExt {
53 fn collect_bytes(self) -> impl Future<Output = Result<bytes::Bytes, Error>> + Send;
63 fn collect_bytes_limited(
72 self,
73 max: usize,
74 ) -> impl Future<Output = Result<bytes::Bytes, Error>> + Send;
75 fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static;
77 }
78
79 impl<B> ResponseExt for http::Response<B>
80 where
81 B: hyper::body::Body + Send + Unpin + 'static,
82 B::Data: Send + 'static,
83 B::Error: core::error::Error + Send + Sync + 'static,
84 {
85 async fn collect_bytes(self) -> Result<bytes::Bytes, Error> {
86 let buf = self
87 .into_body()
88 .collect()
89 .await
90 .map_err(|e| {
91 tracing::error!(error = %e, "collecting response body");
92 Error::Io
93 })?
94 .to_bytes();
95
96 Ok(buf)
97 }
98
99 async fn collect_bytes_limited(self, max: usize) -> Result<bytes::Bytes, Error> {
100 let buf = Limited::new(self.into_body(), max)
104 .collect()
105 .await
106 .map_err(|e| {
107 if e.downcast_ref::<http_body_util::LengthLimitError>()
111 .is_some()
112 {
113 tracing::error!(max, "response body exceeded the size limit");
114 Error::BodyTooLarge
115 } else {
116 tracing::error!(error = %e, max, "collecting response body (limited)");
117 Error::Io
118 }
119 })?
120 .to_bytes();
121
122 Ok(buf)
123 }
124
125 fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static {
126 StreamReader::new(
127 self.into_body()
128 .into_data_stream()
129 .map_err(tokio::io::Error::other),
130 )
131 }
132 }
133}
134
135pub fn make_upgrade_req(
143 u: &url::Url,
144 protocol: &str,
145 extra_headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
146) -> Result<Request<EmptyBody>, Error> {
147 let mut req = Request::post(u.as_str())
151 .header(HOST, u.host_str().ok_or(Error::InvalidInput)?)
152 .header(UPGRADE, protocol)
153 .header(CONNECTION, "Upgrade")
154 .body(EmptyBody::new())
155 .map_err(|e| {
156 tracing::error!(error = %e, "creating upgrade request");
157 Error::InvalidInput
158 })?;
159
160 req.headers_mut().extend(extra_headers);
161
162 Ok(req)
163}
164
165pub fn host_header(u: &url::Url) -> Option<(HeaderName, HeaderValue)> {
174 let host = match u.port() {
175 Some(port) => format!("{}:{port}", u.host_str()?),
176 None => u.host_str()?.to_owned(),
177 };
178 Some((HOST, HeaderValue::from_str(&host).ok()?))
179}
180
181async fn dial_tcp(url: &url::Url) -> Result<TcpStream, Error> {
182 let conn = TcpStream::connect((
183 url.host_str().ok_or(Error::InvalidInput)?,
184 url.port_or_known_default()
185 .ok_or(Error::InvalidInput)
186 .inspect_err(|_err| tracing::error!("unknown url port"))?,
187 ))
188 .await
189 .map_err(|e| {
190 tracing::error!(error = %e, %url, "dialing tcp");
191 Error::Io
192 })?;
193
194 Ok(conn)
195}
196
197async fn dial_tls(
198 url: &url::Url,
199 alpn: impl IntoIterator<Item = Vec<u8>>,
200) -> Result<ts_tls_util::TlsStream<TcpStream>, Error> {
201 let server_name = ts_tls_util::server_name(url)
202 .ok_or_else(|| {
203 tracing::error!(%url, "parsing server name");
204 Error::InvalidInput
205 })?
206 .to_owned();
207
208 let conn = dial_tcp(url).await?;
209
210 ts_tls_util::connect_alpn(server_name, conn, alpn)
211 .await
212 .map_err(|e| {
213 tracing::error!(error = %e, "dialing tls connection");
214
215 Error::Io
216 })
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 fn url(s: &str) -> url::Url {
224 url::Url::parse(s).unwrap()
225 }
226
227 #[test]
228 fn host_header_omits_default_https_port() {
229 let (name, value) = host_header(&url("https://h/")).unwrap();
230 assert_eq!(name, HOST);
231 assert_eq!(value, "h");
232 assert!(!value.to_str().unwrap().contains(":443"));
233 }
234
235 #[test]
236 fn host_header_omits_default_http_port() {
237 let (name, value) = host_header(&url("http://h/")).unwrap();
238 assert_eq!(name, HOST);
239 assert_eq!(value, "h");
240 assert!(!value.to_str().unwrap().contains(":80"));
241 }
242
243 #[test]
244 fn host_header_includes_non_default_port() {
245 let (name, value) = host_header(&url("https://localhost:14000/")).unwrap();
246 assert_eq!(name, HOST);
247 assert_eq!(value, "localhost:14000");
248 }
249
250 #[tokio::test]
255 async fn collect_bytes_limited_caps_the_body() {
256 use http_body_util::Full;
257
258 async fn collect(len: usize, max: usize) -> Result<bytes::Bytes, Error> {
259 let body = Full::new(Bytes::from(vec![0u8; len]));
260 Response::new(body).collect_bytes_limited(max).await
261 }
262
263 const MAX: usize = 1024;
264
265 let under = collect(MAX - 1, MAX)
267 .await
268 .expect("a body under the cap is collected");
269 assert_eq!(under.len(), MAX - 1);
270
271 let at = collect(MAX, MAX)
273 .await
274 .expect("a body exactly at the cap is collected");
275 assert_eq!(at.len(), MAX);
276
277 let over = collect(MAX + 1, MAX).await;
280 assert!(
281 matches!(over, Err(Error::BodyTooLarge)),
282 "a body over the cap must be rejected as BodyTooLarge, got {over:?}"
283 );
284 }
285}