use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use bytes::{Bytes, BytesMut};
use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
use tokio::time::interval;
use crate::client::LanceClient;
use crate::connection::ReconnectingClient;
use crate::error::{ClientError, Result};
#[derive(Debug, Clone)]
pub struct ProducerConfig {
pub batch_size: usize,
pub linger_ms: u64,
pub max_in_flight: usize,
pub buffer_memory: usize,
pub connect_timeout: Duration,
pub request_timeout: Duration,
pub compression: bool,
}
impl Default for ProducerConfig {
fn default() -> Self {
Self {
batch_size: 16 * 1024, linger_ms: 5, max_in_flight: 5, buffer_memory: 32 * 1024 * 1024, connect_timeout: Duration::from_secs(30),
request_timeout: Duration::from_secs(30),
compression: false,
}
}
}
impl ProducerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn with_linger_ms(mut self, ms: u64) -> Self {
self.linger_ms = ms;
self
}
pub fn with_max_in_flight(mut self, n: usize) -> Self {
self.max_in_flight = n;
self
}
pub fn with_buffer_memory(mut self, bytes: usize) -> Self {
self.buffer_memory = bytes;
self
}
pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub fn with_compression(mut self, enabled: bool) -> Self {
self.compression = enabled;
self
}
}
#[derive(Debug, Clone)]
pub struct SendAck {
pub batch_id: u64,
pub topic_id: u32,
pub timestamp: Instant,
pub offset_in_batch: u32,
}
pub type SendResult = Result<SendAck>;
struct RecordBatch {
topic_id: u32,
data: BytesMut,
record_count: usize,
ack_txs: Vec<oneshot::Sender<SendResult>>,
record_offsets: Vec<usize>,
created_at: Instant,
}
impl RecordBatch {
fn new(topic_id: u32) -> Self {
Self {
topic_id,
data: BytesMut::with_capacity(16 * 1024),
record_count: 0,
ack_txs: Vec::new(),
record_offsets: Vec::new(),
created_at: Instant::now(),
}
}
fn add(&mut self, data: Bytes, ack_tx: oneshot::Sender<SendResult>) {
self.record_offsets.push(self.data.len());
self.data.extend_from_slice(&data);
self.record_count += 1;
self.ack_txs.push(ack_tx);
}
fn size(&self) -> usize {
self.data.len()
}
fn is_empty(&self) -> bool {
self.record_count == 0
}
}
#[derive(Debug, Default)]
pub struct ProducerMetrics {
pub records_sent: AtomicU64,
pub bytes_sent: AtomicU64,
pub batches_sent: AtomicU64,
pub errors: AtomicU64,
pub buffer_size: AtomicU64,
pub backpressure_drops: AtomicU64,
pub backpressure_waits: AtomicU64,
pub backpressure_wait_ms: AtomicU64,
}
impl ProducerMetrics {
pub fn snapshot(&self) -> MetricsSnapshot {
MetricsSnapshot {
records_sent: self.records_sent.load(Ordering::Relaxed),
bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
batches_sent: self.batches_sent.load(Ordering::Relaxed),
errors: self.errors.load(Ordering::Relaxed),
buffer_size: self.buffer_size.load(Ordering::Relaxed),
backpressure_drops: self.backpressure_drops.load(Ordering::Relaxed),
backpressure_waits: self.backpressure_waits.load(Ordering::Relaxed),
backpressure_wait_ms: self.backpressure_wait_ms.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct MetricsSnapshot {
pub records_sent: u64,
pub bytes_sent: u64,
pub batches_sent: u64,
pub errors: u64,
pub buffer_size: u64,
pub backpressure_drops: u64,
pub backpressure_waits: u64,
pub backpressure_wait_ms: u64,
}
pub struct Producer {
client: Arc<Mutex<ReconnectingClient>>,
config: ProducerConfig,
batches: Arc<RwLock<HashMap<u32, RecordBatch>>>,
metrics: Arc<ProducerMetrics>,
running: Arc<AtomicBool>,
connection_healthy: Arc<AtomicBool>,
flush_tx: mpsc::Sender<oneshot::Sender<Result<()>>>,
}
impl Producer {
pub async fn connect(addr: &str, config: ProducerConfig) -> Result<Self> {
let rc = ReconnectingClient::connect(addr)
.await?
.with_unlimited_retries()
.with_base_delay(Duration::from_millis(500))
.with_max_delay(Duration::from_secs(30));
Self::from_reconnecting_client(rc, config).await
}
pub async fn from_client(
client: LanceClient,
addr: &str,
config: ProducerConfig,
) -> Result<Self> {
let rc = ReconnectingClient::from_existing(client, addr);
Self::from_reconnecting_client(rc, config).await
}
async fn from_reconnecting_client(
client: ReconnectingClient,
config: ProducerConfig,
) -> Result<Self> {
let client = Arc::new(Mutex::new(client));
let batches = Arc::new(RwLock::new(HashMap::new()));
let metrics = Arc::new(ProducerMetrics::default());
let running = Arc::new(AtomicBool::new(true));
let connection_healthy = Arc::new(AtomicBool::new(true));
let (flush_tx, mut flush_rx) = mpsc::channel::<oneshot::Sender<Result<()>>>(16);
let linger_ms = config.linger_ms;
let client_clone = client.clone();
let batches_clone = batches.clone();
let metrics_clone = metrics.clone();
let running_clone = running.clone();
let healthy_clone = connection_healthy.clone();
tokio::spawn(async move {
let mut linger_interval = interval(Duration::from_millis(linger_ms.max(1)));
loop {
tokio::select! {
_ = linger_interval.tick() => {
if !running_clone.load(Ordering::Relaxed) {
break;
}
match Self::flush_expired_batches(
&client_clone,
&batches_clone,
&metrics_clone,
linger_ms,
).await {
Ok(()) => {
healthy_clone.store(true, Ordering::SeqCst);
}
Err(_) => {
healthy_clone.store(false, Ordering::SeqCst);
}
}
}
Some(ack_tx) = flush_rx.recv() => {
let result = Self::flush_all_batches(
&client_clone,
&batches_clone,
&metrics_clone,
).await;
if result.is_ok() {
healthy_clone.store(true, Ordering::SeqCst);
}
let _ = ack_tx.send(result);
}
}
}
});
Ok(Self {
client,
config,
batches,
metrics,
running,
connection_healthy,
flush_tx,
})
}
pub async fn produce<T>(&self, topic_id: u32, records: &[T]) -> Result<Vec<SendAck>>
where
T: AsRef<[u8]>,
{
let mut acks = Vec::with_capacity(records.len());
for record in records {
acks.push(self.send(topic_id, record.as_ref()).await?);
}
Ok(acks)
}
#[inline]
pub async fn produce_single(&self, topic_id: u32, data: &[u8]) -> Result<SendAck> {
self.send(topic_id, data).await
}
pub async fn send(&self, topic_id: u32, data: &[u8]) -> Result<SendAck> {
let (ack_tx, ack_rx) = oneshot::channel();
let current_buffer = self.metrics.buffer_size.load(Ordering::Relaxed);
if current_buffer + data.len() as u64 > self.config.buffer_memory as u64 {
self.metrics
.backpressure_drops
.fetch_add(1, Ordering::Relaxed);
return Err(ClientError::ServerBackpressure);
}
let should_flush = {
let mut batches = self.batches.write().await;
let batch = batches
.entry(topic_id)
.or_insert_with(|| RecordBatch::new(topic_id));
batch.add(Bytes::copy_from_slice(data), ack_tx);
self.metrics
.buffer_size
.fetch_add(data.len() as u64, Ordering::Relaxed);
batch.size() >= self.config.batch_size
};
if should_flush {
self.flush_topic(topic_id).await?;
}
ack_rx.await.map_err(|_| ClientError::ConnectionClosed)?
}
pub async fn try_send(&self, topic_id: u32, data: &[u8]) -> Result<()> {
let (ack_tx, _ack_rx) = oneshot::channel();
let current_buffer = self.metrics.buffer_size.load(Ordering::Relaxed);
if current_buffer + data.len() as u64 > self.config.buffer_memory as u64 {
self.metrics
.backpressure_drops
.fetch_add(1, Ordering::Relaxed);
return Err(ClientError::WouldBlock);
}
let should_flush = {
let mut batches = self.batches.write().await;
let batch = batches
.entry(topic_id)
.or_insert_with(|| RecordBatch::new(topic_id));
batch.add(Bytes::copy_from_slice(data), ack_tx);
self.metrics
.buffer_size
.fetch_add(data.len() as u64, Ordering::Relaxed);
batch.size() >= self.config.batch_size
};
if should_flush {
self.flush_topic(topic_id).await?;
}
Ok(())
}
pub async fn send_async(&self, topic_id: u32, data: &[u8]) -> Result<()> {
let (ack_tx, _ack_rx) = oneshot::channel();
let current_buffer = self.metrics.buffer_size.load(Ordering::Relaxed);
if current_buffer + data.len() as u64 > self.config.buffer_memory as u64 {
return Err(ClientError::ServerBackpressure);
}
let should_flush = {
let mut batches = self.batches.write().await;
let batch = batches
.entry(topic_id)
.or_insert_with(|| RecordBatch::new(topic_id));
batch.add(Bytes::copy_from_slice(data), ack_tx);
self.metrics
.buffer_size
.fetch_add(data.len() as u64, Ordering::Relaxed);
batch.size() >= self.config.batch_size
};
if should_flush {
self.flush_topic(topic_id).await?;
}
Ok(())
}
pub async fn send_callback<F>(&self, topic_id: u32, data: &[u8], callback: F) -> Result<()>
where
F: FnOnce(SendResult) + Send + 'static,
{
let (ack_tx, ack_rx) = oneshot::channel();
let current_buffer = self.metrics.buffer_size.load(Ordering::Relaxed);
if current_buffer + data.len() as u64 > self.config.buffer_memory as u64 {
return Err(ClientError::ServerBackpressure);
}
let should_flush = {
let mut batches = self.batches.write().await;
let batch = batches
.entry(topic_id)
.or_insert_with(|| RecordBatch::new(topic_id));
batch.add(Bytes::copy_from_slice(data), ack_tx);
self.metrics
.buffer_size
.fetch_add(data.len() as u64, Ordering::Relaxed);
batch.size() >= self.config.batch_size
};
tokio::spawn(async move {
let result = ack_rx.await.unwrap_or(Err(ClientError::ConnectionClosed));
callback(result);
});
if should_flush {
self.flush_topic(topic_id).await?;
}
Ok(())
}
pub async fn flush(&self) -> Result<()> {
let (ack_tx, ack_rx) = oneshot::channel();
self.flush_tx
.send(ack_tx)
.await
.map_err(|_| ClientError::ConnectionClosed)?;
ack_rx.await.map_err(|_| ClientError::ConnectionClosed)?
}
async fn flush_topic(&self, topic_id: u32) -> Result<()> {
let batch = {
let mut batches = self.batches.write().await;
batches.remove(&topic_id)
};
if let Some(batch) = batch {
if !batch.is_empty() {
self.send_batch(batch).await?;
}
}
Ok(())
}
async fn send_batch(&self, batch: RecordBatch) -> Result<()> {
let topic_id = batch.topic_id;
let record_count = batch.record_count;
let byte_count = batch.data.len();
let ack_txs = batch.ack_txs;
let data = batch.data.freeze();
const MAX_RETRIES: u32 = 30;
let mut attempt = 0u32;
let mut backoff = Duration::from_millis(500);
const MAX_BACKOFF: Duration = Duration::from_secs(30);
let result = loop {
let send_result = {
let mut rc = self.client.lock().await;
match rc.client().await {
Ok(c) => {
c.send_ingest_to_topic_sync(
topic_id,
data.clone(),
record_count as u32,
None,
)
.await
},
Err(e) => Err(e),
}
};
match &send_result {
Ok(_) => break send_result,
Err(e) if e.is_retryable() && attempt < MAX_RETRIES => {
attempt += 1;
self.metrics.errors.fetch_add(1, Ordering::Relaxed);
{
let mut rc = self.client.lock().await;
rc.mark_failed();
}
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(MAX_BACKOFF);
},
_ => break send_result,
}
};
self.metrics
.buffer_size
.fetch_sub(byte_count as u64, Ordering::Relaxed);
match result {
Ok(batch_id) => {
self.metrics
.records_sent
.fetch_add(record_count as u64, Ordering::Relaxed);
self.metrics
.bytes_sent
.fetch_add(byte_count as u64, Ordering::Relaxed);
self.metrics.batches_sent.fetch_add(1, Ordering::Relaxed);
self.connection_healthy.store(true, Ordering::SeqCst);
let timestamp = Instant::now();
for (offset_in_batch, tx) in ack_txs.into_iter().enumerate() {
let ack = SendAck {
batch_id,
topic_id,
timestamp,
offset_in_batch: offset_in_batch as u32,
};
let _ = tx.send(Ok(ack));
if offset_in_batch % 8 == 7 {
tokio::task::yield_now().await;
}
}
Ok(())
},
Err(e) => {
self.metrics.errors.fetch_add(1, Ordering::Relaxed);
for tx in ack_txs {
let _ = tx.send(Err(ClientError::ServerError(e.to_string())));
}
Err(e)
},
}
}
async fn flush_expired_batches(
client: &Arc<Mutex<ReconnectingClient>>,
batches: &Arc<RwLock<HashMap<u32, RecordBatch>>>,
metrics: &Arc<ProducerMetrics>,
linger_ms: u64,
) -> Result<()> {
let linger_duration = Duration::from_millis(linger_ms);
let now = Instant::now();
let expired_topics: Vec<u32> = {
let batches_read = batches.read().await;
batches_read
.iter()
.filter(|(_, batch)| {
!batch.is_empty() && now.duration_since(batch.created_at) >= linger_duration
})
.map(|(topic_id, _)| *topic_id)
.collect()
};
for topic_id in expired_topics {
let batch = {
let mut batches_write = batches.write().await;
batches_write.remove(&topic_id)
};
if let Some(batch) = batch {
if !batch.is_empty() {
Self::send_batch_static(client, metrics, batch).await?;
}
}
}
Ok(())
}
async fn flush_all_batches(
client: &Arc<Mutex<ReconnectingClient>>,
batches: &Arc<RwLock<HashMap<u32, RecordBatch>>>,
metrics: &Arc<ProducerMetrics>,
) -> Result<()> {
let all_batches: Vec<RecordBatch> = {
let mut batches_write = batches.write().await;
batches_write.drain().map(|(_, batch)| batch).collect()
};
for batch in all_batches {
if !batch.is_empty() {
Self::send_batch_static(client, metrics, batch).await?;
}
}
Ok(())
}
async fn send_batch_static(
client: &Arc<Mutex<ReconnectingClient>>,
metrics: &Arc<ProducerMetrics>,
batch: RecordBatch,
) -> Result<()> {
let topic_id = batch.topic_id;
let record_count = batch.record_count;
let byte_count = batch.data.len();
let ack_txs = batch.ack_txs;
let data = batch.data.freeze();
const MAX_RETRIES: u32 = 30;
let mut attempt = 0u32;
let mut backoff = Duration::from_millis(500);
const MAX_BACKOFF: Duration = Duration::from_secs(30);
let result = loop {
let send_result = {
let mut rc = client.lock().await;
match rc.client().await {
Ok(c) => {
c.send_ingest_to_topic_sync(
topic_id,
data.clone(),
record_count as u32,
None,
)
.await
},
Err(e) => Err(e),
}
};
match &send_result {
Ok(_) => break send_result,
Err(e) if e.is_retryable() && attempt < MAX_RETRIES => {
attempt += 1;
metrics.errors.fetch_add(1, Ordering::Relaxed);
{
let mut rc = client.lock().await;
rc.mark_failed();
}
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(MAX_BACKOFF);
},
_ => break send_result,
}
};
metrics
.buffer_size
.fetch_sub(byte_count as u64, Ordering::Relaxed);
match result {
Ok(batch_id) => {
metrics
.records_sent
.fetch_add(record_count as u64, Ordering::Relaxed);
metrics
.bytes_sent
.fetch_add(byte_count as u64, Ordering::Relaxed);
metrics.batches_sent.fetch_add(1, Ordering::Relaxed);
let timestamp = Instant::now();
for (offset_in_batch, tx) in ack_txs.into_iter().enumerate() {
let ack = SendAck {
batch_id,
topic_id,
timestamp,
offset_in_batch: offset_in_batch as u32,
};
let _ = tx.send(Ok(ack));
if offset_in_batch % 8 == 7 {
tokio::task::yield_now().await;
}
}
Ok(())
},
Err(e) => {
metrics.errors.fetch_add(1, Ordering::Relaxed);
for tx in ack_txs {
let _ = tx.send(Err(ClientError::ServerError(e.to_string())));
}
Err(e)
},
}
}
pub fn is_healthy(&self) -> bool {
self.connection_healthy.load(Ordering::SeqCst)
}
pub fn metrics(&self) -> MetricsSnapshot {
self.metrics.snapshot()
}
pub async fn close(self) -> Result<()> {
self.flush().await?;
self.running.store(false, Ordering::Relaxed);
Ok(())
}
}
impl std::fmt::Debug for Producer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Producer")
.field("config", &self.config)
.field("metrics", &self.metrics.snapshot())
.field("running", &self.running.load(Ordering::Relaxed))
.finish()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_producer_config_defaults() {
let config = ProducerConfig::new();
assert_eq!(config.batch_size, 16 * 1024);
assert_eq!(config.linger_ms, 5);
assert_eq!(config.max_in_flight, 5);
assert!(!config.compression);
}
#[test]
fn test_producer_config_builder() {
let config = ProducerConfig::new()
.with_batch_size(32 * 1024)
.with_linger_ms(10)
.with_max_in_flight(10)
.with_compression(true);
assert_eq!(config.batch_size, 32 * 1024);
assert_eq!(config.linger_ms, 10);
assert_eq!(config.max_in_flight, 10);
assert!(config.compression);
}
#[test]
fn test_record_batch() {
let mut batch = RecordBatch::new(1);
assert!(batch.is_empty());
assert_eq!(batch.size(), 0);
assert!(batch.record_offsets.is_empty());
let (tx, _rx) = oneshot::channel();
batch.add(Bytes::from_static(b"hello"), tx);
assert!(!batch.is_empty());
assert_eq!(batch.size(), 5);
assert_eq!(batch.record_count, 1);
assert_eq!(batch.record_offsets, vec![0]);
let (tx2, _rx2) = oneshot::channel();
batch.add(Bytes::from_static(b"world"), tx2);
assert_eq!(batch.record_count, 2);
assert_eq!(batch.record_offsets, vec![0, 5]); assert_eq!(batch.size(), 10);
}
#[test]
fn test_metrics_snapshot() {
let metrics = ProducerMetrics::default();
metrics.records_sent.fetch_add(100, Ordering::Relaxed);
metrics.bytes_sent.fetch_add(1000, Ordering::Relaxed);
let snapshot = metrics.snapshot();
assert_eq!(snapshot.records_sent, 100);
assert_eq!(snapshot.bytes_sent, 1000);
}
#[test]
fn test_send_callback_closure_traits() {
fn assert_callback_traits<F>(_f: F)
where
F: FnOnce(SendResult) + Send + 'static,
{
}
assert_callback_traits(|_result| {});
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
let counter_clone = counter.clone();
assert_callback_traits(move |result| {
if result.is_ok() {
counter_clone.fetch_add(1, Ordering::Relaxed);
}
});
}
#[tokio::test]
async fn test_send_ack_ordering() {
let mut batch = RecordBatch::new(1);
let mut receivers = Vec::new();
for i in 0..5 {
let (tx, rx) = oneshot::channel();
batch.add(Bytes::from(format!("record-{}", i)), tx);
receivers.push(rx);
}
let batch_id = 42u64;
let timestamp = Instant::now();
let topic_id = batch.topic_id;
for (offset_in_batch, tx) in batch.ack_txs.into_iter().enumerate() {
let ack = SendAck {
batch_id,
topic_id,
timestamp,
offset_in_batch: offset_in_batch as u32,
};
let _ = tx.send(Ok(ack));
}
for (i, rx) in receivers.into_iter().enumerate() {
let result = rx.await.unwrap();
assert!(result.is_ok(), "Record {} should be ACKed", i);
let ack = result.unwrap();
assert_eq!(ack.batch_id, batch_id);
assert_eq!(
ack.offset_in_batch, i as u32,
"Record {} should have offset_in_batch={}",
i, i
);
assert_eq!(ack.topic_id, 1);
}
}
#[tokio::test]
async fn test_send_ack_failure_ordering() {
let mut batch = RecordBatch::new(2);
let mut receivers = Vec::new();
for i in 0..3 {
let (tx, rx) = oneshot::channel();
batch.add(Bytes::from(format!("record-{}", i)), tx);
receivers.push(rx);
}
let error_msg = "server timeout";
for tx in batch.ack_txs {
let _ = tx.send(Err(ClientError::ServerError(error_msg.to_string())));
}
for (i, rx) in receivers.into_iter().enumerate() {
let result = rx.await.unwrap();
assert!(result.is_err(), "Record {} should receive error", i);
let err = result.unwrap_err();
assert!(matches!(err, ClientError::ServerError(ref msg) if msg == error_msg));
}
}
#[test]
fn test_record_offsets_mapping() {
let mut batch = RecordBatch::new(1);
let records = vec![
Bytes::from(vec![0u8; 100]), Bytes::from(vec![1u8; 50]), Bytes::from(vec![2u8; 200]), Bytes::from(vec![3u8; 75]), ];
let expected_offsets = vec![0, 100, 150, 350];
let expected_total_size: usize = records.iter().map(|r| r.len()).sum();
for data in records {
let (tx, _rx) = oneshot::channel();
batch.add(data, tx);
}
assert_eq!(
batch.record_offsets, expected_offsets,
"Record offsets should match expected positions"
);
assert_eq!(
batch.size(),
expected_total_size,
"Total batch size should be sum of record sizes"
);
assert_eq!(batch.record_count, 4);
}
#[test]
fn test_backpressure_telemetry_metrics() {
let metrics = ProducerMetrics::default();
metrics.backpressure_drops.fetch_add(5, Ordering::Relaxed);
metrics.backpressure_waits.fetch_add(3, Ordering::Relaxed);
metrics
.backpressure_wait_ms
.fetch_add(150, Ordering::Relaxed);
let snapshot = metrics.snapshot();
assert_eq!(snapshot.backpressure_drops, 5);
assert_eq!(snapshot.backpressure_waits, 3);
assert_eq!(snapshot.backpressure_wait_ms, 150);
}
#[test]
fn test_buffer_memory_limit_checking() {
let metrics = Arc::new(ProducerMetrics::default());
let data_len = 1000u64;
metrics.buffer_size.fetch_add(data_len, Ordering::Relaxed);
let current = metrics.buffer_size.load(Ordering::Relaxed);
assert_eq!(current, data_len);
metrics.buffer_size.fetch_sub(data_len, Ordering::Relaxed);
assert_eq!(metrics.buffer_size.load(Ordering::Relaxed), 0);
}
}