1pub mod body;
2pub mod cache;
3pub mod client;
4pub mod dns;
5pub mod headers;
6pub mod middleware;
7pub mod proxy;
8pub mod server;
9pub mod shed;
10
11use std::{
12 io,
13 pin::{Pin, pin},
14 sync::{Arc, atomic::Ordering},
15 task::{Context, Poll},
16};
17
18use axum::response::{IntoResponse, Redirect};
19use http::{HeaderMap, Method, Request, StatusCode, Uri, Version, header::HOST, uri::PathAndQuery};
20use ic_bn_lib_common::types::http::Stats;
21use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22
23#[cfg(feature = "clients-hyper")]
24pub use client::clients_hyper::{HyperClient, HyperClientLeastLoaded};
25pub use client::clients_reqwest::{
26 ReqwestClient, ReqwestClientLeastLoaded, ReqwestClientRoundRobin,
27};
28pub use server::{Server, ServerBuilder};
29use url::Url;
30
31use crate::http::headers::X_FORWARDED_HOST;
32
33trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
35impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> AsyncReadWrite for T {}
36
37pub fn calc_headers_size(h: &HeaderMap) -> usize {
41 h.iter().map(|(k, v)| k.as_str().len() + v.len() + 2).sum()
42}
43
44pub const fn http_version(v: Version) -> &'static str {
46 match v {
47 Version::HTTP_09 => "0.9",
48 Version::HTTP_10 => "1.0",
49 Version::HTTP_11 => "1.1",
50 Version::HTTP_2 => "2.0",
51 Version::HTTP_3 => "3.0",
52 _ => "-",
53 }
54}
55
56pub const fn http_method(v: &Method) -> &'static str {
58 match *v {
59 Method::OPTIONS => "OPTIONS",
60 Method::GET => "GET",
61 Method::POST => "POST",
62 Method::PUT => "PUT",
63 Method::DELETE => "DELETE",
64 Method::HEAD => "HEAD",
65 Method::TRACE => "TRACE",
66 Method::CONNECT => "CONNECT",
67 Method::PATCH => "PATCH",
68 _ => "",
69 }
70}
71
72pub fn extract_host(host_port: &str) -> Option<&str> {
75 if host_port.is_empty() {
76 return None;
77 }
78
79 if host_port.as_bytes()[0] == b'[' {
81 host_port.find(']').map(|i| &host_port[1..i])
82 } else {
83 host_port.split(':').next()
84 }
85 .filter(|x| !x.is_empty())
86}
87
88pub fn extract_authority<T>(request: &Request<T>) -> Option<&str> {
91 request
93 .headers()
94 .get(X_FORWARDED_HOST)
95 .and_then(|x| x.to_str().ok())
96 .or_else(|| request.uri().authority().map(|x| x.host()))
98 .or_else(|| request.headers().get(HOST).and_then(|x| x.to_str().ok()))
100 .and_then(extract_host)
102}
103
104struct AsyncCounter<T: AsyncReadWrite> {
106 inner: T,
107 stats: Arc<Stats>,
108}
109
110impl<T: AsyncReadWrite> AsyncCounter<T> {
111 pub fn new(inner: T) -> (Self, Arc<Stats>) {
113 let stats = Arc::new(Stats::new());
114
115 (
116 Self {
117 inner,
118 stats: stats.clone(),
119 },
120 stats,
121 )
122 }
123}
124
125impl<T: AsyncReadWrite> AsyncRead for AsyncCounter<T> {
126 fn poll_read(
127 mut self: Pin<&mut Self>,
128 cx: &mut Context<'_>,
129 buf: &mut ReadBuf<'_>,
130 ) -> Poll<io::Result<()>> {
131 let size_before = buf.filled().len();
132 let poll = pin!(&mut self.inner).poll_read(cx, buf);
133 if matches!(&poll, Poll::Ready(Ok(()))) {
134 let rcvd = buf.filled().len() - size_before;
135 self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst);
136 }
137
138 poll
139 }
140}
141
142impl<T: AsyncReadWrite> AsyncWrite for AsyncCounter<T> {
143 fn poll_write(
144 mut self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 buf: &[u8],
147 ) -> Poll<io::Result<usize>> {
148 let poll = pin!(&mut self.inner).poll_write(cx, buf);
149 if let Poll::Ready(Ok(v)) = &poll {
150 self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst);
151 }
152
153 poll
154 }
155
156 fn poll_shutdown(
157 mut self: Pin<&mut Self>,
158 cx: &mut Context<'_>,
159 ) -> Poll<Result<(), io::Error>> {
160 pin!(&mut self.inner).poll_shutdown(cx)
161 }
162
163 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
164 pin!(&mut self.inner).poll_flush(cx)
165 }
166}
167
168#[derive(thiserror::Error, Debug)]
170pub enum UrlToUriError {
171 #[error("No Authority")]
172 NoAuthority,
173 #[error("No Host")]
174 NoHost,
175 #[error(transparent)]
176 Http(#[from] http::Error),
177}
178
179pub fn url_to_uri(url: &Url) -> Result<Uri, UrlToUriError> {
181 if !url.has_authority() {
182 return Err(UrlToUriError::NoAuthority);
183 }
184
185 if !url.has_host() {
186 return Err(UrlToUriError::NoHost);
187 }
188
189 let scheme = url.scheme();
190 let authority = url.authority();
191
192 let authority_end = scheme.len() + "://".len() + authority.len();
193 let path_and_query = &url.as_str()[authority_end..];
194
195 Uri::builder()
196 .scheme(scheme)
197 .authority(authority)
198 .path_and_query(path_and_query)
199 .build()
200 .map_err(UrlToUriError::Http)
201}
202
203pub async fn redirect_to_https(
205 request: axum::extract::Request,
206) -> Result<impl IntoResponse, impl IntoResponse> {
207 let host = extract_authority(&request)
208 .ok_or((StatusCode::BAD_REQUEST, "Unable to extract authority"))?;
209 let uri = request.uri().clone();
210
211 let fallback_path = PathAndQuery::from_static("/");
212 let pq = uri.path_and_query().unwrap_or(&fallback_path).as_str();
213
214 Ok::<_, (_, _)>(Redirect::permanent(
215 &Uri::builder()
216 .scheme("https")
217 .authority(host)
218 .path_and_query(pq)
219 .build()
220 .map_err(|_| (StatusCode::BAD_REQUEST, "Incorrect URL"))?
221 .to_string(),
222 ))
223}
224
225#[cfg(test)]
226mod test {
227 use axum::{Router, body::Body};
228 use http::{
229 Uri,
230 header::{HOST, LOCATION},
231 };
232 use tower::ServiceExt;
233
234 use crate::hval;
235
236 use super::*;
237
238 #[test]
239 fn test_extract_host() {
240 assert_eq!(extract_host("foo.bar"), Some("foo.bar"));
241 assert_eq!(extract_host("foo.bar:443"), Some("foo.bar"));
242 assert_eq!(extract_host("foo.bar:"), Some("foo.bar"));
243 assert_eq!(extract_host("foo:443"), Some("foo"));
244
245 assert_eq!(extract_host("127.0.0.1:443"), Some("127.0.0.1"));
246 assert_eq!(extract_host("[::1]:443"), Some("::1"));
247
248 assert_eq!(
249 extract_host("[fe80::b696:91ff:fe84:3ae8]"),
250 Some("fe80::b696:91ff:fe84:3ae8")
251 );
252 assert_eq!(
253 extract_host("[fe80::b696:91ff:fe84:3ae8]:123"),
254 Some("fe80::b696:91ff:fe84:3ae8")
255 );
256
257 assert_eq!(extract_host("[fe80::b696:91ff:fe84:3ae8:123"), None);
259 assert_eq!(extract_host(""), None);
261 assert_eq!(extract_host("[]:443"), None);
262 }
263
264 #[test]
265 fn test_extract_authority() {
266 let mut req = Request::new(());
268 *req.uri_mut() = Uri::builder()
269 .path_and_query("/foo?bar=baz")
270 .build()
271 .unwrap();
272 assert_eq!(extract_authority(&req), None);
273
274 let mut req = Request::new(());
276 *req.uri_mut() = Uri::builder()
277 .scheme("http")
278 .authority("foo.bar:443")
279 .path_and_query("/foo?bar=baz")
280 .build()
281 .unwrap();
282 assert_eq!(extract_authority(&req), Some("foo.bar"));
283
284 let mut req = Request::new(());
285 *req.uri_mut() = Uri::builder()
286 .scheme("http")
287 .authority("[::1]:443")
288 .path_and_query("/foo?bar=baz")
289 .build()
290 .unwrap();
291 assert_eq!(extract_authority(&req), Some("::1"));
292
293 let mut req = Request::new(());
295 *req.uri_mut() = Uri::builder()
296 .path_and_query("/foo?bar=baz")
297 .build()
298 .unwrap();
299 (*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
300 assert_eq!(extract_authority(&req), Some("foo.baz"));
301
302 let mut req = Request::new(());
304 *req.uri_mut() = Uri::builder()
305 .path_and_query("/foo?bar=baz")
306 .build()
307 .unwrap();
308 (*req.headers_mut()).insert(X_FORWARDED_HOST, hval!("foo.baz:443"));
309 assert_eq!(extract_authority(&req), Some("foo.baz"));
310
311 let mut req = Request::new(());
313 *req.uri_mut() = Uri::builder()
314 .scheme("http")
315 .authority("foo.bar:443")
316 .path_and_query("/foo?bar=baz")
317 .build()
318 .unwrap();
319 (*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
320 assert_eq!(extract_authority(&req), Some("foo.bar"));
321
322 let mut req = Request::new(());
324 *req.uri_mut() = Uri::builder()
325 .scheme("http")
326 .authority("foo.bar:443")
327 .path_and_query("/foo?bar=baz")
328 .build()
329 .unwrap();
330 (*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
331 (*req.headers_mut()).insert(X_FORWARDED_HOST, hval!("dead.beef:443"));
332 assert_eq!(extract_authority(&req), Some("dead.beef"));
333 }
334
335 #[test]
336 fn test_url_to_uri() {
337 let url = "https://foo.bar/baz?dead=beef".parse().unwrap();
338
339 assert_eq!(
340 url_to_uri(&url).unwrap(),
341 Uri::from_static("https://foo.bar/baz?dead=beef")
342 );
343
344 let url = "unix:/foo/bar".parse().unwrap();
345 assert!(url_to_uri(&url).is_err());
346 }
347
348 #[tokio::test]
349 async fn test_redirect_to_https() {
350 let mut request = axum::extract::Request::new(Body::empty());
351 *request.uri_mut() = Uri::from_static("http://foo/bar/baz.bin?a=b");
352
353 let router = Router::new().fallback(redirect_to_https);
354
355 let response = router.oneshot(request).await.unwrap();
356 let location = response.headers().get(LOCATION).unwrap().to_str().unwrap();
357 assert_eq!(location, "https://foo/bar/baz.bin?a=b");
358 }
359}