1#![doc = include_str!("../README.md")]
2#![deny(unsafe_code)]
3
4use bytes::Bytes;
5use http::header::{CONNECTION, UPGRADE};
6pub use http::{
7 HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, header::HOST,
8};
9use http_body_util::{Empty, Full};
10use hyper::body::Incoming;
11use tokio::net::TcpStream;
12
13mod client;
14mod error;
15pub mod http1;
16pub mod http2;
17
18pub use client::{Client, ClientExt};
19pub use error::Error;
20pub use http1::Http1;
21pub use http2::Http2;
22pub use hyper::upgrade::on as upgrade;
23pub use sealed::ResponseExt;
24
25pub type EmptyBody = Empty<Bytes>;
28
29pub type BytesBody = Full<Bytes>;
32
33pub type Upgraded = hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>;
36
37pub async fn do_upgrade(resp: Response<Incoming>) -> hyper::Result<Upgraded> {
39 let upgraded = hyper::upgrade::on(resp).await?;
40 Ok(hyper_util::rt::TokioIo::new(upgraded))
41}
42
43mod sealed {
44 use futures::TryStreamExt;
45 use http_body_util::BodyExt;
46 use tokio::io::AsyncRead;
47 use tokio_util::io::StreamReader;
48
49 use crate::Error;
50
51 pub trait ResponseExt {
53 fn collect_bytes(self) -> impl Future<Output = Result<bytes::Bytes, Error>> + Send;
55 fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static;
57 }
58
59 impl<B> ResponseExt for http::Response<B>
60 where
61 B: hyper::body::Body + Send + Unpin + 'static,
62 B::Data: Send + 'static,
63 B::Error: core::error::Error + Send + Sync + 'static,
64 {
65 async fn collect_bytes(self) -> Result<bytes::Bytes, Error> {
66 let buf = self
67 .into_body()
68 .collect()
69 .await
70 .map_err(|e| {
71 tracing::error!(error = %e, "collecting response body");
72 Error::Io
73 })?
74 .to_bytes();
75
76 Ok(buf)
77 }
78
79 fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static {
80 StreamReader::new(
81 self.into_body()
82 .into_data_stream()
83 .map_err(tokio::io::Error::other),
84 )
85 }
86 }
87}
88
89pub fn make_upgrade_req(
97 u: &url::Url,
98 protocol: &str,
99 extra_headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
100) -> Result<Request<EmptyBody>, Error> {
101 let mut req = Request::post(u.as_str())
105 .header(HOST, u.host_str().ok_or(Error::InvalidInput)?)
106 .header(UPGRADE, protocol)
107 .header(CONNECTION, "Upgrade")
108 .body(EmptyBody::new())
109 .map_err(|e| {
110 tracing::error!(error = %e, "creating upgrade request");
111 Error::InvalidInput
112 })?;
113
114 req.headers_mut().extend(extra_headers);
115
116 Ok(req)
117}
118
119pub fn host_header(u: &url::Url) -> Option<(HeaderName, HeaderValue)> {
128 let host = match u.port() {
129 Some(port) => format!("{}:{port}", u.host_str()?),
130 None => u.host_str()?.to_owned(),
131 };
132 Some((HOST, HeaderValue::from_str(&host).ok()?))
133}
134
135async fn dial_tcp(url: &url::Url) -> Result<TcpStream, Error> {
136 let conn = TcpStream::connect((
137 url.host_str().ok_or(Error::InvalidInput)?,
138 url.port_or_known_default()
139 .ok_or(Error::InvalidInput)
140 .inspect_err(|_err| tracing::error!("unknown url port"))?,
141 ))
142 .await
143 .map_err(|e| {
144 tracing::error!(error = %e, %url, "dialing tcp");
145 Error::Io
146 })?;
147
148 Ok(conn)
149}
150
151async fn dial_tls(
152 url: &url::Url,
153 alpn: impl IntoIterator<Item = Vec<u8>>,
154) -> Result<ts_tls_util::TlsStream<TcpStream>, Error> {
155 let server_name = ts_tls_util::server_name(url)
156 .ok_or_else(|| {
157 tracing::error!(%url, "parsing server name");
158 Error::InvalidInput
159 })?
160 .to_owned();
161
162 let conn = dial_tcp(url).await?;
163
164 ts_tls_util::connect_alpn(server_name, conn, alpn)
165 .await
166 .map_err(|e| {
167 tracing::error!(error = %e, "dialing tls connection");
168
169 Error::Io
170 })
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 fn url(s: &str) -> url::Url {
178 url::Url::parse(s).unwrap()
179 }
180
181 #[test]
182 fn host_header_omits_default_https_port() {
183 let (name, value) = host_header(&url("https://h/")).unwrap();
184 assert_eq!(name, HOST);
185 assert_eq!(value, "h");
186 assert!(!value.to_str().unwrap().contains(":443"));
187 }
188
189 #[test]
190 fn host_header_omits_default_http_port() {
191 let (name, value) = host_header(&url("http://h/")).unwrap();
192 assert_eq!(name, HOST);
193 assert_eq!(value, "h");
194 assert!(!value.to_str().unwrap().contains(":80"));
195 }
196
197 #[test]
198 fn host_header_includes_non_default_port() {
199 let (name, value) = host_header(&url("https://localhost:14000/")).unwrap();
200 assert_eq!(name, HOST);
201 assert_eq!(value, "localhost:14000");
202 }
203}