use crate::error::{ProtocolError, Result};
use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, Semaphore};
use tracing::{debug, error, warn};
pub type RequestId = u64;
#[derive(Debug, Clone)]
pub struct MultiplexFrame {
pub request_id: RequestId,
pub payload: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct MultiplexConfig {
pub max_in_flight: usize,
pub request_timeout: Duration,
pub send_buffer_size: usize,
}
impl MultiplexConfig {
pub fn validate(&self) -> Result<()> {
let mut errors = Vec::new();
if self.max_in_flight == 0 {
errors.push("max_in_flight must be greater than 0".to_string());
}
if self.max_in_flight > 1_000_000 {
errors.push(format!(
"max_in_flight ({}) exceeds recommended limit (1,000,000)",
self.max_in_flight
));
}
if self.request_timeout.is_zero() {
errors.push("request_timeout must be greater than 0".to_string());
}
if self.request_timeout.as_millis() < 100 {
errors.push(format!(
"request_timeout ({} ms) is too short (minimum: 100ms)",
self.request_timeout.as_millis()
));
}
if self.request_timeout.as_secs() > 300 {
errors.push(format!(
"request_timeout ({} seconds) is unusually long (recommended: < 5 minutes)",
self.request_timeout.as_secs()
));
}
if self.send_buffer_size == 0 {
errors.push("send_buffer_size must be greater than 0".to_string());
}
if self.send_buffer_size > 10_000 {
errors.push(format!(
"send_buffer_size ({}) is unusually large (recommended: < 10,000)",
self.send_buffer_size
));
}
if errors.is_empty() {
Ok(())
} else {
Err(ProtocolError::ConfigError(format!(
"Multiplex configuration validation failed:\n - {}",
errors.join("\n - ")
)))
}
}
}
impl Default for MultiplexConfig {
fn default() -> Self {
Self {
max_in_flight: 10_000, request_timeout: Duration::from_secs(30),
send_buffer_size: 100,
}
}
}
struct PendingRequest {
response_tx: oneshot::Sender<Vec<u8>>,
created_at: Instant,
}
#[derive(Debug, Default)]
pub struct MultiplexMetrics {
pub requests_sent: AtomicU64,
pub responses_received: AtomicU64,
pub timeouts: AtomicU64,
pub errors: AtomicU64,
pub in_flight: AtomicU64,
}
pub struct Multiplexer<R, W>
where
R: AsyncReadExt + Send + Unpin + 'static,
W: AsyncWriteExt + Send + Unpin + 'static,
{
config: MultiplexConfig,
next_request_id: Arc<AtomicU64>,
pending: Arc<DashMap<RequestId, PendingRequest>>,
send_tx: mpsc::Sender<MultiplexFrame>,
backpressure: Arc<Semaphore>,
metrics: Arc<MultiplexMetrics>,
reader: Option<R>,
writer: Option<W>,
}
impl<R, W> Multiplexer<R, W>
where
R: AsyncReadExt + Send + Unpin + 'static,
W: AsyncWriteExt + Send + Unpin + 'static,
{
pub fn new(reader: R, writer: W, config: MultiplexConfig) -> Self {
let (send_tx, send_rx) = mpsc::channel(config.send_buffer_size);
let pending = Arc::new(DashMap::new());
let metrics = Arc::new(MultiplexMetrics::default());
let backpressure = Arc::new(Semaphore::new(config.max_in_flight));
let mut multiplexer = Self {
config: config.clone(),
next_request_id: Arc::new(AtomicU64::new(1)),
pending: pending.clone(),
send_tx,
backpressure,
metrics: metrics.clone(),
reader: Some(reader),
writer: Some(writer),
};
#[allow(clippy::expect_used)] let writer = multiplexer.writer.take().expect("Writer should exist");
tokio::spawn(Self::send_loop(writer, send_rx, metrics.clone()));
#[allow(clippy::expect_used)] let reader = multiplexer.reader.take().expect("Reader should exist");
tokio::spawn(Self::receive_loop(reader, pending.clone(), metrics.clone()));
let pending_clone = pending.clone();
let timeout = config.request_timeout;
let metrics_clone = metrics.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
Self::cleanup_stale_requests(&pending_clone, timeout, &metrics_clone);
}
});
multiplexer
}
pub async fn request(&self, payload: Vec<u8>) -> Result<Vec<u8>> {
let _permit = self
.backpressure
.acquire()
.await
.map_err(|_| ProtocolError::PoolExhausted)?;
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
let (response_tx, response_rx) = oneshot::channel();
self.pending.insert(
request_id,
PendingRequest {
response_tx,
created_at: Instant::now(),
},
);
self.metrics.in_flight.fetch_add(1, Ordering::Relaxed);
let frame = MultiplexFrame {
request_id,
payload,
};
self.send_tx
.send(frame)
.await
.map_err(|_| ProtocolError::ConnectionClosed)?;
self.metrics.requests_sent.fetch_add(1, Ordering::Relaxed);
tokio::time::timeout(self.config.request_timeout, response_rx)
.await
.map_err(|_| {
self.pending.remove(&request_id);
self.metrics.timeouts.fetch_add(1, Ordering::Relaxed);
self.metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
ProtocolError::Timeout
})?
.map_err(|_| {
self.metrics.errors.fetch_add(1, Ordering::Relaxed);
self.metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
ProtocolError::ConnectionClosed
})
}
async fn send_loop(
mut writer: W,
mut send_rx: mpsc::Receiver<MultiplexFrame>,
_metrics: Arc<MultiplexMetrics>,
) {
while let Some(frame) = send_rx.recv().await {
let payload_len = frame.payload.len() as u32;
if let Err(e) = writer.write_u64(frame.request_id).await {
error!("Failed to write request ID: {}", e);
break;
}
if let Err(e) = writer.write_u32(payload_len).await {
error!("Failed to write payload length: {}", e);
break;
}
if let Err(e) = writer.write_all(&frame.payload).await {
error!("Failed to write payload: {}", e);
break;
}
if let Err(e) = writer.flush().await {
error!("Failed to flush writer: {}", e);
break;
}
debug!("Sent multiplexed request {}", frame.request_id);
}
}
async fn receive_loop(
mut reader: R,
pending: Arc<DashMap<RequestId, PendingRequest>>,
metrics: Arc<MultiplexMetrics>,
) {
loop {
let request_id = match reader.read_u64().await {
Ok(id) => id,
Err(e) => {
error!("Failed to read request ID: {}", e);
break;
}
};
let payload_len = match reader.read_u32().await {
Ok(len) => len as usize,
Err(e) => {
error!("Failed to read payload length: {}", e);
break;
}
};
let mut payload = vec![0u8; payload_len];
if let Err(e) = reader.read_exact(&mut payload).await {
error!("Failed to read payload: {}", e);
break;
}
debug!("Received multiplexed response {}", request_id);
if let Some((_, pending_req)) = pending.remove(&request_id) {
metrics.responses_received.fetch_add(1, Ordering::Relaxed);
metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
if pending_req.response_tx.send(payload).is_err() {
warn!("Failed to send response to waiting request {}", request_id);
}
} else {
warn!("Received response for unknown request {}", request_id);
}
}
}
fn cleanup_stale_requests(
pending: &Arc<DashMap<RequestId, PendingRequest>>,
timeout: Duration,
metrics: &Arc<MultiplexMetrics>,
) {
let now = Instant::now();
let mut stale_count = 0;
pending.retain(|_id, req| {
let is_stale = now.duration_since(req.created_at) > timeout;
if is_stale {
stale_count += 1;
metrics.timeouts.fetch_add(1, Ordering::Relaxed);
metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
}
!is_stale
});
if stale_count > 0 {
warn!("Cleaned up {} stale requests", stale_count);
}
}
pub fn metrics(&self) -> Arc<MultiplexMetrics> {
self.metrics.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[allow(clippy::unwrap_used)] async fn test_multiplex_single_request() {
let (client_stream, server_stream) = tokio::io::duplex(1024);
let (client_reader, client_writer) = tokio::io::split(client_stream);
let config = MultiplexConfig::default();
let multiplexer = Multiplexer::new(client_reader, client_writer, config);
tokio::spawn(async move {
let (mut server_reader, mut server_writer) = tokio::io::split(server_stream);
#[allow(clippy::while_let_loop)] loop {
let request_id = match server_reader.read_u64().await {
Ok(id) => id,
Err(_) => break,
};
let payload_len = match server_reader.read_u32().await {
Ok(len) => len,
Err(_) => break,
};
let mut payload = vec![0u8; payload_len as usize];
if server_reader.read_exact(&mut payload).await.is_err() {
break;
}
if server_writer.write_u64(request_id).await.is_err() {
break;
}
if server_writer.write_u32(payload_len).await.is_err() {
break;
}
if server_writer.write_all(&payload).await.is_err() {
break;
}
if server_writer.flush().await.is_err() {
break;
}
}
});
let response = multiplexer.request(b"hello".to_vec()).await.unwrap();
assert_eq!(response, b"hello");
let metrics = multiplexer.metrics();
assert_eq!(metrics.requests_sent.load(Ordering::Relaxed), 1);
assert_eq!(metrics.responses_received.load(Ordering::Relaxed), 1);
}
#[tokio::test]
#[allow(clippy::unwrap_used)] async fn test_multiplex_concurrent_requests() {
let (client_stream, server_stream) = tokio::io::duplex(8192);
let (client_reader, client_writer) = tokio::io::split(client_stream);
let config = MultiplexConfig::default();
let multiplexer = Arc::new(Multiplexer::new(client_reader, client_writer, config));
tokio::spawn(async move {
let (mut server_reader, mut server_writer) = tokio::io::split(server_stream);
#[allow(clippy::while_let_loop)] loop {
let request_id = match server_reader.read_u64().await {
Ok(id) => id,
Err(_) => break,
};
let payload_len = match server_reader.read_u32().await {
Ok(len) => len,
Err(_) => break,
};
let mut payload = vec![0u8; payload_len as usize];
if server_reader.read_exact(&mut payload).await.is_err() {
break;
}
if server_writer.write_u64(request_id).await.is_err() {
break;
}
if server_writer.write_u32(payload_len).await.is_err() {
break;
}
if server_writer.write_all(&payload).await.is_err() {
break;
}
if server_writer.flush().await.is_err() {
break;
}
}
});
let mut tasks = vec![];
for i in 0..10 {
let multiplexer_clone = multiplexer.clone();
tasks.push(tokio::spawn(async move {
let payload = format!("request_{}", i).into_bytes();
multiplexer_clone.request(payload.clone()).await.unwrap()
}));
}
for task in tasks {
task.await.unwrap();
}
let metrics = multiplexer.metrics();
assert_eq!(metrics.requests_sent.load(Ordering::Relaxed), 10);
assert_eq!(metrics.responses_received.load(Ordering::Relaxed), 10);
}
#[tokio::test]
async fn test_multiplex_config_validation() {
let config = MultiplexConfig::default();
assert!(config.validate().is_ok());
}
#[tokio::test]
async fn test_multiplex_config_validation_zero_in_flight() {
let config = MultiplexConfig {
max_in_flight: 0,
..Default::default()
};
assert!(config.validate().is_err());
}
#[tokio::test]
async fn test_multiplex_config_validation_zero_timeout() {
let config = MultiplexConfig {
request_timeout: Duration::from_secs(0),
..Default::default()
};
assert!(config.validate().is_err());
}
#[tokio::test]
async fn test_multiplex_config_validation_short_timeout() {
let config = MultiplexConfig {
request_timeout: Duration::from_millis(50),
..Default::default()
};
assert!(config.validate().is_err());
}
#[tokio::test]
async fn test_multiplex_config_validation_zero_buffer() {
let config = MultiplexConfig {
send_buffer_size: 0,
..Default::default()
};
assert!(config.validate().is_err());
}
}