use nexus_net::http::ResponseReader;
use nexus_net::rest::{RequestWriter, RestError};
#[cfg(feature = "tls")]
use nexus_net::tls::TlsConfig;
use nexus_pool::local::{Pool, Pooled};
use super::connection::{AsyncHttpConnection, AsyncHttpConnectionBuilder};
use crate::maybe_tls::MaybeTls;
pub struct ClientSlot {
pub writer: RequestWriter,
pub reader: ResponseReader,
pub conn: Option<AsyncHttpConnection<MaybeTls>>,
}
impl ClientSlot {
pub fn needs_reconnect(&self) -> bool {
self.conn
.as_ref()
.is_none_or(AsyncHttpConnection::is_poisoned)
}
pub fn conn_and_reader(
&mut self,
) -> Result<(&mut AsyncHttpConnection<MaybeTls>, &mut ResponseReader), RestError> {
let conn = self.conn.as_mut().ok_or(RestError::ConnectionPoisoned)?;
Ok((conn, &mut self.reader))
}
}
pub struct ClientPool {
pool: Pool<ClientSlot>,
reconnect_config: ReconnectConfig,
}
#[derive(Clone)]
struct ReconnectConfig {
url: String,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
nodelay: bool,
#[cfg(feature = "socket-opts")]
tcp_keepalive: Option<std::time::Duration>,
#[cfg(feature = "socket-opts")]
recv_buf_size: Option<usize>,
#[cfg(feature = "socket-opts")]
send_buf_size: Option<usize>,
}
#[allow(clippy::future_not_send)] impl ClientPool {
#[must_use]
pub fn builder() -> ClientPoolBuilder {
ClientPoolBuilder::new()
}
pub fn try_acquire(&self) -> Option<Pooled<ClientSlot>> {
loop {
let slot = self.pool.try_acquire()?;
if !slot.needs_reconnect() {
return Some(slot);
}
self.spawn_reconnect(slot);
}
}
pub async fn acquire(&self) -> Result<Pooled<ClientSlot>, RestError> {
const MAX_BACKOFF_MS: u64 = 1_000;
const MAX_ATTEMPTS: u32 = 20;
let mut backoff_ms = 1u64;
for _ in 0..MAX_ATTEMPTS {
if let Some(slot) = self.try_acquire() {
return Ok(slot);
}
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
}
Err(RestError::ConnectionClosed(
"pool acquire timed out: no healthy slots available",
))
}
pub fn available(&self) -> usize {
self.pool.available()
}
fn spawn_reconnect(&self, mut slot: Pooled<ClientSlot>) {
let config = self.reconnect_config.clone();
tokio::task::spawn_local(async move {
const MAX_BACKOFF_MS: u64 = 5_000;
let mut backoff_ms = 100u64;
loop {
if let Ok(conn) = Self::connect_one_with(&config).await {
slot.conn = Some(conn);
slot.reader.reset();
return;
}
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
}
});
}
async fn connect_one_with(
config: &ReconnectConfig,
) -> Result<AsyncHttpConnection<MaybeTls>, RestError> {
let mut builder = AsyncHttpConnectionBuilder::new();
#[cfg(feature = "tls")]
if let Some(ref tls) = config.tls_config {
builder = builder.tls(tls);
}
if config.nodelay {
builder = builder.disable_nagle();
}
#[cfg(feature = "socket-opts")]
{
if let Some(idle) = config.tcp_keepalive {
builder = builder.tcp_keepalive(idle);
}
if let Some(size) = config.recv_buf_size {
builder = builder.recv_buffer_size(size);
}
if let Some(size) = config.send_buf_size {
builder = builder.send_buffer_size(size);
}
}
builder.connect(&config.url).await
}
}
pub struct ClientPoolBuilder {
url: String,
base_path: String,
default_headers: Vec<(String, String)>,
connections: usize,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
nodelay: bool,
#[cfg(feature = "socket-opts")]
tcp_keepalive: Option<std::time::Duration>,
#[cfg(feature = "socket-opts")]
recv_buf_size: Option<usize>,
#[cfg(feature = "socket-opts")]
send_buf_size: Option<usize>,
write_buffer_capacity: usize,
response_buffer_capacity: usize,
max_body_size: usize,
}
impl ClientPoolBuilder {
#[must_use]
pub fn new() -> Self {
Self {
url: String::new(),
base_path: String::new(),
default_headers: Vec::new(),
connections: 1,
#[cfg(feature = "tls")]
tls_config: None,
nodelay: false,
#[cfg(feature = "socket-opts")]
tcp_keepalive: None,
#[cfg(feature = "socket-opts")]
recv_buf_size: None,
#[cfg(feature = "socket-opts")]
send_buf_size: None,
write_buffer_capacity: 32 * 1024,
response_buffer_capacity: 32 * 1024,
max_body_size: 0,
}
}
#[must_use]
pub fn url(mut self, url: &str) -> Self {
self.url = url.to_string();
self
}
#[must_use]
pub fn base_path(mut self, path: &str) -> Self {
self.base_path = path.to_string();
self
}
pub fn default_header(mut self, name: &str, value: &str) -> Result<Self, RestError> {
if name.bytes().any(|b| b == b'\r' || b == b'\n')
|| value.bytes().any(|b| b == b'\r' || b == b'\n')
{
return Err(RestError::CrlfInjection);
}
self.default_headers
.push((name.to_string(), value.to_string()));
Ok(self)
}
#[must_use]
pub fn connections(mut self, n: usize) -> Self {
self.connections = n;
self
}
#[must_use]
#[cfg(feature = "tls")]
pub fn tls(mut self, config: &TlsConfig) -> Self {
self.tls_config = Some(config.clone());
self
}
#[must_use]
pub fn disable_nagle(mut self) -> Self {
self.nodelay = true;
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn tcp_keepalive(mut self, idle: std::time::Duration) -> Self {
self.tcp_keepalive = Some(idle);
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn recv_buffer_size(mut self, n: usize) -> Self {
self.recv_buf_size = Some(n);
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn send_buffer_size(mut self, n: usize) -> Self {
self.send_buf_size = Some(n);
self
}
#[must_use]
pub fn write_buffer_capacity(mut self, n: usize) -> Self {
self.write_buffer_capacity = n;
self
}
#[must_use]
pub fn response_buffer_capacity(mut self, n: usize) -> Self {
self.response_buffer_capacity = n;
self
}
#[must_use]
pub fn max_body_size(mut self, n: usize) -> Self {
self.max_body_size = n;
self
}
pub async fn build(self) -> Result<ClientPool, RestError> {
if self.url.is_empty() {
return Err(RestError::InvalidUrl("url is required".to_string()));
}
if self.connections == 0 {
return Err(RestError::InvalidUrl("connections must be > 0".to_string()));
}
let parsed = nexus_net::rest::parse_base_url(&self.url)?;
let host_header = parsed.host_header();
let reconnect_config = ReconnectConfig {
url: self.url.clone(),
#[cfg(feature = "tls")]
tls_config: self.tls_config.clone(),
nodelay: self.nodelay,
#[cfg(feature = "socket-opts")]
tcp_keepalive: self.tcp_keepalive,
#[cfg(feature = "socket-opts")]
recv_buf_size: self.recv_buf_size,
#[cfg(feature = "socket-opts")]
send_buf_size: self.send_buf_size,
};
let mut initial_slots = Vec::with_capacity(self.connections);
for _ in 0..self.connections {
let mut builder = AsyncHttpConnectionBuilder::new();
#[cfg(feature = "tls")]
if let Some(ref tls) = self.tls_config {
builder = builder.tls(tls);
}
if self.nodelay {
builder = builder.disable_nagle();
}
#[cfg(feature = "socket-opts")]
{
if let Some(idle) = self.tcp_keepalive {
builder = builder.tcp_keepalive(idle);
}
if let Some(size) = self.recv_buf_size {
builder = builder.recv_buffer_size(size);
}
if let Some(size) = self.send_buf_size {
builder = builder.send_buffer_size(size);
}
}
let conn = builder.connect(&self.url).await?;
let mut writer = RequestWriter::new(&host_header)?;
if !self.base_path.is_empty() {
writer.set_base_path(&self.base_path)?;
}
writer.set_write_buffer_capacity(self.write_buffer_capacity);
for (name, value) in &self.default_headers {
writer.default_header(name, value)?;
}
let reader = ResponseReader::new(self.response_buffer_capacity)
.max_body_size(self.max_body_size);
initial_slots.push(ClientSlot {
writer,
reader,
conn: Some(conn),
});
}
let host = host_header.clone();
let base = self.base_path.clone();
let headers = self.default_headers.clone();
let wbuf_cap = self.write_buffer_capacity;
let rbuf_cap = self.response_buffer_capacity;
let max_body = self.max_body_size;
let pool = Pool::new(
move || {
let mut writer = RequestWriter::new(&host).expect("host already validated");
if !base.is_empty() {
writer
.set_base_path(&base)
.expect("base_path already validated");
}
writer.set_write_buffer_capacity(wbuf_cap);
for (name, value) in &headers {
writer
.default_header(name, value)
.expect("headers already validated");
}
ClientSlot {
writer,
reader: ResponseReader::new(rbuf_cap).max_body_size(max_body),
conn: None,
}
},
|slot: &mut ClientSlot| {
if slot.needs_reconnect() {
slot.conn = None;
slot.reader.reset();
}
},
);
for slot in initial_slots {
pool.put(slot);
}
Ok(ClientPool {
pool,
reconnect_config,
})
}
}
impl Default for ClientPoolBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_disconnected_slot() -> ClientSlot {
ClientSlot {
writer: RequestWriter::new("host").unwrap(),
reader: ResponseReader::new(4096),
conn: None,
}
}
#[test]
fn slot_needs_reconnect_when_no_conn() {
let slot = make_disconnected_slot();
assert!(slot.needs_reconnect());
}
#[test]
fn pool_acquire_release_cycle() {
let pool = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
assert_eq!(pool.available(), 1);
let slot = pool.acquire();
assert_eq!(pool.available(), 0);
drop(slot);
assert_eq!(pool.available(), 1);
}
#[test]
fn pool_acquire_returns_available() {
let pool: Pool<ClientSlot> = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
pool.put(make_disconnected_slot());
assert_eq!(pool.available(), 2);
let _s1 = pool.acquire();
assert_eq!(pool.available(), 1);
let _s2 = pool.acquire();
assert_eq!(pool.available(), 0);
}
#[test]
fn pool_reset_clears_dead_conn() {
let pool = Pool::new(make_disconnected_slot, |slot| {
if slot.needs_reconnect() {
slot.conn = None;
}
});
pool.put(make_disconnected_slot());
let slot = pool.acquire();
assert!(slot.conn.is_none());
assert!(slot.needs_reconnect());
drop(slot);
let slot = pool.acquire();
assert!(slot.conn.is_none());
}
#[test]
fn pool_multiple_slots() {
let pool = Pool::new(make_disconnected_slot, |_| {});
for _ in 0..4 {
pool.put(make_disconnected_slot());
}
assert_eq!(pool.available(), 4);
let s1 = pool.acquire();
let s2 = pool.acquire();
assert_eq!(pool.available(), 2);
drop(s1);
assert_eq!(pool.available(), 3);
drop(s2);
assert_eq!(pool.available(), 4);
}
#[test]
fn try_acquire_returns_none_when_all_in_use() {
let pool = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
pool.put(make_disconnected_slot());
let s1 = pool.try_acquire().unwrap();
let s2 = pool.try_acquire().unwrap();
assert!(pool.try_acquire().is_none());
assert_eq!(pool.available(), 0);
drop(s1);
drop(s2);
assert_eq!(pool.available(), 2);
assert!(pool.try_acquire().is_some());
}
#[test]
fn try_acquire_returns_some_after_slot_released() {
let pool = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
let held = pool.try_acquire().unwrap();
assert!(pool.try_acquire().is_none());
drop(held);
assert!(pool.try_acquire().is_some());
}
#[tokio::test(flavor = "current_thread")]
async fn pool_four_connections_all_succeed() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
for _ in 0..4 {
let (mut tcp, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let _ = tcp.read(&mut buf).await.unwrap();
let resp = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
tcp.write_all(resp).await.unwrap();
}
});
let pool = Pool::new(make_disconnected_slot, |slot| {
if slot.needs_reconnect() {
slot.conn = None;
slot.reader.reset();
}
});
let mut success_count = 0u8;
for _ in 0..4u8 {
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
tcp.set_nodelay(true).unwrap();
let stream = MaybeTls::Plain(tcp);
let conn = AsyncHttpConnection::new(stream);
pool.put(ClientSlot {
writer: RequestWriter::new(&addr.to_string()).unwrap(),
reader: ResponseReader::new(4096),
conn: Some(conn),
});
let mut slot = pool.try_acquire().unwrap();
let s: &mut ClientSlot = &mut slot;
let req = s.writer.get("/test").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let resp = conn.send(req, &mut s.reader).await.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "ok");
success_count += 1;
}
assert_eq!(success_count, 4);
}
#[tokio::test(flavor = "current_thread")]
async fn pool_loopback_send() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut tcp, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let _ = tcp.read(&mut buf).await.unwrap();
let resp = b"HTTP/1.1 200 OK\r\nContent-Length: 15\r\n\r\n{\"orderId\":123}";
tcp.write_all(resp).await.unwrap();
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let stream = MaybeTls::Plain(tcp);
let conn = AsyncHttpConnection::new(stream);
let mut slot = ClientSlot {
writer: RequestWriter::new(&addr.to_string()).unwrap(),
reader: ResponseReader::new(4096),
conn: Some(conn),
};
let req = slot.writer.get("/test").finish().unwrap();
let conn = slot.conn.as_mut().unwrap();
let resp = conn.send(req, &mut slot.reader).await.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), r#"{"orderId":123}"#);
}
#[test]
fn try_acquire_returns_none_when_exhausted() {
let pool = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
let _s1 = pool.try_acquire().unwrap();
assert!(pool.try_acquire().is_none());
}
#[tokio::test(flavor = "current_thread")]
async fn pool_keep_alive_multiple_requests() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut tcp, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let _ = tcp.read(&mut buf).await.unwrap();
tcp.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 7\r\n\r\n{\"r\":1}")
.await
.unwrap();
let _ = tcp.read(&mut buf).await.unwrap();
tcp.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 7\r\n\r\n{\"r\":2}")
.await
.unwrap();
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
tcp.set_nodelay(true).unwrap();
let stream = MaybeTls::Plain(tcp);
let conn = AsyncHttpConnection::new(stream);
let pool = Pool::new(make_disconnected_slot, |slot| {
if slot.needs_reconnect() {
slot.conn = None;
slot.reader.reset();
}
});
pool.put(ClientSlot {
writer: RequestWriter::new(&addr.to_string()).unwrap(),
reader: ResponseReader::new(4096),
conn: Some(conn),
});
{
let mut slot = pool.acquire();
let s: &mut ClientSlot = &mut slot;
let req = s.writer.get("/first").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let resp = conn.send(req, &mut s.reader).await.unwrap();
assert_eq!(resp.body_str().unwrap(), r#"{"r":1}"#);
}
{
let mut slot = pool.acquire();
let s: &mut ClientSlot = &mut slot;
let req = s.writer.get("/second").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let resp = conn.send(req, &mut s.reader).await.unwrap();
assert_eq!(resp.body_str().unwrap(), r#"{"r":2}"#);
}
}
#[tokio::test(flavor = "current_thread")]
async fn builder_validates_empty_url() {
let result = ClientPool::builder().connections(1).build().await;
assert!(result.is_err());
}
#[tokio::test(flavor = "current_thread")]
async fn builder_validates_zero_connections() {
let result = ClientPool::builder()
.url("http://localhost")
.connections(0)
.build()
.await;
assert!(result.is_err());
}
#[tokio::test(flavor = "current_thread")]
async fn stale_connection_timeout_not_hang() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut tcp, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let _ = tcp.read(&mut buf).await.unwrap();
tcp.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")
.await
.unwrap();
drop(tcp);
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
tcp.set_nodelay(true).unwrap();
let stream = MaybeTls::Plain(tcp);
let conn = AsyncHttpConnection::new(stream);
let pool = Pool::new(make_disconnected_slot, |slot| {
if slot.needs_reconnect() {
slot.conn = None;
slot.reader.reset();
}
});
pool.put(ClientSlot {
writer: RequestWriter::new(&addr.to_string()).unwrap(),
reader: ResponseReader::new(4096),
conn: Some(conn),
});
{
let mut slot = pool.acquire();
let s: &mut ClientSlot = &mut slot;
let req = s.writer.get("/first").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let resp = conn.send(req, &mut s.reader).await.unwrap();
assert_eq!(resp.status(), 200);
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
{
let mut slot = pool.acquire();
let s: &mut ClientSlot = &mut slot;
let req = s.writer.get("/second").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let timeout = std::time::Duration::from_millis(500);
let result = tokio::time::timeout(timeout, conn.send(req, &mut s.reader)).await;
match result {
Ok(Ok(_)) => panic!("stale connection should not succeed"),
Ok(Err(_)) => {} Err(_elapsed) => {
}
}
assert!(
s.needs_reconnect(),
"slot should be poisoned after stale connection"
);
}
}
#[tokio::test(flavor = "current_thread")]
#[allow(clippy::large_futures)]
async fn dead_connection_heals_via_reconnect_task() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let local = tokio::task::LocalSet::new();
local
.run_until(async {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::task::spawn_local({
let listener_fd = listener.into_std().unwrap();
async move {
let listener = TcpListener::from_std(listener_fd).unwrap();
let (mut tcp, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let _ = tcp.read(&mut buf).await.unwrap();
tcp.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nfirst")
.await
.unwrap();
drop(tcp);
let (mut tcp, _) = listener.accept().await.unwrap();
let _ = tcp.read(&mut buf).await.unwrap();
tcp.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nhealed")
.await
.unwrap();
}
});
let pool = ClientPool::builder()
.url(&format!("http://{addr}"))
.connections(1)
.build()
.await
.unwrap();
{
let mut slot = pool.try_acquire().unwrap();
let s: &mut ClientSlot = &mut slot;
let req = s.writer.get("/first").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let resp = conn.send(req, &mut s.reader).await.unwrap();
assert_eq!(resp.body_str().unwrap(), "first");
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
{
let mut slot = pool.try_acquire().unwrap();
let s: &mut ClientSlot = &mut slot;
let req = s.writer.get("/dead").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let timeout = std::time::Duration::from_millis(500);
let _ = tokio::time::timeout(timeout, conn.send(req, &mut s.reader)).await;
assert!(s.needs_reconnect(), "should be poisoned");
}
let mut slot = pool.acquire().await.unwrap();
let s: &mut ClientSlot = &mut slot;
let req = s.writer.get("/healed").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let resp = conn.send(req, &mut s.reader).await.unwrap();
assert_eq!(resp.body_str().unwrap(), "healed");
server_handle.await.unwrap();
})
.await;
}
}