use super::unified_client::{ClientStats, Headers, RequestPlaneClient};
use crate::metrics::transport_metrics::{
TCP_BYTES_RECEIVED_TOTAL, TCP_BYTES_SENT_TOTAL, TCP_ERRORS_TOTAL,
};
use anyhow::Result;
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use crossbeam_queue::SegQueue;
use dashmap::DashMap;
use futures::StreamExt;
use lru::LruCache;
use std::net::SocketAddr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio_util::codec::FramedRead;
const DEFAULT_TCP_REQUEST_TIMEOUT_SECS: u64 = 5;
const DEFAULT_POOL_SIZE: usize = 100;
const REQUEST_CHANNEL_BUFFER: usize = 1024;
const MAX_CONNECT_RETRIES: usize = 5;
const DEFAULT_GLOBAL_CONNECT_LIMIT: usize = 64;
const DEFAULT_HOST_IDLE_TTL_SECS: u64 = 300;
const DEFAULT_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
const WRITER_SPIN_LIMIT: u32 = 64;
const WRITER_INITIAL_BUF_CAPACITY: usize = 256 * 1024;
fn get_max_message_size() -> usize {
std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE)
}
fn latency_trace_enabled() -> bool {
std::env::var("DYN_TCP_LATENCY_TRACE")
.ok()
.is_some_and(|v| v == "1" || v == "true")
}
#[derive(Debug, Clone)]
pub struct TcpRequestConfig {
pub request_timeout: Duration,
pub pool_size: usize,
pub connect_timeout: Duration,
pub channel_buffer: usize,
}
impl Default for TcpRequestConfig {
fn default() -> Self {
Self {
request_timeout: Duration::from_secs(DEFAULT_TCP_REQUEST_TIMEOUT_SECS),
pool_size: DEFAULT_POOL_SIZE,
connect_timeout: Duration::from_secs(5),
channel_buffer: REQUEST_CHANNEL_BUFFER,
}
}
}
impl TcpRequestConfig {
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(val) = std::env::var("DYN_TCP_REQUEST_TIMEOUT")
&& let Ok(timeout) = val.parse::<u64>()
{
config.request_timeout = Duration::from_secs(timeout);
}
if let Ok(val) = std::env::var("DYN_TCP_POOL_SIZE")
&& let Ok(size) = val.parse::<usize>()
{
config.pool_size = size;
}
if let Ok(val) = std::env::var("DYN_TCP_CONNECT_TIMEOUT")
&& let Ok(timeout) = val.parse::<u64>()
{
config.connect_timeout = Duration::from_secs(timeout);
}
if let Ok(val) = std::env::var("DYN_TCP_CHANNEL_BUFFER")
&& let Ok(size) = val.parse::<usize>()
{
config.channel_buffer = size;
}
config
}
}
struct PendingRequest {
encoded_data: Bytes,
response_tx: oneshot::Sender<Result<Bytes>>,
}
struct InflightGuard(Arc<AtomicU64>);
impl Drop for InflightGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Release);
}
}
struct ConnectingGuard<'a> {
connecting: &'a AtomicBool,
notify: &'a tokio::sync::Notify,
}
impl Drop for ConnectingGuard<'_> {
fn drop(&mut self) {
self.connecting.store(false, Ordering::Release);
self.notify.notify_waiters();
}
}
struct TcpConnection {
addr: SocketAddr,
submit_queue: Arc<SegQueue<PendingRequest>>,
response_queue: Arc<SegQueue<oneshot::Sender<Result<Bytes>>>>,
writer_notify: Arc<tokio::sync::Notify>,
writer_handle: Arc<JoinHandle<()>>,
reader_handle: Arc<JoinHandle<()>>,
healthy: Arc<AtomicBool>,
closed: Arc<AtomicBool>,
inflight: Arc<AtomicU64>,
channel_buffer: usize,
admission: Arc<tokio::sync::Semaphore>,
#[cfg(test)]
post_enqueue_barrier: Option<Arc<tokio::sync::Barrier>>,
}
impl TcpConnection {
async fn connect(addr: SocketAddr, timeout: Duration, channel_buffer: usize) -> Result<Self> {
let stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
.await
.map_err(|_| anyhow::anyhow!("TCP connect timeout to {}", addr))??;
Self::configure_socket(&stream)?;
let (read_half, write_half) = tokio::io::split(stream);
let submit_queue = Arc::new(SegQueue::new());
let response_queue = Arc::new(SegQueue::new());
let writer_notify = Arc::new(tokio::sync::Notify::new());
let healthy = Arc::new(AtomicBool::new(true));
let closed = Arc::new(AtomicBool::new(false));
let inflight = Arc::new(AtomicU64::new(0));
let admission = Arc::new(tokio::sync::Semaphore::new(channel_buffer));
let writer_handle = {
let submit_q = submit_queue.clone();
let response_q = response_queue.clone();
let notify = writer_notify.clone();
let healthy = healthy.clone();
let closed = closed.clone();
let admission = admission.clone();
tokio::spawn(async move {
if let Err(e) = Self::writer_task(
write_half,
submit_q,
response_q,
notify,
healthy.clone(),
closed.clone(),
)
.await
{
tracing::debug!("Writer task failed for {}: {}", addr, e);
healthy.store(false, Ordering::Relaxed);
closed.store(true, Ordering::Release);
admission.close();
}
})
};
let reader_handle = {
let response_q = response_queue.clone();
let healthy = healthy.clone();
let writer_notify = writer_notify.clone();
let admission = admission.clone();
tokio::spawn(async move {
if let Err(e) =
Self::reader_task(read_half, response_q, healthy.clone(), writer_notify).await
{
tracing::debug!("Reader task failed for {}: {}", addr, e);
healthy.store(false, Ordering::Relaxed);
admission.close();
}
})
};
Ok(Self {
addr,
submit_queue,
response_queue,
writer_notify,
writer_handle: Arc::new(writer_handle),
reader_handle: Arc::new(reader_handle),
healthy,
closed,
inflight,
channel_buffer,
admission,
#[cfg(test)]
post_enqueue_barrier: None,
})
}
fn configure_socket(stream: &TcpStream) -> Result<()> {
use socket2::SockRef;
let sock_ref = SockRef::from(stream);
sock_ref.set_nodelay(true)?;
sock_ref.set_recv_buffer_size(2 * 1024 * 1024)?; sock_ref.set_send_buffer_size(2 * 1024 * 1024)?;
#[cfg(feature = "tcp-low-latency")]
{
use std::os::unix::io::AsRawFd;
unsafe {
let fd = stream.as_raw_fd();
let quickack: libc::c_int = 1;
libc::setsockopt(
fd,
libc::SOL_TCP,
libc::TCP_QUICKACK,
&quickack as *const _ as *const libc::c_void,
std::mem::size_of_val(&quickack) as libc::socklen_t,
);
let busy_poll: libc::c_int = 50;
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_BUSY_POLL,
&busy_poll as *const _ as *const libc::c_void,
std::mem::size_of_val(&busy_poll) as libc::socklen_t,
);
}
tracing::debug!("TCP low-latency optimizations enabled (TCP_QUICKACK, SO_BUSY_POLL)");
}
Ok(())
}
fn drain_pending(submit_queue: &SegQueue<PendingRequest>) {
while let Some(req) = submit_queue.pop() {
let _ = req
.response_tx
.send(Err(anyhow::anyhow!("Connection closed")));
}
}
fn drain_response_waiters(
response_queue: &SegQueue<oneshot::Sender<Result<Bytes>>>,
err_msg: impl Into<String>,
) {
let err_msg = err_msg.into();
while let Some(tx) = response_queue.pop() {
let _ = tx.send(Err(anyhow::anyhow!("{}", err_msg)));
}
}
async fn writer_task(
mut write_half: tokio::io::WriteHalf<TcpStream>,
submit_queue: Arc<SegQueue<PendingRequest>>,
response_queue: Arc<SegQueue<oneshot::Sender<Result<Bytes>>>>,
notify: Arc<tokio::sync::Notify>,
healthy: Arc<AtomicBool>,
closed: Arc<AtomicBool>,
) -> Result<()> {
let mut send_buf = BytesMut::with_capacity(WRITER_INITIAL_BUF_CAPACITY);
let mut encoded_batch: Vec<Bytes> = Vec::with_capacity(64);
let mut response_batch: Vec<oneshot::Sender<Result<Bytes>>> = Vec::with_capacity(64);
let trace = latency_trace_enabled();
let mut batch_count: u64 = 0;
let mut total_batch_size: u64 = 0;
let mut total_batch_write_ns: u64 = 0;
let mut last_report = std::time::Instant::now();
let result: Result<()> = async {
loop {
let mut spins: u32 = 0;
while submit_queue.is_empty() {
if !healthy.load(Ordering::Relaxed) {
return Err(anyhow::anyhow!("Reader exited, writer stopping"));
}
spins += 1;
if spins >= WRITER_SPIN_LIMIT {
notify.notified().await;
break;
}
std::hint::spin_loop();
}
encoded_batch.clear();
response_batch.clear();
while let Some(req) = submit_queue.pop() {
encoded_batch.push(req.encoded_data);
response_batch.push(req.response_tx);
}
let count = encoded_batch.len();
if count == 0 {
continue; }
let write_start = if trace {
Some(std::time::Instant::now())
} else {
None
};
for data in &encoded_batch {
send_buf.extend_from_slice(data);
}
if let Err(e) = write_half.write_all(&send_buf).await {
send_buf.clear();
let err_msg = format!("Write failed: {}", e);
for tx in response_batch.drain(..) {
let _ = tx.send(Err(anyhow::anyhow!("{}", err_msg)));
}
return Err(e.into());
}
TCP_BYTES_SENT_TOTAL.inc_by(send_buf.len() as f64);
send_buf.clear();
for tx in response_batch.drain(..) {
response_queue.push(tx);
}
if !healthy.load(Ordering::Relaxed) {
return Err(anyhow::anyhow!("Reader exited, writer stopping"));
}
if trace {
if let Some(start) = write_start {
total_batch_write_ns += start.elapsed().as_nanos() as u64;
}
batch_count += 1;
total_batch_size += count as u64;
if last_report.elapsed() >= Duration::from_secs(5) {
let avg_batch = if batch_count > 0 {
total_batch_size / batch_count
} else {
0
};
let avg_write_ns = if batch_count > 0 {
total_batch_write_ns / batch_count
} else {
0
};
tracing::info!(
batches = batch_count,
avg_batch_size = avg_batch,
avg_batch_write_ns = avg_write_ns,
"TCP writer instrumentation summary"
);
batch_count = 0;
total_batch_size = 0;
total_batch_write_ns = 0;
last_report = std::time::Instant::now();
}
}
encoded_batch.clear();
}
}
.await;
healthy.store(false, Ordering::Relaxed);
closed.store(true, Ordering::Release);
Self::drain_pending(&submit_queue);
result
}
async fn reader_task(
read_half: tokio::io::ReadHalf<TcpStream>,
response_queue: Arc<SegQueue<oneshot::Sender<Result<Bytes>>>>,
healthy: Arc<AtomicBool>,
writer_notify: Arc<tokio::sync::Notify>,
) -> Result<()> {
use crate::pipeline::network::codec::TcpResponseCodec;
let max_message_size = get_max_message_size();
let codec = TcpResponseCodec::new(Some(max_message_size));
let mut framed = FramedRead::new(read_half, codec);
while let Some(result) = framed.next().await {
let tx = loop {
if let Some(tx) = response_queue.pop() {
break tx;
}
if !healthy.load(Ordering::Relaxed) {
return Err(anyhow::anyhow!("Connection unhealthy, reader exiting"));
}
tokio::task::yield_now().await;
};
match result {
Ok(response_msg) => {
let _ = tx.send(Ok(response_msg.data));
}
Err(e) => {
let err_msg = format!("Failed to decode response: {}", e);
let _ = tx.send(Err(anyhow::anyhow!("{}", err_msg)));
healthy.store(false, Ordering::Relaxed);
Self::drain_response_waiters(
&response_queue,
format!("Connection closed after decode failure: {}", e),
);
writer_notify.notify_one();
return Err(anyhow::anyhow!("Failed to decode response"));
}
}
}
healthy.store(false, Ordering::Relaxed);
Self::drain_response_waiters(
&response_queue,
"Connection closed before response was received",
);
writer_notify.notify_one();
Ok(())
}
async fn send_request(&self, payload: Bytes, headers: &Headers) -> Result<Bytes> {
use crate::pipeline::network::codec::TcpRequestMessage;
if !self.healthy.load(Ordering::Relaxed) {
anyhow::bail!("Connection unhealthy (tasks failed)");
}
if self.closed.load(Ordering::Acquire) {
anyhow::bail!("Connection closed (writer exited)");
}
let endpoint_path = headers
.get("x-endpoint-path")
.ok_or_else(|| anyhow::anyhow!("Missing x-endpoint-path header for TCP request"))?
.to_string();
let trace = latency_trace_enabled();
let e2e_start = if trace {
Some(std::time::Instant::now())
} else {
None
};
let _permit = self
.admission
.acquire()
.await
.map_err(|_| anyhow::anyhow!("Connection closed (admission gate shut)"))?;
let request_msg = TcpRequestMessage::with_headers(endpoint_path, headers.clone(), payload);
let encoded_data = request_msg.encode()?;
let (response_tx, response_rx) = oneshot::channel();
if !self.healthy.load(Ordering::Relaxed) {
anyhow::bail!("Connection unhealthy (tasks failed)");
}
if self.closed.load(Ordering::Acquire) {
anyhow::bail!("Connection closed (writer exited)");
}
self.inflight.fetch_add(1, Ordering::Release);
let _inflight_guard = InflightGuard(self.inflight.clone());
self.submit_queue.push(PendingRequest {
encoded_data,
response_tx,
});
#[cfg(test)]
if let Some(barrier) = &self.post_enqueue_barrier {
barrier.wait().await;
barrier.wait().await;
}
if self.closed.load(Ordering::Acquire) {
Self::drain_pending(&self.submit_queue);
} else {
self.writer_notify.notify_one();
if self.closed.load(Ordering::Acquire) {
Self::drain_pending(&self.submit_queue);
}
}
let result = response_rx
.await
.map_err(|_| anyhow::anyhow!("Reader task closed"))?;
if trace && let Some(start) = e2e_start {
let e2e_ns = start.elapsed().as_nanos() as u64;
tracing::trace!(e2e_ns = e2e_ns, "TCP request e2e latency");
}
result
}
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Relaxed)
}
fn available_capacity(&self) -> usize {
let inflight = self.inflight.load(Ordering::Acquire) as usize;
self.channel_buffer.saturating_sub(inflight)
}
}
struct HostPool {
snapshot: arc_swap::ArcSwap<Vec<Arc<TcpConnection>>>,
lru: parking_lot::Mutex<LruCache<u64, Arc<TcpConnection>>>,
counter: AtomicU64,
next_id: AtomicU64,
connecting: AtomicBool,
connect_notify: tokio::sync::Notify,
max_connections: usize,
addr: SocketAddr,
connect_timeout: Duration,
channel_buffer: usize,
last_used_ms: AtomicU64,
}
impl HostPool {
fn new(addr: SocketAddr, config: &TcpRequestConfig) -> Self {
let cap = NonZeroUsize::new(config.pool_size).unwrap_or(NonZeroUsize::new(1).unwrap());
Self {
snapshot: arc_swap::ArcSwap::from_pointee(Vec::new()),
lru: parking_lot::Mutex::new(LruCache::new(cap)),
counter: AtomicU64::new(0),
next_id: AtomicU64::new(0),
connecting: AtomicBool::new(false),
connect_notify: tokio::sync::Notify::new(),
max_connections: config.pool_size,
addr,
connect_timeout: config.connect_timeout,
channel_buffer: config.channel_buffer,
last_used_ms: AtomicU64::new(current_time_ms()),
}
}
async fn get_connection(
&self,
connect_limiter: &tokio::sync::Semaphore,
) -> Result<Arc<TcpConnection>> {
{
let guard = self.snapshot.load();
let conns = &**guard;
let len = conns.len();
if len > 0 {
let start = self.counter.fetch_add(1, Ordering::Relaxed) as usize;
for i in 0..len {
let idx = (start + i) % len;
let conn = &conns[idx];
if conn.is_healthy() && conn.available_capacity() > 0 {
return Ok(conn.clone());
}
}
if len >= self.max_connections {
for i in 0..len {
let idx = (start + i) % len;
if conns[idx].is_healthy() {
return Ok(conns[idx].clone());
}
}
}
}
}
self.ensure_capacity_or_heal(connect_limiter).await
}
fn should_grow(healthy: &[Arc<TcpConnection>], max_connections: usize) -> bool {
if healthy.is_empty() {
return true;
}
if healthy.len() >= max_connections {
return false;
}
healthy.iter().all(|c| c.available_capacity() == 0)
}
async fn ensure_capacity_or_heal(
&self,
connect_limiter: &tokio::sync::Semaphore,
) -> Result<Arc<TcpConnection>> {
let (need_connect, new_snap) = {
let mut lru = self.lru.lock();
let dead: Vec<u64> = lru
.iter()
.filter(|(_, c)| !c.is_healthy())
.map(|(&k, _)| k)
.collect();
for k in dead {
lru.pop(&k);
}
let snap: Vec<Arc<TcpConnection>> = lru.iter().map(|(_, c)| c.clone()).collect();
let grow = Self::should_grow(&snap, self.max_connections);
(grow, snap)
};
self.snapshot.store(Arc::new(new_snap.clone()));
{
let guard = self.snapshot.load();
let conns = &**guard;
if !conns.is_empty() {
let start = self.counter.fetch_add(1, Ordering::Relaxed) as usize;
for i in 0..conns.len() {
let idx = (start + i) % conns.len();
let conn = &conns[idx];
if conn.is_healthy() && (!need_connect || conn.available_capacity() > 0) {
return Ok(conn.clone());
}
}
}
}
if !need_connect {
anyhow::bail!(
"No healthy TCP connection to {} and pool at capacity ({})",
self.addr,
self.max_connections
);
}
for retry in 0..MAX_CONNECT_RETRIES {
if self
.connecting
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let _connecting_guard = ConnectingGuard {
connecting: &self.connecting,
notify: &self.connect_notify,
};
let _permit = connect_limiter
.acquire()
.await
.map_err(|_| anyhow::anyhow!("Global connect limiter closed"))?;
let connect_result =
TcpConnection::connect(self.addr, self.connect_timeout, self.channel_buffer)
.await;
self.connecting.store(false, Ordering::Release);
match connect_result {
Ok(stream) => {
let new_conn = Arc::new(stream);
{
let mut lru = self.lru.lock();
let dead: Vec<u64> = lru
.iter()
.filter(|(_, c)| !c.is_healthy())
.map(|(&k, _)| k)
.collect();
for k in dead {
lru.pop(&k);
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
lru.put(id, new_conn.clone());
let snap: Vec<Arc<TcpConnection>> =
lru.iter().map(|(_, c)| c.clone()).collect();
drop(lru);
self.snapshot.store(Arc::new(snap));
}
self.connect_notify.notify_waiters();
return Ok(new_conn);
}
Err(e) => {
self.connect_notify.notify_waiters();
return Err(e);
}
}
}
let _ =
tokio::time::timeout(self.connect_timeout, self.connect_notify.notified()).await;
let guard = self.snapshot.load();
let conns = &**guard;
let len = conns.len();
if len > 0 {
let start = self.counter.fetch_add(1, Ordering::Relaxed) as usize;
for i in 0..len {
let idx = (start + i) % len;
if conns[idx].is_healthy() && conns[idx].available_capacity() > 0 {
return Ok(conns[idx].clone());
}
}
if len >= self.max_connections {
for i in 0..len {
let idx = (start + i) % len;
if conns[idx].is_healthy() {
return Ok(conns[idx].clone());
}
}
}
}
drop(guard);
tracing::trace!(
"TCP pool connect retry {}/{} for {}",
retry + 1,
MAX_CONNECT_RETRIES,
self.addr
);
}
{
let guard = self.snapshot.load();
let conns = &**guard;
let len = conns.len();
if len > 0 {
let start = self.counter.fetch_add(1, Ordering::Relaxed) as usize;
for i in 0..len {
let idx = (start + i) % len;
if conns[idx].is_healthy() {
return Ok(conns[idx].clone());
}
}
}
}
anyhow::bail!(
"No healthy TCP connection to {} after {} connect retries",
self.addr,
MAX_CONNECT_RETRIES
)
}
}
fn current_time_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
struct TcpConnectionPool {
hosts: DashMap<SocketAddr, Arc<HostPool>>,
config: TcpRequestConfig,
connect_limiter: Arc<tokio::sync::Semaphore>,
host_idle_ttl_ms: u64,
}
impl TcpConnectionPool {
fn new(config: TcpRequestConfig) -> Self {
let global_limit = std::env::var("DYN_TCP_GLOBAL_CONNECT_LIMIT")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_GLOBAL_CONNECT_LIMIT);
let host_idle_ttl_secs = std::env::var("DYN_TCP_HOST_IDLE_TTL_SECS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(DEFAULT_HOST_IDLE_TTL_SECS);
Self {
hosts: DashMap::new(),
config,
connect_limiter: Arc::new(tokio::sync::Semaphore::new(global_limit)),
host_idle_ttl_ms: host_idle_ttl_secs * 1000,
}
}
async fn get_connection(&self, addr: SocketAddr) -> Result<Arc<TcpConnection>> {
if let Some(host) = self.hosts.get(&addr).map(|entry| Arc::clone(&*entry)) {
host.last_used_ms
.store(current_time_ms(), Ordering::Relaxed);
return host.get_connection(&self.connect_limiter).await;
}
let host = self
.hosts
.entry(addr)
.or_insert_with(|| Arc::new(HostPool::new(addr, &self.config)))
.clone();
host.last_used_ms
.store(current_time_ms(), Ordering::Relaxed);
host.get_connection(&self.connect_limiter).await
}
async fn warmup(&self, addr: SocketAddr) {
let host = self
.hosts
.entry(addr)
.or_insert_with(|| Arc::new(HostPool::new(addr, &self.config)))
.clone();
host.last_used_ms
.store(current_time_ms(), Ordering::Relaxed);
match host.get_connection(&self.connect_limiter).await {
Ok(_) => tracing::debug!("TCP warmup: pre-connected to {}", addr),
Err(e) => tracing::warn!("TCP warmup: failed to pre-connect to {}: {}", addr, e),
}
}
fn start_warmup_watcher(
self: &Arc<Self>,
mut instance_rx: tokio::sync::watch::Receiver<Vec<crate::component::Instance>>,
cancel_token: tokio_util::sync::CancellationToken,
) {
let pool = Arc::clone(self);
tokio::spawn(async move {
let mut known_addrs = std::collections::HashSet::<SocketAddr>::new();
{
let instances = instance_rx.borrow_and_update();
for inst in instances.iter() {
if let crate::component::TransportType::Tcp(ref addr_str) = inst.transport
&& let Ok((sock, _)) = TcpRequestClient::parse_address(addr_str)
{
known_addrs.insert(sock);
}
}
}
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::debug!("TCP warmup watcher cancelled");
break;
}
result = instance_rx.changed() => {
if result.is_err() {
tracing::debug!("TCP warmup watcher: instance channel closed");
break;
}
let instances = instance_rx.borrow_and_update().clone();
let mut current_addrs = std::collections::HashSet::<SocketAddr>::new();
for inst in &instances {
if let crate::component::TransportType::Tcp(ref addr_str) = inst.transport
&& let Ok((sock, _)) = TcpRequestClient::parse_address(addr_str)
{
current_addrs.insert(sock);
if !known_addrs.contains(&sock) {
let pool = Arc::clone(&pool);
tokio::spawn(async move {
pool.warmup(sock).await;
});
}
}
}
known_addrs = current_addrs;
}
}
}
});
}
fn cleanup_idle_hosts(&self) {
let now = current_time_ms();
let ttl = self.host_idle_ttl_ms;
let stale: Vec<SocketAddr> = self
.hosts
.iter()
.filter(|entry| {
let last = entry.value().last_used_ms.load(Ordering::Relaxed);
if now.saturating_sub(last) <= ttl {
return false;
}
let snap = entry.value().snapshot.load();
!snap.iter().any(|c| c.inflight.load(Ordering::Acquire) > 0)
})
.map(|entry| *entry.key())
.collect();
for addr in stale {
tracing::debug!("Removing idle TCP host pool for {}", addr);
if let Some((_, host)) = self.hosts.remove(&addr) {
let snap = host.snapshot.load();
for conn in snap.iter() {
conn.healthy.store(false, Ordering::Relaxed);
conn.admission.close();
conn.writer_notify.notify_one();
conn.reader_handle.abort();
}
}
}
}
fn connection_counts(&self) -> (usize, usize) {
let mut active = 0usize;
let mut idle = 0usize;
for entry in self.hosts.iter() {
let snap = entry.value().snapshot.load();
for conn in snap.iter() {
if conn.is_healthy() {
if conn.inflight.load(Ordering::Acquire) > 0 {
active += 1;
} else {
idle += 1;
}
}
}
}
(active, idle)
}
fn start_idle_cleanup(self: &Arc<Self>) {
let Ok(handle) = tokio::runtime::Handle::try_current() else {
tracing::debug!("No tokio runtime available; idle host cleanup disabled");
return;
};
let pool = Arc::downgrade(self);
let interval = Duration::from_millis((self.host_idle_ttl_ms / 2).max(30_000));
handle.spawn(async move {
loop {
tokio::time::sleep(interval).await;
match pool.upgrade() {
Some(pool) => pool.cleanup_idle_hosts(),
None => break,
}
}
});
}
}
pub struct TcpRequestClient {
pool: Arc<TcpConnectionPool>,
config: TcpRequestConfig,
stats: Arc<TcpClientStats>,
}
struct TcpClientStats {
requests_sent: AtomicU64,
responses_received: AtomicU64,
errors: AtomicU64,
bytes_sent: AtomicU64,
bytes_received: AtomicU64,
}
impl TcpRequestClient {
pub fn new() -> Result<Self> {
Self::with_config(TcpRequestConfig::default())
}
pub fn with_config(config: TcpRequestConfig) -> Result<Self> {
let pool = Arc::new(TcpConnectionPool::new(config.clone()));
pool.start_idle_cleanup();
Ok(Self {
pool,
config,
stats: Arc::new(TcpClientStats {
requests_sent: AtomicU64::new(0),
responses_received: AtomicU64::new(0),
errors: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
}),
})
}
pub fn from_env() -> Result<Self> {
Self::with_config(TcpRequestConfig::from_env())
}
pub fn start_warmup(
&self,
instance_rx: tokio::sync::watch::Receiver<Vec<crate::component::Instance>>,
cancel_token: tokio_util::sync::CancellationToken,
) {
self.pool.start_warmup_watcher(instance_rx, cancel_token);
}
pub(crate) fn parse_address(address: &str) -> Result<(SocketAddr, Option<String>)> {
let addr_str = if let Some(stripped) = address.strip_prefix("tcp://") {
stripped
} else {
address
};
if let Some((socket_part, endpoint_name)) = addr_str.split_once('/') {
let socket_addr = socket_part
.parse::<SocketAddr>()
.map_err(|e| anyhow::anyhow!("Invalid TCP address '{}': {}", address, e))?;
Ok((socket_addr, Some(endpoint_name.to_string())))
} else {
let socket_addr = addr_str
.parse::<SocketAddr>()
.map_err(|e| anyhow::anyhow!("Invalid TCP address '{}': {}", address, e))?;
Ok((socket_addr, None))
}
}
}
impl Default for TcpRequestClient {
fn default() -> Self {
Self::new().expect("Failed to create TCP request client")
}
}
#[async_trait]
impl RequestPlaneClient for TcpRequestClient {
async fn send_request(
&self,
address: String,
payload: Bytes,
mut headers: Headers,
) -> Result<Bytes> {
tracing::debug!("TCP client sending request to address: {}", address);
self.stats.requests_sent.fetch_add(1, Ordering::Relaxed);
self.stats
.bytes_sent
.fetch_add(payload.len() as u64, Ordering::Relaxed);
let (addr, endpoint_name) = Self::parse_address(&address)?;
if let Some(endpoint_name) = endpoint_name {
headers.insert("x-endpoint-path".to_string(), endpoint_name.clone());
}
let conn = self.pool.get_connection(addr).await?;
let result = tokio::time::timeout(
self.config.request_timeout,
conn.send_request(payload, &headers),
)
.await;
match result {
Ok(Ok(response)) => {
self.stats
.responses_received
.fetch_add(1, Ordering::Relaxed);
self.stats
.bytes_received
.fetch_add(response.len() as u64, Ordering::Relaxed);
TCP_BYTES_RECEIVED_TOTAL.inc_by(response.len() as f64);
Ok(response)
}
Ok(Err(e)) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed);
TCP_ERRORS_TOTAL.inc();
tracing::warn!("TCP request failed to {}: {}", addr, e);
let cause = crate::error::DynamoError::from(
e.into_boxed_dyn_error() as Box<dyn std::error::Error + 'static>
);
Err(anyhow::anyhow!(
crate::error::DynamoError::builder()
.error_type(crate::error::ErrorType::CannotConnect)
.message(format!("TCP request to {addr} failed"))
.cause(cause)
.build()
))
}
Err(_) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed);
TCP_ERRORS_TOTAL.inc();
tracing::warn!("TCP request timeout to {}", addr);
Err(anyhow::anyhow!(
crate::error::DynamoError::builder()
.error_type(crate::error::ErrorType::CannotConnect)
.message(format!("TCP request to {addr} timed out"))
.build()
))
}
}
}
fn transport_name(&self) -> &'static str {
"tcp"
}
fn is_healthy(&self) -> bool {
true }
fn start_warmup(
&self,
instance_rx: tokio::sync::watch::Receiver<Vec<crate::component::Instance>>,
cancel_token: tokio_util::sync::CancellationToken,
) {
TcpRequestClient::start_warmup(self, instance_rx, cancel_token);
}
fn stats(&self) -> ClientStats {
let (active, idle) = self.pool.connection_counts();
ClientStats {
requests_sent: self.stats.requests_sent.load(Ordering::Relaxed),
responses_received: self.stats.responses_received.load(Ordering::Relaxed),
errors: self.stats.errors.load(Ordering::Relaxed),
bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
active_connections: active,
idle_connections: idle,
avg_latency_us: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use tokio::io::AsyncReadExt;
use tokio::net::TcpListener;
#[test]
fn test_tcp_config_default() {
let config = TcpRequestConfig::default();
assert_eq!(config.pool_size, DEFAULT_POOL_SIZE);
assert_eq!(
config.request_timeout,
Duration::from_secs(DEFAULT_TCP_REQUEST_TIMEOUT_SECS)
);
assert_eq!(config.channel_buffer, REQUEST_CHANNEL_BUFFER);
}
#[test]
fn test_tcp_config_from_env() {
unsafe {
std::env::set_var("DYN_TCP_REQUEST_TIMEOUT", "10");
std::env::set_var("DYN_TCP_POOL_SIZE", "50");
std::env::set_var("DYN_TCP_CONNECT_TIMEOUT", "3");
std::env::set_var("DYN_TCP_CHANNEL_BUFFER", "100");
}
let config = TcpRequestConfig::from_env();
assert_eq!(config.request_timeout, Duration::from_secs(10));
assert_eq!(config.pool_size, 50);
assert_eq!(config.connect_timeout, Duration::from_secs(3));
assert_eq!(config.channel_buffer, 100);
unsafe {
std::env::remove_var("DYN_TCP_REQUEST_TIMEOUT");
std::env::remove_var("DYN_TCP_POOL_SIZE");
std::env::remove_var("DYN_TCP_CONNECT_TIMEOUT");
std::env::remove_var("DYN_TCP_CHANNEL_BUFFER");
}
}
#[test]
fn test_parse_address() {
let (addr1, _) = TcpRequestClient::parse_address("127.0.0.1:8080").unwrap();
assert_eq!(addr1.port(), 8080);
let (addr2, _) = TcpRequestClient::parse_address("tcp://127.0.0.1:9090").unwrap();
assert_eq!(addr2.port(), 9090);
let (addr3, endpoint) =
TcpRequestClient::parse_address("127.0.0.1:8080/test_endpoint").unwrap();
assert_eq!(addr3.port(), 8080);
assert_eq!(endpoint, Some("test_endpoint".to_string()));
assert!(TcpRequestClient::parse_address("invalid").is_err());
}
#[test]
fn test_tcp_client_creation() {
let client = TcpRequestClient::new();
assert!(client.is_ok());
let client = client.unwrap();
assert_eq!(client.transport_name(), "tcp");
assert!(client.is_healthy());
}
async fn spawn_echo_server() -> (SocketAddr, Arc<AtomicUsize>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let conn_count = Arc::new(AtomicUsize::new(0));
let conn_count_clone = conn_count.clone();
tokio::spawn(async move {
loop {
let result = listener.accept().await;
if result.is_err() {
break;
}
let (stream, _) = result.unwrap();
conn_count_clone.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
if read_half.read_exact(&mut path_buf).await.is_err() {
break;
}
let mut headers_len_buf = [0u8; 2];
if read_half.read_exact(&mut headers_len_buf).await.is_err() {
break;
}
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
if read_half.read_exact(&mut headers_buf).await.is_err() {
break;
}
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
if read_half.read_exact(&mut payload_buf).await.is_err() {
break;
}
use crate::pipeline::network::codec::TcpResponseMessage;
let response = TcpResponseMessage::new(Bytes::from(payload_buf));
let encoded = response.encode().unwrap();
if write_half.write_all(&encoded).await.is_err() {
break;
}
}
});
}
});
(addr, conn_count)
}
#[tokio::test]
async fn test_connection_health_check() {
use crate::pipeline::network::codec::TcpResponseMessage;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
let mut len_buf = [0u8; 2];
read_half.read_exact(&mut len_buf).await.unwrap();
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
read_half.read_exact(&mut path_buf).await.unwrap();
let mut headers_len_buf = [0u8; 2];
read_half.read_exact(&mut headers_len_buf).await.unwrap();
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
read_half.read_exact(&mut headers_buf).await.unwrap();
let mut len_buf = [0u8; 4];
read_half.read_exact(&mut len_buf).await.unwrap();
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
read_half.read_exact(&mut payload_buf).await.unwrap();
let response = TcpResponseMessage::new(Bytes::from_static(b"pong"));
let encoded = response.encode().unwrap();
write_half.write_all(&encoded).await.unwrap();
});
let conn = TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap();
assert!(conn.is_healthy(), "New connection should be healthy");
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let result = conn.send_request(Bytes::from("ping"), &headers).await;
assert!(result.is_ok(), "Request should succeed");
assert_eq!(result.unwrap(), Bytes::from("pong"));
}
#[tokio::test]
async fn test_concurrent_requests_single_connection() {
use crate::pipeline::network::codec::TcpResponseMessage;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_clone = request_count.clone();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
for _ in 0..5 {
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
if read_half.read_exact(&mut path_buf).await.is_err() {
break;
}
let mut headers_len_buf = [0u8; 2];
if read_half.read_exact(&mut headers_len_buf).await.is_err() {
break;
}
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
if read_half.read_exact(&mut headers_buf).await.is_err() {
break;
}
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
if read_half.read_exact(&mut payload_buf).await.is_err() {
break;
}
request_count_clone.fetch_add(1, Ordering::SeqCst);
let response = TcpResponseMessage::new(Bytes::from(payload_buf));
let encoded = response.encode().unwrap();
if write_half.write_all(&encoded).await.is_err() {
break;
}
}
});
let conn = Arc::new(
TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap(),
);
let mut handles = vec![];
for i in 0..5 {
let conn = conn.clone();
let handle = tokio::spawn(async move {
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let payload = format!("request_{}", i);
conn.send_request(Bytes::from(payload.clone()), &headers)
.await
.map(|response| (payload, response))
});
handles.push(handle);
}
let mut results = vec![];
for handle in handles {
let result = handle.await.unwrap();
assert!(result.is_ok(), "Request should succeed");
results.push(result.unwrap());
}
assert_eq!(results.len(), 5);
assert_eq!(
request_count.load(Ordering::SeqCst),
5,
"Server should have received 5 requests"
);
}
#[tokio::test]
async fn test_lru_connection_reuse() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 4,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
let conn1 = pool.get_connection(addr).await.unwrap();
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let _ = conn1
.send_request(Bytes::from("ping1"), &headers)
.await
.unwrap();
drop(conn1);
let conn2 = pool.get_connection(addr).await.unwrap();
let _ = conn2
.send_request(Bytes::from("ping2"), &headers)
.await
.unwrap();
drop(conn2);
assert_eq!(
conn_count.load(Ordering::SeqCst),
1,
"Should reuse connection from pool (1 TCP connection total)"
);
}
#[tokio::test]
async fn test_lru_eviction_keeps_inflight_alive() {
let (addr, _conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 1,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
let conn = pool.get_connection(addr).await.unwrap();
conn.healthy.store(false, Ordering::Relaxed);
let conn2 = pool.get_connection(addr).await.unwrap();
assert!(conn2.is_healthy());
assert!(!conn.is_healthy());
assert!(Arc::strong_count(&conn.writer_handle) >= 1);
}
#[tokio::test]
async fn test_high_concurrency_bounded_connections() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let mut handles = vec![];
for i in 0..500 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
let mut ok_count = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok_count += 1;
}
}
let total_conns = conn_count.load(Ordering::SeqCst);
assert!(
total_conns <= 2,
"Should create at most pool_size (2) connections, got {}",
total_conns
);
assert!(ok_count > 0, "At least some requests should succeed");
}
#[tokio::test]
async fn test_thundering_herd_cold_start() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 4,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let mut handles = vec![];
for i in 0..100 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
for handle in handles {
let _ = handle.await.unwrap();
}
let total_conns = conn_count.load(Ordering::SeqCst);
assert!(
total_conns <= 4,
"Thundering herd: should create at most pool_size (4) connections, got {}",
total_conns
);
}
#[tokio::test]
async fn test_server_crash_recovery() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let cancel = tokio_util::sync::CancellationToken::new();
let cancel_clone = cancel.clone();
let server_task = tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
if let Ok((stream, _)) = result {
let cancel = cancel_clone.clone();
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
tokio::select! {
_ = cancel.cancelled() => break,
result = async {
let mut len_buf = [0u8; 2];
read_half.read_exact(&mut len_buf).await?;
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut buf = vec![0u8; path_len];
read_half.read_exact(&mut buf).await?;
let mut hlen = [0u8; 2];
read_half.read_exact(&mut hlen).await?;
let hl = u16::from_be_bytes(hlen) as usize;
let mut hbuf = vec![0u8; hl];
read_half.read_exact(&mut hbuf).await?;
let mut plen = [0u8; 4];
read_half.read_exact(&mut plen).await?;
let pl = u32::from_be_bytes(plen) as usize;
let mut pbuf = vec![0u8; pl];
read_half.read_exact(&mut pbuf).await?;
use crate::pipeline::network::codec::TcpResponseMessage;
let resp = TcpResponseMessage::new(Bytes::from(pbuf));
let enc = resp.encode()?;
write_half.write_all(&enc).await?;
Ok::<_, anyhow::Error>(())
} => {
if result.is_err() { break; }
}
}
}
});
}
}
_ = cancel_clone.cancelled() => break,
}
}
});
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(2),
connect_timeout: Duration::from_secs(2),
pool_size: 2,
channel_buffer: 10,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let conn = pool.get_connection(addr).await.unwrap();
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let result = conn
.send_request(Bytes::from("before_crash"), &headers)
.await;
assert!(result.is_ok());
drop(conn);
cancel.cancel();
let _ = server_task.await;
tokio::time::sleep(Duration::from_millis(50)).await;
let conn = pool.get_connection(addr).await;
if let Ok(conn) = conn {
let result = conn
.send_request(Bytes::from("after_crash"), &headers)
.await;
assert!(result.is_err() || !conn.is_healthy());
}
let listener2 = TcpListener::bind(addr).await.unwrap();
tokio::spawn(async move {
loop {
let result = listener2.accept().await;
if result.is_err() {
break;
}
let (stream, _) = result.unwrap();
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut buf = vec![0u8; path_len];
if read_half.read_exact(&mut buf).await.is_err() {
break;
}
let mut hlen = [0u8; 2];
if read_half.read_exact(&mut hlen).await.is_err() {
break;
}
let hl = u16::from_be_bytes(hlen) as usize;
let mut hbuf = vec![0u8; hl];
if read_half.read_exact(&mut hbuf).await.is_err() {
break;
}
let mut plen = [0u8; 4];
if read_half.read_exact(&mut plen).await.is_err() {
break;
}
let pl = u32::from_be_bytes(plen) as usize;
let mut pbuf = vec![0u8; pl];
if read_half.read_exact(&mut pbuf).await.is_err() {
break;
}
use crate::pipeline::network::codec::TcpResponseMessage;
let resp = TcpResponseMessage::new(Bytes::from(pbuf));
let enc = resp.encode().unwrap();
if write_half.write_all(&enc).await.is_err() {
break;
}
}
});
}
});
tokio::time::sleep(Duration::from_millis(100)).await;
let conn = pool.get_connection(addr).await.unwrap();
let result = conn
.send_request(Bytes::from("after_recovery"), &headers)
.await;
assert!(result.is_ok(), "Pool should heal after server recovery");
}
#[tokio::test]
async fn test_pool_scales_under_pressure() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 4,
channel_buffer: 1, };
let pool = Arc::new(TcpConnectionPool::new(config));
let mut handles = vec![];
for i in 0..20 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
for handle in handles {
let _ = handle.await.unwrap();
}
let total_conns = conn_count.load(Ordering::SeqCst);
assert!(
total_conns > 1,
"Pool should scale beyond 1 connection under pressure, got {}",
total_conns
);
assert!(
total_conns <= 4,
"Pool should not exceed pool_size (4), got {}",
total_conns
);
}
#[tokio::test]
async fn test_pool_size_cap_sustained_load() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 3,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
for round in 0..3 {
let mut handles = vec![];
for i in 0..200 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("round_{}_req_{}", round, i)), &headers)
.await
}));
}
for handle in handles {
let _ = handle.await.unwrap();
}
}
let total_conns = conn_count.load(Ordering::SeqCst);
assert!(
total_conns <= 3,
"Sustained load should not exceed pool_size (3), got {}",
total_conns
);
}
#[tokio::test]
async fn test_backpressure_small_channel() {
let (addr, _conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 1,
channel_buffer: 1,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let mut handles = vec![];
for i in 0..50 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
let mut ok_count = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok_count += 1;
}
}
assert_eq!(
ok_count, 50,
"All 50 requests should complete under backpressure"
);
}
#[tokio::test]
async fn test_no_recursive_retry_under_connect_contention() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 1,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let mut handles = vec![];
for _ in 0..50 {
let pool = pool.clone();
handles.push(tokio::spawn(async move { pool.get_connection(addr).await }));
}
let mut ok = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok += 1;
}
}
assert!(ok > 0, "At least some tasks should get connections");
assert_eq!(
conn_count.load(Ordering::SeqCst),
1,
"Only 1 TCP connection should be created"
);
}
#[tokio::test]
async fn test_global_connect_limiter_multi_host() {
let mut addrs = vec![];
let total_conns = Arc::new(AtomicUsize::new(0));
for _ in 0..4 {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
addrs.push(addr);
let total_conns = total_conns.clone();
tokio::spawn(async move {
loop {
let result = listener.accept().await;
if result.is_err() {
break;
}
let (stream, _) = result.unwrap();
total_conns.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut buf = vec![0u8; path_len];
if read_half.read_exact(&mut buf).await.is_err() {
break;
}
let mut hlen = [0u8; 2];
if read_half.read_exact(&mut hlen).await.is_err() {
break;
}
let hl = u16::from_be_bytes(hlen) as usize;
let mut hbuf = vec![0u8; hl];
if read_half.read_exact(&mut hbuf).await.is_err() {
break;
}
let mut plen = [0u8; 4];
if read_half.read_exact(&mut plen).await.is_err() {
break;
}
let pl = u32::from_be_bytes(plen) as usize;
let mut pbuf = vec![0u8; pl];
if read_half.read_exact(&mut pbuf).await.is_err() {
break;
}
use crate::pipeline::network::codec::TcpResponseMessage;
let resp = TcpResponseMessage::new(Bytes::from(pbuf));
let enc = resp.encode().unwrap();
if write_half.write_all(&enc).await.is_err() {
break;
}
}
});
}
});
}
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let mut handles = vec![];
for addr in &addrs {
let pool = pool.clone();
let addr = *addr;
for i in 0..10 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
}
let mut ok_count = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok_count += 1;
}
}
assert!(
ok_count > 0,
"Requests across multiple hosts should succeed"
);
let tc = total_conns.load(Ordering::SeqCst);
assert!(
tc <= 8,
"Total connections across 4 hosts should be <= 4*pool_size(2)=8, got {}",
tc
);
}
#[tokio::test]
async fn test_idle_host_pool_cleanup() {
let (addr, _conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
let pool = TcpConnectionPool {
host_idle_ttl_ms: 0,
..pool
};
let conn = pool.get_connection(addr).await.unwrap();
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let _ = conn.send_request(Bytes::from("test"), &headers).await;
drop(conn);
assert!(pool.hosts.contains_key(&addr), "Host entry should exist");
tokio::time::sleep(Duration::from_millis(10)).await;
pool.cleanup_idle_hosts();
assert!(
!pool.hosts.contains_key(&addr),
"Idle host entry should be cleaned up"
);
}
#[tokio::test]
async fn test_connection_pool_reuse() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
let conn1 = pool.get_connection(addr).await.unwrap();
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let _ = conn1
.send_request(Bytes::from("test1"), &headers)
.await
.unwrap();
drop(conn1);
tokio::time::sleep(Duration::from_millis(10)).await;
let conn2 = pool.get_connection(addr).await.unwrap();
let _ = conn2
.send_request(Bytes::from("test2"), &headers)
.await
.unwrap();
drop(conn2);
assert_eq!(
conn_count.load(Ordering::SeqCst),
1,
"Should reuse connection from pool"
);
}
#[tokio::test]
async fn test_unhealthy_connection_filtered() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
drop(stream);
}
});
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(1),
connect_timeout: Duration::from_secs(1),
pool_size: 2,
channel_buffer: 10,
};
let result =
TcpConnection::connect(addr, config.connect_timeout, config.channel_buffer).await;
if let Ok(conn) = result {
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let result = tokio::time::timeout(
Duration::from_millis(250),
conn.send_request(Bytes::from("test"), &headers),
)
.await;
assert!(
result.is_ok(),
"send_request should fail promptly when the peer closes cleanly"
);
assert!(
result.unwrap().is_err(),
"clean peer close should surface as a request error"
);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(
!conn.is_healthy(),
"Connection should become unhealthy after peer closes it"
);
}
}
#[tokio::test]
async fn test_warmup_pre_connects_on_instance_discovery() {
use crate::component::{Instance, TransportType};
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 4,
channel_buffer: 10,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let cancel_token = tokio_util::sync::CancellationToken::new();
let (instance_tx, instance_rx) = tokio::sync::watch::channel::<Vec<Instance>>(Vec::new());
pool.start_warmup_watcher(instance_rx, cancel_token.clone());
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(conn_count.load(Ordering::SeqCst), 0);
let tcp_addr = format!("{}:{}/test_endpoint", addr.ip(), addr.port());
instance_tx
.send(vec![Instance {
component: "test".to_string(),
endpoint: "test_ep".to_string(),
namespace: "default".to_string(),
instance_id: 1,
transport: TransportType::Tcp(tcp_addr),
device_type: None,
}])
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(
conn_count.load(Ordering::SeqCst),
1,
"Warmup should have created 1 connection to the newly discovered backend"
);
assert!(
pool.hosts.contains_key(&addr),
"Pool should have a host entry for the warmed address"
);
cancel_token.cancel();
}
#[tokio::test]
async fn test_closed_connection_after_enqueue_fails_promptly() {
let (addr, _conn_count) = spawn_echo_server().await;
let mut conn = TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap();
let barrier = Arc::new(tokio::sync::Barrier::new(2));
conn.post_enqueue_barrier = Some(barrier.clone());
let conn = Arc::new(conn);
let request = {
let conn = conn.clone();
tokio::spawn(async move {
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
tokio::time::timeout(
Duration::from_millis(200),
conn.send_request(Bytes::from("queued_after_close"), &headers),
)
.await
})
};
barrier.wait().await;
conn.closed.store(true, Ordering::Release);
conn.healthy.store(false, Ordering::Relaxed);
barrier.wait().await;
let result = request.await.unwrap();
assert!(
result.is_ok(),
"send_request should fail promptly instead of waiting for request_timeout"
);
let inner = result.unwrap();
assert!(inner.is_err(), "closed connection should return an error");
conn.writer_handle.abort();
conn.reader_handle.abort();
}
#[tokio::test]
async fn test_lockfree_submit_and_batch() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 1,
channel_buffer: 200,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let conn = pool.get_connection(addr).await.unwrap();
let mut handles = vec![];
for i in 0..100 {
let conn = conn.clone();
handles.push(tokio::spawn(async move {
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("batch_req_{}", i)), &headers)
.await
}));
}
let mut ok_count = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok_count += 1;
}
}
assert_eq!(ok_count, 100, "All 100 concurrent requests should succeed");
assert_eq!(
conn_count.load(Ordering::SeqCst),
1,
"Should use only 1 connection"
);
}
}