use {
crate::{
constants::HEADER_BEARER_PREFIX,
url::IpUrl,
},
flowcontrol::shed,
futures::future::join_all,
hickory_resolver::{
config::LookupIpStrategy,
proto::{
op::Query,
rr::{
RData,
RecordType,
},
},
Hosts,
Name,
},
http::{
header::{
AUTHORIZATION,
HOST,
},
HeaderValue,
},
http_body_util::{
BodyExt,
Full,
Limited,
},
hyper::{
body::Bytes,
client::conn::http1::{
Connection,
SendRequest,
},
HeaderMap,
Request,
StatusCode,
Uri,
},
hyper_rustls::{
ConfigBuilderExt,
HttpsConnectorBuilder,
},
loga::{
ea,
DebugDisplay,
Log,
ResultContext,
},
rand::{
seq::SliceRandom,
thread_rng,
},
rustls::ClientConfig,
serde::{
de::DeserializeOwned,
Serialize,
},
std::{
collections::HashMap,
net::{
IpAddr,
Ipv4Addr,
Ipv6Addr,
},
str::FromStr,
sync::LazyLock,
time::Duration,
},
tokio::{
io::{
AsyncWrite,
AsyncWriteExt,
},
join,
select,
sync::mpsc::{
self,
channel,
},
time::sleep,
},
tower_service::Service,
};
#[derive(Clone, Copy)]
pub struct Limits {
pub resolve_time: Duration,
pub connect_time: Duration,
pub read_header_time: Duration,
pub read_body_time: Duration,
pub read_body_size: usize,
}
impl Default for Limits {
fn default() -> Self {
return Self {
resolve_time: Duration::from_secs(10),
connect_time: Duration::from_secs(10),
read_header_time: Duration::from_secs(10),
read_body_time: Duration::from_secs(30),
read_body_size: 16 * 1024 * 1024,
};
}
}
impl Limits {
pub fn with_resolve_time(self, time: Duration) -> Self {
return Self {
resolve_time: time,
..Default::default()
};
}
pub fn with_connect_time(self, time: Duration) -> Self {
return Self {
connect_time: time,
..Default::default()
};
}
pub fn with_read_header_time(self, time: Duration) -> Self {
return Self {
read_header_time: time,
..Default::default()
};
}
pub fn with_read_body_time(self, time: Duration) -> Self {
return Self {
read_body_time: time,
..Default::default()
};
}
pub fn with_read_body_size(self, size: usize) -> Self {
return Self {
read_body_size: size,
..Default::default()
};
}
}
pub fn default_tls() -> rustls::ClientConfig {
static S: LazyLock<rustls::ClientConfig> =
LazyLock::new(
|| ClientConfig::builder()
.with_native_roots()
.context("Error loading native roots")
.unwrap()
.with_no_client_auth(),
);
return S.clone();
}
pub struct Conn<B: 'static + http_body::Body = Full<Bytes>> {
pub inner: Option<
(
SendRequest<B>,
Connection<hyper_rustls::MaybeHttpsStream<hyper_util::rt::tokio::TokioIo<tokio::net::TcpStream>>, B>,
),
>,
}
impl<B: 'static + http_body::Body> Conn<B> {
pub fn new(
c:
(
SendRequest<B>,
Connection<
hyper_rustls::MaybeHttpsStream<hyper_util::rt::tokio::TokioIo<tokio::net::TcpStream>>,
B,
>,
),
) -> Self {
return Conn { inner: Some(c) };
}
}
#[derive(Clone)]
pub enum Host {
Ip(IpAddr),
Name(String),
}
impl std::fmt::Display for Host {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Host::Ip(i) => return i.fmt(f),
Host::Name(i) => return i.fmt(f),
}
}
}
pub fn uri_parts(url: &Uri) -> Result<(String, Host, u16), loga::Error> {
let host = url.host().context("Url is missing host")?;
if host.is_empty() {
return Err(loga::err("Host portion of url is empty"));
}
let host = if host.as_bytes()[0] as char == '[' {
Host::Ip(
IpAddr::V6(
Ipv6Addr::from_str(
&String::from_utf8(
host.as_bytes()[1..]
.split_last()
.context("URL ipv6 missing ending ]")?
.1
.iter()
.cloned()
.collect(),
).unwrap(),
).context("Invalid ipv6 address in URL")?,
),
)
} else if host.as_bytes().iter().all(|b| (*b as char) == '.' || ('0' ..= '9').contains(&(*b as char))) {
Host::Ip(IpAddr::V4(Ipv4Addr::from_str(host).context("Invalid ipv4 address in URL")?))
} else {
Host::Name(host.to_string())
};
let scheme = url.scheme().context("Url is missing scheme")?.to_string();
let port = match url.port_u16() {
Some(p) => p,
None => match scheme.as_str() {
"http" => 80,
"https" => 443,
_ => return Err(loga::err("Only http/https urls are supported")),
},
};
return Ok((scheme, host, port));
}
#[derive(Clone)]
pub struct Ips {
pub ipv4s: Vec<Ipv4Addr>,
pub ipv6s: Vec<Ipv6Addr>,
}
impl Ips {
pub fn push(&mut self, addr: IpAddr) {
match addr {
IpAddr::V4(a) => self.ipv4s.push(a),
IpAddr::V6(a) => self.ipv6s.push(a),
}
}
}
impl From<IpAddr> for Ips {
fn from(value: IpAddr) -> Self {
match &value {
IpAddr::V4(a) => return Ips {
ipv4s: vec![*a],
ipv6s: vec![],
},
IpAddr::V6(a) => return Ips {
ipv4s: vec![],
ipv6s: vec![*a],
},
}
}
}
pub async fn resolve(limits: Limits, host: &Host) -> Result<Ips, loga::Error> {
let mut ipv4s = vec![];
let mut ipv6s = vec![];
match host {
Host::Ip(i) => {
match i {
IpAddr::V4(i) => ipv4s.push(*i),
IpAddr::V6(i) => ipv6s.push(*i),
};
},
Host::Name(host) => {
let host = format!("{}.", host);
static HOSTS: LazyLock<Hosts> = LazyLock::new(|| Hosts::new());
shed!{
'found_hosts _;
shed!{
let Ok(name) = Name::from_utf8(&host) else {
break;
};
let mut found_etc_hosts = false;
if let Some(res) = HOSTS.lookup_static_host(&Query::query(name.clone(), RecordType::A)) {
for rec in res {
let RData::A(rec) = rec else {
continue;
};
ipv4s.push(rec.0);
found_etc_hosts = true;
}
};
if let Some(res) = HOSTS.lookup_static_host(&Query::query(name.clone(), RecordType::AAAA)) {
for rec in res {
let RData::AAAA(rec) = rec else {
continue;
};
ipv6s.push(rec.0);
found_etc_hosts = true;
}
};
if found_etc_hosts {
break 'found_hosts;
}
}
let (hickory_config, mut hickory_options) =
hickory_resolver
::system_conf
::read_system_conf().context("Error reading system dns resolver config for http request")?;
hickory_options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
hickory_options.timeout = limits.resolve_time;
for ip in hickory_resolver::TokioAsyncResolver::tokio(hickory_config, hickory_options)
.lookup_ip(&host)
.await
.context("Failed to look up lookup host ip addresses")? {
match ip {
std::net::IpAddr::V4(ip) => {
ipv4s.push(ip);
},
std::net::IpAddr::V6(ip) => {
ipv6s.push(ip);
},
}
};
}
},
};
{
let mut r = thread_rng();
ipv4s.shuffle(&mut r);
ipv6s.shuffle(&mut r);
}
return Ok(Ips {
ipv4s: ipv4s,
ipv6s: ipv6s,
});
}
pub async fn connect_ips<
D: hyper::body::Buf + Send,
E: 'static + std::error::Error + Send + Sync,
B: 'static + http_body::Body<Data = D, Error = E>,
>(
limits: Limits,
ips: Ips,
tls: rustls::ClientConfig,
scheme: String,
host: Host,
port: u16,
) -> Result<Conn<B>, loga::Error> {
let mut bg = vec![];
let (found_tx, mut found_rx) = mpsc::channel(1);
for ips in [
ips.ipv6s.into_iter().map(|x| IpAddr::V6(x)).collect::<Vec<_>>(),
ips.ipv4s.into_iter().map(|x| IpAddr::V4(x)).collect::<Vec<_>>(),
] {
bg.push({
let found_tx = found_tx.clone();
let scheme = &scheme;
let host = &host;
let tls = tls.clone();
async move {
let mut errs = vec![];
for ip in &ips {
let connect = async {
return Ok(
HttpsConnectorBuilder::new()
.with_tls_config(tls.clone())
.https_or_http()
.with_server_name(host.to_string())
.enable_http1()
.build()
.call(Uri::from_str(&format!("{}://{}:{}", scheme, ip.as_url_host(), port)).unwrap())
.await
.map_err(
|e| loga::err_with(
"Connection failed",
ea!(err = e.to_string(), dest_addr = ip, host = host, port = port),
),
)?,
);
};
let res = select!{
_ = sleep(limits.connect_time) => Err(loga::err("Timeout connecting")),
res = connect => res,
};
match res {
Ok(conn) => {
_ =
found_tx.try_send(
Conn::new(
hyper::client::conn::http1::handshake(conn)
.await
.context("Error completing http handshake")?,
),
);
return Ok(());
},
Err(e) => {
errs.push(e);
},
}
}
return Err(
loga::agg_err_with(
"Couldn't establish a connection to any ip in version set",
errs,
ea!(ips = ips.dbg_str()),
),
);
}
});
}
let results = select!{
results = join_all(bg) => results,
found = found_rx.recv() => {
return Ok(found.unwrap());
}
};
if results.is_empty() {
return Err(loga::err("No addresses found when looking up host"));
}
let mut failed = vec![];
for res in results {
match res {
Ok(_) => {
let found = found_rx.recv().await.unwrap();
return Ok(found);
},
Err(e) => {
failed.push(e);
},
}
}
return Err(loga::agg_err("Unable to connect to host", failed));
}
pub async fn connect<
D: hyper::body::Buf + Send,
E: 'static + std::error::Error + Send + Sync,
B: 'static + http_body::Body<Data = D, Error = E>,
>(limits: Limits, base_url: &Uri) -> Result<Conn<B>, loga::Error> {
let log = &Log::new().fork(ea!(url = base_url));
let (scheme, host, port) = uri_parts(base_url).stack_context(log, "Incomplete url")?;
let ips = resolve(limits, &host).await.stack_context(log, "Error resolving ips for url")?;
return Ok(
connect_ips(limits, ips, default_tls(), scheme, host, port)
.await
.stack_context(log, "Failed to establish connection")?,
);
}
pub async fn connect_with_tls<
D: hyper::body::Buf + Send,
E: 'static + std::error::Error + Send + Sync,
B: 'static + http_body::Body<Data = D, Error = E>,
>(limits: Limits, base_url: &Uri, tls: rustls::ClientConfig) -> Result<Conn<B>, loga::Error> {
let log = &Log::new().fork(ea!(url = base_url));
let (scheme, host, port) = uri_parts(base_url).stack_context(log, "Incomplete url")?;
let ips = resolve(limits, &host).await.stack_context(log, "Error resolving ips for url")?;
return Ok(
connect_ips(limits, ips, tls, scheme, host, port).await.stack_context(log, "Failed to establish connection")?,
);
}
#[must_use]
pub struct ContinueSend<'a, B: 'static + http_body::Body> {
body: hyper::body::Incoming,
conn_send: SendRequest<B>,
conn_bg: Connection<hyper_rustls::MaybeHttpsStream<hyper_util::rt::TokioIo<tokio::net::TcpStream>>, B>,
conn: &'a mut Conn<B>,
}
pub async fn send<
'a,
D: hyper::body::Buf + Send,
E: 'static + std::error::Error + Send + Sync,
B: 'static + http_body::Body<Data = D, Error = E>,
>(
log: &Log,
limits: Limits,
conn: &'a mut Conn<B>,
mut req: Request<B>,
) -> Result<(StatusCode, HeaderMap, ContinueSend<'a, B>), loga::Error> {
if !req.headers().contains_key(HOST) {
if let Some(host) = req.uri().host() {
let host = HeaderValue::from_str(host).context("Error setting HOST header")?;
req.headers_mut().insert(HOST, host);
}
}
let Some((mut conn_send, mut conn_bg)) = conn.inner.take() else {
return Err(loga::err("Connection already lost"));
};
let method = req.method().to_string();
let url = req.uri().to_string();
let read = async move {
let work = conn_send.send_request(req);
let resp = select!{
_ =& mut conn_bg => {
return Err(loga::err("Connection failed while sending request"));
}
r = work => r,
}.context("Error sending request")?;
let status = resp.status();
let headers = resp.headers().clone();
return Ok((status, headers, ContinueSend {
body: resp.into_body(),
conn_send: conn_send,
conn_bg: conn_bg,
conn: conn,
}));
};
let (status, headers, continue_send) = select!{
_ = sleep(limits.read_header_time) => {
return Err(loga::err("Timeout sending request and waiting for headers from server"));
}
x = read => x ?,
};
log.log_with(
loga::DEBUG,
"Receive (streamed)",
ea!(method = method, url = url, status = status, headers = headers.dbg_str()),
);
if !status.is_success() {
match receive(limits, continue_send).await {
Ok(body) => {
return Err(
loga::err_with(
"Server returned error response",
ea!(status = status, body = String::from_utf8_lossy(&body)),
),
);
},
Err(e) => {
let err = loga::err_with("Server returned error response", ea!(status = status));
return Err(err.also(e));
},
}
}
return Ok((status, headers, continue_send));
}
pub async fn receive_stream<
'a,
D: hyper::body::Buf + Send,
E: 'static + std::error::Error + Send + Sync,
B: 'static + http_body::Body<Data = D, Error = E>,
>(mut continue_send: ContinueSend<'a, B>, mut writer: impl Unpin + AsyncWrite) -> Result<(), loga::Error> {
let (chan_tx, mut chan_rx) = channel(10);
let work_read = {
let continue_send = &mut continue_send;
async move {
loop {
let work = continue_send.body.frame();
let frame = match select!{
_ =& mut continue_send.conn_bg => {
return Err(loga::err("Connection failed while reading body"));
}
r = work => r,
} {
Some(f) => {
f
.map_err(|e| loga::err_with("Error reading response", ea!(err = e)))?
.into_data()
.map_err(
|e| loga::err_with("Received unexpected non-data frame", ea!(err = e.dbg_str())),
)?
.to_vec()
},
None => {
break;
},
};
chan_tx.send(frame).await.context("Error writing frame to channel for writer")?;
}
return Ok(()) as Result<(), loga::Error>;
}
};
let work_write = async move {
while let Some(frame) = chan_rx.recv().await {
writer.write_all(&frame).await.context("Error sending frame to writer")?;
}
return Ok(()) as Result<(), loga::Error>;
};
let (read_res, write_res) = join!(work_read, work_write);
let mut errs = vec![];
if let Err(e) = read_res {
errs.push(e);
}
if let Err(e) = write_res {
errs.push(e);
}
if !errs.is_empty() {
return Err(loga::agg_err("Encountered errors while streaming response body", errs));
}
continue_send.conn.inner = Some((continue_send.conn_send, continue_send.conn_bg));
return Ok(());
}
pub async fn receive<
'a,
D: hyper::body::Buf + Send,
E: 'static + std::error::Error + Send + Sync,
B: 'static + http_body::Body<Data = D, Error = E>,
>(limits: Limits, mut continue_send: ContinueSend<'a, B>) -> Result<Vec<u8>, loga::Error> {
let read = async move {
let work = Limited::new(continue_send.body, limits.read_body_size).collect();
let body = select!{
_ =& mut continue_send.conn_bg => {
return Err(loga::err("Connection failed while reading body"));
}
r = work => r,
}.map_err(|e| loga::err_with("Error reading response", ea!(err = e)))?.to_bytes().to_vec();
return Ok((body, continue_send.conn_send, continue_send.conn_bg));
};
let (body, conn_send, conn_bg) = select!{
_ = sleep(limits.read_body_time) => {
return Err(loga::err("Timeout waiting for response from server"));
}
x = read => x ?,
};
continue_send.conn.inner = Some((conn_send, conn_bg));
return Ok(body);
}
pub async fn send_simple<
D: hyper::body::Buf + Send,
E: 'static + std::error::Error + Send + Sync,
B: 'static + http_body::Body<Data = D, Error = E>,
>(log: &Log, limits: Limits, conn: &mut Conn<B>, req: Request<B>) -> Result<Vec<u8>, loga::Error> {
let (code, _, continue_send) = send(log, limits, conn, req).await?;
let body = receive(limits, continue_send).await?;
if !code.is_success() {
return Err(
loga::err_with("Received invalid status in http response", ea!(body = String::from_utf8_lossy(&body))),
);
}
return Ok(body);
}
pub async fn post(
log: &Log,
limits: Limits,
conn: &mut Conn,
url: &Uri,
headers: &HashMap<String, String>,
body: Vec<u8>,
) -> Result<Vec<u8>, loga::Error> {
let req = Request::builder();
let mut req = req.method("POST").uri(url.clone());
for (k, v) in headers.iter() {
req = req.header(k, v);
}
log.log_with(
loga::DEBUG,
"Send",
ea!(method = "POST", url = url, headers = req.headers_ref().dbg_str(), body = String::from_utf8_lossy(&body)),
);
return Ok(
send_simple(log, limits, conn, req.body(Full::new(Bytes::from(body))).unwrap())
.await
.context_with("Error sending POST", ea!(url = url))?,
);
}
pub async fn post_json<
T: DeserializeOwned,
>(
log: &Log,
limits: Limits,
conn: &mut Conn,
url: &Uri,
headers: &HashMap<String, String>,
body: impl Serialize,
) -> Result<T, loga::Error> {
let res = post(log, limits, conn, url, headers, serde_json::to_vec(&body).unwrap()).await?;
return Ok(
serde_json::from_slice(
&res,
).context_with("Error deserializing response as json", ea!(url = url, body = String::from_utf8_lossy(&res)))?,
);
}
pub fn auth_token_headers(token: &str) -> HashMap<String, String> {
let mut out = HashMap::new();
out.insert(AUTHORIZATION.to_string(), format!("{}{}", HEADER_BEARER_PREFIX, token));
return out;
}
pub async fn get(
log: &Log,
limits: Limits,
conn: &mut Conn,
url: &Uri,
headers: &HashMap<String, String>,
) -> Result<Vec<u8>, loga::Error> {
let req = Request::builder();
const METHOD: &'static str = "GET";
let mut req = req.method(METHOD).uri(url.clone());
for (k, v) in headers.iter() {
req = req.header(k, v);
}
return Ok(
send_simple(log, limits, conn, req.body(Full::<Bytes>::new(Bytes::new())).unwrap())
.await
.context_with("Error sending GET", ea!(url = url))?,
);
}
pub async fn get_text(
log: &Log,
limits: Limits,
conn: &mut Conn,
url: &Uri,
headers: &HashMap<String, String>,
) -> Result<String, loga::Error> {
let body = get(log, limits, conn, url, headers).await?;
return Ok(
String::from_utf8(
body,
).map_err(
|e| loga::err_with(
"Received data isn't valid utf-8",
ea!(err = e.to_string(), body = String::from_utf8_lossy(e.as_bytes())),
),
)?,
);
}
pub async fn get_json<
T: DeserializeOwned,
>(
log: &Log,
limits: Limits,
conn: &mut Conn,
url: &Uri,
headers: &HashMap<String, String>,
) -> Result<T, loga::Error> {
let res = get(log, limits, conn, url, headers).await?;
return Ok(
serde_json::from_slice(
&res,
).context_with("Error deserializing response as json", ea!(url = url, body = String::from_utf8_lossy(&res)))?,
);
}
pub async fn delete(
log: &Log,
limits: Limits,
conn: &mut Conn,
url: &Uri,
headers: &HashMap<String, String>,
) -> Result<Vec<u8>, loga::Error> {
let req = Request::builder();
const METHOD: &'static str = "DELETE";
let mut req = req.method(METHOD).uri(url.clone());
for (k, v) in headers.iter() {
req = req.header(k, v);
}
return Ok(
send_simple(log, limits, conn, req.body(Full::<Bytes>::new(Bytes::new())).unwrap())
.await
.context_with("Error sending DELETE", ea!(url = url))?,
);
}