Skip to main content

ts_http_util/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(unsafe_code)]
3
4use bytes::Bytes;
5use http::header::{CONNECTION, UPGRADE};
6pub use http::{
7    HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, header::HOST,
8};
9use http_body_util::{Empty, Full};
10use hyper::body::Incoming;
11use tokio::net::TcpStream;
12
13mod client;
14mod error;
15pub mod http1;
16pub mod http2;
17
18pub use client::{Client, ClientExt};
19pub use error::Error;
20pub use http1::Http1;
21pub use http2::Http2;
22pub use hyper::upgrade::on as upgrade;
23pub use sealed::ResponseExt;
24
25/// The body of an HTTP [`Request`] or [`Response`] that's always empty; i.e., the body will always
26/// be zero bytes in length.
27pub type EmptyBody = Empty<Bytes>;
28
29/// The body of an HTTP [`Request`] or [`Response`] that may contain one or more bytes; i.e., a body
30/// may be present.
31pub type BytesBody = Full<Bytes>;
32
33/// A connection that has been upgraded from HTTP/1.1 to a different protocol, such as HTTP/2 or
34/// DERP, via HTTP/1.1's upgrade mechanism.protocol upgrade
35pub type Upgraded = hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>;
36
37/// Upgrade a [`Response`] from HTTP/1.1 to the requested protocol.
38pub async fn do_upgrade(resp: Response<Incoming>) -> hyper::Result<Upgraded> {
39    let upgraded = hyper::upgrade::on(resp).await?;
40    Ok(hyper_util::rt::TokioIo::new(upgraded))
41}
42
43mod sealed {
44    use futures::TryStreamExt;
45    use http_body_util::{BodyExt, Limited};
46    use tokio::io::AsyncRead;
47    use tokio_util::io::StreamReader;
48
49    use crate::Error;
50
51    /// Helper methods for [`http::Response`].
52    pub trait ResponseExt {
53        /// Collect the response body into a [`bytes::Bytes`] with **no size limit**.
54        ///
55        /// Use this only when the body is locally generated or otherwise trusted to be small. For
56        /// any body read from a network peer (a control/DERP/upstream server response — all of which
57        /// this fork treats as hostile-capable), prefer
58        /// [`collect_bytes_limited`][Self::collect_bytes_limited]: `collect()` here buffers the
59        /// *entire* body into memory before returning, so a malicious or MITM'd server answering a
60        /// short request with a multi-gigabyte streamed body would OOM the client (a length check
61        /// *after* `collect_bytes` is too late — the allocation already happened).
62        fn collect_bytes(self) -> impl Future<Output = Result<bytes::Bytes, Error>> + Send;
63        /// Collect the response body into a [`bytes::Bytes`], failing if it exceeds `max` bytes.
64        ///
65        /// The body is wrapped in [`http_body_util::Limited`], which aborts collection as soon as
66        /// more than `max` bytes arrive — so the allocation is bounded *during* the read, mirroring
67        /// Go's `io.LimitedReader`-bounded body reads. An over-limit body yields
68        /// [`Error::BodyTooLarge`] (distinct from a transient [`Error::Io`], so callers can treat it
69        /// as terminal). This is the body reader every network-response caller should use; pick `max`
70        /// to comfortably fit the largest legitimate response for that endpoint.
71        fn collect_bytes_limited(
72            self,
73            max: usize,
74        ) -> impl Future<Output = Result<bytes::Bytes, Error>> + Send;
75        /// Convert the response body into an [`AsyncRead`].
76        fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static;
77    }
78
79    impl<B> ResponseExt for http::Response<B>
80    where
81        B: hyper::body::Body + Send + Unpin + 'static,
82        B::Data: Send + 'static,
83        B::Error: core::error::Error + Send + Sync + 'static,
84    {
85        async fn collect_bytes(self) -> Result<bytes::Bytes, Error> {
86            let buf = self
87                .into_body()
88                .collect()
89                .await
90                .map_err(|e| {
91                    tracing::error!(error = %e, "collecting response body");
92                    Error::Io
93                })?
94                .to_bytes();
95
96            Ok(buf)
97        }
98
99        async fn collect_bytes_limited(self, max: usize) -> Result<bytes::Bytes, Error> {
100            // `Limited` errors (with a boxed `LengthLimitError`) once the body exceeds `max`, so
101            // collection stops there instead of buffering an unbounded body — the cap binds the
102            // allocation, not just a post-hoc length check.
103            let buf = Limited::new(self.into_body(), max)
104                .collect()
105                .await
106                .map_err(|e| {
107                    // Distinguish "the peer exceeded the cap" (terminal — an attack/misconfig signal,
108                    // not worth retrying) from a transient mid-read I/O failure. `Limited` boxes a
109                    // `LengthLimitError` for the former.
110                    if e.downcast_ref::<http_body_util::LengthLimitError>()
111                        .is_some()
112                    {
113                        tracing::error!(max, "response body exceeded the size limit");
114                        Error::BodyTooLarge
115                    } else {
116                        tracing::error!(error = %e, max, "collecting response body (limited)");
117                        Error::Io
118                    }
119                })?
120                .to_bytes();
121
122            Ok(buf)
123        }
124
125        fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static {
126            StreamReader::new(
127                self.into_body()
128                    .into_data_stream()
129                    .map_err(tokio::io::Error::other),
130            )
131        }
132    }
133}
134
135/// Create a [`Request`] to upgrade from HTTP/1.1 to the given `protocol`, which can be sent to the
136/// server via an [`Http1`] client to start the [HTTP/1.1 protocol upgrade] process.
137///
138/// Some protocols, such as TS2021, require additional headers in the initial request to
139/// successfully upgrade; these can be provided via `extra_headers`.
140///
141/// [HTTP/1.1 protocol upgrade]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/Protocol_upgrade_mechanism
142pub fn make_upgrade_req(
143    u: &url::Url,
144    protocol: &str,
145    extra_headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
146) -> Result<Request<EmptyBody>, Error> {
147    // Use POST for the upgrade request. Some server implementations accept both
148    // GET and POST, but others (e.g. Go's testcontrol) only accept POST. POST
149    // is what Go's controlhttp client sends, so use it for widest compatibility.
150    let mut req = Request::post(u.as_str())
151        .header(HOST, u.host_str().ok_or(Error::InvalidInput)?)
152        .header(UPGRADE, protocol)
153        .header(CONNECTION, "Upgrade")
154        .body(EmptyBody::new())
155        .map_err(|e| {
156            tracing::error!(error = %e, "creating upgrade request");
157            Error::InvalidInput
158        })?;
159
160    req.headers_mut().extend(extra_headers);
161
162    Ok(req)
163}
164
165/// Produce a `Host` header for the given URL.
166///
167/// Includes the port when the URL carries a non-default one (`u.port()` is `Some`), per
168/// RFC 7230 §5.4 — e.g. `localhost:14000`. Origin servers that reconstruct their own absolute
169/// URLs from the `Host` header (such as an ACME directory emitting `newNonce`/`newAccount`
170/// endpoints) would otherwise drop the port and advertise unreachable `:443` URLs.
171///
172/// Returns `None` if `u.host_str()` is `None` or includes non-ascii-printable characters.
173pub fn host_header(u: &url::Url) -> Option<(HeaderName, HeaderValue)> {
174    let host = match u.port() {
175        Some(port) => format!("{}:{port}", u.host_str()?),
176        None => u.host_str()?.to_owned(),
177    };
178    Some((HOST, HeaderValue::from_str(&host).ok()?))
179}
180
181async fn dial_tcp(url: &url::Url) -> Result<TcpStream, Error> {
182    let conn = TcpStream::connect((
183        url.host_str().ok_or(Error::InvalidInput)?,
184        url.port_or_known_default()
185            .ok_or(Error::InvalidInput)
186            .inspect_err(|_err| tracing::error!("unknown url port"))?,
187    ))
188    .await
189    .map_err(|e| {
190        tracing::error!(error = %e, %url, "dialing tcp");
191        Error::Io
192    })?;
193
194    Ok(conn)
195}
196
197async fn dial_tls(
198    url: &url::Url,
199    alpn: impl IntoIterator<Item = Vec<u8>>,
200) -> Result<ts_tls_util::TlsStream<TcpStream>, Error> {
201    let server_name = ts_tls_util::server_name(url)
202        .ok_or_else(|| {
203            tracing::error!(%url, "parsing server name");
204            Error::InvalidInput
205        })?
206        .to_owned();
207
208    let conn = dial_tcp(url).await?;
209
210    ts_tls_util::connect_alpn(server_name, conn, alpn)
211        .await
212        .map_err(|e| {
213            tracing::error!(error = %e, "dialing tls connection");
214
215            Error::Io
216        })
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    fn url(s: &str) -> url::Url {
224        url::Url::parse(s).unwrap()
225    }
226
227    #[test]
228    fn host_header_omits_default_https_port() {
229        let (name, value) = host_header(&url("https://h/")).unwrap();
230        assert_eq!(name, HOST);
231        assert_eq!(value, "h");
232        assert!(!value.to_str().unwrap().contains(":443"));
233    }
234
235    #[test]
236    fn host_header_omits_default_http_port() {
237        let (name, value) = host_header(&url("http://h/")).unwrap();
238        assert_eq!(name, HOST);
239        assert_eq!(value, "h");
240        assert!(!value.to_str().unwrap().contains(":80"));
241    }
242
243    #[test]
244    fn host_header_includes_non_default_port() {
245        let (name, value) = host_header(&url("https://localhost:14000/")).unwrap();
246        assert_eq!(name, HOST);
247        assert_eq!(value, "localhost:14000");
248    }
249
250    /// `collect_bytes_limited` must accept a body up to and including `max` bytes and reject anything
251    /// larger, bounding the allocation during the read (a control/DERP/upstream server can't OOM the
252    /// client by streaming an oversized body to a small request). Pins the boundary: `< max` and
253    /// `== max` succeed, `max + 1` errors.
254    #[tokio::test]
255    async fn collect_bytes_limited_caps_the_body() {
256        use http_body_util::Full;
257
258        async fn collect(len: usize, max: usize) -> Result<bytes::Bytes, Error> {
259            let body = Full::new(Bytes::from(vec![0u8; len]));
260            Response::new(body).collect_bytes_limited(max).await
261        }
262
263        const MAX: usize = 1024;
264
265        // Below the cap: ok, full body returned.
266        let under = collect(MAX - 1, MAX)
267            .await
268            .expect("a body under the cap is collected");
269        assert_eq!(under.len(), MAX - 1);
270
271        // Exactly at the cap: ok (Limited allows == max, rejects only > max).
272        let at = collect(MAX, MAX)
273            .await
274            .expect("a body exactly at the cap is collected");
275        assert_eq!(at.len(), MAX);
276
277        // Over the cap: rejected with the distinct BodyTooLarge (not a generic Io), allocation
278        // bounded — never buffers the whole oversized body.
279        let over = collect(MAX + 1, MAX).await;
280        assert!(
281            matches!(over, Err(Error::BodyTooLarge)),
282            "a body over the cap must be rejected as BodyTooLarge, got {over:?}"
283        );
284    }
285}