use super::unified_client::{ClientStats, Headers, RequestPlaneClient};
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use dashmap::DashMap;
use futures::StreamExt;
use std::io::IoSlice;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::{Mutex, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_util::codec::FramedRead;
const DEFAULT_TCP_REQUEST_TIMEOUT_SECS: u64 = 5;
const DEFAULT_POOL_SIZE: usize = 100;
const REQUEST_CHANNEL_BUFFER: usize = 50;
const READ_BUFFER_SIZE: usize = 65536;
const DEFAULT_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
fn get_max_message_size() -> usize {
std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE)
}
#[derive(Debug, Clone)]
pub struct TcpRequestConfig {
pub request_timeout: Duration,
pub pool_size: usize,
pub connect_timeout: Duration,
pub channel_buffer: usize,
}
impl Default for TcpRequestConfig {
fn default() -> Self {
Self {
request_timeout: Duration::from_secs(DEFAULT_TCP_REQUEST_TIMEOUT_SECS),
pool_size: DEFAULT_POOL_SIZE,
connect_timeout: Duration::from_secs(5),
channel_buffer: REQUEST_CHANNEL_BUFFER,
}
}
}
impl TcpRequestConfig {
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(val) = std::env::var("DYN_TCP_REQUEST_TIMEOUT")
&& let Ok(timeout) = val.parse::<u64>()
{
config.request_timeout = Duration::from_secs(timeout);
}
if let Ok(val) = std::env::var("DYN_TCP_POOL_SIZE")
&& let Ok(size) = val.parse::<usize>()
{
config.pool_size = size;
}
if let Ok(val) = std::env::var("DYN_TCP_CONNECT_TIMEOUT")
&& let Ok(timeout) = val.parse::<u64>()
{
config.connect_timeout = Duration::from_secs(timeout);
}
if let Ok(val) = std::env::var("DYN_TCP_CHANNEL_BUFFER")
&& let Ok(size) = val.parse::<usize>()
{
config.channel_buffer = size;
}
config
}
}
struct TcpRequest {
encoded_data: Bytes,
response_tx: oneshot::Sender<Result<Bytes>>,
}
struct TcpConnection {
addr: SocketAddr,
request_tx: mpsc::Sender<TcpRequest>,
writer_handle: Arc<JoinHandle<()>>,
reader_handle: Arc<JoinHandle<()>>,
healthy: Arc<AtomicBool>,
}
impl TcpConnection {
async fn connect(addr: SocketAddr, timeout: Duration, channel_buffer: usize) -> Result<Self> {
let stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
.await
.map_err(|_| anyhow::anyhow!("TCP connect timeout to {}", addr))??;
Self::configure_socket(&stream)?;
let (read_half, write_half) = tokio::io::split(stream);
let (request_tx, request_rx) = mpsc::channel::<TcpRequest>(channel_buffer);
let (response_tx_channel, response_rx_channel) =
mpsc::unbounded_channel::<oneshot::Sender<Result<Bytes>>>();
let healthy = Arc::new(AtomicBool::new(true));
let writer_handle = {
let healthy = healthy.clone();
tokio::spawn(async move {
if let Err(e) = Self::writer_task(write_half, request_rx, response_tx_channel).await
{
tracing::debug!("Writer task failed for {}: {}", addr, e);
healthy.store(false, Ordering::Relaxed);
}
})
};
let reader_handle = {
let healthy = healthy.clone();
tokio::spawn(async move {
if let Err(e) = Self::reader_task(read_half, response_rx_channel).await {
tracing::debug!("Reader task failed for {}: {}", addr, e);
healthy.store(false, Ordering::Relaxed);
}
})
};
Ok(Self {
addr,
request_tx,
writer_handle: Arc::new(writer_handle),
reader_handle: Arc::new(reader_handle),
healthy,
})
}
fn configure_socket(stream: &TcpStream) -> Result<()> {
use socket2::{SockRef, Socket};
let sock_ref = SockRef::from(stream);
sock_ref.set_nodelay(true)?;
sock_ref.set_recv_buffer_size(2 * 1024 * 1024)?; sock_ref.set_send_buffer_size(2 * 1024 * 1024)?;
#[cfg(feature = "tcp-low-latency")]
{
use std::os::unix::io::AsRawFd;
unsafe {
let fd = stream.as_raw_fd();
let quickack: libc::c_int = 1;
libc::setsockopt(
fd,
libc::SOL_TCP,
libc::TCP_QUICKACK,
&quickack as *const _ as *const libc::c_void,
std::mem::size_of_val(&quickack) as libc::socklen_t,
);
let busy_poll: libc::c_int = 50;
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_BUSY_POLL,
&busy_poll as *const _ as *const libc::c_void,
std::mem::size_of_val(&busy_poll) as libc::socklen_t,
);
}
tracing::debug!("TCP low-latency optimizations enabled (TCP_QUICKACK, SO_BUSY_POLL)");
}
Ok(())
}
async fn writer_task(
mut write_half: tokio::io::WriteHalf<TcpStream>,
mut request_rx: mpsc::Receiver<TcpRequest>,
response_tx_channel: mpsc::UnboundedSender<oneshot::Sender<Result<Bytes>>>,
) -> Result<()> {
while let Some(req) = request_rx.recv().await {
match write_half.write_all(&req.encoded_data).await {
Ok(()) => {
if response_tx_channel.send(req.response_tx).is_err() {
tracing::debug!("Reader task closed, stopping writer");
break;
}
}
Err(e) => {
let err_msg = format!("Write failed: {}", e);
let _ = req.response_tx.send(Err(anyhow::anyhow!("{}", err_msg)));
return Err(anyhow::anyhow!("{}", err_msg));
}
}
}
Ok(())
}
async fn reader_task(
read_half: tokio::io::ReadHalf<TcpStream>,
mut response_rx_channel: mpsc::UnboundedReceiver<oneshot::Sender<Result<Bytes>>>,
) -> Result<()> {
use crate::pipeline::network::codec::TcpResponseCodec;
let max_message_size = get_max_message_size();
let codec = TcpResponseCodec::new(Some(max_message_size));
let mut framed = FramedRead::new(read_half, codec);
while let Some(response_tx) = response_rx_channel.recv().await {
match framed.next().await {
Some(Ok(response_msg)) => {
let _ = response_tx.send(Ok(response_msg.data));
}
Some(Err(e)) => {
let err = anyhow::anyhow!("Failed to decode response: {}", e);
let _ = response_tx.send(Err(err));
return Err(anyhow::anyhow!("Failed to decode response"));
}
None => {
let err = anyhow::anyhow!("Connection closed by peer");
let _ = response_tx.send(Err(err));
return Err(anyhow::anyhow!("Connection closed"));
}
}
}
Ok(())
}
async fn send_request(&self, payload: Bytes, headers: &Headers) -> Result<Bytes> {
use crate::pipeline::network::codec::TcpRequestMessage;
if !self.healthy.load(Ordering::Relaxed) {
anyhow::bail!("Connection unhealthy (tasks failed)");
}
let endpoint_path = headers
.get("x-endpoint-path")
.ok_or_else(|| anyhow::anyhow!("Missing x-endpoint-path header for TCP request"))?
.to_string();
let request_msg = TcpRequestMessage::with_headers(endpoint_path, headers.clone(), payload);
let encoded_data = request_msg.encode()?;
let (response_tx, response_rx) = oneshot::channel();
let req = TcpRequest {
encoded_data,
response_tx,
};
self.request_tx
.send(req)
.await
.map_err(|_| anyhow::anyhow!("Writer task closed"))?;
response_rx
.await
.map_err(|_| anyhow::anyhow!("Reader task closed"))?
}
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Relaxed)
}
}
struct TcpConnectionPool {
pools: DashMap<SocketAddr, Arc<Mutex<Vec<TcpConnection>>>>,
config: TcpRequestConfig,
}
impl TcpConnectionPool {
fn new(config: TcpRequestConfig) -> Self {
Self {
pools: DashMap::new(),
config,
}
}
async fn get_connection(&self, addr: SocketAddr) -> Result<TcpConnection> {
if let Some(pool) = self.pools.get(&addr) {
let mut pool = pool.lock().await;
while let Some(conn) = pool.pop() {
if conn.is_healthy() {
return Ok(conn);
} else {
tracing::debug!("Discarding unhealthy connection for {}", addr);
}
}
}
tracing::debug!("Creating new TCP connection to {}", addr);
TcpConnection::connect(
addr,
self.config.connect_timeout,
self.config.channel_buffer,
)
.await
}
async fn return_connection(&self, conn: TcpConnection) {
if !conn.is_healthy() {
tracing::debug!("Not returning unhealthy connection to pool");
return;
}
let addr = conn.addr;
let pool = self
.pools
.entry(addr)
.or_insert_with(|| Arc::new(Mutex::new(Vec::new())))
.clone();
let mut pool = pool.lock().await;
if pool.len() < self.config.pool_size {
pool.push(conn);
} else {
tracing::debug!("Connection pool full for {}, dropping connection", addr);
}
}
}
pub struct TcpRequestClient {
pool: Arc<TcpConnectionPool>,
config: TcpRequestConfig,
stats: Arc<TcpClientStats>,
}
struct TcpClientStats {
requests_sent: AtomicU64,
responses_received: AtomicU64,
errors: AtomicU64,
bytes_sent: AtomicU64,
bytes_received: AtomicU64,
}
impl TcpRequestClient {
pub fn new() -> Result<Self> {
Self::with_config(TcpRequestConfig::default())
}
pub fn with_config(config: TcpRequestConfig) -> Result<Self> {
Ok(Self {
pool: Arc::new(TcpConnectionPool::new(config.clone())),
config,
stats: Arc::new(TcpClientStats {
requests_sent: AtomicU64::new(0),
responses_received: AtomicU64::new(0),
errors: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
}),
})
}
pub fn from_env() -> Result<Self> {
Self::with_config(TcpRequestConfig::from_env())
}
fn parse_address(address: &str) -> Result<(SocketAddr, Option<String>)> {
let addr_str = if let Some(stripped) = address.strip_prefix("tcp://") {
stripped
} else {
address
};
if let Some((socket_part, endpoint_name)) = addr_str.split_once('/') {
let socket_addr = socket_part
.parse::<SocketAddr>()
.map_err(|e| anyhow::anyhow!("Invalid TCP address '{}': {}", address, e))?;
Ok((socket_addr, Some(endpoint_name.to_string())))
} else {
let socket_addr = addr_str
.parse::<SocketAddr>()
.map_err(|e| anyhow::anyhow!("Invalid TCP address '{}': {}", address, e))?;
Ok((socket_addr, None))
}
}
}
impl Default for TcpRequestClient {
fn default() -> Self {
Self::new().expect("Failed to create TCP request client")
}
}
#[async_trait]
impl RequestPlaneClient for TcpRequestClient {
async fn send_request(
&self,
address: String,
payload: Bytes,
mut headers: Headers,
) -> Result<Bytes> {
tracing::debug!("TCP client sending request to address: {}", address);
self.stats.requests_sent.fetch_add(1, Ordering::Relaxed);
self.stats
.bytes_sent
.fetch_add(payload.len() as u64, Ordering::Relaxed);
let (addr, endpoint_name) = Self::parse_address(&address)?;
if let Some(endpoint_name) = endpoint_name {
headers.insert("x-endpoint-path".to_string(), endpoint_name.clone());
}
let conn = self.pool.get_connection(addr).await?;
let result = tokio::time::timeout(
self.config.request_timeout,
conn.send_request(payload, &headers),
)
.await;
match result {
Ok(Ok(response)) => {
self.stats
.responses_received
.fetch_add(1, Ordering::Relaxed);
self.stats
.bytes_received
.fetch_add(response.len() as u64, Ordering::Relaxed);
self.pool.return_connection(conn).await;
Ok(response)
}
Ok(Err(e)) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed);
tracing::warn!("TCP request failed to {}: {}", addr, e);
let cause = crate::error::DynamoError::from(
e.into_boxed_dyn_error() as Box<dyn std::error::Error + 'static>
);
Err(anyhow::anyhow!(
crate::error::DynamoError::builder()
.error_type(crate::error::ErrorType::CannotConnect)
.message(format!("TCP request to {addr} failed"))
.cause(cause)
.build()
))
}
Err(_) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed);
tracing::warn!("TCP request timeout to {}", addr);
Err(anyhow::anyhow!(
crate::error::DynamoError::builder()
.error_type(crate::error::ErrorType::CannotConnect)
.message(format!("TCP request to {addr} timed out"))
.build()
))
}
}
}
fn transport_name(&self) -> &'static str {
"tcp"
}
fn is_healthy(&self) -> bool {
true }
fn stats(&self) -> ClientStats {
ClientStats {
requests_sent: self.stats.requests_sent.load(Ordering::Relaxed),
responses_received: self.stats.responses_received.load(Ordering::Relaxed),
errors: self.stats.errors.load(Ordering::Relaxed),
bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
active_connections: 0, idle_connections: 0,
avg_latency_us: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use tokio::net::TcpListener;
#[test]
fn test_tcp_config_default() {
let config = TcpRequestConfig::default();
assert_eq!(config.pool_size, DEFAULT_POOL_SIZE);
assert_eq!(
config.request_timeout,
Duration::from_secs(DEFAULT_TCP_REQUEST_TIMEOUT_SECS)
);
assert_eq!(config.channel_buffer, REQUEST_CHANNEL_BUFFER);
}
#[test]
fn test_tcp_config_from_env() {
unsafe {
std::env::set_var("DYN_TCP_REQUEST_TIMEOUT", "10");
std::env::set_var("DYN_TCP_POOL_SIZE", "50");
std::env::set_var("DYN_TCP_CONNECT_TIMEOUT", "3");
std::env::set_var("DYN_TCP_CHANNEL_BUFFER", "100");
}
let config = TcpRequestConfig::from_env();
assert_eq!(config.request_timeout, Duration::from_secs(10));
assert_eq!(config.pool_size, 50);
assert_eq!(config.connect_timeout, Duration::from_secs(3));
assert_eq!(config.channel_buffer, 100);
unsafe {
std::env::remove_var("DYN_TCP_REQUEST_TIMEOUT");
std::env::remove_var("DYN_TCP_POOL_SIZE");
std::env::remove_var("DYN_TCP_CONNECT_TIMEOUT");
std::env::remove_var("DYN_TCP_CHANNEL_BUFFER");
}
}
#[test]
fn test_parse_address() {
let (addr1, _) = TcpRequestClient::parse_address("127.0.0.1:8080").unwrap();
assert_eq!(addr1.port(), 8080);
let (addr2, _) = TcpRequestClient::parse_address("tcp://127.0.0.1:9090").unwrap();
assert_eq!(addr2.port(), 9090);
let (addr3, endpoint) =
TcpRequestClient::parse_address("127.0.0.1:8080/test_endpoint").unwrap();
assert_eq!(addr3.port(), 8080);
assert_eq!(endpoint, Some("test_endpoint".to_string()));
assert!(TcpRequestClient::parse_address("invalid").is_err());
}
#[test]
fn test_tcp_client_creation() {
let client = TcpRequestClient::new();
assert!(client.is_ok());
let client = client.unwrap();
assert_eq!(client.transport_name(), "tcp");
assert!(client.is_healthy());
}
#[tokio::test]
async fn test_connection_health_check() {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage};
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
let mut len_buf = [0u8; 2];
read_half.read_exact(&mut len_buf).await.unwrap();
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
read_half.read_exact(&mut path_buf).await.unwrap();
let mut headers_len_buf = [0u8; 2];
read_half.read_exact(&mut headers_len_buf).await.unwrap();
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
read_half.read_exact(&mut headers_buf).await.unwrap();
let mut len_buf = [0u8; 4];
read_half.read_exact(&mut len_buf).await.unwrap();
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
read_half.read_exact(&mut payload_buf).await.unwrap();
let response = TcpResponseMessage::new(Bytes::from_static(b"pong"));
let encoded = response.encode().unwrap();
write_half.write_all(&encoded).await.unwrap();
});
let conn = TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap();
assert!(conn.is_healthy(), "New connection should be healthy");
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let result = conn.send_request(Bytes::from("ping"), &headers).await;
assert!(result.is_ok(), "Request should succeed");
assert_eq!(result.unwrap(), Bytes::from("pong"));
assert!(
conn.is_healthy(),
"Connection should remain healthy after successful request"
);
}
#[tokio::test]
async fn test_concurrent_requests_single_connection() {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage};
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_clone = request_count.clone();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
for _ in 0..5 {
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
if read_half.read_exact(&mut path_buf).await.is_err() {
break;
}
let mut headers_len_buf = [0u8; 2];
if read_half.read_exact(&mut headers_len_buf).await.is_err() {
break;
}
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
if read_half.read_exact(&mut headers_buf).await.is_err() {
break;
}
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
if read_half.read_exact(&mut payload_buf).await.is_err() {
break;
}
request_count_clone.fetch_add(1, Ordering::SeqCst);
let response = TcpResponseMessage::new(Bytes::from(payload_buf));
let encoded = response.encode().unwrap();
if write_half.write_all(&encoded).await.is_err() {
break;
}
}
});
let conn = Arc::new(
TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap(),
);
let mut handles = vec![];
for i in 0..5 {
let conn = conn.clone();
let handle = tokio::spawn(async move {
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let payload = format!("request_{}", i);
conn.send_request(Bytes::from(payload.clone()), &headers)
.await
.map(|response| (payload, response))
});
handles.push(handle);
}
let mut results = vec![];
for handle in handles {
let result = handle.await.unwrap();
assert!(result.is_ok(), "Request should succeed");
results.push(result.unwrap());
}
assert_eq!(results.len(), 5);
assert_eq!(
request_count.load(Ordering::SeqCst),
5,
"Server should have received 5 requests"
);
}
#[tokio::test]
async fn test_connection_pool_reuse() {
use crate::pipeline::network::codec::TcpResponseMessage;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connection_count = Arc::new(AtomicUsize::new(0));
let connection_count_clone = connection_count.clone();
tokio::spawn(async move {
loop {
let result = listener.accept().await;
if result.is_err() {
break;
}
let (stream, _) = result.unwrap();
connection_count_clone.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
if read_half.read_exact(&mut path_buf).await.is_err() {
break;
}
let mut headers_len_buf = [0u8; 2];
if read_half.read_exact(&mut headers_len_buf).await.is_err() {
break;
}
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
if read_half.read_exact(&mut headers_buf).await.is_err() {
break;
}
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
if read_half.read_exact(&mut payload_buf).await.is_err() {
break;
}
let response = TcpResponseMessage::new(Bytes::from_static(b"ok"));
let encoded = response.encode().unwrap();
if write_half.write_all(&encoded).await.is_err() {
break;
}
}
});
}
});
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
let conn1 = pool.get_connection(addr).await.unwrap();
pool.return_connection(conn1).await;
tokio::time::sleep(Duration::from_millis(10)).await;
let conn2 = pool.get_connection(addr).await.unwrap();
pool.return_connection(conn2).await;
assert_eq!(
connection_count.load(Ordering::SeqCst),
1,
"Should reuse connection from pool"
);
}
#[tokio::test]
async fn test_unhealthy_connection_filtered() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
drop(stream); }
});
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(1),
connect_timeout: Duration::from_secs(1),
pool_size: 2,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config.clone());
let result =
TcpConnection::connect(addr, config.connect_timeout, config.channel_buffer).await;
if let Ok(conn) = result {
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let _ = conn.send_request(Bytes::from("test"), &headers).await;
pool.return_connection(conn).await;
let result2 = pool.get_connection(addr).await;
let _ = result2;
}
}
}