Skip to main content

ic_bn_lib/http/
mod.rs

1pub mod body;
2pub mod cache;
3pub mod client;
4pub mod dns;
5pub mod headers;
6pub mod middleware;
7pub mod proxy;
8pub mod server;
9pub mod shed;
10
11use std::{
12    io,
13    pin::{Pin, pin},
14    sync::{Arc, atomic::Ordering},
15    task::{Context, Poll},
16};
17
18use axum::response::{IntoResponse, Redirect};
19use http::{HeaderMap, Method, Request, StatusCode, Uri, Version, header::HOST, uri::PathAndQuery};
20use ic_bn_lib_common::types::http::Stats;
21use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22
23#[cfg(feature = "clients-hyper")]
24pub use client::clients_hyper::{HyperClient, HyperClientLeastLoaded};
25pub use client::clients_reqwest::{
26    ReqwestClient, ReqwestClientLeastLoaded, ReqwestClientRoundRobin,
27};
28pub use server::{Server, ServerBuilder};
29use url::Url;
30
31use crate::http::headers::X_FORWARDED_HOST;
32
33/// Blanket async read+write trait for streams Box-ing
34trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
35impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> AsyncReadWrite for T {}
36
37/// Calculate very approximate HTTP request/response headers size in bytes.
38/// More or less accurate only for http/1.1 since in h2 headers are in HPACK-compressed.
39/// But it seems there's no better way.
40pub fn calc_headers_size(h: &HeaderMap) -> usize {
41    h.iter().map(|(k, v)| k.as_str().len() + v.len() + 2).sum()
42}
43
44/// Get a static string representing given HTTP version
45pub const fn http_version(v: Version) -> &'static str {
46    match v {
47        Version::HTTP_09 => "0.9",
48        Version::HTTP_10 => "1.0",
49        Version::HTTP_11 => "1.1",
50        Version::HTTP_2 => "2.0",
51        Version::HTTP_3 => "3.0",
52        _ => "-",
53    }
54}
55
56/// Get a static string representing given HTTP method
57pub const fn http_method(v: &Method) -> &'static str {
58    match *v {
59        Method::OPTIONS => "OPTIONS",
60        Method::GET => "GET",
61        Method::POST => "POST",
62        Method::PUT => "PUT",
63        Method::DELETE => "DELETE",
64        Method::HEAD => "HEAD",
65        Method::TRACE => "TRACE",
66        Method::CONNECT => "CONNECT",
67        Method::PATCH => "PATCH",
68        _ => "",
69    }
70}
71
72/// Attempts to extract "host" from "host:port" format.
73/// Host can be either FQDN or IPv4/IPv6 address.
74pub fn extract_host(host_port: &str) -> Option<&str> {
75    if host_port.is_empty() {
76        return None;
77    }
78
79    // Cover IPv6 case
80    if host_port.as_bytes()[0] == b'[' {
81        host_port.find(']').map(|i| &host_port[1..i])
82    } else {
83        host_port.split(':').next()
84    }
85    .filter(|x| !x.is_empty())
86}
87
88/// Attempts to extract host from `X-Forwarded-Host` header, HTTP2 "authority" pseudo-header or from HTTP/1.1 `Host` header
89/// (in this order of preference)
90pub fn extract_authority<T>(request: &Request<T>) -> Option<&str> {
91    // Try `X-Forwarded-Host` header first
92    request
93        .headers()
94        .get(X_FORWARDED_HOST)
95        .and_then(|x| x.to_str().ok())
96        // Then URI authority
97        .or_else(|| request.uri().authority().map(|x| x.host()))
98        // THen `Host` header
99        .or_else(|| request.headers().get(HOST).and_then(|x| x.to_str().ok()))
100        // Extract host w/o port
101        .and_then(extract_host)
102}
103
104/// Async read+write wrapper that counts bytes read/written
105struct AsyncCounter<T: AsyncReadWrite> {
106    inner: T,
107    stats: Arc<Stats>,
108}
109
110impl<T: AsyncReadWrite> AsyncCounter<T> {
111    /// Create new `AsyncCounter`
112    pub fn new(inner: T) -> (Self, Arc<Stats>) {
113        let stats = Arc::new(Stats::new());
114
115        (
116            Self {
117                inner,
118                stats: stats.clone(),
119            },
120            stats,
121        )
122    }
123}
124
125impl<T: AsyncReadWrite> AsyncRead for AsyncCounter<T> {
126    fn poll_read(
127        mut self: Pin<&mut Self>,
128        cx: &mut Context<'_>,
129        buf: &mut ReadBuf<'_>,
130    ) -> Poll<io::Result<()>> {
131        let size_before = buf.filled().len();
132        let poll = pin!(&mut self.inner).poll_read(cx, buf);
133        if matches!(&poll, Poll::Ready(Ok(()))) {
134            let rcvd = buf.filled().len() - size_before;
135            self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst);
136        }
137
138        poll
139    }
140}
141
142impl<T: AsyncReadWrite> AsyncWrite for AsyncCounter<T> {
143    fn poll_write(
144        mut self: Pin<&mut Self>,
145        cx: &mut Context<'_>,
146        buf: &[u8],
147    ) -> Poll<io::Result<usize>> {
148        let poll = pin!(&mut self.inner).poll_write(cx, buf);
149        if let Poll::Ready(Ok(v)) = &poll {
150            self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst);
151        }
152
153        poll
154    }
155
156    fn poll_shutdown(
157        mut self: Pin<&mut Self>,
158        cx: &mut Context<'_>,
159    ) -> Poll<Result<(), io::Error>> {
160        pin!(&mut self.inner).poll_shutdown(cx)
161    }
162
163    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
164        pin!(&mut self.inner).poll_flush(cx)
165    }
166}
167
168/// Error that might happen during Url to Uri conversion
169#[derive(thiserror::Error, Debug)]
170pub enum UrlToUriError {
171    #[error("No Authority")]
172    NoAuthority,
173    #[error("No Host")]
174    NoHost,
175    #[error(transparent)]
176    Http(#[from] http::Error),
177}
178
179/// Converts `Url` to `Uri`
180pub fn url_to_uri(url: &Url) -> Result<Uri, UrlToUriError> {
181    if !url.has_authority() {
182        return Err(UrlToUriError::NoAuthority);
183    }
184
185    if !url.has_host() {
186        return Err(UrlToUriError::NoHost);
187    }
188
189    let scheme = url.scheme();
190    let authority = url.authority();
191
192    let authority_end = scheme.len() + "://".len() + authority.len();
193    let path_and_query = &url.as_str()[authority_end..];
194
195    Uri::builder()
196        .scheme(scheme)
197        .authority(authority)
198        .path_and_query(path_and_query)
199        .build()
200        .map_err(UrlToUriError::Http)
201}
202
203/// Redirects any request to an HTTPS scheme
204pub async fn redirect_to_https(
205    request: axum::extract::Request,
206) -> Result<impl IntoResponse, impl IntoResponse> {
207    let host = extract_authority(&request)
208        .ok_or((StatusCode::BAD_REQUEST, "Unable to extract authority"))?;
209    let uri = request.uri().clone();
210
211    let fallback_path = PathAndQuery::from_static("/");
212    let pq = uri.path_and_query().unwrap_or(&fallback_path).as_str();
213
214    Ok::<_, (_, _)>(Redirect::permanent(
215        &Uri::builder()
216            .scheme("https")
217            .authority(host)
218            .path_and_query(pq)
219            .build()
220            .map_err(|_| (StatusCode::BAD_REQUEST, "Incorrect URL"))?
221            .to_string(),
222    ))
223}
224
225#[cfg(test)]
226mod test {
227    use axum::{Router, body::Body};
228    use http::{
229        Uri,
230        header::{HOST, LOCATION},
231    };
232    use tower::ServiceExt;
233
234    use crate::hval;
235
236    use super::*;
237
238    #[test]
239    fn test_extract_host() {
240        assert_eq!(extract_host("foo.bar"), Some("foo.bar"));
241        assert_eq!(extract_host("foo.bar:443"), Some("foo.bar"));
242        assert_eq!(extract_host("foo.bar:"), Some("foo.bar"));
243        assert_eq!(extract_host("foo:443"), Some("foo"));
244
245        assert_eq!(extract_host("127.0.0.1:443"), Some("127.0.0.1"));
246        assert_eq!(extract_host("[::1]:443"), Some("::1"));
247
248        assert_eq!(
249            extract_host("[fe80::b696:91ff:fe84:3ae8]"),
250            Some("fe80::b696:91ff:fe84:3ae8")
251        );
252        assert_eq!(
253            extract_host("[fe80::b696:91ff:fe84:3ae8]:123"),
254            Some("fe80::b696:91ff:fe84:3ae8")
255        );
256
257        // Unterminated bracket
258        assert_eq!(extract_host("[fe80::b696:91ff:fe84:3ae8:123"), None);
259        // Empty
260        assert_eq!(extract_host(""), None);
261        assert_eq!(extract_host("[]:443"), None);
262    }
263
264    #[test]
265    fn test_extract_authority() {
266        // No authority & no host header
267        let mut req = Request::new(());
268        *req.uri_mut() = Uri::builder()
269            .path_and_query("/foo?bar=baz")
270            .build()
271            .unwrap();
272        assert_eq!(extract_authority(&req), None);
273
274        // Authority
275        let mut req = Request::new(());
276        *req.uri_mut() = Uri::builder()
277            .scheme("http")
278            .authority("foo.bar:443")
279            .path_and_query("/foo?bar=baz")
280            .build()
281            .unwrap();
282        assert_eq!(extract_authority(&req), Some("foo.bar"));
283
284        let mut req = Request::new(());
285        *req.uri_mut() = Uri::builder()
286            .scheme("http")
287            .authority("[::1]:443")
288            .path_and_query("/foo?bar=baz")
289            .build()
290            .unwrap();
291        assert_eq!(extract_authority(&req), Some("::1"));
292
293        // Host header
294        let mut req = Request::new(());
295        *req.uri_mut() = Uri::builder()
296            .path_and_query("/foo?bar=baz")
297            .build()
298            .unwrap();
299        (*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
300        assert_eq!(extract_authority(&req), Some("foo.baz"));
301
302        // XFH header
303        let mut req = Request::new(());
304        *req.uri_mut() = Uri::builder()
305            .path_and_query("/foo?bar=baz")
306            .build()
307            .unwrap();
308        (*req.headers_mut()).insert(X_FORWARDED_HOST, hval!("foo.baz:443"));
309        assert_eq!(extract_authority(&req), Some("foo.baz"));
310
311        // Host+Authority: authority should take precedence
312        let mut req = Request::new(());
313        *req.uri_mut() = Uri::builder()
314            .scheme("http")
315            .authority("foo.bar:443")
316            .path_and_query("/foo?bar=baz")
317            .build()
318            .unwrap();
319        (*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
320        assert_eq!(extract_authority(&req), Some("foo.bar"));
321
322        // XFH+Host+Authority: XFH should take precedence
323        let mut req = Request::new(());
324        *req.uri_mut() = Uri::builder()
325            .scheme("http")
326            .authority("foo.bar:443")
327            .path_and_query("/foo?bar=baz")
328            .build()
329            .unwrap();
330        (*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
331        (*req.headers_mut()).insert(X_FORWARDED_HOST, hval!("dead.beef:443"));
332        assert_eq!(extract_authority(&req), Some("dead.beef"));
333    }
334
335    #[test]
336    fn test_url_to_uri() {
337        let url = "https://foo.bar/baz?dead=beef".parse().unwrap();
338
339        assert_eq!(
340            url_to_uri(&url).unwrap(),
341            Uri::from_static("https://foo.bar/baz?dead=beef")
342        );
343
344        let url = "unix:/foo/bar".parse().unwrap();
345        assert!(url_to_uri(&url).is_err());
346    }
347
348    #[tokio::test]
349    async fn test_redirect_to_https() {
350        let mut request = axum::extract::Request::new(Body::empty());
351        *request.uri_mut() = Uri::from_static("http://foo/bar/baz.bin?a=b");
352
353        let router = Router::new().fallback(redirect_to_https);
354
355        let response = router.oneshot(request).await.unwrap();
356        let location = response.headers().get(LOCATION).unwrap().to_str().unwrap();
357        assert_eq!(location, "https://foo/bar/baz.bin?a=b");
358    }
359}