1pub mod body;
2pub mod cache;
3pub mod client;
4pub mod dns;
5pub mod headers;
6pub mod middleware;
7pub mod proxy;
8pub mod server;
9pub mod shed;
10
11use std::{
12 io,
13 pin::{Pin, pin},
14 sync::{Arc, atomic::Ordering},
15 task::{Context, Poll},
16};
17
18use axum::{
19 extract::OriginalUri,
20 response::{IntoResponse, Redirect},
21};
22use axum_extra::extract::Host;
23use http::{HeaderMap, Method, Request, StatusCode, Uri, Version, header::HOST, uri::PathAndQuery};
24use ic_bn_lib_common::types::http::Stats;
25use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
26
27#[cfg(feature = "clients-hyper")]
28pub use client::clients_hyper::{HyperClient, HyperClientLeastLoaded};
29pub use client::clients_reqwest::{
30 ReqwestClient, ReqwestClientLeastLoaded, ReqwestClientRoundRobin,
31};
32pub use server::{Server, ServerBuilder};
33use url::Url;
34
35use crate::http::headers::X_FORWARDED_HOST;
36
37trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
39impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> AsyncReadWrite for T {}
40
41pub fn calc_headers_size(h: &HeaderMap) -> usize {
45 h.iter().map(|(k, v)| k.as_str().len() + v.len() + 2).sum()
46}
47
48pub const fn http_version(v: Version) -> &'static str {
50 match v {
51 Version::HTTP_09 => "0.9",
52 Version::HTTP_10 => "1.0",
53 Version::HTTP_11 => "1.1",
54 Version::HTTP_2 => "2.0",
55 Version::HTTP_3 => "3.0",
56 _ => "-",
57 }
58}
59
60pub const fn http_method(v: &Method) -> &'static str {
62 match *v {
63 Method::OPTIONS => "OPTIONS",
64 Method::GET => "GET",
65 Method::POST => "POST",
66 Method::PUT => "PUT",
67 Method::DELETE => "DELETE",
68 Method::HEAD => "HEAD",
69 Method::TRACE => "TRACE",
70 Method::CONNECT => "CONNECT",
71 Method::PATCH => "PATCH",
72 _ => "",
73 }
74}
75
76pub fn extract_host(host_port: &str) -> Option<&str> {
79 if host_port.is_empty() {
80 return None;
81 }
82
83 if host_port.as_bytes()[0] == b'[' {
85 host_port.find(']').map(|i| &host_port[1..i])
86 } else {
87 host_port.split(':').next()
88 }
89 .filter(|x| !x.is_empty())
90}
91
92pub fn extract_authority<T>(request: &Request<T>) -> Option<&str> {
95 request
97 .headers()
98 .get(X_FORWARDED_HOST)
99 .and_then(|x| x.to_str().ok())
100 .or_else(|| request.uri().authority().map(|x| x.host()))
102 .or_else(|| request.headers().get(HOST).and_then(|x| x.to_str().ok()))
104 .and_then(extract_host)
106}
107
108struct AsyncCounter<T: AsyncReadWrite> {
110 inner: T,
111 stats: Arc<Stats>,
112}
113
114impl<T: AsyncReadWrite> AsyncCounter<T> {
115 pub fn new(inner: T) -> (Self, Arc<Stats>) {
117 let stats = Arc::new(Stats::new());
118
119 (
120 Self {
121 inner,
122 stats: stats.clone(),
123 },
124 stats,
125 )
126 }
127}
128
129impl<T: AsyncReadWrite> AsyncRead for AsyncCounter<T> {
130 fn poll_read(
131 mut self: Pin<&mut Self>,
132 cx: &mut Context<'_>,
133 buf: &mut ReadBuf<'_>,
134 ) -> Poll<io::Result<()>> {
135 let size_before = buf.filled().len();
136 let poll = pin!(&mut self.inner).poll_read(cx, buf);
137 if matches!(&poll, Poll::Ready(Ok(()))) {
138 let rcvd = buf.filled().len() - size_before;
139 self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst);
140 }
141
142 poll
143 }
144}
145
146impl<T: AsyncReadWrite> AsyncWrite for AsyncCounter<T> {
147 fn poll_write(
148 mut self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 buf: &[u8],
151 ) -> Poll<io::Result<usize>> {
152 let poll = pin!(&mut self.inner).poll_write(cx, buf);
153 if let Poll::Ready(Ok(v)) = &poll {
154 self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst);
155 }
156
157 poll
158 }
159
160 fn poll_shutdown(
161 mut self: Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 ) -> Poll<Result<(), io::Error>> {
164 pin!(&mut self.inner).poll_shutdown(cx)
165 }
166
167 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
168 pin!(&mut self.inner).poll_flush(cx)
169 }
170}
171
172#[derive(thiserror::Error, Debug)]
174pub enum UrlToUriError {
175 #[error("No Authority")]
176 NoAuthority,
177 #[error("No Host")]
178 NoHost,
179 #[error(transparent)]
180 Http(#[from] http::Error),
181}
182
183pub fn url_to_uri(url: &Url) -> Result<Uri, UrlToUriError> {
185 if !url.has_authority() {
186 return Err(UrlToUriError::NoAuthority);
187 }
188
189 if !url.has_host() {
190 return Err(UrlToUriError::NoHost);
191 }
192
193 let scheme = url.scheme();
194 let authority = url.authority();
195
196 let authority_end = scheme.len() + "://".len() + authority.len();
197 let path_and_query = &url.as_str()[authority_end..];
198
199 Uri::builder()
200 .scheme(scheme)
201 .authority(authority)
202 .path_and_query(path_and_query)
203 .build()
204 .map_err(UrlToUriError::Http)
205}
206
207pub async fn redirect_to_https(
209 Host(host): Host,
210 OriginalUri(uri): OriginalUri,
211) -> Result<impl IntoResponse, impl IntoResponse> {
212 let fallback_path = PathAndQuery::from_static("/");
213 let pq = uri.path_and_query().unwrap_or(&fallback_path).as_str();
214
215 Ok::<_, (_, _)>(Redirect::permanent(
216 &Uri::builder()
217 .scheme("https")
218 .authority(host)
219 .path_and_query(pq)
220 .build()
221 .map_err(|_| (StatusCode::BAD_REQUEST, "Incorrect URL"))?
222 .to_string(),
223 ))
224}
225
226#[cfg(test)]
227mod test {
228 use http::{Uri, header::HOST};
229
230 use crate::hval;
231
232 use super::*;
233
234 #[test]
235 fn test_extract_host() {
236 assert_eq!(extract_host("foo.bar"), Some("foo.bar"));
237 assert_eq!(extract_host("foo.bar:443"), Some("foo.bar"));
238 assert_eq!(extract_host("foo.bar:"), Some("foo.bar"));
239 assert_eq!(extract_host("foo:443"), Some("foo"));
240
241 assert_eq!(extract_host("127.0.0.1:443"), Some("127.0.0.1"));
242 assert_eq!(extract_host("[::1]:443"), Some("::1"));
243
244 assert_eq!(
245 extract_host("[fe80::b696:91ff:fe84:3ae8]"),
246 Some("fe80::b696:91ff:fe84:3ae8")
247 );
248 assert_eq!(
249 extract_host("[fe80::b696:91ff:fe84:3ae8]:123"),
250 Some("fe80::b696:91ff:fe84:3ae8")
251 );
252
253 assert_eq!(extract_host("[fe80::b696:91ff:fe84:3ae8:123"), None);
255 assert_eq!(extract_host(""), None);
257 assert_eq!(extract_host("[]:443"), None);
258 }
259
260 #[test]
261 fn test_extract_authority() {
262 let mut req = Request::new(());
264 *req.uri_mut() = Uri::builder()
265 .path_and_query("/foo?bar=baz")
266 .build()
267 .unwrap();
268 assert_eq!(extract_authority(&req), None);
269
270 let mut req = Request::new(());
272 *req.uri_mut() = Uri::builder()
273 .scheme("http")
274 .authority("foo.bar:443")
275 .path_and_query("/foo?bar=baz")
276 .build()
277 .unwrap();
278 assert_eq!(extract_authority(&req), Some("foo.bar"));
279
280 let mut req = Request::new(());
281 *req.uri_mut() = Uri::builder()
282 .scheme("http")
283 .authority("[::1]:443")
284 .path_and_query("/foo?bar=baz")
285 .build()
286 .unwrap();
287 assert_eq!(extract_authority(&req), Some("::1"));
288
289 let mut req = Request::new(());
291 *req.uri_mut() = Uri::builder()
292 .path_and_query("/foo?bar=baz")
293 .build()
294 .unwrap();
295 (*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
296 assert_eq!(extract_authority(&req), Some("foo.baz"));
297
298 let mut req = Request::new(());
300 *req.uri_mut() = Uri::builder()
301 .path_and_query("/foo?bar=baz")
302 .build()
303 .unwrap();
304 (*req.headers_mut()).insert(X_FORWARDED_HOST, hval!("foo.baz:443"));
305 assert_eq!(extract_authority(&req), Some("foo.baz"));
306
307 let mut req = Request::new(());
309 *req.uri_mut() = Uri::builder()
310 .scheme("http")
311 .authority("foo.bar:443")
312 .path_and_query("/foo?bar=baz")
313 .build()
314 .unwrap();
315 (*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
316 assert_eq!(extract_authority(&req), Some("foo.bar"));
317
318 let mut req = Request::new(());
320 *req.uri_mut() = Uri::builder()
321 .scheme("http")
322 .authority("foo.bar:443")
323 .path_and_query("/foo?bar=baz")
324 .build()
325 .unwrap();
326 (*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
327 (*req.headers_mut()).insert(X_FORWARDED_HOST, hval!("dead.beef:443"));
328 assert_eq!(extract_authority(&req), Some("dead.beef"));
329 }
330
331 #[test]
332 fn test_url_to_uri() {
333 let url = "https://foo.bar/baz?dead=beef".parse().unwrap();
334
335 assert_eq!(
336 url_to_uri(&url).unwrap(),
337 Uri::from_static("https://foo.bar/baz?dead=beef")
338 );
339
340 let url = "unix:/foo/bar".parse().unwrap();
341 assert!(url_to_uri(&url).is_err());
342 }
343}