use crate::{Error, Request, Response, Result};
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, Notify, RwLock, Semaphore};
use tracing::{debug, info, warn};
const DEFAULT_BATCH_SIZE: usize = 16384;
const DEFAULT_LINGER_MS: u64 = 0;
const DEFAULT_BUFFER_MEMORY: usize = 32 * 1024 * 1024;
const DEFAULT_METADATA_TTL: Duration = Duration::from_secs(300);
const STICKY_BATCH_THRESHOLD: usize = 8;
#[derive(Debug, Clone)]
pub struct ProducerConfig {
pub bootstrap_servers: Vec<String>,
pub batch_size: usize,
pub linger_ms: u64,
pub buffer_memory: usize,
pub max_in_flight_requests: usize,
pub enable_idempotence: bool,
pub request_timeout: Duration,
pub delivery_timeout: Duration,
pub retries: u32,
pub retry_backoff_ms: u64,
pub retry_backoff_max_ms: u64,
pub metadata_max_age: Duration,
pub compression_type: CompressionType,
pub acks: i8,
pub connection_timeout: Duration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompressionType {
#[default]
None,
Gzip,
Snappy,
Lz4,
Zstd,
}
impl Default for ProducerConfig {
fn default() -> Self {
Self {
bootstrap_servers: vec!["localhost:9092".to_string()],
batch_size: DEFAULT_BATCH_SIZE,
linger_ms: DEFAULT_LINGER_MS,
buffer_memory: DEFAULT_BUFFER_MEMORY,
max_in_flight_requests: 5,
enable_idempotence: false,
request_timeout: Duration::from_secs(30),
delivery_timeout: Duration::from_secs(120),
retries: 3,
retry_backoff_ms: 100,
retry_backoff_max_ms: 1000,
metadata_max_age: DEFAULT_METADATA_TTL,
compression_type: CompressionType::None,
acks: 1,
connection_timeout: Duration::from_secs(10),
}
}
}
impl ProducerConfig {
pub fn builder() -> ProducerConfigBuilder {
ProducerConfigBuilder::default()
}
pub fn high_throughput() -> Self {
Self {
batch_size: 65536,
linger_ms: 10,
max_in_flight_requests: 10,
compression_type: CompressionType::Lz4,
..Default::default()
}
}
pub fn low_latency() -> Self {
Self {
batch_size: 1,
linger_ms: 0,
max_in_flight_requests: 1,
compression_type: CompressionType::None,
..Default::default()
}
}
pub fn exactly_once() -> Self {
Self {
enable_idempotence: true,
acks: -1,
max_in_flight_requests: 5,
retries: i32::MAX as u32,
..Default::default()
}
}
}
#[derive(Default)]
pub struct ProducerConfigBuilder {
config: ProducerConfig,
}
impl ProducerConfigBuilder {
pub fn bootstrap_servers(mut self, servers: Vec<String>) -> Self {
self.config.bootstrap_servers = servers;
self
}
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_size = size;
self
}
pub fn linger_ms(mut self, ms: u64) -> Self {
self.config.linger_ms = ms;
self
}
pub fn buffer_memory(mut self, bytes: usize) -> Self {
self.config.buffer_memory = bytes;
self
}
pub fn max_in_flight_requests(mut self, max: usize) -> Self {
self.config.max_in_flight_requests = max;
self
}
pub fn enable_idempotence(mut self, enable: bool) -> Self {
self.config.enable_idempotence = enable;
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.config.request_timeout = timeout;
self
}
pub fn delivery_timeout(mut self, timeout: Duration) -> Self {
self.config.delivery_timeout = timeout;
self
}
pub fn retries(mut self, retries: u32) -> Self {
self.config.retries = retries;
self
}
pub fn retry_backoff_ms(mut self, ms: u64) -> Self {
self.config.retry_backoff_ms = ms;
self
}
pub fn retry_backoff_max_ms(mut self, ms: u64) -> Self {
self.config.retry_backoff_max_ms = ms;
self
}
pub fn metadata_max_age(mut self, duration: Duration) -> Self {
self.config.metadata_max_age = duration;
self
}
pub fn compression_type(mut self, compression: CompressionType) -> Self {
self.config.compression_type = compression;
self
}
pub fn acks(mut self, acks: i8) -> Self {
self.config.acks = acks;
self
}
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.config.connection_timeout = timeout;
self
}
pub fn build(self) -> ProducerConfig {
self.config
}
}
#[derive(Debug, Clone)]
struct CachedMetadata {
partition_count: u32,
cached_at: Instant,
}
struct MetadataCache {
cache: RwLock<HashMap<String, CachedMetadata>>,
ttl: Duration,
}
impl MetadataCache {
fn new(ttl: Duration) -> Self {
Self {
cache: RwLock::new(HashMap::new()),
ttl,
}
}
async fn get(&self, topic: &str) -> Option<u32> {
let cache = self.cache.read().await;
cache.get(topic).and_then(|entry| {
if entry.cached_at.elapsed() < self.ttl {
Some(entry.partition_count)
} else {
None
}
})
}
#[allow(dead_code)]
async fn put(&self, topic: String, partition_count: u32) {
let mut cache = self.cache.write().await;
cache.insert(
topic,
CachedMetadata {
partition_count,
cached_at: Instant::now(),
},
);
}
#[allow(dead_code)]
async fn invalidate(&self, topic: &str) {
let mut cache = self.cache.write().await;
cache.remove(topic);
}
async fn cleanup(&self) {
let mut cache = self.cache.write().await;
cache.retain(|_, entry| entry.cached_at.elapsed() < self.ttl);
}
}
struct StickyPartitioner {
sticky_partitions: RwLock<HashMap<String, StickyState>>,
}
struct StickyState {
partition: u32,
batch_count: usize,
total_partitions: u32,
}
impl StickyPartitioner {
fn new() -> Self {
Self {
sticky_partitions: RwLock::new(HashMap::new()),
}
}
async fn partition(&self, topic: &str, key: Option<&[u8]>, num_partitions: u32) -> u32 {
if num_partitions == 0 {
return 0;
}
if let Some(key) = key {
return murmur2_partition(key, num_partitions);
}
let mut sticky = self.sticky_partitions.write().await;
let state = sticky
.entry(topic.to_string())
.or_insert_with(|| StickyState {
partition: rand_partition(num_partitions),
batch_count: 0,
total_partitions: num_partitions,
});
if state.total_partitions != num_partitions {
state.total_partitions = num_partitions;
state.partition = rand_partition(num_partitions);
state.batch_count = 0;
}
state.batch_count += 1;
if state.batch_count >= STICKY_BATCH_THRESHOLD {
state.partition = (state.partition + 1) % num_partitions;
state.batch_count = 0;
}
state.partition
}
#[allow(dead_code)]
async fn switch_partition(&self, topic: &str) {
let mut sticky = self.sticky_partitions.write().await;
if let Some(state) = sticky.get_mut(topic) {
state.partition = (state.partition + 1) % state.total_partitions;
state.batch_count = 0;
}
}
}
#[inline]
fn murmur2_partition(key: &[u8], num_partitions: u32) -> u32 {
const SEED: u32 = 0x9747b28c;
const M: u32 = 0x5bd1e995;
const R: u32 = 24;
let len = key.len();
let mut h: u32 = SEED ^ (len as u32);
let mut i = 0;
while i + 4 <= len {
let k = u32::from_le_bytes([key[i], key[i + 1], key[i + 2], key[i + 3]]);
let k = k.wrapping_mul(M);
let k = k ^ (k >> R);
let k = k.wrapping_mul(M);
h = h.wrapping_mul(M);
h ^= k;
i += 4;
}
let remainder = len - i;
if remainder >= 3 {
h ^= (key[i + 2] as u32) << 16;
}
if remainder >= 2 {
h ^= (key[i + 1] as u32) << 8;
}
if remainder >= 1 {
h ^= key[i] as u32;
h = h.wrapping_mul(M);
}
h ^= h >> 13;
h = h.wrapping_mul(M);
h ^= h >> 15;
(h & 0x7fffffff) % num_partitions
}
fn rand_partition(num_partitions: u32) -> u32 {
use std::time::SystemTime;
let nanos = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
nanos % num_partitions
}
pub struct ProducerRecord {
pub topic: String,
pub partition: Option<u32>,
pub key: Option<Bytes>,
pub value: Bytes,
response_tx: oneshot::Sender<Result<RecordMetadata>>,
}
#[derive(Debug, Clone)]
pub struct RecordMetadata {
pub topic: String,
pub partition: u32,
pub offset: u64,
pub timestamp: u64,
}
struct RecordBatch {
topic: String,
partition: u32,
records: Vec<(
Bytes,
Option<Bytes>,
oneshot::Sender<Result<RecordMetadata>>,
)>,
batch_size_bytes: usize,
created_at: Instant,
}
impl RecordBatch {
fn new(topic: String, partition: u32) -> Self {
Self {
topic,
partition,
records: Vec::new(),
batch_size_bytes: 0,
created_at: Instant::now(),
}
}
fn add(
&mut self,
key: Option<Bytes>,
value: Bytes,
tx: oneshot::Sender<Result<RecordMetadata>>,
) {
self.batch_size_bytes += value.len() + key.as_ref().map(|k| k.len()).unwrap_or(0);
self.records.push((value, key, tx));
}
fn is_empty(&self) -> bool {
self.records.is_empty()
}
fn len(&self) -> usize {
self.records.len()
}
}
pub struct Producer {
inner: Arc<ProducerInner>,
}
struct ProducerInner {
config: ProducerConfig,
metadata_cache: MetadataCache,
partitioner: StickyPartitioner,
record_tx: mpsc::Sender<ProducerRecord>,
memory_semaphore: Arc<Semaphore>,
#[allow(dead_code)]
producer_id: AtomicU64,
#[allow(dead_code)]
sequences: RwLock<HashMap<String, AtomicU32>>,
stats: ProducerStats,
pending_records: AtomicU64,
flush_notify: Notify,
shutdown: AtomicBool,
shutdown_tx: tokio::sync::watch::Sender<bool>,
}
impl Clone for Producer {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl Producer {
pub async fn new(config: ProducerConfig) -> Result<Self> {
if config.bootstrap_servers.is_empty() {
return Err(Error::ConfigError(
"No bootstrap servers configured".to_string(),
));
}
let memory_permits = config.buffer_memory / 1024; let memory_semaphore = Arc::new(Semaphore::new(memory_permits));
let (record_tx, record_rx) = mpsc::channel(config.buffer_memory / config.batch_size);
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let inner = Arc::new(ProducerInner {
config: config.clone(),
metadata_cache: MetadataCache::new(config.metadata_max_age),
partitioner: StickyPartitioner::new(),
record_tx,
memory_semaphore: memory_semaphore.clone(),
producer_id: AtomicU64::new(0),
sequences: RwLock::new(HashMap::new()),
stats: ProducerStats::new(),
pending_records: AtomicU64::new(0),
flush_notify: Notify::new(),
shutdown: AtomicBool::new(false),
shutdown_tx,
});
let addr = config.bootstrap_servers[0].clone();
let stream = tokio::time::timeout(config.connection_timeout, TcpStream::connect(&addr))
.await
.map_err(|_| Error::ConnectionError(format!("Connection timeout to {}", addr)))?
.map_err(|e| Error::ConnectionError(e.to_string()))?;
stream.set_nodelay(true).ok();
let sender_inner = Arc::clone(&inner);
tokio::spawn(async move {
sender_task(sender_inner, stream, record_rx, shutdown_rx).await;
});
let cleanup_inner = Arc::clone(&inner);
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
if cleanup_inner.shutdown.load(Ordering::Relaxed) {
break;
}
cleanup_inner.metadata_cache.cleanup().await;
}
});
info!(
"Producer initialized with {} bootstrap servers",
config.bootstrap_servers.len()
);
Ok(Self { inner })
}
pub async fn send(
&self,
topic: impl Into<String>,
value: impl Into<Bytes>,
) -> Result<RecordMetadata> {
self.send_with_key(topic, None::<Bytes>, value).await
}
pub async fn send_with_key(
&self,
topic: impl Into<String>,
key: Option<impl Into<Bytes>>,
value: impl Into<Bytes>,
) -> Result<RecordMetadata> {
let topic = topic.into();
let key: Option<Bytes> = key.map(|k| k.into());
let value = value.into();
let value_size = value.len() + key.as_ref().map(|k| k.len()).unwrap_or(0);
let permits_needed = (value_size / 1024).max(1);
let _permit = tokio::time::timeout(
self.inner.config.request_timeout,
self.inner
.memory_semaphore
.acquire_many(permits_needed as u32),
)
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::ConnectionError("Producer closed".into()))?;
let (response_tx, response_rx) = oneshot::channel();
let record = ProducerRecord {
topic,
partition: None,
key,
value,
response_tx,
};
self.inner
.record_tx
.send(record)
.await
.map_err(|_| Error::ConnectionError("Sender task closed".into()))?;
self.inner
.stats
.records_sent
.fetch_add(1, Ordering::Relaxed);
self.inner.pending_records.fetch_add(1, Ordering::Release);
tokio::time::timeout(self.inner.config.delivery_timeout, response_rx)
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::ConnectionError("Response channel dropped".into()))?
}
pub async fn send_to_partition(
&self,
topic: impl Into<String>,
partition: u32,
key: Option<impl Into<Bytes>>,
value: impl Into<Bytes>,
) -> Result<RecordMetadata> {
let topic = topic.into();
let key: Option<Bytes> = key.map(|k| k.into());
let value = value.into();
let (response_tx, response_rx) = oneshot::channel();
let record = ProducerRecord {
topic,
partition: Some(partition),
key,
value,
response_tx,
};
self.inner
.record_tx
.send(record)
.await
.map_err(|_| Error::ConnectionError("Sender task closed".into()))?;
tokio::time::timeout(self.inner.config.delivery_timeout, response_rx)
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::ConnectionError("Response channel dropped".into()))?
}
pub async fn flush(&self) -> Result<()> {
if self.inner.pending_records.load(Ordering::Acquire) == 0 {
return Ok(());
}
let deadline = tokio::time::Instant::now() + self.inner.config.delivery_timeout;
loop {
if self.inner.pending_records.load(Ordering::Acquire) == 0 {
return Ok(());
}
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(Error::Timeout);
}
tokio::select! {
_ = self.inner.flush_notify.notified() => {
continue;
}
_ = tokio::time::sleep(remaining.min(Duration::from_millis(10))) => {
continue;
}
}
}
}
pub fn stats(&self) -> ProducerStatsSnapshot {
ProducerStatsSnapshot {
records_sent: self.inner.stats.records_sent.load(Ordering::Relaxed),
records_delivered: self.inner.stats.records_delivered.load(Ordering::Relaxed),
batches_sent: self.inner.stats.batches_sent.load(Ordering::Relaxed),
errors: self.inner.stats.errors.load(Ordering::Relaxed),
retries: self.inner.stats.retries.load(Ordering::Relaxed),
}
}
pub async fn close(&self) {
self.inner.shutdown.store(true, Ordering::Relaxed);
let _ = self.inner.shutdown_tx.send(true);
}
#[allow(dead_code)]
pub async fn get_partition_count(&self, topic: &str) -> Result<u32> {
if let Some(count) = self.inner.metadata_cache.get(topic).await {
return Ok(count);
}
self.refresh_metadata(topic).await
}
pub async fn refresh_metadata(&self, topic: &str) -> Result<u32> {
let addr = &self.inner.config.bootstrap_servers[0];
let stream = tokio::time::timeout(
self.inner.config.connection_timeout,
TcpStream::connect(addr),
)
.await
.map_err(|_| Error::ConnectionError(format!("Metadata request timeout to {}", addr)))?
.map_err(|e| Error::ConnectionError(e.to_string()))?;
stream.set_nodelay(true).ok();
let (read_half, write_half) = stream.into_split();
let mut writer = BufWriter::with_capacity(4096, write_half);
let mut reader = BufReader::with_capacity(4096, read_half);
let request = Request::GetMetadata {
topic: topic.to_string(),
};
let request_bytes = request.to_wire(rivven_protocol::WireFormat::Postcard)?;
let len = request_bytes.len() as u32;
writer
.write_all(&len.to_be_bytes())
.await
.map_err(|e| Error::IoError(e.to_string()))?;
writer
.write_all(&request_bytes)
.await
.map_err(|e| Error::IoError(e.to_string()))?;
writer
.flush()
.await
.map_err(|e| Error::IoError(e.to_string()))?;
let mut len_buf = [0u8; 4];
reader
.read_exact(&mut len_buf)
.await
.map_err(|e| Error::IoError(e.to_string()))?;
let response_len = u32::from_be_bytes(len_buf) as usize;
let mut response_buf = vec![0u8; response_len];
reader
.read_exact(&mut response_buf)
.await
.map_err(|e| Error::IoError(e.to_string()))?;
let (response, _format) =
Response::from_wire(&response_buf).map_err(Error::ProtocolError)?;
match response {
Response::Metadata { name, partitions } => {
self.inner
.metadata_cache
.put(name.clone(), partitions)
.await;
debug!(
"Refreshed metadata for topic '{}': {} partitions",
name, partitions
);
Ok(partitions)
}
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
}
async fn sender_task(
inner: Arc<ProducerInner>,
stream: TcpStream,
mut record_rx: mpsc::Receiver<ProducerRecord>,
mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
) {
let (read_half, write_half) = stream.into_split();
let mut writer = BufWriter::with_capacity(64 * 1024, write_half);
let mut reader = BufReader::with_capacity(64 * 1024, read_half);
let mut batches: HashMap<(String, u32), RecordBatch> = HashMap::new();
let mut last_flush = Instant::now();
loop {
if *shutdown_rx.borrow() {
break;
}
let linger_remaining = if inner.config.linger_ms == 0 {
Duration::ZERO
} else {
let elapsed = last_flush.elapsed();
let linger = Duration::from_millis(inner.config.linger_ms);
linger.saturating_sub(elapsed)
};
let record = if linger_remaining == Duration::ZERO && !batches.is_empty() {
None
} else {
tokio::select! {
biased;
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
continue;
}
result = record_rx.recv() => {
match result {
Some(r) => Some(r),
None => break,
}
}
_ = tokio::time::sleep(linger_remaining), if linger_remaining > Duration::ZERO => {
None
}
}
};
if let Some(record) = record {
let partition = if let Some(p) = record.partition {
p
} else {
let num_partitions = inner.metadata_cache.get(&record.topic).await.unwrap_or(1);
inner
.partitioner
.partition(&record.topic, record.key.as_deref(), num_partitions)
.await
};
let key = (record.topic.clone(), partition);
let batch = batches
.entry(key)
.or_insert_with(|| RecordBatch::new(record.topic.clone(), partition));
batch.add(record.key, record.value, record.response_tx);
}
let should_flush = batches.values().any(|b| {
b.batch_size_bytes >= inner.config.batch_size
|| b.created_at.elapsed().as_millis() as u64 >= inner.config.linger_ms
}) || (inner.config.linger_ms == 0 && !batches.is_empty());
if should_flush {
for ((topic, partition), batch) in batches.drain() {
if batch.is_empty() {
continue;
}
if let Err(e) = send_batch(&mut writer, &mut reader, &inner, batch).await {
warn!("Failed to send batch to {}/{}: {}", topic, partition, e);
inner.stats.errors.fetch_add(1, Ordering::Relaxed);
}
}
last_flush = Instant::now();
}
}
for ((_topic, _partition), batch) in batches.drain() {
if !batch.is_empty() {
let _ = send_batch(&mut writer, &mut reader, &inner, batch).await;
}
}
}
async fn send_batch(
writer: &mut BufWriter<tokio::net::tcp::OwnedWriteHalf>,
reader: &mut BufReader<tokio::net::tcp::OwnedReadHalf>,
inner: &Arc<ProducerInner>,
batch: RecordBatch,
) -> Result<()> {
let topic = batch.topic.clone();
let partition = batch.partition;
let num_records = batch.len();
if num_records == 0 {
return Ok(());
}
inner.stats.batches_sent.fetch_add(1, Ordering::Relaxed);
let mut response_channels: Vec<oneshot::Sender<Result<RecordMetadata>>> =
Vec::with_capacity(num_records);
for (value, key, response_tx) in batch.records {
response_channels.push(response_tx);
let request = Request::Publish {
topic: topic.clone(),
partition: Some(partition),
key,
value,
};
let request_bytes = request.to_wire(rivven_protocol::WireFormat::Postcard)?;
let len = request_bytes.len() as u32;
writer
.write_all(&len.to_be_bytes())
.await
.map_err(|e| Error::IoError(e.to_string()))?;
writer
.write_all(&request_bytes)
.await
.map_err(|e| Error::IoError(e.to_string()))?;
}
writer
.flush()
.await
.map_err(|e| Error::IoError(e.to_string()))?;
let mut len_buf = [0u8; 4];
for response_tx in response_channels {
reader
.read_exact(&mut len_buf)
.await
.map_err(|e| Error::IoError(e.to_string()))?;
let response_len = u32::from_be_bytes(len_buf) as usize;
let mut response_buf = vec![0u8; response_len];
reader
.read_exact(&mut response_buf)
.await
.map_err(|e| Error::IoError(e.to_string()))?;
let (response, _format) =
Response::from_wire(&response_buf).map_err(Error::ProtocolError)?;
match response {
Response::Published {
offset,
partition: resp_partition,
} => {
inner
.stats
.records_delivered
.fetch_add(1, Ordering::Relaxed);
let _ = response_tx.send(Ok(RecordMetadata {
topic: topic.clone(),
partition: resp_partition,
offset,
timestamp: 0, }));
}
Response::Error { message } => {
inner.stats.errors.fetch_add(1, Ordering::Relaxed);
let _ = response_tx.send(Err(Error::ServerError(message)));
}
_ => {
let _ = response_tx.send(Err(Error::InvalidResponse));
}
}
let prev = inner.pending_records.fetch_sub(1, Ordering::Release);
if prev == 1 {
inner.flush_notify.notify_waiters();
}
}
debug!(
"Sent batch of {} records to {}/{}",
num_records, topic, partition
);
Ok(())
}
struct ProducerStats {
records_sent: AtomicU64,
records_delivered: AtomicU64,
batches_sent: AtomicU64,
errors: AtomicU64,
retries: AtomicU64,
}
impl ProducerStats {
fn new() -> Self {
Self {
records_sent: AtomicU64::new(0),
records_delivered: AtomicU64::new(0),
batches_sent: AtomicU64::new(0),
errors: AtomicU64::new(0),
retries: AtomicU64::new(0),
}
}
}
#[derive(Debug, Clone)]
pub struct ProducerStatsSnapshot {
pub records_sent: u64,
pub records_delivered: u64,
pub batches_sent: u64,
pub errors: u64,
pub retries: u64,
}
impl ProducerStatsSnapshot {
pub fn success_rate(&self) -> f64 {
if self.records_sent == 0 {
1.0
} else {
self.records_delivered as f64 / self.records_sent as f64
}
}
pub fn pending(&self) -> u64 {
self.records_sent
.saturating_sub(self.records_delivered + self.errors)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_producer_config_builder() {
let config = ProducerConfig::builder()
.bootstrap_servers(vec!["server1:9092".to_string()])
.batch_size(32768)
.linger_ms(10)
.buffer_memory(64 * 1024 * 1024)
.enable_idempotence(true)
.retries(5)
.build();
assert_eq!(config.bootstrap_servers.len(), 1);
assert_eq!(config.batch_size, 32768);
assert_eq!(config.linger_ms, 10);
assert_eq!(config.buffer_memory, 64 * 1024 * 1024);
assert!(config.enable_idempotence);
assert_eq!(config.retries, 5);
}
#[test]
fn test_high_throughput_config() {
let config = ProducerConfig::high_throughput();
assert_eq!(config.batch_size, 65536);
assert_eq!(config.linger_ms, 10);
assert_eq!(config.max_in_flight_requests, 10);
}
#[test]
fn test_low_latency_config() {
let config = ProducerConfig::low_latency();
assert_eq!(config.batch_size, 1);
assert_eq!(config.linger_ms, 0);
assert_eq!(config.max_in_flight_requests, 1);
}
#[test]
fn test_exactly_once_config() {
let config = ProducerConfig::exactly_once();
assert!(config.enable_idempotence);
assert_eq!(config.acks, -1);
assert_eq!(config.retries, i32::MAX as u32);
}
#[test]
fn test_murmur2_partition() {
let key = b"test-key";
let p1 = murmur2_partition(key, 10);
let p2 = murmur2_partition(key, 10);
assert_eq!(p1, p2);
let key2 = b"other-key";
let p3 = murmur2_partition(key2, 10);
assert!(p3 < 10);
}
#[test]
fn test_murmur2_partition_range() {
for i in 0..1000 {
let key = format!("key-{}", i);
let partition = murmur2_partition(key.as_bytes(), 16);
assert!(partition < 16);
}
}
#[test]
fn test_stats_snapshot() {
let stats = ProducerStatsSnapshot {
records_sent: 100,
records_delivered: 90,
batches_sent: 10,
errors: 5,
retries: 3,
};
assert!((stats.success_rate() - 0.9).abs() < 0.001);
assert_eq!(stats.pending(), 5);
}
#[tokio::test]
async fn test_metadata_cache() {
let cache = MetadataCache::new(Duration::from_millis(100));
assert!(cache.get("topic1").await.is_none());
cache.put("topic1".to_string(), 4).await;
assert_eq!(cache.get("topic1").await, Some(4));
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(cache.get("topic1").await.is_none());
}
#[tokio::test]
async fn test_sticky_partitioner_with_key() {
let partitioner = StickyPartitioner::new();
let p1 = partitioner
.partition("topic", Some(b"key1".as_slice()), 10)
.await;
let p2 = partitioner
.partition("topic", Some(b"key1".as_slice()), 10)
.await;
assert_eq!(p1, p2);
}
#[tokio::test]
async fn test_sticky_partitioner_without_key() {
let partitioner = StickyPartitioner::new();
let mut partitions = Vec::new();
for _ in 0..STICKY_BATCH_THRESHOLD {
let p = partitioner.partition("topic", None, 10).await;
partitions.push(p);
}
for i in 1..STICKY_BATCH_THRESHOLD - 1 {
assert_eq!(partitions[0], partitions[i]);
}
}
#[test]
fn test_pending_records_tracking() {
use std::sync::atomic::AtomicU64;
let pending = AtomicU64::new(0);
pending.fetch_add(1, Ordering::Release);
pending.fetch_add(1, Ordering::Release);
assert_eq!(pending.load(Ordering::Acquire), 2);
let prev = pending.fetch_sub(1, Ordering::Release);
assert_eq!(prev, 2);
assert_eq!(pending.load(Ordering::Acquire), 1);
let prev = pending.fetch_sub(1, Ordering::Release);
assert_eq!(prev, 1);
assert_eq!(pending.load(Ordering::Acquire), 0);
}
#[tokio::test]
async fn test_flush_notify_mechanism() {
let notify = tokio::sync::Notify::new();
let pending = Arc::new(AtomicU64::new(1));
let pending_clone = Arc::clone(&pending);
let notify_clone = Arc::new(notify);
let notify_clone2 = Arc::clone(¬ify_clone);
let handle = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
let prev = pending_clone.fetch_sub(1, Ordering::Release);
if prev == 1 {
notify_clone2.notify_waiters();
}
});
let start = Instant::now();
loop {
if pending.load(Ordering::Acquire) == 0 {
break;
}
tokio::select! {
_ = notify_clone.notified() => continue,
_ = tokio::time::sleep(Duration::from_millis(5)) => continue,
}
}
let elapsed = start.elapsed();
handle.await.unwrap();
assert!(elapsed < Duration::from_millis(100));
}
#[test]
fn test_murmur2_kafka_compatibility() {
let p = murmur2_partition(b"", 10);
assert!(p < 10);
let p = murmur2_partition(b"a", 10);
assert!(p < 10);
let p = murmur2_partition(b"abc", 10); assert!(p < 10);
let p = murmur2_partition(b"abcd", 10); assert!(p < 10);
let p = murmur2_partition(b"abcdefgh", 10); assert!(p < 10);
let p = murmur2_partition(b"abcdefghijk", 10); assert!(p < 10);
}
#[test]
fn test_murmur2_distribution() {
let mut buckets = [0u32; 16];
for i in 0..10000 {
let key = format!("key-{}-data-{}", i, i * 7);
let partition = murmur2_partition(key.as_bytes(), 16);
buckets[partition as usize] += 1;
}
for (i, &count) in buckets.iter().enumerate() {
assert!(
count > 300 && count < 950,
"Bucket {} has {} keys, expected ~625",
i,
count
);
}
}
#[test]
fn test_compression_type_default() {
let compression = CompressionType::default();
assert_eq!(compression, CompressionType::None);
}
#[test]
fn test_record_batch() {
let mut batch = RecordBatch::new("test-topic".to_string(), 0);
assert!(batch.is_empty());
assert_eq!(batch.len(), 0);
let (tx, _rx) = oneshot::channel();
batch.add(Some(Bytes::from("key")), Bytes::from("value"), tx);
assert!(!batch.is_empty());
assert_eq!(batch.len(), 1);
assert_eq!(batch.batch_size_bytes, 8); }
#[test]
fn test_stats_snapshot_zero_sent() {
let stats = ProducerStatsSnapshot {
records_sent: 0,
records_delivered: 0,
batches_sent: 0,
errors: 0,
retries: 0,
};
assert!((stats.success_rate() - 1.0).abs() < 0.001);
assert_eq!(stats.pending(), 0);
}
#[tokio::test]
async fn test_metadata_cache_cleanup() {
let cache = MetadataCache::new(Duration::from_millis(50));
cache.put("topic1".to_string(), 4).await;
cache.put("topic2".to_string(), 8).await;
assert_eq!(cache.get("topic1").await, Some(4));
assert_eq!(cache.get("topic2").await, Some(8));
tokio::time::sleep(Duration::from_millis(100)).await;
cache.cleanup().await;
assert!(cache.get("topic1").await.is_none());
assert!(cache.get("topic2").await.is_none());
}
#[tokio::test]
async fn test_sticky_partitioner_partition_change() {
let partitioner = StickyPartitioner::new();
let p1 = partitioner.partition("topic", None, 4).await;
assert!(p1 < 4);
let p2 = partitioner.partition("topic", None, 8).await;
assert!(p2 < 8);
}
}