use core::error::Error;
use core::ops::{Deref, DerefMut};
use core::time::Duration;
use hyper::client::conn::http1;
use hyper_util::rt::TokioIo;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, LazyLock};
use std::time::Instant;
use tokio::join;
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::sync::{Mutex, RwLock};
use tokio::task::{AbortHandle, JoinSet};
use tracing::{trace, warn};
use wrpc_interface_http::bindings::{
wasi::http::types::DnsErrorPayload, wrpc::http::types::ErrorCode,
};
pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(90);
pub const DEFAULT_USER_AGENT: &str =
concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(600);
pub const DEFAULT_FIRST_BYTE_TIMEOUT: Duration = Duration::from_secs(600);
pub const LOAD_NATIVE_CERTS: &str = "load_native_certs";
pub const LOAD_WEBPKI_CERTS: &str = "load_webpki_certs";
pub const SSL_CERTS_FILE: &str = "ssl_certs_file";
pub static ZERO_INSTANT: LazyLock<Instant> = LazyLock::new(Instant::now);
#[derive(Clone, Debug)]
pub struct PooledConn<T> {
pub sender: T,
pub abort: AbortHandle,
pub last_seen: Instant,
}
impl<T> Deref for PooledConn<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.sender
}
}
impl<T> DerefMut for PooledConn<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.sender
}
}
impl<T: PartialEq> PartialEq for PooledConn<T> {
fn eq(
&self,
Self {
sender,
abort,
last_seen,
}: &Self,
) -> bool {
self.sender == *sender && self.abort.id() == abort.id() && self.last_seen == *last_seen
}
}
impl<T> Drop for PooledConn<T> {
fn drop(&mut self) {
self.abort.abort();
}
}
impl<T> PooledConn<T> {
pub fn new(sender: T, abort: AbortHandle) -> Self {
Self {
sender,
abort,
last_seen: *ZERO_INSTANT,
}
}
}
pub type ConnPoolTable<T> =
RwLock<HashMap<Box<str>, std::sync::Mutex<VecDeque<PooledConn<http1::SendRequest<T>>>>>>;
#[derive(Debug)]
pub struct ConnPool<T> {
pub http: Arc<ConnPoolTable<T>>,
pub https: Arc<ConnPoolTable<T>>,
pub tasks: Arc<Mutex<JoinSet<()>>>,
}
impl<T> Default for ConnPool<T> {
fn default() -> Self {
Self {
http: Arc::default(),
https: Arc::default(),
tasks: Arc::default(),
}
}
}
impl<T> Clone for ConnPool<T> {
fn clone(&self) -> Self {
Self {
http: self.http.clone(),
https: self.https.clone(),
tasks: self.tasks.clone(),
}
}
}
pub fn evict_conns<T>(
cutoff: Instant,
conns: &mut HashMap<Box<str>, std::sync::Mutex<VecDeque<PooledConn<T>>>>,
) {
trace!(target: "http_client::evict", ?cutoff, total_authorities=conns.len(), "evicting connections older than cutoff");
let mut total_evicted = 0;
conns.retain(|authority, conns| {
let Ok(conns) = conns.get_mut() else {
trace!(target: "http_client::evict", %authority, "skipping locked connection pool");
return true;
};
let total_conns = conns.len();
let idx = conns.partition_point(|&PooledConn { last_seen, .. }| last_seen <= cutoff);
if idx == conns.len() {
trace!(target: "http_client::evict", %authority, evicted=total_conns, "evicting all connections");
total_evicted += total_conns;
false
} else if idx == 0 {
trace!(target: "http_client::evict", %authority, total=total_conns, "no connections to evict");
true
} else {
trace!(target: "http_client::evict", %authority, evicted=idx, remaining=(total_conns - idx), "partially evicting connections");
conns.rotate_left(idx);
conns.truncate(total_conns - idx);
total_evicted += idx;
true
}
});
trace!(target: "http_client::evict", total_evicted, remaining_authorities=conns.len(), "connection eviction complete");
}
impl<T> ConnPool<T> {
pub async fn evict(&self, timeout: Duration) {
let Some(cutoff) = Instant::now().checked_sub(timeout) else {
return;
};
join!(
async {
let mut conns = self.http.write().await;
evict_conns(cutoff, &mut conns);
},
async {
let mut conns = self.https.write().await;
evict_conns(cutoff, &mut conns);
}
);
}
#[allow(dead_code)]
pub async fn connect_http(
&self,
authority: &str,
) -> Result<Cacheable<PooledConn<http1::SendRequest<T>>>, ErrorCode>
where
T: http_body::Body + Send + 'static,
T::Data: Send,
T::Error: Into<Box<dyn Error + Send + Sync>>,
{
trace!(target: "http_client::connect_http", authority, "attempting HTTP connection");
{
let http = self.http.read().await;
if let Some(conns) = http.get(authority) {
if let Ok(mut conns) = conns.lock() {
trace!(target: "http_client::connect_http", authority, cached_connections=conns.len(), "checking cached HTTP connections");
while let Some(conn) = conns.pop_front() {
trace!(target: "http_client::connect_http", authority, "found cached HTTP connection");
if !conn.is_closed() && conn.is_ready() {
trace!(target: "http_client::connect_http", authority, "returning HTTP connection cache hit");
return Ok(Cacheable::Hit(conn));
} else {
trace!(target: "http_client::connect_http", authority, is_closed=conn.is_closed(), is_ready=conn.is_ready(), "discarding unusable cached HTTP connection");
}
}
}
}
}
trace!(target: "http_client::connect_http", authority, "establishing new TCP connection");
let stream = connect(authority).await?;
trace!(target: "http_client::connect_http", authority, "starting HTTP handshake");
let (sender, conn) = http1::handshake(TokioIo::new(stream))
.await
.map_err(|err| {
warn!(target: "http_client::connect_http", error=?err, authority, "HTTP handshake failed");
hyper_request_error(err)
})?;
let tasks = Arc::clone(&self.tasks);
let authority_clone = authority.to_string();
let abort = tasks.lock().await.spawn(async move {
match conn.await {
Ok(()) => trace!(target: "http_client::connect_http", authority=authority_clone, "HTTP connection closed successfully"),
Err(err) => warn!(target: "http_client::connect_http", ?err, authority=authority_clone, "HTTP connection closed with error"),
}
});
trace!(target: "http_client::connect_http", authority, "returning HTTP connection cache miss");
Ok(Cacheable::Miss(PooledConn::new(sender, abort)))
}
#[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
pub async fn connect_https(
&self,
_tls: &tokio_rustls::TlsConnector,
_authority: &str,
) -> Result<Cacheable<PooledConn<http1::SendRequest<T>>>, ErrorCode>
where
T: http_body::Body + Send + 'static,
T::Data: Send,
T::Error: Into<Box<dyn Error + Send + Sync>>,
{
Err(ErrorCode::InternalError(Some(
"HTTPS connections are not supported on this architecture".to_string(),
)))
}
#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
pub async fn connect_https(
&self,
tls: &tokio_rustls::TlsConnector,
authority: &str,
) -> Result<Cacheable<PooledConn<http1::SendRequest<T>>>, ErrorCode>
where
T: http_body::Body + Send + 'static,
T::Data: Send,
T::Error: Into<Box<dyn Error + Send + Sync>>,
{
use rustls::pki_types::ServerName;
trace!(target: "http_client::connect_https", authority, "attempting HTTPS connection");
{
let https = self.https.read().await;
if let Some(conns) = https.get(authority) {
if let Ok(mut conns) = conns.lock() {
trace!(target: "http_client::connect_https", authority, cached_connections=conns.len(), "checking cached HTTPS connections");
while let Some(conn) = conns.pop_front() {
trace!(target: "http_client::connect_https", authority, "found cached HTTPS connection");
if !conn.is_closed() && conn.is_ready() {
trace!(target: "http_client::connect_https", authority, "returning HTTPS connection cache hit");
return Ok(Cacheable::Hit(conn));
} else {
trace!(target: "http_client::connect_https", authority, is_closed=conn.is_closed(), is_ready=conn.is_ready(), "discarding unusable cached HTTPS connection");
}
}
}
}
}
trace!(target: "http_client::connect_https", authority, "establishing new TCP connection");
let stream = connect(authority).await?;
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(authority);
trace!(target: "http_client::connect_https", authority, host, "resolving server name for TLS");
let domain = ServerName::try_from(host)
.map_err(|err| {
warn!(target: "http_client::connect_https", ?err, authority, host, "invalid DNS name for TLS");
dns_error("invalid DNS name".to_string(), 0)
})?
.to_owned();
trace!(target: "http_client::connect_https", authority, host, "starting TLS handshake");
let stream = tls.connect(domain, stream).await.map_err(|err| {
warn!(target: "http_client::connect_https", ?err, authority, host, "TLS handshake failed");
ErrorCode::TlsProtocolError
})?;
trace!(target: "http_client::connect_https", authority, "starting HTTP handshake over TLS");
let (sender, conn) = http1::handshake(TokioIo::new(stream))
.await
.map_err(|err| {
warn!(target: "http_client::connect_https", error=?err, authority, "HTTP handshake failed over TLS");
hyper_request_error(err)
})?;
let tasks = Arc::clone(&self.tasks);
let authority_clone = authority.to_string();
let abort = tasks.lock().await.spawn(async move {
match conn.await {
Ok(()) => trace!(target: "http_client::connect_https", authority=authority_clone, "HTTPS connection closed successfully"),
Err(err) => warn!(target: "http_client::connect_https", ?err, authority=authority_clone, "HTTPS connection closed with error"),
}
});
trace!(target: "http_client::connect_https", authority, "returning HTTPS connection cache miss");
Ok(Cacheable::Miss(PooledConn::new(sender, abort)))
}
}
pub enum Cacheable<T> {
Miss(T),
Hit(T),
}
impl<T> Deref for Cacheable<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
Self::Miss(v) | Self::Hit(v) => v,
}
}
}
impl<T> DerefMut for Cacheable<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
Self::Miss(v) | Self::Hit(v) => v,
}
}
}
impl<T> Cacheable<T> {
#[allow(dead_code)]
pub fn unwrap(self) -> T {
match self {
Self::Miss(v) => {
trace!(target: "http_client::cache", "unwrapping cache miss");
v
}
Self::Hit(v) => {
trace!(target: "http_client::cache", "unwrapping cache hit");
v
}
}
}
}
fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
ErrorCode::DnsError(DnsErrorPayload {
rcode: Some(rcode),
info_code: Some(info_code),
})
}
async fn connect(addr: impl ToSocketAddrs) -> Result<TcpStream, ErrorCode> {
trace!(target: "http_client::connect", "attempting TCP connection");
match TcpStream::connect(addr).await {
Ok(stream) => {
trace!(target: "http_client::connect", "TCP connection established successfully");
Ok(stream)
}
Err(err) if err.kind() == std::io::ErrorKind::AddrNotAvailable => {
warn!(target: "http_client::connect", error=?err, "address not available");
Err(dns_error("address not available".to_string(), 0))
}
Err(err) => {
if err
.to_string()
.starts_with("failed to lookup address information")
{
warn!(target: "http_client::connect", error=?err, "DNS lookup failed");
Err(dns_error("address not available".to_string(), 0))
} else {
warn!(target: "http_client::connect", error=?err, "connection refused");
Err(ErrorCode::ConnectionRefused)
}
}
}
}
pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
if let Some(cause) = err.source() {
if let Some(io_err) = cause.downcast_ref::<std::io::Error>() {
match io_err.kind() {
std::io::ErrorKind::ConnectionRefused => return ErrorCode::ConnectionRefused,
std::io::ErrorKind::ConnectionReset => return ErrorCode::ConnectionTerminated,
std::io::ErrorKind::TimedOut => return ErrorCode::ConnectionTimeout,
_ => {}
}
}
warn!(
target: "http_client::error",
error=?err,
cause=?cause,
"HTTP request failed with underlying cause"
);
return ErrorCode::HttpProtocolError;
}
warn!(
target: "http_client::error",
error=?err,
"HTTP request failed"
);
ErrorCode::HttpProtocolError
}
#[cfg(test)]
mod tests {
use core::net::Ipv4Addr;
use std::collections::{HashMap, VecDeque};
use std::time::Instant;
use anyhow::Context as _;
use bytes::Bytes;
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio::spawn;
use tokio::try_join;
use super::*;
use wrpc_interface_http::HttpBody;
const N: usize = 20;
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn test_conn_evict() -> anyhow::Result<()> {
let now = Instant::now();
let mut foo = VecDeque::from([
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_sub(Duration::from_secs(10))
.expect("time subtraction should not overflow"),
},
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_sub(Duration::from_secs(1))
.expect("time subtraction should not overflow"),
},
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now,
},
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_add(Duration::from_secs(1))
.expect("time addition should not overflow"),
},
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_add(Duration::from_secs(1))
.expect("time addition should not overflow"),
},
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_add(Duration::from_secs(3))
.expect("time addition should not overflow"),
},
]);
let qux = VecDeque::from([
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_add(Duration::from_secs(10))
.expect("time addition should not overflow"),
},
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_add(Duration::from_secs(12))
.expect("time addition should not overflow"),
},
]);
let mut conns = HashMap::from([
("foo".into(), std::sync::Mutex::new(foo.clone())),
("bar".into(), std::sync::Mutex::default()),
(
"baz".into(),
std::sync::Mutex::new(VecDeque::from([
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_sub(Duration::from_secs(10))
.expect("time subtraction should not overflow"),
},
PooledConn {
sender: (),
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_sub(Duration::from_secs(1))
.expect("time subtraction should not overflow"),
},
])),
),
("qux".into(), std::sync::Mutex::new(qux.clone())),
]);
evict_conns(now, &mut conns);
assert_eq!(
conns
.remove("foo")
.expect("foo should exist")
.into_inner()
.expect("mutex should be unlocked"),
foo.split_off(3)
);
assert_eq!(
conns
.remove("qux")
.expect("qux should exist")
.into_inner()
.expect("mutex should be unlocked"),
qux
);
assert!(conns.is_empty());
evict_conns(now, &mut conns);
assert!(conns.is_empty());
Ok(())
}
#[cfg(feature = "http")]
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn test_pool_evict() -> anyhow::Result<()> {
eprintln!("Starting test_pool_evict");
const IDLE_TIMEOUT: Duration = Duration::from_millis(10);
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
let addr = listener.local_addr()?;
eprintln!("Test server bound to {addr}");
try_join!(
async {
eprintln!("Server task starting, will accept {N} connections");
for i in 0..N {
eprintln!("[{}/{}] Waiting to accept connection...", i + 1, N);
let (stream, _) = listener
.accept()
.await
.with_context(|| format!("failed to accept connection {i}"))?;
eprintln!("[{}/{}] Connection accepted, serving...", i + 1, N);
hyper::server::conn::http1::Builder::new()
.serve_connection(
TokioIo::new(stream),
hyper::service::service_fn(move |_| async {
anyhow::Ok(http::Response::new(
http_body_util::Empty::<Bytes>::new(),
))
}),
)
.await
.with_context(|| format!("failed to serve connection {i}"))?;
eprintln!("[{}/{}] Connection served and closed", i + 1, N);
}
eprintln!("Server task completed all {N} connections");
anyhow::Ok(())
},
async {
eprintln!("Client task starting");
let pool = ConnPool::<HttpBody>::default();
let now = Instant::now();
eprintln!(" Creating {N} connections to server at {addr}");
{
let mut http_conns = pool.http.write().await;
let mut connections = VecDeque::new();
for i in 0..N {
eprintln!("[{}/{}] Establishing handshake...", i + 1, N);
let (sender, _) =
http1::handshake(TokioIo::new(TcpStream::connect(addr).await?)).await?;
eprintln!("[{}/{}] Handshake completed", i + 1, N);
connections.push_back(PooledConn {
sender,
abort: spawn(async {}).abort_handle(),
last_seen: now
.checked_sub(Duration::from_secs(10))
.expect("time subtraction should not overflow"),
});
}
http_conns.insert(addr.to_string().into(), std::sync::Mutex::new(connections));
eprintln!("All {N} connections added to pool");
}
eprintln!("Sleeping for a bit to let connections age...");
tokio::time::sleep(Duration::from_millis(20)).await;
eprintln!("Starting eviction process...");
pool.evict(IDLE_TIMEOUT).await;
eprintln!("Eviction completed");
eprintln!("Verifying connections were evicted...");
let http_conns = pool.http.read().await;
let result = http_conns
.get(addr.to_string().into_boxed_str().as_ref())
.is_none();
eprintln!("Eviction verification result: authority removed = {result}");
assert!(result);
eprintln!("Client task completed successfully");
Ok(())
}
)?;
Ok(())
}
}