use nexus_net::http::ResponseReader;
use nexus_net::rest::{RequestWriter, RestError};
#[cfg(feature = "tls")]
use nexus_net::tls::TlsConfig;
use nexus_pool::sync::{Pool, Pooled};
use super::connection::{AsyncHttpConnection, AsyncHttpConnectionBuilder};
use crate::maybe_tls::MaybeTls;
pub type AtomicClientSlot = super::ClientSlot;
pub struct AtomicClientPool {
pool: Pool<AtomicClientSlot>,
reconnect_config: ReconnectConfig,
}
#[derive(Clone)]
struct ReconnectConfig {
url: String,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
nodelay: bool,
#[cfg(feature = "socket-opts")]
tcp_keepalive: Option<std::time::Duration>,
#[cfg(feature = "socket-opts")]
recv_buf_size: Option<usize>,
#[cfg(feature = "socket-opts")]
send_buf_size: Option<usize>,
}
impl AtomicClientPool {
#[must_use]
pub fn builder() -> AtomicClientPoolBuilder {
AtomicClientPoolBuilder::new()
}
pub fn try_acquire(&self) -> Option<Pooled<AtomicClientSlot>> {
loop {
let slot = self.pool.try_acquire()?;
if !slot.needs_reconnect() {
return Some(slot);
}
self.spawn_reconnect(slot);
}
}
pub async fn acquire(&self) -> Result<Pooled<AtomicClientSlot>, RestError> {
const MAX_BACKOFF_MS: u64 = 1_000;
const MAX_ATTEMPTS: u32 = 20;
let mut backoff_ms = 1u64;
for _ in 0..MAX_ATTEMPTS {
if let Some(slot) = self.try_acquire() {
return Ok(slot);
}
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
}
Err(RestError::ConnectionClosed(
"pool acquire timed out: no healthy slots available",
))
}
pub fn available(&self) -> usize {
self.pool.available()
}
fn spawn_reconnect(&self, mut slot: Pooled<AtomicClientSlot>) {
let config = self.reconnect_config.clone();
tokio::spawn(async move {
const MAX_BACKOFF_MS: u64 = 5_000;
let mut backoff_ms = 100u64;
loop {
if let Ok(conn) = Self::connect_one_with(&config).await {
slot.conn = Some(conn);
slot.reader.reset();
return;
}
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
}
});
}
async fn connect_one_with(
config: &ReconnectConfig,
) -> Result<AsyncHttpConnection<MaybeTls>, RestError> {
let mut builder = AsyncHttpConnectionBuilder::new();
#[cfg(feature = "tls")]
if let Some(ref tls) = config.tls_config {
builder = builder.tls(tls);
}
if config.nodelay {
builder = builder.disable_nagle();
}
#[cfg(feature = "socket-opts")]
{
if let Some(idle) = config.tcp_keepalive {
builder = builder.tcp_keepalive(idle);
}
if let Some(size) = config.recv_buf_size {
builder = builder.recv_buffer_size(size);
}
if let Some(size) = config.send_buf_size {
builder = builder.send_buffer_size(size);
}
}
builder.connect(&config.url).await
}
}
pub struct AtomicClientPoolBuilder {
url: String,
base_path: String,
default_headers: Vec<(String, String)>,
connections: usize,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
nodelay: bool,
#[cfg(feature = "socket-opts")]
tcp_keepalive: Option<std::time::Duration>,
#[cfg(feature = "socket-opts")]
recv_buf_size: Option<usize>,
#[cfg(feature = "socket-opts")]
send_buf_size: Option<usize>,
write_buffer_capacity: usize,
response_buffer_capacity: usize,
max_body_size: usize,
}
impl AtomicClientPoolBuilder {
#[must_use]
pub fn new() -> Self {
Self {
url: String::new(),
base_path: String::new(),
default_headers: Vec::new(),
connections: 1,
#[cfg(feature = "tls")]
tls_config: None,
nodelay: false,
#[cfg(feature = "socket-opts")]
tcp_keepalive: None,
#[cfg(feature = "socket-opts")]
recv_buf_size: None,
#[cfg(feature = "socket-opts")]
send_buf_size: None,
write_buffer_capacity: 32 * 1024,
response_buffer_capacity: 32 * 1024,
max_body_size: 0,
}
}
#[must_use]
pub fn url(mut self, url: &str) -> Self {
self.url = url.to_string();
self
}
#[must_use]
pub fn base_path(mut self, path: &str) -> Self {
self.base_path = path.to_string();
self
}
pub fn default_header(mut self, name: &str, value: &str) -> Result<Self, RestError> {
if name.bytes().any(|b| b == b'\r' || b == b'\n')
|| value.bytes().any(|b| b == b'\r' || b == b'\n')
{
return Err(RestError::CrlfInjection);
}
self.default_headers
.push((name.to_string(), value.to_string()));
Ok(self)
}
#[must_use]
pub fn connections(mut self, n: usize) -> Self {
self.connections = n;
self
}
#[must_use]
#[cfg(feature = "tls")]
pub fn tls(mut self, config: &TlsConfig) -> Self {
self.tls_config = Some(config.clone());
self
}
#[must_use]
pub fn disable_nagle(mut self) -> Self {
self.nodelay = true;
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn tcp_keepalive(mut self, idle: std::time::Duration) -> Self {
self.tcp_keepalive = Some(idle);
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn recv_buffer_size(mut self, n: usize) -> Self {
self.recv_buf_size = Some(n);
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn send_buffer_size(mut self, n: usize) -> Self {
self.send_buf_size = Some(n);
self
}
#[must_use]
pub fn write_buffer_capacity(mut self, n: usize) -> Self {
self.write_buffer_capacity = n;
self
}
#[must_use]
pub fn response_buffer_capacity(mut self, n: usize) -> Self {
self.response_buffer_capacity = n;
self
}
#[must_use]
pub fn max_body_size(mut self, n: usize) -> Self {
self.max_body_size = n;
self
}
pub async fn build(self) -> Result<AtomicClientPool, RestError> {
if self.url.is_empty() {
return Err(RestError::InvalidUrl("url is required".to_string()));
}
if self.connections == 0 {
return Err(RestError::InvalidUrl("connections must be > 0".to_string()));
}
let parsed = nexus_net::rest::parse_base_url(&self.url)?;
let host_header = parsed.host_header();
let reconnect_config = ReconnectConfig {
url: self.url.clone(),
#[cfg(feature = "tls")]
tls_config: self.tls_config.clone(),
nodelay: self.nodelay,
#[cfg(feature = "socket-opts")]
tcp_keepalive: self.tcp_keepalive,
#[cfg(feature = "socket-opts")]
recv_buf_size: self.recv_buf_size,
#[cfg(feature = "socket-opts")]
send_buf_size: self.send_buf_size,
};
let host = host_header.clone();
let base = self.base_path.clone();
let headers = self.default_headers.clone();
let wbuf_cap = self.write_buffer_capacity;
let rbuf_cap = self.response_buffer_capacity;
let max_body = self.max_body_size;
let pool = Pool::new(
self.connections,
move || {
let mut writer = RequestWriter::new(&host).expect("host already validated");
if !base.is_empty() {
writer
.set_base_path(&base)
.expect("base_path already validated");
}
writer.set_write_buffer_capacity(wbuf_cap);
for (name, value) in &headers {
writer
.default_header(name, value)
.expect("headers already validated");
}
AtomicClientSlot {
writer,
reader: ResponseReader::new(rbuf_cap).max_body_size(max_body),
conn: None,
}
},
|slot: &mut AtomicClientSlot| {
if slot.needs_reconnect() {
slot.conn = None;
slot.reader.reset();
}
},
);
for _ in 0..self.connections {
{
let mut slot = pool
.try_acquire()
.expect("pool should have slots during initial setup");
let mut builder = AsyncHttpConnectionBuilder::new();
#[cfg(feature = "tls")]
if let Some(ref tls) = self.tls_config {
builder = builder.tls(tls);
}
if self.nodelay {
builder = builder.disable_nagle();
}
#[cfg(feature = "socket-opts")]
{
if let Some(idle) = self.tcp_keepalive {
builder = builder.tcp_keepalive(idle);
}
if let Some(size) = self.recv_buf_size {
builder = builder.recv_buffer_size(size);
}
if let Some(size) = self.send_buf_size {
builder = builder.send_buffer_size(size);
}
}
let conn = builder.connect(&self.url).await?;
slot.conn = Some(conn);
}
}
Ok(AtomicClientPool {
pool,
reconnect_config,
})
}
}
impl Default for AtomicClientPoolBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_pool(n: usize) -> Pool<AtomicClientSlot> {
Pool::new(
n,
|| AtomicClientSlot {
writer: RequestWriter::new("host").unwrap(),
reader: ResponseReader::new(4096),
conn: None,
},
|slot| {
if slot.needs_reconnect() {
slot.conn = None;
}
},
)
}
#[test]
fn atomic_pool_acquire_release() {
let pool = make_pool(2);
assert_eq!(pool.available(), 2);
let s1 = pool.try_acquire().unwrap();
assert_eq!(pool.available(), 1);
let s2 = pool.try_acquire().unwrap();
assert_eq!(pool.available(), 0);
assert!(pool.try_acquire().is_none());
drop(s1);
assert_eq!(pool.available(), 1);
drop(s2);
assert_eq!(pool.available(), 2);
}
#[test]
fn atomic_slot_needs_reconnect() {
let slot = AtomicClientSlot {
writer: RequestWriter::new("host").unwrap(),
reader: ResponseReader::new(4096),
conn: None,
};
assert!(slot.needs_reconnect());
}
#[tokio::test]
async fn atomic_pool_loopback() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut tcp, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let _ = tcp.read(&mut buf).await.unwrap();
let resp = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
tcp.write_all(resp).await.unwrap();
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let stream = MaybeTls::Plain(tcp);
let conn = AsyncHttpConnection::new(stream);
let pool = make_pool(1);
{
let mut slot = pool.try_acquire().unwrap();
slot.writer = RequestWriter::new(&addr.to_string()).unwrap();
slot.conn = Some(conn);
}
let mut slot = pool.try_acquire().unwrap();
let s: &mut AtomicClientSlot = &mut slot;
let req = s.writer.get("/test").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let resp = conn.send(req, &mut s.reader).await.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "ok");
}
#[test]
fn atomic_try_acquire_none_when_exhausted() {
let pool = make_pool(1);
let _s1 = pool.try_acquire().unwrap();
assert!(pool.try_acquire().is_none());
}
#[test]
fn atomic_conn_and_reader_error_when_no_conn() {
let pool = make_pool(1);
let mut slot = pool.try_acquire().unwrap();
assert!(slot.conn_and_reader().is_err());
}
#[test]
fn atomic_try_acquire_none_when_all_in_use() {
let pool = make_pool(2);
let s1 = pool.try_acquire().unwrap();
let s2 = pool.try_acquire().unwrap();
assert!(pool.try_acquire().is_none());
assert_eq!(pool.available(), 0);
drop(s1);
drop(s2);
assert_eq!(pool.available(), 2);
assert!(pool.try_acquire().is_some());
}
#[test]
fn atomic_try_acquire_returns_some_after_release() {
let pool = make_pool(1);
let held = pool.try_acquire().unwrap();
assert!(pool.try_acquire().is_none());
drop(held);
assert!(pool.try_acquire().is_some());
}
#[tokio::test]
async fn atomic_pool_four_connections() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
for _ in 0..4 {
let (mut tcp, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let _ = tcp.read(&mut buf).await.unwrap();
let resp = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
tcp.write_all(resp).await.unwrap();
}
});
let pool = make_pool(4);
let mut success_count = 0u8;
for _ in 0..4u8 {
{
let mut slot = pool.try_acquire().unwrap();
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
tcp.set_nodelay(true).unwrap();
let stream = MaybeTls::Plain(tcp);
slot.conn = Some(AsyncHttpConnection::new(stream));
slot.writer = RequestWriter::new(&addr.to_string()).unwrap();
let s: &mut AtomicClientSlot = &mut slot;
let req = s.writer.get("/test").finish().unwrap();
let conn = s.conn.as_mut().unwrap();
let resp = conn.send(req, &mut s.reader).await.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "ok");
success_count += 1;
}
}
assert_eq!(success_count, 4);
}
#[tokio::test]
async fn atomic_builder_validates_empty_url() {
let result = AtomicClientPool::builder().connections(1).build().await;
assert!(result.is_err());
}
#[tokio::test]
async fn atomic_builder_validates_zero_connections() {
let result = AtomicClientPool::builder()
.url("http://localhost")
.connections(0)
.build()
.await;
assert!(result.is_err());
}
}