use nexus_net::http::ResponseReader;
use nexus_net::rest::{RequestWriter, RestError};
#[cfg(feature = "tls")]
use nexus_net::tls::TlsConfig;
use nexus_pool::local::{Pool, Pooled};
use super::connection::{AsyncHttpConnection, AsyncHttpConnectionBuilder};
use crate::maybe_tls::MaybeTls;
pub struct ClientSlot {
pub writer: RequestWriter,
pub reader: ResponseReader,
pub conn: Option<AsyncHttpConnection<MaybeTls>>,
}
impl ClientSlot {
pub fn needs_reconnect(&self) -> bool {
self.conn
.as_ref()
.is_none_or(AsyncHttpConnection::is_poisoned)
}
pub fn conn_and_reader(
&mut self,
) -> Result<(&mut AsyncHttpConnection<MaybeTls>, &mut ResponseReader), RestError> {
let conn = self.conn.as_mut().ok_or(RestError::ConnectionPoisoned)?;
Ok((conn, &mut self.reader))
}
}
pub struct ClientPool {
pool: Pool<ClientSlot>,
reconnect_config: ReconnectConfig,
}
#[derive(Clone)]
struct ReconnectConfig {
url: String,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
nodelay: bool,
#[cfg(feature = "socket-opts")]
tcp_keepalive: Option<std::time::Duration>,
#[cfg(feature = "socket-opts")]
recv_buf_size: Option<usize>,
#[cfg(feature = "socket-opts")]
send_buf_size: Option<usize>,
}
#[allow(clippy::future_not_send)] impl ClientPool {
#[must_use]
pub fn builder() -> ClientPoolBuilder {
ClientPoolBuilder::new()
}
pub fn try_acquire(&self) -> Option<Pooled<ClientSlot>> {
loop {
let slot = self.pool.try_acquire()?;
if !slot.needs_reconnect() {
return Some(slot);
}
self.spawn_reconnect(slot);
}
}
pub async fn acquire(&self) -> Result<Pooled<ClientSlot>, RestError> {
const MAX_BACKOFF_MS: u64 = 1_000;
const MAX_ATTEMPTS: u32 = 20;
let mut backoff_ms = 1u64;
for _ in 0..MAX_ATTEMPTS {
if let Some(slot) = self.try_acquire() {
return Ok(slot);
}
nexus_async_rt::sleep(std::time::Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
}
Err(RestError::ConnectionClosed(
"pool acquire timed out: no healthy slots available",
))
}
pub fn available(&self) -> usize {
self.pool.available()
}
fn spawn_reconnect(&self, mut slot: Pooled<ClientSlot>) {
let config = self.reconnect_config.clone();
drop(nexus_async_rt::spawn_boxed(async move {
const MAX_BACKOFF_MS: u64 = 5_000;
let mut backoff_ms = 100u64;
loop {
if let Ok(conn) = Self::connect_one_with(&config).await {
slot.conn = Some(conn);
slot.reader.reset();
return;
}
nexus_async_rt::sleep(std::time::Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
}
}));
}
async fn connect_one_with(
config: &ReconnectConfig,
) -> Result<AsyncHttpConnection<MaybeTls>, RestError> {
let mut builder = AsyncHttpConnectionBuilder::new();
#[cfg(feature = "tls")]
if let Some(ref tls) = config.tls_config {
builder = builder.tls(tls);
}
if config.nodelay {
builder = builder.disable_nagle();
}
#[cfg(feature = "socket-opts")]
{
if let Some(idle) = config.tcp_keepalive {
builder = builder.tcp_keepalive(idle);
}
if let Some(size) = config.recv_buf_size {
builder = builder.recv_buffer_size(size);
}
if let Some(size) = config.send_buf_size {
builder = builder.send_buffer_size(size);
}
}
builder.connect(&config.url).await
}
}
pub struct ClientPoolBuilder {
url: String,
base_path: String,
default_headers: Vec<(String, String)>,
connections: usize,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
nodelay: bool,
#[cfg(feature = "socket-opts")]
tcp_keepalive: Option<std::time::Duration>,
#[cfg(feature = "socket-opts")]
recv_buf_size: Option<usize>,
#[cfg(feature = "socket-opts")]
send_buf_size: Option<usize>,
write_buffer_capacity: usize,
response_buffer_capacity: usize,
max_body_size: usize,
}
impl ClientPoolBuilder {
#[must_use]
pub fn new() -> Self {
Self {
url: String::new(),
base_path: String::new(),
default_headers: Vec::new(),
connections: 1,
#[cfg(feature = "tls")]
tls_config: None,
nodelay: false,
#[cfg(feature = "socket-opts")]
tcp_keepalive: None,
#[cfg(feature = "socket-opts")]
recv_buf_size: None,
#[cfg(feature = "socket-opts")]
send_buf_size: None,
write_buffer_capacity: 32 * 1024,
response_buffer_capacity: 32 * 1024,
max_body_size: 0,
}
}
#[must_use]
pub fn url(mut self, url: &str) -> Self {
self.url = url.to_string();
self
}
#[must_use]
pub fn base_path(mut self, path: &str) -> Self {
self.base_path = path.to_string();
self
}
pub fn default_header(mut self, name: &str, value: &str) -> Result<Self, RestError> {
if name.bytes().any(|b| b == b'\r' || b == b'\n')
|| value.bytes().any(|b| b == b'\r' || b == b'\n')
{
return Err(RestError::CrlfInjection);
}
self.default_headers
.push((name.to_string(), value.to_string()));
Ok(self)
}
#[must_use]
pub fn connections(mut self, n: usize) -> Self {
self.connections = n;
self
}
#[must_use]
#[cfg(feature = "tls")]
pub fn tls(mut self, config: &TlsConfig) -> Self {
self.tls_config = Some(config.clone());
self
}
#[must_use]
pub fn disable_nagle(mut self) -> Self {
self.nodelay = true;
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn tcp_keepalive(mut self, idle: std::time::Duration) -> Self {
self.tcp_keepalive = Some(idle);
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn recv_buffer_size(mut self, n: usize) -> Self {
self.recv_buf_size = Some(n);
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn send_buffer_size(mut self, n: usize) -> Self {
self.send_buf_size = Some(n);
self
}
#[must_use]
pub fn write_buffer_capacity(mut self, n: usize) -> Self {
self.write_buffer_capacity = n;
self
}
#[must_use]
pub fn response_buffer_capacity(mut self, n: usize) -> Self {
self.response_buffer_capacity = n;
self
}
#[must_use]
pub fn max_body_size(mut self, n: usize) -> Self {
self.max_body_size = n;
self
}
#[allow(clippy::future_not_send)]
pub async fn build(self) -> Result<ClientPool, RestError> {
if self.url.is_empty() {
return Err(RestError::InvalidUrl("url is required".to_string()));
}
if self.connections == 0 {
return Err(RestError::InvalidUrl("connections must be > 0".to_string()));
}
let parsed = nexus_net::rest::parse_base_url(&self.url)?;
let host_header = parsed.host_header();
let reconnect_config = ReconnectConfig {
url: self.url.clone(),
#[cfg(feature = "tls")]
tls_config: self.tls_config.clone(),
nodelay: self.nodelay,
#[cfg(feature = "socket-opts")]
tcp_keepalive: self.tcp_keepalive,
#[cfg(feature = "socket-opts")]
recv_buf_size: self.recv_buf_size,
#[cfg(feature = "socket-opts")]
send_buf_size: self.send_buf_size,
};
let mut initial_slots = Vec::with_capacity(self.connections);
for _ in 0..self.connections {
let mut builder = AsyncHttpConnectionBuilder::new();
#[cfg(feature = "tls")]
if let Some(ref tls) = self.tls_config {
builder = builder.tls(tls);
}
if self.nodelay {
builder = builder.disable_nagle();
}
#[cfg(feature = "socket-opts")]
{
if let Some(idle) = self.tcp_keepalive {
builder = builder.tcp_keepalive(idle);
}
if let Some(size) = self.recv_buf_size {
builder = builder.recv_buffer_size(size);
}
if let Some(size) = self.send_buf_size {
builder = builder.send_buffer_size(size);
}
}
let conn = builder.connect(&self.url).await?;
let mut writer = RequestWriter::new(&host_header)?;
if !self.base_path.is_empty() {
writer.set_base_path(&self.base_path)?;
}
writer.set_write_buffer_capacity(self.write_buffer_capacity);
for (name, value) in &self.default_headers {
writer.default_header(name, value)?;
}
let reader = ResponseReader::new(self.response_buffer_capacity)
.max_body_size(self.max_body_size);
initial_slots.push(ClientSlot {
writer,
reader,
conn: Some(conn),
});
}
let host = host_header.clone();
let base = self.base_path.clone();
let headers = self.default_headers.clone();
let wbuf_cap = self.write_buffer_capacity;
let rbuf_cap = self.response_buffer_capacity;
let max_body = self.max_body_size;
let pool = Pool::new(
move || {
let mut writer = RequestWriter::new(&host).expect("host already validated");
if !base.is_empty() {
writer
.set_base_path(&base)
.expect("base_path already validated");
}
writer.set_write_buffer_capacity(wbuf_cap);
for (name, value) in &headers {
writer
.default_header(name, value)
.expect("headers already validated");
}
ClientSlot {
writer,
reader: ResponseReader::new(rbuf_cap).max_body_size(max_body),
conn: None,
}
},
|slot: &mut ClientSlot| {
if slot.needs_reconnect() {
slot.conn = None;
slot.reader.reset();
}
},
);
for slot in initial_slots {
pool.put(slot);
}
Ok(ClientPool {
pool,
reconnect_config,
})
}
}
impl Default for ClientPoolBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_disconnected_slot() -> ClientSlot {
ClientSlot {
writer: RequestWriter::new("host").unwrap(),
reader: ResponseReader::new(4096),
conn: None,
}
}
#[test]
fn slot_needs_reconnect_when_no_conn() {
let slot = make_disconnected_slot();
assert!(slot.needs_reconnect());
}
#[test]
fn pool_acquire_release_cycle() {
let pool = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
assert_eq!(pool.available(), 1);
let slot = pool.acquire();
assert_eq!(pool.available(), 0);
drop(slot);
assert_eq!(pool.available(), 1);
}
#[test]
fn pool_acquire_returns_available() {
let pool: Pool<ClientSlot> = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
pool.put(make_disconnected_slot());
assert_eq!(pool.available(), 2);
let _s1 = pool.acquire();
assert_eq!(pool.available(), 1);
let _s2 = pool.acquire();
assert_eq!(pool.available(), 0);
}
#[test]
fn pool_reset_clears_dead_conn() {
let pool = Pool::new(make_disconnected_slot, |slot| {
if slot.needs_reconnect() {
slot.conn = None;
}
});
pool.put(make_disconnected_slot());
let slot = pool.acquire();
assert!(slot.conn.is_none());
assert!(slot.needs_reconnect());
drop(slot);
let slot = pool.acquire();
assert!(slot.conn.is_none());
}
#[test]
fn pool_multiple_slots() {
let pool = Pool::new(make_disconnected_slot, |_| {});
for _ in 0..4 {
pool.put(make_disconnected_slot());
}
assert_eq!(pool.available(), 4);
let s1 = pool.acquire();
let s2 = pool.acquire();
assert_eq!(pool.available(), 2);
drop(s1);
assert_eq!(pool.available(), 3);
drop(s2);
assert_eq!(pool.available(), 4);
}
#[test]
fn try_acquire_returns_none_when_all_in_use() {
let pool = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
pool.put(make_disconnected_slot());
let s1 = pool.try_acquire().unwrap();
let s2 = pool.try_acquire().unwrap();
assert!(pool.try_acquire().is_none());
assert_eq!(pool.available(), 0);
drop(s1);
drop(s2);
assert_eq!(pool.available(), 2);
assert!(pool.try_acquire().is_some());
}
#[test]
fn try_acquire_returns_some_after_slot_released() {
let pool = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
let held = pool.try_acquire().unwrap();
assert!(pool.try_acquire().is_none());
drop(held);
assert!(pool.try_acquire().is_some());
}
#[test]
fn try_acquire_returns_none_when_exhausted() {
let pool = Pool::new(make_disconnected_slot, |_| {});
pool.put(make_disconnected_slot());
let _s1 = pool.try_acquire().unwrap();
assert!(pool.try_acquire().is_none());
}
}