Skip to main content

ts_http_util/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(unsafe_code)]
3
4use bytes::Bytes;
5use http::header::{CONNECTION, UPGRADE};
6pub use http::{
7    HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, header::HOST,
8};
9use http_body_util::{Empty, Full};
10use hyper::body::Incoming;
11use tokio::net::TcpStream;
12
13mod client;
14mod error;
15pub mod http1;
16pub mod http2;
17
18pub use client::{Client, ClientExt};
19pub use error::Error;
20pub use http1::Http1;
21pub use http2::Http2;
22pub use hyper::upgrade::on as upgrade;
23pub use sealed::ResponseExt;
24
25/// The body of an HTTP [`Request`] or [`Response`] that's always empty; i.e., the body will always
26/// be zero bytes in length.
27pub type EmptyBody = Empty<Bytes>;
28
29/// The body of an HTTP [`Request`] or [`Response`] that may contain one or more bytes; i.e., a body
30/// may be present.
31pub type BytesBody = Full<Bytes>;
32
33/// A connection that has been upgraded from HTTP/1.1 to a different protocol, such as HTTP/2 or
34/// DERP, via HTTP/1.1's upgrade mechanism.protocol upgrade
35pub type Upgraded = hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>;
36
37/// Upgrade a [`Response`] from HTTP/1.1 to the requested protocol.
38pub async fn do_upgrade(resp: Response<Incoming>) -> hyper::Result<Upgraded> {
39    let upgraded = hyper::upgrade::on(resp).await?;
40    Ok(hyper_util::rt::TokioIo::new(upgraded))
41}
42
43mod sealed {
44    use futures::TryStreamExt;
45    use http_body_util::BodyExt;
46    use tokio::io::AsyncRead;
47    use tokio_util::io::StreamReader;
48
49    use crate::Error;
50
51    /// Helper methods for [`http::Response`].
52    pub trait ResponseExt {
53        /// Collect the response body into a [`bytes::Bytes`].
54        fn collect_bytes(self) -> impl Future<Output = Result<bytes::Bytes, Error>> + Send;
55        /// Convert the response body into an [`AsyncRead`].
56        fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static;
57    }
58
59    impl<B> ResponseExt for http::Response<B>
60    where
61        B: hyper::body::Body + Send + Unpin + 'static,
62        B::Data: Send + 'static,
63        B::Error: core::error::Error + Send + Sync + 'static,
64    {
65        async fn collect_bytes(self) -> Result<bytes::Bytes, Error> {
66            let buf = self
67                .into_body()
68                .collect()
69                .await
70                .map_err(|e| {
71                    tracing::error!(error = %e, "collecting response body");
72                    Error::Io
73                })?
74                .to_bytes();
75
76            Ok(buf)
77        }
78
79        fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static {
80            StreamReader::new(
81                self.into_body()
82                    .into_data_stream()
83                    .map_err(tokio::io::Error::other),
84            )
85        }
86    }
87}
88
89/// Create a [`Request`] to upgrade from HTTP/1.1 to the given `protocol`, which can be sent to the
90/// server via an [`Http1`] client to start the [HTTP/1.1 protocol upgrade] process.
91///
92/// Some protocols, such as TS2021, require additional headers in the initial request to
93/// successfully upgrade; these can be provided via `extra_headers`.
94///
95/// [HTTP/1.1 protocol upgrade]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/Protocol_upgrade_mechanism
96pub fn make_upgrade_req(
97    u: &url::Url,
98    protocol: &str,
99    extra_headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
100) -> Result<Request<EmptyBody>, Error> {
101    // Use POST for the upgrade request. Some server implementations accept both
102    // GET and POST, but others (e.g. Go's testcontrol) only accept POST. POST
103    // is what Go's controlhttp client sends, so use it for widest compatibility.
104    let mut req = Request::post(u.as_str())
105        .header(HOST, u.host_str().ok_or(Error::InvalidInput)?)
106        .header(UPGRADE, protocol)
107        .header(CONNECTION, "Upgrade")
108        .body(EmptyBody::new())
109        .map_err(|e| {
110            tracing::error!(error = %e, "creating upgrade request");
111            Error::InvalidInput
112        })?;
113
114    req.headers_mut().extend(extra_headers);
115
116    Ok(req)
117}
118
119/// Produce a `Host` header for the given URL.
120///
121/// Includes the port when the URL carries a non-default one (`u.port()` is `Some`), per
122/// RFC 7230 §5.4 — e.g. `localhost:14000`. Origin servers that reconstruct their own absolute
123/// URLs from the `Host` header (such as an ACME directory emitting `newNonce`/`newAccount`
124/// endpoints) would otherwise drop the port and advertise unreachable `:443` URLs.
125///
126/// Returns `None` if `u.host_str()` is `None` or includes non-ascii-printable characters.
127pub fn host_header(u: &url::Url) -> Option<(HeaderName, HeaderValue)> {
128    let host = match u.port() {
129        Some(port) => format!("{}:{port}", u.host_str()?),
130        None => u.host_str()?.to_owned(),
131    };
132    Some((HOST, HeaderValue::from_str(&host).ok()?))
133}
134
135async fn dial_tcp(url: &url::Url) -> Result<TcpStream, Error> {
136    let conn = TcpStream::connect((
137        url.host_str().ok_or(Error::InvalidInput)?,
138        url.port_or_known_default()
139            .ok_or(Error::InvalidInput)
140            .inspect_err(|_err| tracing::error!("unknown url port"))?,
141    ))
142    .await
143    .map_err(|e| {
144        tracing::error!(error = %e, %url, "dialing tcp");
145        Error::Io
146    })?;
147
148    Ok(conn)
149}
150
151async fn dial_tls(
152    url: &url::Url,
153    alpn: impl IntoIterator<Item = Vec<u8>>,
154) -> Result<ts_tls_util::TlsStream<TcpStream>, Error> {
155    let server_name = ts_tls_util::server_name(url)
156        .ok_or_else(|| {
157            tracing::error!(%url, "parsing server name");
158            Error::InvalidInput
159        })?
160        .to_owned();
161
162    let conn = dial_tcp(url).await?;
163
164    ts_tls_util::connect_alpn(server_name, conn, alpn)
165        .await
166        .map_err(|e| {
167            tracing::error!(error = %e, "dialing tls connection");
168
169            Error::Io
170        })
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    fn url(s: &str) -> url::Url {
178        url::Url::parse(s).unwrap()
179    }
180
181    #[test]
182    fn host_header_omits_default_https_port() {
183        let (name, value) = host_header(&url("https://h/")).unwrap();
184        assert_eq!(name, HOST);
185        assert_eq!(value, "h");
186        assert!(!value.to_str().unwrap().contains(":443"));
187    }
188
189    #[test]
190    fn host_header_omits_default_http_port() {
191        let (name, value) = host_header(&url("http://h/")).unwrap();
192        assert_eq!(name, HOST);
193        assert_eq!(value, "h");
194        assert!(!value.to_str().unwrap().contains(":80"));
195    }
196
197    #[test]
198    fn host_header_includes_non_default_port() {
199        let (name, value) = host_header(&url("https://localhost:14000/")).unwrap();
200        assert_eq!(name, HOST);
201        assert_eq!(value, "localhost:14000");
202    }
203}