use crate::{Error, Request, Response, Result};
use bytes::{Bytes, BytesMut};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, Mutex, Semaphore};
use tracing::{debug, info, trace, warn};
type PendingResponses = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Response>>>>>;
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub max_in_flight: usize,
pub batch_linger_us: u64,
pub max_batch_size: usize,
pub read_buffer_size: usize,
pub write_buffer_size: usize,
pub request_timeout: Duration,
pub close_timeout: Duration,
#[cfg(feature = "tls")]
pub tls: Option<PipelineTlsConfig>,
pub auth: Option<PipelineAuthConfig>,
}
#[cfg(feature = "tls")]
#[derive(Debug, Clone)]
pub struct PipelineTlsConfig {
pub tls_config: rivven_core::tls::TlsConfig,
pub server_name: String,
}
#[derive(Clone)]
pub struct PipelineAuthConfig {
pub username: String,
pub password: String,
}
impl std::fmt::Debug for PipelineAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PipelineAuthConfig")
.field("username", &self.username)
.field("password", &"[REDACTED]")
.finish()
}
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
max_in_flight: 100,
batch_linger_us: 1000, max_batch_size: 64,
read_buffer_size: 64 * 1024, write_buffer_size: 64 * 1024, request_timeout: Duration::from_secs(30),
close_timeout: Duration::from_secs(5),
#[cfg(feature = "tls")]
tls: None,
auth: None,
}
}
}
impl PipelineConfig {
pub fn builder() -> PipelineConfigBuilder {
PipelineConfigBuilder::default()
}
pub fn high_throughput() -> Self {
Self {
max_in_flight: 1000,
batch_linger_us: 5000, max_batch_size: 256,
read_buffer_size: 256 * 1024,
write_buffer_size: 256 * 1024,
request_timeout: Duration::from_secs(60),
close_timeout: Duration::from_secs(10),
#[cfg(feature = "tls")]
tls: None,
auth: None,
}
}
pub fn low_latency() -> Self {
Self {
max_in_flight: 32,
batch_linger_us: 0, max_batch_size: 1,
read_buffer_size: 16 * 1024,
write_buffer_size: 16 * 1024,
request_timeout: Duration::from_secs(10),
close_timeout: Duration::from_secs(3),
#[cfg(feature = "tls")]
tls: None,
auth: None,
}
}
}
#[derive(Default)]
pub struct PipelineConfigBuilder {
config: PipelineConfig,
}
impl PipelineConfigBuilder {
pub fn max_in_flight(mut self, max: usize) -> Self {
self.config.max_in_flight = max;
self
}
pub fn batch_linger_ms(mut self, ms: u64) -> Self {
self.config.batch_linger_us = ms * 1000;
self
}
pub fn batch_linger_us(mut self, us: u64) -> Self {
self.config.batch_linger_us = us;
self
}
pub fn max_batch_size(mut self, size: usize) -> Self {
self.config.max_batch_size = size;
self
}
pub fn read_buffer_size(mut self, size: usize) -> Self {
self.config.read_buffer_size = size;
self
}
pub fn write_buffer_size(mut self, size: usize) -> Self {
self.config.write_buffer_size = size;
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.config.request_timeout = timeout;
self
}
#[cfg(feature = "tls")]
pub fn tls(
mut self,
tls_config: rivven_core::tls::TlsConfig,
server_name: impl Into<String>,
) -> Self {
self.config.tls = Some(PipelineTlsConfig {
tls_config,
server_name: server_name.into(),
});
self
}
pub fn auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.config.auth = Some(PipelineAuthConfig {
username: username.into(),
password: password.into(),
});
self
}
pub fn close_timeout(mut self, timeout: Duration) -> Self {
self.config.close_timeout = timeout;
self
}
pub fn build(self) -> PipelineConfig {
self.config
}
}
struct PipelinedRequest {
id: u64,
data: Bytes,
response_tx: oneshot::Sender<Result<Response>>,
#[allow(dead_code)] created_at: Instant,
}
pub struct PipelinedClient {
inner: Arc<PipelinedClientInner>,
}
struct PipelinedClientInner {
request_tx: mpsc::Sender<PipelinedRequest>,
in_flight_semaphore: Arc<Semaphore>,
next_request_id: AtomicU64,
config: PipelineConfig,
stats: Arc<PipelineStats>,
shutdown: tokio::sync::watch::Sender<bool>,
pending_responses: Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Response>>>>>,
}
impl Clone for PipelinedClient {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl PipelinedClient {
pub async fn connect(addr: &str, config: PipelineConfig) -> Result<Self> {
let auth_config = config.auth.clone();
let connect_timeout = config.request_timeout;
let stream = tokio::time::timeout(connect_timeout, TcpStream::connect(addr))
.await
.map_err(|_| Error::Timeout)?
.map_err(|e| Error::ConnectionError(e.to_string()))?;
stream
.set_nodelay(true)
.map_err(|e| Error::ConnectionError(format!("Failed to set TCP_NODELAY: {}", e)))?;
#[cfg(feature = "tls")]
if let Some(tls_cfg) = &config.tls {
let connector = rivven_core::tls::TlsConnector::new(&tls_cfg.tls_config)
.map_err(|e| Error::ConnectionError(format!("TLS config error: {e}")))?;
let tls_stream = connector
.connect(stream, &tls_cfg.server_name)
.await
.map_err(|e| Error::ConnectionError(format!("TLS handshake error: {e}")))?;
let (read_half, write_half) = tokio::io::split(tls_stream);
let client = Self::setup_pipeline(addr, config, read_half, write_half).await?;
Self::pipeline_handshake(&client).await?;
if let Some(auth) = &auth_config {
Self::pipeline_authenticate(&client, &auth.username, &auth.password).await?;
}
return Ok(client);
}
let (read_half, write_half) = stream.into_split();
let client = Self::setup_pipeline(addr, config, read_half, write_half).await?;
Self::pipeline_handshake(&client).await?;
if let Some(auth) = &auth_config {
Self::pipeline_authenticate(&client, &auth.username, &auth.password).await?;
}
Ok(client)
}
async fn setup_pipeline<R, W>(
_addr: &str,
config: PipelineConfig,
read_half: R,
write_half: W,
) -> Result<Self>
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
W: tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let (request_tx, request_rx) = mpsc::channel(config.max_in_flight * 2);
let in_flight_semaphore = Arc::new(Semaphore::new(config.max_in_flight));
let pending_responses = Arc::new(Mutex::new(HashMap::new()));
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let stats = Arc::new(PipelineStats::new());
let writer_config = config.clone();
let pending_for_writer = Arc::clone(&pending_responses);
let writer_shutdown = shutdown_rx.clone();
let writer_stats = Arc::clone(&stats);
tokio::spawn(async move {
writer_task(
write_half,
request_rx,
pending_for_writer,
writer_config,
writer_shutdown,
writer_stats,
)
.await;
});
let reader_config = config.clone();
let pending_for_reader = Arc::clone(&pending_responses);
let reader_shutdown = shutdown_rx;
tokio::spawn(async move {
reader_task(
read_half,
pending_for_reader,
reader_config,
reader_shutdown,
)
.await;
});
Ok(Self {
inner: Arc::new(PipelinedClientInner {
request_tx,
in_flight_semaphore,
next_request_id: AtomicU64::new(1),
config,
stats,
shutdown: shutdown_tx,
pending_responses,
}),
})
}
pub async fn close(&self) {
let drain_deadline = tokio::time::Instant::now() + self.inner.config.close_timeout;
loop {
let pending = {
let map = self.inner.pending_responses.lock().await;
map.len()
};
if pending == 0 {
break;
}
if tokio::time::Instant::now() >= drain_deadline {
tracing::warn!(
pending,
"Pipeline close() timed out waiting for in-flight responses"
);
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
let _ = self.inner.shutdown.send(true);
}
async fn pipeline_handshake(client: &PipelinedClient) -> Result<()> {
let response = client
.send_request(Request::Handshake {
protocol_version: rivven_protocol::PROTOCOL_VERSION,
client_id: format!("pipeline-{}", std::process::id()),
})
.await?;
match response {
Response::HandshakeResult {
compatible,
server_version,
message: _,
} => {
if compatible {
info!(
"Pipeline handshake OK (client v{}, server v{})",
rivven_protocol::PROTOCOL_VERSION,
server_version
);
Ok(())
} else {
Err(Error::ProtocolError(
rivven_protocol::ProtocolError::VersionMismatch {
expected: rivven_protocol::PROTOCOL_VERSION,
actual: server_version,
},
))
}
}
Response::Error { message } => {
warn!(
"Server returned error on pipeline handshake: {}, proceeding anyway",
message
);
Ok(())
}
_ => {
warn!("Server did not return HandshakeResult, proceeding without version check");
Ok(())
}
}
}
async fn pipeline_authenticate(
client: &PipelinedClient,
username: &str,
password: &str,
) -> Result<()> {
use crate::client::{
base64_decode, base64_encode, escape_username, generate_nonce, parse_server_first,
pbkdf2_sha256, sha256, xor_bytes,
};
use rivven_core::PasswordHash;
let client_nonce = generate_nonce();
let client_first_bare = format!("n={},r={}", escape_username(username), client_nonce);
let client_first = format!("n,,{}", client_first_bare);
let response = client
.send_request(Request::ScramClientFirst {
message: Bytes::from(client_first.clone()),
})
.await?;
let server_first = match response {
Response::ScramServerFirst { message } => String::from_utf8(message.to_vec())
.map_err(|_| Error::AuthenticationFailed("Invalid server-first encoding".into()))?,
Response::Error { message } => return Err(Error::AuthenticationFailed(message)),
_ => return Err(Error::InvalidResponse),
};
let (combined_nonce, salt_b64, iterations) = parse_server_first(&server_first)?;
if !combined_nonce.starts_with(&client_nonce) {
return Err(Error::AuthenticationFailed("Server nonce mismatch".into()));
}
let salt = base64_decode(&salt_b64)
.map_err(|_| Error::AuthenticationFailed("Invalid salt encoding".into()))?;
let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
let stored_key = sha256(&client_key);
let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
let auth_message = format!(
"{},{},{}",
client_first_bare, server_first, client_final_without_proof
);
let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
let client_proof = xor_bytes(&client_key, &client_signature);
let client_final = format!(
"{},p={}",
client_final_without_proof,
base64_encode(&client_proof)
);
let response = client
.send_request(Request::ScramClientFinal {
message: Bytes::from(client_final),
})
.await?;
match response {
Response::ScramServerFinal { message, .. } => {
let verifier = String::from_utf8(message.to_vec())
.map_err(|_| Error::AuthenticationFailed("Invalid server-final".into()))?;
let server_key = PasswordHash::hmac_sha256(&salted_password, b"Server Key");
let expected_sig = PasswordHash::hmac_sha256(&server_key, auth_message.as_bytes());
let expected_verifier = format!("v={}", base64_encode(&expected_sig));
if verifier != expected_verifier {
return Err(Error::AuthenticationFailed(
"Server signature mismatch".into(),
));
}
tracing::info!("Pipeline SCRAM auth successful for '{}'", username);
Ok(())
}
Response::Error { message } => Err(Error::AuthenticationFailed(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn send_request(&self, request: Request) -> Result<Response> {
let _permit = self
.inner
.in_flight_semaphore
.acquire()
.await
.map_err(|_| Error::ConnectionError("Pipeline closed".into()))?;
let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
let (response_tx, response_rx) = oneshot::channel();
let request_bytes = request.to_wire(rivven_protocol::WireFormat::Postcard, 0u32)?;
if request_bytes.len() > rivven_protocol::MAX_MESSAGE_SIZE {
return Err(Error::RequestTooLarge(
request_bytes.len(),
rivven_protocol::MAX_MESSAGE_SIZE,
));
}
let mut data = BytesMut::with_capacity(8 + request_bytes.len());
data.extend_from_slice(&request_id.to_be_bytes());
data.extend_from_slice(&request_bytes);
let pipelined = PipelinedRequest {
id: request_id,
data: data.freeze(),
response_tx,
created_at: Instant::now(),
};
self.inner
.request_tx
.send(pipelined)
.await
.map_err(|_| Error::ConnectionError("Writer task closed".into()))?;
self.inner
.stats
.requests_sent
.fetch_add(1, Ordering::Relaxed);
let timeout_duration = self.inner.config.request_timeout;
match tokio::time::timeout(timeout_duration, response_rx).await {
Ok(Ok(result)) => {
self.inner
.stats
.responses_received
.fetch_add(1, Ordering::Relaxed);
result
}
Ok(Err(_)) => Err(Error::ConnectionError("Response channel dropped".into())),
Err(_) => {
{
let mut pending = self.inner.pending_responses.lock().await;
pending.remove(&request_id);
}
self.inner.stats.timeouts.fetch_add(1, Ordering::Relaxed);
Err(Error::Timeout)
}
}
}
pub async fn publish(&self, topic: impl Into<String>, value: impl Into<Bytes>) -> Result<u64> {
let request = Request::Publish {
topic: topic.into(),
partition: None,
key: None,
value: value.into(),
leader_epoch: None,
};
match self.send_request(request).await? {
Response::Published { offset, .. } => Ok(offset),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn publish_with_key(
&self,
topic: impl Into<String>,
key: impl Into<Bytes>,
value: impl Into<Bytes>,
) -> Result<u64> {
let request = Request::Publish {
topic: topic.into(),
partition: None,
key: Some(key.into()),
value: value.into(),
leader_epoch: None,
};
match self.send_request(request).await? {
Response::Published { offset, .. } => Ok(offset),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub fn stats(&self) -> PipelineStatsSnapshot {
PipelineStatsSnapshot {
requests_sent: self.inner.stats.requests_sent.load(Ordering::Relaxed),
responses_received: self.inner.stats.responses_received.load(Ordering::Relaxed),
batches_flushed: self.inner.stats.batches_flushed.load(Ordering::Relaxed),
timeouts: self.inner.stats.timeouts.load(Ordering::Relaxed),
}
}
}
async fn writer_task<W: tokio::io::AsyncWrite + Unpin>(
write_half: W,
mut request_rx: mpsc::Receiver<PipelinedRequest>,
pending: PendingResponses,
config: PipelineConfig,
mut shutdown: tokio::sync::watch::Receiver<bool>,
stats: Arc<PipelineStats>,
) {
let mut writer = BufWriter::with_capacity(config.write_buffer_size, write_half);
let mut batch: Vec<PipelinedRequest> = Vec::with_capacity(config.max_batch_size);
let mut batch_started: Option<Instant> = None;
loop {
if *shutdown.borrow() {
break;
}
let request = if batch.is_empty() {
tokio::select! {
req = request_rx.recv() => {
match req {
Some(req) => Some(req),
None => break, }
}
_ = shutdown.changed() => {
if *shutdown.borrow() {
break;
}
continue;
}
}
} else if config.batch_linger_us == 0 {
None
} else {
let elapsed = batch_started
.map(|t| t.elapsed().as_micros() as u64)
.unwrap_or(0);
let remaining = config.batch_linger_us.saturating_sub(elapsed);
if remaining == 0 {
None } else {
match tokio::time::timeout(Duration::from_micros(remaining), request_rx.recv())
.await
{
Ok(Some(req)) => Some(req),
Ok(None) => break, Err(_) => None, }
}
};
if let Some(req) = request {
if batch.is_empty() {
batch_started = Some(Instant::now());
}
batch.push(req);
}
let should_flush = batch.len() >= config.max_batch_size
|| (!batch.is_empty()
&& batch_started
.is_some_and(|t| t.elapsed().as_micros() as u64 >= config.batch_linger_us));
if should_flush && !batch.is_empty() {
if let Err(e) = flush_batch(&mut writer, &mut batch, &pending, &stats).await {
warn!("Failed to flush batch: {}", e);
for req in batch.drain(..) {
let _ = req
.response_tx
.send(Err(Error::ConnectionError(e.to_string())));
}
}
batch_started = None;
}
}
if !batch.is_empty() {
if let Err(e) = flush_batch(&mut writer, &mut batch, &pending, &stats).await {
tracing::warn!(error = %e, "Failed to flush remaining batch on shutdown");
}
}
}
async fn flush_batch<W: tokio::io::AsyncWrite + Unpin>(
writer: &mut BufWriter<W>,
batch: &mut Vec<PipelinedRequest>,
pending: &PendingResponses,
stats: &PipelineStats,
) -> std::io::Result<()> {
let batch_count = batch.len();
let mut request_data: Vec<(u64, bytes::Bytes)> = Vec::with_capacity(batch_count);
{
let mut pending_guard = pending.lock().await;
for req in batch.drain(..) {
pending_guard.insert(req.id, req.response_tx);
request_data.push((req.id, req.data));
}
}
let write_result: std::io::Result<()> = async {
for (_id, data) in &request_data {
let len: u32 = data.len().try_into().unwrap_or(u32::MAX);
writer.write_all(&len.to_be_bytes()).await?;
writer.write_all(data).await?;
}
writer.flush().await?;
Ok(())
}
.await;
if let Err(ref e) = write_result {
let mut pending_guard = pending.lock().await;
for (id, _) in &request_data {
if let Some(tx) = pending_guard.remove(id) {
let _ = tx.send(Err(Error::ConnectionError(e.to_string())));
}
}
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
e.to_string(),
));
}
trace!("Flushed batch of {} requests", batch_count);
stats.batches_flushed.fetch_add(1, Ordering::Relaxed);
Ok(())
}
async fn reader_task<R: tokio::io::AsyncRead + Unpin>(
read_half: R,
pending: PendingResponses,
config: PipelineConfig,
mut shutdown: tokio::sync::watch::Receiver<bool>,
) {
let mut reader = BufReader::with_capacity(config.read_buffer_size, read_half);
let mut len_buf = [0u8; 4];
let mut id_buf = [0u8; 8];
loop {
if *shutdown.borrow() {
break;
}
let read_result = tokio::select! {
result = reader.read_exact(&mut len_buf) => result,
_ = shutdown.changed() => {
if *shutdown.borrow() {
break;
}
continue;
}
};
if read_result.is_err() {
break; }
let msg_len = u32::from_be_bytes(len_buf) as usize;
const MAX_PIPELINE_RESPONSE_SIZE: usize = 100 * 1024 * 1024; if msg_len > MAX_PIPELINE_RESPONSE_SIZE {
warn!(
"Pipeline response too large: {} bytes (max {})",
msg_len, MAX_PIPELINE_RESPONSE_SIZE
);
break;
}
if msg_len < 8 {
warn!(
"Invalid response length: {} — stream desynchronized, closing connection",
msg_len
);
break;
}
if reader.read_exact(&mut id_buf).await.is_err() {
break;
}
let request_id = u64::from_be_bytes(id_buf);
let body_len = msg_len - 8;
let mut response_buf = vec![0u8; body_len];
if reader.read_exact(&mut response_buf).await.is_err() {
break;
}
let result = Response::from_wire(&response_buf)
.map(|(resp, _format, _correlation_id)| resp)
.map_err(Error::ProtocolError);
let sender = {
let mut pending_guard = pending.lock().await;
pending_guard.remove(&request_id)
};
if let Some(tx) = sender {
let _ = tx.send(result);
} else {
debug!("Received response for unknown request ID: {}", request_id);
}
}
let mut pending_guard = pending.lock().await;
for (_, tx) in pending_guard.drain() {
let _ = tx.send(Err(Error::ConnectionError("Connection closed".into())));
}
}
struct PipelineStats {
requests_sent: AtomicU64,
responses_received: AtomicU64,
batches_flushed: AtomicU64,
timeouts: AtomicU64,
}
impl PipelineStats {
fn new() -> Self {
Self {
requests_sent: AtomicU64::new(0),
responses_received: AtomicU64::new(0),
batches_flushed: AtomicU64::new(0),
timeouts: AtomicU64::new(0),
}
}
}
#[derive(Debug, Clone)]
pub struct PipelineStatsSnapshot {
pub requests_sent: u64,
pub responses_received: u64,
pub batches_flushed: u64,
pub timeouts: u64,
}
impl PipelineStatsSnapshot {
pub fn in_flight(&self) -> u64 {
self.requests_sent.saturating_sub(self.responses_received)
}
pub fn success_rate(&self) -> f64 {
if self.requests_sent == 0 {
1.0
} else {
self.responses_received as f64 / self.requests_sent as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_config_builder() {
let config = PipelineConfig::builder()
.max_in_flight(200)
.batch_linger_ms(10)
.max_batch_size(128)
.request_timeout(Duration::from_secs(60))
.build();
assert_eq!(config.max_in_flight, 200);
assert_eq!(config.batch_linger_us, 10_000);
assert_eq!(config.max_batch_size, 128);
assert_eq!(config.request_timeout, Duration::from_secs(60));
}
#[test]
fn test_high_throughput_config() {
let config = PipelineConfig::high_throughput();
assert_eq!(config.max_in_flight, 1000);
assert_eq!(config.batch_linger_us, 5000);
assert_eq!(config.max_batch_size, 256);
}
#[test]
fn test_low_latency_config() {
let config = PipelineConfig::low_latency();
assert_eq!(config.max_in_flight, 32);
assert_eq!(config.batch_linger_us, 0);
assert_eq!(config.max_batch_size, 1);
}
#[test]
fn test_stats_snapshot() {
let stats = PipelineStatsSnapshot {
requests_sent: 100,
responses_received: 95,
batches_flushed: 10,
timeouts: 5,
};
assert_eq!(stats.in_flight(), 5);
assert!((stats.success_rate() - 0.95).abs() < 0.001);
}
#[test]
fn test_stats_snapshot_empty() {
let stats = PipelineStatsSnapshot {
requests_sent: 0,
responses_received: 0,
batches_flushed: 0,
timeouts: 0,
};
assert_eq!(stats.in_flight(), 0);
assert!((stats.success_rate() - 1.0).abs() < 0.001);
}
}