pub mod body;
pub mod cache;
pub mod client;
pub mod dns;
pub mod headers;
pub mod middleware;
pub mod proxy;
pub mod server;
pub mod shed;
use std::{
io,
pin::{Pin, pin},
sync::{Arc, atomic::Ordering},
task::{Context, Poll},
};
use axum::response::{IntoResponse, Redirect};
use http::{HeaderMap, Method, Request, StatusCode, Uri, Version, header::HOST, uri::PathAndQuery};
use ic_bn_lib_common::types::http::Stats;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[cfg(feature = "clients-hyper")]
pub use client::clients_hyper::{HyperClient, HyperClientLeastLoaded};
pub use client::clients_reqwest::{
ReqwestClient, ReqwestClientLeastLoaded, ReqwestClientRoundRobin,
};
pub use server::{Server, ServerBuilder};
use url::Url;
use crate::http::headers::X_FORWARDED_HOST;
trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> AsyncReadWrite for T {}
pub fn calc_headers_size(h: &HeaderMap) -> usize {
h.iter().map(|(k, v)| k.as_str().len() + v.len() + 2).sum()
}
pub const fn http_version(v: Version) -> &'static str {
match v {
Version::HTTP_09 => "0.9",
Version::HTTP_10 => "1.0",
Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2.0",
Version::HTTP_3 => "3.0",
_ => "-",
}
}
pub const fn http_method(v: &Method) -> &'static str {
match *v {
Method::OPTIONS => "OPTIONS",
Method::GET => "GET",
Method::POST => "POST",
Method::PUT => "PUT",
Method::DELETE => "DELETE",
Method::HEAD => "HEAD",
Method::TRACE => "TRACE",
Method::CONNECT => "CONNECT",
Method::PATCH => "PATCH",
_ => "",
}
}
pub fn extract_host(host_port: &str) -> Option<&str> {
if host_port.is_empty() {
return None;
}
if host_port.as_bytes()[0] == b'[' {
host_port.find(']').map(|i| &host_port[1..i])
} else {
host_port.split(':').next()
}
.filter(|x| !x.is_empty())
}
pub fn extract_authority<T>(request: &Request<T>) -> Option<&str> {
request
.headers()
.get(X_FORWARDED_HOST)
.and_then(|x| x.to_str().ok())
.or_else(|| request.uri().authority().map(|x| x.host()))
.or_else(|| request.headers().get(HOST).and_then(|x| x.to_str().ok()))
.and_then(extract_host)
}
struct AsyncCounter<T: AsyncReadWrite> {
inner: T,
stats: Arc<Stats>,
}
impl<T: AsyncReadWrite> AsyncCounter<T> {
pub fn new(inner: T) -> (Self, Arc<Stats>) {
let stats = Arc::new(Stats::new());
(
Self {
inner,
stats: stats.clone(),
},
stats,
)
}
}
impl<T: AsyncReadWrite> AsyncRead for AsyncCounter<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let size_before = buf.filled().len();
let poll = pin!(&mut self.inner).poll_read(cx, buf);
if matches!(&poll, Poll::Ready(Ok(()))) {
let rcvd = buf.filled().len() - size_before;
self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst);
}
poll
}
}
impl<T: AsyncReadWrite> AsyncWrite for AsyncCounter<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let poll = pin!(&mut self.inner).poll_write(cx, buf);
if let Poll::Ready(Ok(v)) = &poll {
self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst);
}
poll
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
pin!(&mut self.inner).poll_shutdown(cx)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
pin!(&mut self.inner).poll_flush(cx)
}
}
#[derive(thiserror::Error, Debug)]
pub enum UrlToUriError {
#[error("No Authority")]
NoAuthority,
#[error("No Host")]
NoHost,
#[error(transparent)]
Http(#[from] http::Error),
}
pub fn url_to_uri(url: &Url) -> Result<Uri, UrlToUriError> {
if !url.has_authority() {
return Err(UrlToUriError::NoAuthority);
}
if !url.has_host() {
return Err(UrlToUriError::NoHost);
}
let scheme = url.scheme();
let authority = url.authority();
let authority_end = scheme.len() + "://".len() + authority.len();
let path_and_query = &url.as_str()[authority_end..];
Uri::builder()
.scheme(scheme)
.authority(authority)
.path_and_query(path_and_query)
.build()
.map_err(UrlToUriError::Http)
}
pub async fn redirect_to_https(
request: axum::extract::Request,
) -> Result<impl IntoResponse, impl IntoResponse> {
let host = extract_authority(&request)
.ok_or((StatusCode::BAD_REQUEST, "Unable to extract authority"))?;
let uri = request.uri().clone();
let fallback_path = PathAndQuery::from_static("/");
let pq = uri.path_and_query().unwrap_or(&fallback_path).as_str();
Ok::<_, (_, _)>(Redirect::permanent(
&Uri::builder()
.scheme("https")
.authority(host)
.path_and_query(pq)
.build()
.map_err(|_| (StatusCode::BAD_REQUEST, "Incorrect URL"))?
.to_string(),
))
}
#[cfg(test)]
mod test {
use axum::{Router, body::Body};
use http::{
Uri,
header::{HOST, LOCATION},
};
use tower::ServiceExt;
use crate::hval;
use super::*;
#[test]
fn test_extract_host() {
assert_eq!(extract_host("foo.bar"), Some("foo.bar"));
assert_eq!(extract_host("foo.bar:443"), Some("foo.bar"));
assert_eq!(extract_host("foo.bar:"), Some("foo.bar"));
assert_eq!(extract_host("foo:443"), Some("foo"));
assert_eq!(extract_host("127.0.0.1:443"), Some("127.0.0.1"));
assert_eq!(extract_host("[::1]:443"), Some("::1"));
assert_eq!(
extract_host("[fe80::b696:91ff:fe84:3ae8]"),
Some("fe80::b696:91ff:fe84:3ae8")
);
assert_eq!(
extract_host("[fe80::b696:91ff:fe84:3ae8]:123"),
Some("fe80::b696:91ff:fe84:3ae8")
);
assert_eq!(extract_host("[fe80::b696:91ff:fe84:3ae8:123"), None);
assert_eq!(extract_host(""), None);
assert_eq!(extract_host("[]:443"), None);
}
#[test]
fn test_extract_authority() {
let mut req = Request::new(());
*req.uri_mut() = Uri::builder()
.path_and_query("/foo?bar=baz")
.build()
.unwrap();
assert_eq!(extract_authority(&req), None);
let mut req = Request::new(());
*req.uri_mut() = Uri::builder()
.scheme("http")
.authority("foo.bar:443")
.path_and_query("/foo?bar=baz")
.build()
.unwrap();
assert_eq!(extract_authority(&req), Some("foo.bar"));
let mut req = Request::new(());
*req.uri_mut() = Uri::builder()
.scheme("http")
.authority("[::1]:443")
.path_and_query("/foo?bar=baz")
.build()
.unwrap();
assert_eq!(extract_authority(&req), Some("::1"));
let mut req = Request::new(());
*req.uri_mut() = Uri::builder()
.path_and_query("/foo?bar=baz")
.build()
.unwrap();
(*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
assert_eq!(extract_authority(&req), Some("foo.baz"));
let mut req = Request::new(());
*req.uri_mut() = Uri::builder()
.path_and_query("/foo?bar=baz")
.build()
.unwrap();
(*req.headers_mut()).insert(X_FORWARDED_HOST, hval!("foo.baz:443"));
assert_eq!(extract_authority(&req), Some("foo.baz"));
let mut req = Request::new(());
*req.uri_mut() = Uri::builder()
.scheme("http")
.authority("foo.bar:443")
.path_and_query("/foo?bar=baz")
.build()
.unwrap();
(*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
assert_eq!(extract_authority(&req), Some("foo.bar"));
let mut req = Request::new(());
*req.uri_mut() = Uri::builder()
.scheme("http")
.authority("foo.bar:443")
.path_and_query("/foo?bar=baz")
.build()
.unwrap();
(*req.headers_mut()).insert(HOST, hval!("foo.baz:443"));
(*req.headers_mut()).insert(X_FORWARDED_HOST, hval!("dead.beef:443"));
assert_eq!(extract_authority(&req), Some("dead.beef"));
}
#[test]
fn test_url_to_uri() {
let url = "https://foo.bar/baz?dead=beef".parse().unwrap();
assert_eq!(
url_to_uri(&url).unwrap(),
Uri::from_static("https://foo.bar/baz?dead=beef")
);
let url = "unix:/foo/bar".parse().unwrap();
assert!(url_to_uri(&url).is_err());
}
#[tokio::test]
async fn test_redirect_to_https() {
let mut request = axum::extract::Request::new(Body::empty());
*request.uri_mut() = Uri::from_static("http://foo/bar/baz.bin?a=b");
let router = Router::new().fallback(redirect_to_https);
let response = router.oneshot(request).await.unwrap();
let location = response.headers().get(LOCATION).unwrap().to_str().unwrap();
assert_eq!(location, "https://foo/bar/baz.bin?a=b");
}
}