use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{Semaphore, mpsc, oneshot};
use tokio::time::interval;
use tracing::{debug, trace, warn};
use super::barrier::{InFlightBarrier, InFlightOpGuard};
use super::batch::ProducerBatch;
use super::record::{ProducerRecord, RecordMetadata, RoutedRecord, TopicHandle};
use super::retry::{RetryContext, RetryPolicy};
use crate::PartitionId;
use crate::error::{ErrorCode, KrafkaError, Result};
use crate::interceptor::ProducerInterceptor;
use crate::metadata::ClusterMetadata;
use crate::metrics::ProducerMetrics;
use crate::protocol::{
ApiKey, Compression, ProducePartitionData, ProduceRequest, ProduceResponse, ProduceTopicData,
RecordBatchBuilder, VersionedDecode, VersionedEncode, versions,
};
const MAX_CONCURRENT_BATCH_SENDS: usize = 64;
fn check_record_admission(record_size: usize, memory_capacity: usize) -> Result<()> {
if record_size > u32::MAX as usize {
return Err(KrafkaError::config(format!(
"record size {record_size} B exceeds the semaphore \
permit-count limit ({} B, u32::MAX); Kafka records must \
be smaller",
u32::MAX
)));
}
if record_size > memory_capacity {
return Err(KrafkaError::config(format!(
"record size {record_size} B exceeds producer buffer_memory \
capacity ({} B); raise ProducerConfig::buffer_memory or \
shrink the record",
memory_capacity
)));
}
Ok(())
}
#[derive(Debug)]
enum AppendResponse {
Done(Result<RecordMetadata>),
}
#[derive(Debug)]
struct AppendCommand {
topic: TopicHandle,
record: RoutedRecord,
partition: PartitionId,
record_size: usize,
response_tx: oneshot::Sender<AppendResponse>,
operation_guard: InFlightOpGuard,
permit_reservation: PermitReservation,
}
#[derive(Debug)]
enum AccumulatorMessage {
Append(AppendCommand),
Flush {
response_tx: oneshot::Sender<Result<()>>,
},
Shutdown { response_tx: oneshot::Sender<()> },
}
struct PermitReservation {
bytes: usize,
memory_permits: Arc<Semaphore>,
}
impl PermitReservation {
fn forget(mut self) {
self.bytes = 0;
}
}
impl Drop for PermitReservation {
fn drop(&mut self) {
self.memory_permits.add_permits(self.bytes);
}
}
impl std::fmt::Debug for PermitReservation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PermitReservation")
.field("bytes", &self.bytes)
.finish()
}
}
#[derive(Clone)]
pub struct RecordAccumulatorHandle {
sender: mpsc::Sender<AccumulatorMessage>,
memory_permits: Arc<Semaphore>,
memory_capacity: usize,
max_block_ms: Duration,
in_flight_barrier: Arc<InFlightBarrier>,
}
impl RecordAccumulatorHandle {
pub async fn append(
&self,
record: ProducerRecord,
partition: PartitionId,
) -> Result<RecordMetadata> {
let operation_guard = self.in_flight_barrier.start("producer")?;
self.append_with_guard(record, partition, operation_guard)
.await
}
pub(crate) async fn append_with_guard(
&self,
record: ProducerRecord,
partition: PartitionId,
operation_guard: InFlightOpGuard,
) -> Result<RecordMetadata> {
let record_size = record.estimated_size();
let routed = record.into_routed_parts();
self.append_routed_with_guard(
routed.topic,
routed.record,
record_size,
partition,
operation_guard,
)
.await
}
pub(crate) async fn append_routed_with_guard(
&self,
topic: TopicHandle,
record: RoutedRecord,
record_size: usize,
partition: PartitionId,
operation_guard: InFlightOpGuard,
) -> Result<RecordMetadata> {
let deadline = tokio::time::Instant::now() + self.max_block_ms;
check_record_admission(record_size, self.memory_capacity)?;
let permit = match tokio::time::timeout(
deadline.saturating_duration_since(tokio::time::Instant::now()),
self.memory_permits.acquire_many(record_size as u32),
)
.await
{
Ok(Ok(p)) => p,
Ok(Err(_)) => return Err(KrafkaError::invalid_state("accumulator closed")),
Err(_) => {
return Err(KrafkaError::timeout(
"producer append: max_block exceeded while waiting for buffer \
memory (ProducerConfig::max_block / AccumulatorConfig::max_block_ms)",
));
}
};
let permit_reservation = PermitReservation {
bytes: record_size,
memory_permits: self.memory_permits.clone(),
};
permit.forget();
let (response_tx, response_rx) = oneshot::channel();
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
match tokio::time::timeout(
remaining,
self.sender.send(AccumulatorMessage::Append(AppendCommand {
topic,
record,
partition,
record_size,
response_tx,
operation_guard,
permit_reservation,
})),
)
.await
{
Ok(Ok(())) => {}
Ok(Err(_)) => return Err(KrafkaError::invalid_state("accumulator closed")),
Err(_) => {
return Err(KrafkaError::timeout(
"producer append: max_block exceeded while sending to accumulator",
));
}
}
match response_rx
.await
.map_err(|_| KrafkaError::invalid_state("accumulator response dropped"))?
{
AppendResponse::Done(result) => result,
}
}
pub async fn flush(&self) -> Result<()> {
let (response_tx, response_rx) = oneshot::channel();
self.sender
.send(AccumulatorMessage::Flush { response_tx })
.await
.map_err(|_| KrafkaError::invalid_state("accumulator closed"))?;
response_rx
.await
.map_err(|_| KrafkaError::invalid_state("accumulator response dropped"))?
}
pub async fn shutdown(&self) -> Result<()> {
let (response_tx, response_rx) = oneshot::channel();
self.sender
.send(AccumulatorMessage::Shutdown { response_tx })
.await
.map_err(|_| {
warn!("Accumulator shutdown failed: task already exited");
KrafkaError::invalid_state("accumulator already shut down")
})?;
response_rx.await.map_err(|_| {
warn!("Accumulator shutdown: response channel dropped before completion");
KrafkaError::invalid_state("accumulator shutdown interrupted")
})?;
Ok(())
}
}
pub struct AccumulatorConfig {
pub batch_size: usize,
pub linger: Duration,
pub compression: Compression,
pub acks: i16,
pub client_id: String,
pub request_timeout: Duration,
pub max_request_size: usize,
pub buffer_memory: usize,
pub max_block_ms: Duration,
pub in_flight_semaphore: Arc<Semaphore>,
pub interceptor: Arc<dyn ProducerInterceptor>,
pub identity: Option<Arc<super::idempotent::ProducerIdentity>>,
}
impl Clone for AccumulatorConfig {
fn clone(&self) -> Self {
Self {
batch_size: self.batch_size,
linger: self.linger,
compression: self.compression,
acks: self.acks,
client_id: self.client_id.clone(),
request_timeout: self.request_timeout,
max_request_size: self.max_request_size,
buffer_memory: self.buffer_memory,
max_block_ms: self.max_block_ms,
in_flight_semaphore: self.in_flight_semaphore.clone(),
interceptor: self.interceptor.clone(),
identity: self.identity.clone(),
}
}
}
impl fmt::Debug for AccumulatorConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AccumulatorConfig")
.field("batch_size", &self.batch_size)
.field("linger", &self.linger)
.field("compression", &self.compression)
.field("acks", &self.acks)
.field("client_id", &self.client_id)
.field("request_timeout", &self.request_timeout)
.field("max_request_size", &self.max_request_size)
.field("buffer_memory", &self.buffer_memory)
.field("max_block_ms", &self.max_block_ms)
.field("interceptor", &self.interceptor)
.finish()
}
}
impl Default for AccumulatorConfig {
fn default() -> Self {
Self {
batch_size: 16384,
linger: Duration::from_millis(0),
compression: Compression::None,
acks: -1,
client_id: "krafka".to_string(),
request_timeout: Duration::from_secs(30),
max_request_size: crate::protocol::MAX_MESSAGE_SIZE,
buffer_memory: 32 * 1024 * 1024, max_block_ms: Duration::from_secs(60), in_flight_semaphore: Arc::new(Semaphore::new(5)), interceptor: Arc::new(crate::interceptor::NoOpProducerInterceptor),
identity: None,
}
}
}
struct PendingRecord {
record: RoutedRecord,
response_tx: oneshot::Sender<AppendResponse>,
offset_in_batch: i64,
estimated_size: usize,
_operation_guard: InFlightOpGuard,
}
struct InFlightGuard {
bytes: usize,
in_flight_memory: Arc<AtomicUsize>,
memory_permits: Arc<Semaphore>,
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
self.in_flight_memory
.fetch_sub(self.bytes, Ordering::Relaxed);
self.memory_permits.add_permits(self.bytes);
}
}
struct AccumulatorBatch {
batch: ProducerBatch,
pending: Vec<PendingRecord>,
created_at: Instant,
}
impl AccumulatorBatch {
fn new(
topic: TopicHandle,
partition: PartitionId,
max_size: usize,
compression: Compression,
) -> Self {
Self {
batch: ProducerBatch::new(topic.to_string(), partition, max_size, compression),
pending: Vec::new(),
created_at: Instant::now(),
}
}
fn age(&self) -> Duration {
self.created_at.elapsed()
}
}
pub struct RecordAccumulator {
config: AccumulatorConfig,
batches: HashMap<(TopicHandle, PartitionId), AccumulatorBatch>,
metadata: Arc<ClusterMetadata>,
in_flight_memory: Arc<AtomicUsize>,
retry_policy: RetryPolicy,
metrics: Arc<ProducerMetrics>,
memory_permits: Arc<Semaphore>,
}
impl RecordAccumulator {
pub(crate) fn spawn(
config: AccumulatorConfig,
metadata: Arc<ClusterMetadata>,
retry_policy: RetryPolicy,
metrics: Arc<ProducerMetrics>,
in_flight_barrier: Arc<InFlightBarrier>,
) -> RecordAccumulatorHandle {
let channel_capacity = if config.buffer_memory > 0 {
let batch = config.batch_size.max(1);
(config.buffer_memory / 10 / batch).clamp(1, 256)
} else {
64
};
let (sender, receiver) = mpsc::channel(channel_capacity);
let memory_capacity = if config.buffer_memory > 0 {
if config.buffer_memory > Semaphore::MAX_PERMITS {
warn!(
requested = config.buffer_memory,
effective = Semaphore::MAX_PERMITS,
"buffer_memory exceeds Semaphore::MAX_PERMITS; clamping effective \
producer memory capacity"
);
Semaphore::MAX_PERMITS
} else {
config.buffer_memory
}
} else {
Semaphore::MAX_PERMITS
};
let memory_permits = Arc::new(Semaphore::new(memory_capacity));
let in_flight_memory = Arc::new(AtomicUsize::new(0));
let max_block_ms = config.max_block_ms;
if config.buffer_memory == 0 {
warn!(
"buffer_memory=0 disables producer backpressure; \
memory usage is unbounded. Not recommended for production."
);
}
let accumulator = Self {
config,
batches: HashMap::new(),
metadata,
in_flight_memory,
retry_policy,
metrics,
memory_permits: memory_permits.clone(),
};
let memory_permits_panic = memory_permits.clone();
tokio::spawn(async move {
let join_handle = tokio::spawn(accumulator.run(receiver));
if let Err(join_err) = join_handle.await {
if join_err.is_panic() {
tracing::error!("Accumulator task panicked: {join_err}");
} else {
tracing::error!("Accumulator task cancelled: {join_err}");
}
memory_permits_panic.close();
}
});
RecordAccumulatorHandle {
sender,
memory_permits,
memory_capacity,
max_block_ms,
in_flight_barrier,
}
}
async fn run(mut self, mut receiver: mpsc::Receiver<AccumulatorMessage>) {
let linger_check_interval = Duration::from_millis(1).max(self.config.linger / 10);
let mut linger_timer = interval(linger_check_interval);
linger_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
msg = receiver.recv() => {
match msg {
Some(AccumulatorMessage::Append(append)) => {
self.handle_append(append).await;
}
Some(AccumulatorMessage::Flush { response_tx }) => {
let result = self.flush_all().await;
let _ = response_tx.send(result);
}
Some(AccumulatorMessage::Shutdown { response_tx }) => {
debug!("Accumulator shutting down, flushing remaining batches");
let _ = self.flush_all().await;
let _ = response_tx.send(());
break;
}
None => {
debug!("Accumulator channel closed, flushing remaining batches");
let _ = self.flush_all().await;
break;
}
}
}
_ = linger_timer.tick() => {
self.check_linger_expiry();
}
}
}
debug!("Accumulator shutdown complete");
}
async fn handle_append(&mut self, append: AppendCommand) {
let AppendCommand {
topic,
record,
partition,
record_size,
response_tx,
operation_guard,
permit_reservation,
} = append;
let key = (topic, partition);
let batch_size = self.config.batch_size;
let compression = self.config.compression;
let accumulator_batch = self.batches.entry(key.clone()).or_insert_with(|| {
AccumulatorBatch::new(key.0.clone(), partition, batch_size, compression)
});
let offset = accumulator_batch.batch.len() as i64;
if accumulator_batch.batch.would_fit(record_size) {
accumulator_batch.batch.track(record_size);
accumulator_batch.pending.push(PendingRecord {
record,
response_tx,
offset_in_batch: offset,
estimated_size: record_size,
_operation_guard: operation_guard,
});
permit_reservation.forget();
if accumulator_batch.batch.is_full() {
trace!("Batch full for {}-{}, flushing", key.0, partition);
self.flush_batch(&key);
} else if self.config.linger.is_zero() {
trace!("Linger=0 for {}-{}, flushing immediately", key.0, partition);
self.flush_batch(&key);
}
} else {
self.flush_batch(&key);
let mut new_batch =
AccumulatorBatch::new(key.0.clone(), partition, batch_size, compression);
if new_batch.batch.would_fit(record_size) {
new_batch.batch.track(record_size);
new_batch.pending.push(PendingRecord {
record,
response_tx,
offset_in_batch: 0,
estimated_size: record_size,
_operation_guard: operation_guard,
});
self.batches.insert(key, new_batch);
permit_reservation.forget();
} else {
drop(permit_reservation);
let _ = response_tx.send(AppendResponse::Done(Err(KrafkaError::config(
"record too large for batch size",
))));
}
}
}
fn check_linger_expiry(&mut self) {
if self.config.linger.is_zero() {
self.flush_all_ready();
return;
}
let keys_to_flush: Vec<_> = self
.batches
.iter()
.filter(|(_, batch)| !batch.batch.is_empty() && batch.age() >= self.config.linger)
.map(|(key, _)| key.clone())
.collect();
if keys_to_flush.is_empty() {
return;
}
let mut extracted = Vec::with_capacity(keys_to_flush.len());
for key in keys_to_flush {
trace!("Linger expired for {:?}, flushing", key);
if let Some(item) = self.extract_batch(&key) {
extracted.push((key, item));
}
}
Self::spawn_batches_detached(
extracted,
&self.metadata,
&self.config,
&self.retry_policy,
&self.metrics,
);
}
fn flush_all_ready(&mut self) {
let keys_to_flush: Vec<_> = self
.batches
.iter()
.filter(|(_, batch)| !batch.batch.is_empty())
.map(|(key, _)| key.clone())
.collect();
if keys_to_flush.is_empty() {
return;
}
let mut extracted = Vec::with_capacity(keys_to_flush.len());
for key in keys_to_flush {
if let Some(item) = self.extract_batch(&key) {
extracted.push((key, item));
}
}
Self::spawn_batches_detached(
extracted,
&self.metadata,
&self.config,
&self.retry_policy,
&self.metrics,
);
}
async fn spawn_batches_bounded(
extracted: Vec<(
(TopicHandle, PartitionId),
(AccumulatorBatch, InFlightGuard),
)>,
metadata: &Arc<ClusterMetadata>,
config: &AccumulatorConfig,
retry_policy: &RetryPolicy,
metrics: &Arc<ProducerMetrics>,
) {
let mut join_set = tokio::task::JoinSet::new();
for ((topic, partition), (batch, guard)) in extracted {
if join_set.len() >= MAX_CONCURRENT_BATCH_SENDS {
if let Some(Err(e)) = join_set.join_next().await
&& e.is_panic()
{
warn!("send_extracted_batch task panicked: {e}");
}
}
join_set.spawn(Self::send_extracted_batch(
topic,
partition,
batch.pending,
batch.created_at,
guard,
metadata.clone(),
config.clone(),
retry_policy.clone(),
metrics.clone(),
));
}
while let Some(result) = join_set.join_next().await {
if let Err(e) = result
&& e.is_panic()
{
warn!("send_extracted_batch task panicked: {e}");
}
}
}
fn spawn_batches_detached(
extracted: Vec<(
(TopicHandle, PartitionId),
(AccumulatorBatch, InFlightGuard),
)>,
metadata: &Arc<ClusterMetadata>,
config: &AccumulatorConfig,
retry_policy: &RetryPolicy,
metrics: &Arc<ProducerMetrics>,
) {
if extracted.is_empty() {
return;
}
let metadata = metadata.clone();
let config = config.clone();
let retry_policy = retry_policy.clone();
let metrics = metrics.clone();
drop(tokio::spawn(async move {
Self::spawn_batches_bounded(extracted, &metadata, &config, &retry_policy, &metrics)
.await;
}));
}
fn extract_batch(
&mut self,
key: &(TopicHandle, PartitionId),
) -> Option<(AccumulatorBatch, InFlightGuard)> {
let batch = self.batches.remove(key)?;
if batch.batch.is_empty() {
return None;
}
let batch_memory: usize = batch.pending.iter().map(|p| p.estimated_size).sum();
self.in_flight_memory
.fetch_add(batch_memory, Ordering::Relaxed);
let guard = InFlightGuard {
bytes: batch_memory,
in_flight_memory: self.in_flight_memory.clone(),
memory_permits: self.memory_permits.clone(),
};
Some((batch, guard))
}
fn flush_batch(&mut self, key: &(TopicHandle, PartitionId)) {
if let Some(item) = self.extract_batch(key) {
Self::spawn_batches_detached(
vec![(key.clone(), item)],
&self.metadata,
&self.config,
&self.retry_policy,
&self.metrics,
);
}
}
#[allow(clippy::too_many_arguments)]
async fn send_extracted_batch(
topic: TopicHandle,
partition: PartitionId,
pending: Vec<PendingRecord>,
enqueued_at: Instant,
_in_flight_guard: InFlightGuard,
metadata: Arc<ClusterMetadata>,
config: AccumulatorConfig,
retry_policy: RetryPolicy,
metrics: Arc<ProducerMetrics>,
) {
if let Some(identity) = config.identity.as_ref()
&& let Err(error) =
super::ensure_idempotent_producer_id_initialized(identity, &metadata, &retry_policy)
.await
{
metrics.record_error();
for pending_record in pending {
let _ = pending_record
.response_tx
.send(AppendResponse::Done(Err(error.clone())));
}
return;
}
let _permit = config.in_flight_semaphore.acquire().await;
let _timer = metrics.send_latency.start();
let record_count = pending.len() as i32;
let mut sequence: Option<i32> = match config
.identity
.as_ref()
.map(|id| id.allocate_sequence(topic.as_ref(), partition, record_count))
.transpose()
{
Ok(s) => s,
Err(e) => {
for p in pending {
let _ = p.response_tx.send(AppendResponse::Done(Err(e.clone())));
}
return;
}
};
let build_batch = |seq: Option<i32>,
cfg: &AccumulatorConfig|
-> crate::error::Result<ProduceRequest> {
let mut batch_builder = RecordBatchBuilder::new().compression(cfg.compression);
if let (Some(identity), Some(s)) = (&cfg.identity, seq) {
batch_builder =
batch_builder.producer(identity.producer_id(), identity.producer_epoch(), s);
}
for p in &pending {
batch_builder = p.record.append_to_batch_builder(batch_builder);
}
let batch = batch_builder.build();
let batch_bytes = batch.encode()?;
Ok(ProduceRequest {
transactional_id: None,
acks: cfg.acks,
timeout_ms: crate::util::duration_to_millis_i32(cfg.request_timeout),
topic_data: vec![ProduceTopicData {
name: topic.to_string(),
topic_id: None,
partition_data: vec![ProducePartitionData {
index: partition,
records: batch_bytes,
}],
}],
})
};
let mut request = match build_batch(sequence, &config) {
Ok(r) => r,
Err(e) => {
if let Some(ref identity) = config.identity {
let _ =
identity.rollback_sequence_range(topic.as_ref(), partition, record_count);
}
for p in pending {
let _ = p.response_tx.send(AppendResponse::Done(Err(e.clone())));
}
return;
}
};
let mut retry_ctx = RetryContext::new_with_start(
retry_policy.clone(),
format!("batch({topic}-{partition})"),
enqueued_at,
);
let result: std::result::Result<(i64, i64), KrafkaError> = loop {
let conn = match metadata
.get_leader_connection(topic.as_ref(), partition)
.await
{
Ok(c) => c,
Err(e) => {
if e.is_retriable() {
debug!(
topic = %topic,
partition = partition,
error = %e,
"Batch connection error, refreshing metadata"
);
if let Err(refresh_err) =
metadata.refresh_for_topics(Some(&[topic.as_ref()])).await
{
debug!(error = %refresh_err, "Metadata refresh failed during batch retry");
}
}
if let Some(backoff) = retry_ctx.record_failure(&e) {
metrics.record_retry();
retry_ctx.wait(backoff).await;
continue;
}
break Err(e);
}
};
let produce_version = match conn
.negotiate_api_version(
ApiKey::Produce,
versions::PRODUCE_MAX,
versions::PRODUCE_MIN,
)
.await
{
Some(v) => v,
None => {
let e = KrafkaError::protocol("no mutually supported Produce API version");
debug!(
topic = %topic,
partition = partition,
"Produce version negotiation failed, refreshing metadata"
);
if let Err(refresh_err) =
metadata.refresh_for_topics(Some(&[topic.as_ref()])).await
{
debug!(
error = %refresh_err,
"Metadata refresh failed during batch retry"
);
}
if let Some(backoff) = retry_ctx.record_failure(&e) {
metrics.record_retry();
retry_ctx.wait(backoff).await;
continue;
}
break Err(e);
}
};
if let Err(error) = super::validate_produce_request_size(
&config.client_id,
config.max_request_size,
produce_version,
&request,
) {
if let (Some(identity), Some(_)) = (config.identity.as_ref(), sequence) {
let _ =
identity.rollback_sequence_range(topic.as_ref(), partition, record_count);
}
metrics.record_error();
for pending_record in pending {
let _ = pending_record
.response_tx
.send(AppendResponse::Done(Err(error.clone())));
}
return;
}
if config.acks == 0 {
match conn
.send_fire_and_forget(ApiKey::Produce, produce_version, |buf| {
request.encode_versioned(produce_version, buf)
})
.await
{
Ok(()) => {
retry_ctx.record_success();
break Ok((-1, -1));
}
Err(e) => {
if let Some(backoff) = retry_ctx.record_failure(&e) {
metrics.record_retry();
retry_ctx.wait(backoff).await;
continue;
}
break Err(e);
}
}
}
let response_result = conn
.send_request(ApiKey::Produce, produce_version, |buf| {
request.encode_versioned(produce_version, buf)
})
.await;
match response_result {
Ok(mut response_buf) => {
match ProduceResponse::decode_versioned(produce_version, &mut response_buf) {
Ok(produce_response) => {
conn.notify_throttle(produce_response.throttle_time_ms);
let pr = produce_response
.responses
.iter()
.find(|r| r.name == topic.as_ref())
.and_then(|r| {
r.partition_responses.iter().find(|p| p.index == partition)
});
match pr {
Some(pr) if pr.error_code.is_ok() => {
retry_ctx.record_success();
break Ok((pr.base_offset, pr.log_append_time_ms));
}
Some(pr)
if pr.error_code == ErrorCode::DuplicateSequenceNumber
&& config.identity.is_some() =>
{
debug!(
topic = %topic,
partition = partition,
"DuplicateSequenceNumber in batch — dedup confirmed"
);
retry_ctx.record_success();
break Ok((-1, -1));
}
Some(pr) => {
let err = KrafkaError::broker(
pr.error_code,
format!("batch produce failed for {topic}-{partition}"),
);
if pr.error_code == ErrorCode::UnknownProducerId
&& let (Some(identity), Some(current_sequence)) =
(config.identity.as_ref(), sequence)
{
warn!(
topic = %topic,
partition = partition,
"UnknownProducerId in batch, reinitializing idempotent producer state"
);
let new_sequence = match super::recover_unknown_producer_id(
identity,
&metadata,
&retry_policy,
topic.as_ref(),
partition,
current_sequence,
record_count,
)
.await
{
Ok(new_sequence) => new_sequence,
Err(recovery_error) => break Err(recovery_error),
};
sequence = Some(new_sequence);
match build_batch(sequence, &config) {
Ok(new_request) => request = new_request,
Err(encode_err) => break Err(encode_err),
}
} else if pr.error_code == ErrorCode::OutOfOrderSequenceNumber
&& let Some(identity) = config.identity.as_ref()
{
warn!(
topic = %topic,
partition = partition,
"OutOfOrderSequenceNumber in batch, resetting sequence"
);
let new_seq = match identity.reset_and_allocate(
topic.as_ref(),
partition,
record_count,
) {
Ok(s) => s,
Err(e) => break Err(e),
};
sequence = Some(new_seq);
match build_batch(sequence, &config) {
Ok(r) => request = r,
Err(encode_err) => {
break Err(encode_err);
}
}
} else if err.is_retriable()
&& let Err(refresh_err) = metadata
.refresh_for_topics(Some(&[topic.as_ref()]))
.await
{
debug!(error = %refresh_err, "Metadata refresh failed during batch retry");
}
if let Some(backoff) = retry_ctx.record_failure(&err) {
metrics.record_retry();
retry_ctx.wait(backoff).await;
continue;
}
break Err(err);
}
None => {
break Err(KrafkaError::protocol(
"partition not found in response",
));
}
}
}
Err(e) => {
if let Some(backoff) = retry_ctx.record_failure(&e) {
metrics.record_retry();
retry_ctx.wait(backoff).await;
continue;
}
break Err(e);
}
}
}
Err(e) => {
if e.is_retriable() {
debug!(
topic = %topic,
partition = partition,
error = %e,
"Batch send error, refreshing metadata"
);
if let Err(refresh_err) =
metadata.refresh_for_topics(Some(&[topic.as_ref()])).await
{
debug!(error = %refresh_err, "Metadata refresh failed during batch retry");
}
}
if let Some(backoff) = retry_ctx.record_failure(&e) {
metrics.record_retry();
retry_ctx.wait(backoff).await;
continue;
}
break Err(e);
}
}
};
match result {
Ok((base_offset, timestamp)) => {
if let (Some(identity), Some(seq)) = (&config.identity, sequence)
&& let Ok(last_seq) =
super::idempotent::last_sequence_of_batch(seq, record_count)
{
identity.acknowledge(topic.as_ref(), partition, last_seq);
}
let batch_bytes_total: u64 = pending.iter().map(|p| p.estimated_size as u64).sum();
metrics.record_batch(pending.len() as u64);
metrics.bytes_sent.add(batch_bytes_total);
let topic_owned = topic.to_string();
for p in pending {
let meta = RecordMetadata {
topic: topic_owned.clone(),
partition,
offset: if base_offset >= 0 {
base_offset + p.offset_in_batch
} else {
-1
},
timestamp,
};
crate::interceptor::safe_on_acknowledgement(&*config.interceptor, &meta, None);
let _ = p.response_tx.send(AppendResponse::Done(Ok(meta)));
}
}
Err(e) => {
if let Some(identity) = config.identity.as_ref() {
let _ =
identity.rollback_sequence_range(topic.as_ref(), partition, record_count);
}
metrics.record_error();
let topic_owned = topic.to_string();
for p in pending {
let meta = RecordMetadata {
topic: topic_owned.clone(),
partition,
offset: -1,
timestamp: 0,
};
crate::interceptor::safe_on_acknowledgement(
&*config.interceptor,
&meta,
Some(&e),
);
let _ = p.response_tx.send(AppendResponse::Done(Err(e.clone())));
}
}
}
}
async fn flush_all(&mut self) -> Result<()> {
let keys: Vec<_> = self
.batches
.iter()
.filter(|(_, batch)| !batch.batch.is_empty())
.map(|(key, _)| key.clone())
.collect();
let mut extracted = Vec::with_capacity(keys.len());
for key in keys {
if let Some(item) = self.extract_batch(&key) {
extracted.push((key, item));
}
}
Self::spawn_batches_bounded(
extracted,
&self.metadata,
&self.config,
&self.retry_policy,
&self.metrics,
)
.await;
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_accumulator_config_default() {
let config = AccumulatorConfig::default();
assert_eq!(config.batch_size, 16384);
assert_eq!(config.linger, Duration::from_millis(0));
assert_eq!(config.acks, -1);
}
#[test]
fn test_accumulator_batch_age() {
let batch = AccumulatorBatch::new("test".to_string().into(), 0, 16384, Compression::None);
std::thread::sleep(Duration::from_millis(10));
assert!(batch.age() >= Duration::from_millis(10));
}
#[test]
fn test_accumulator_batch_new() {
let batch =
AccumulatorBatch::new("test-topic".to_string().into(), 1, 32768, Compression::Gzip);
assert!(batch.batch.is_empty());
assert!(batch.pending.is_empty());
}
#[test]
fn test_accumulator_config_custom() {
let config = AccumulatorConfig {
batch_size: 65536,
linger: Duration::from_millis(50),
compression: Compression::Snappy,
acks: 1,
client_id: "test-client".to_string(),
request_timeout: Duration::from_secs(10),
max_request_size: 131072,
buffer_memory: 64 * 1024 * 1024,
max_block_ms: Duration::from_secs(30),
in_flight_semaphore: Arc::new(Semaphore::new(5)),
interceptor: Arc::new(crate::interceptor::NoOpProducerInterceptor),
identity: None,
};
assert_eq!(config.batch_size, 65536);
assert_eq!(config.linger, Duration::from_millis(50));
assert_eq!(config.acks, 1);
assert_eq!(config.client_id, "test-client");
assert_eq!(config.max_request_size, 131072);
assert_eq!(config.buffer_memory, 64 * 1024 * 1024);
}
#[test]
fn test_estimate_record_size() {
let record = ProducerRecord::new("test-topic", b"value".to_vec());
let size = record.estimated_size();
assert!(size >= 5);
assert!(size > 64);
let record_with_key =
ProducerRecord::new("test-topic", b"value".to_vec()).with_key(b"key".to_vec());
let size_with_key = record_with_key.estimated_size();
assert!(size_with_key > size);
}
#[test]
fn test_linger_zero_check_interval() {
let linger = Duration::from_millis(0);
let check_interval = Duration::from_millis(1).max(linger / 10);
assert_eq!(check_interval, Duration::from_millis(1));
}
#[test]
fn test_linger_zero_is_zero() {
let config = AccumulatorConfig {
linger: Duration::from_millis(0),
..Default::default()
};
assert!(config.linger.is_zero());
}
#[test]
fn test_send_extracted_batch_is_send() {
fn assert_send<T: Send>() {}
assert_send::<std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>>();
}
#[tokio::test]
async fn test_backpressure_timeout_returns_timeout_error() {
let (sender, _receiver) = mpsc::channel::<AccumulatorMessage>(16);
let handle = RecordAccumulatorHandle {
sender,
memory_permits: Arc::new(Semaphore::new(0)),
memory_capacity: 1024 * 1024, max_block_ms: Duration::from_millis(50),
in_flight_barrier: Arc::new(InFlightBarrier::new()),
};
let record = ProducerRecord::new("topic", b"value".to_vec());
let result = handle.append(record, 0).await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(
err_msg.contains("max_block"),
"expected max_block in error, got: {err_msg}"
);
assert!(
matches!(err, KrafkaError::Timeout { .. }),
"expected Timeout variant, got: {err:?}"
);
}
#[tokio::test]
async fn test_backpressure_unblocks_on_permit_release() {
let sem = Arc::new(Semaphore::new(0));
let s = sem.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
s.add_permits(128);
});
let result = tokio::time::timeout(Duration::from_secs(2), sem.acquire_many(64)).await;
assert!(result.is_ok(), "acquire_many should have completed");
assert!(
result.unwrap().is_ok(),
"acquire_many should have succeeded"
);
}
#[tokio::test]
async fn test_oversize_record_rejected_immediately() {
let (sender, _receiver) = mpsc::channel::<AccumulatorMessage>(16);
let handle = RecordAccumulatorHandle {
sender,
memory_permits: Arc::new(Semaphore::new(16)),
memory_capacity: 16, max_block_ms: Duration::from_secs(60),
in_flight_barrier: Arc::new(InFlightBarrier::new()),
};
let record = ProducerRecord::new("topic", vec![0u8; 1024]);
let start = std::time::Instant::now();
let result = handle.append(record, 0).await;
assert!(start.elapsed() < Duration::from_secs(1));
let err = result.expect_err("oversize record must be rejected");
assert!(
err.to_string().contains("buffer_memory"),
"expected buffer_memory error, got: {err}"
);
}
#[tokio::test]
async fn test_closed_semaphore_unblocks_waiters() {
let (sender, _receiver) = mpsc::channel::<AccumulatorMessage>(16);
let sem = Arc::new(Semaphore::new(0));
let handle = RecordAccumulatorHandle {
sender,
memory_permits: sem.clone(),
memory_capacity: 1024 * 1024,
max_block_ms: Duration::from_secs(60),
in_flight_barrier: Arc::new(InFlightBarrier::new()),
};
let sem_close = sem.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
sem_close.close();
});
let record = ProducerRecord::new("topic", b"value".to_vec());
let start = std::time::Instant::now();
let result = handle.append(record, 0).await;
assert!(
start.elapsed() < Duration::from_secs(1),
"must unblock on close, not on max_block timeout"
);
let err = result.expect_err("closed semaphore must surface as error");
assert!(
matches!(err, KrafkaError::InvalidState { .. }),
"expected InvalidState variant, got: {err:?}"
);
}
#[tokio::test]
async fn test_permits_released_when_append_message_dropped() {
let (sender, mut receiver) = mpsc::channel::<AccumulatorMessage>(16);
let sem = Arc::new(Semaphore::new(1024));
let handle = RecordAccumulatorHandle {
sender,
memory_permits: sem.clone(),
memory_capacity: 1024,
max_block_ms: Duration::from_millis(500),
in_flight_barrier: Arc::new(InFlightBarrier::new()),
};
let record = ProducerRecord::new("topic", vec![0u8; 256]);
let append_fut = tokio::spawn(async move { handle.append(record, 0).await });
let msg = tokio::time::timeout(Duration::from_secs(2), receiver.recv())
.await
.expect("timed out waiting for Append message to arrive in channel")
.expect("channel closed before message arrived");
drop(msg);
drop(receiver);
let _ = append_fut.await;
assert_eq!(
sem.available_permits(),
1024,
"permits leaked when the Append message was dropped"
);
}
#[test]
fn test_check_record_admission_rejects_oversized_for_buffer() {
let err = check_record_admission(1024, 16).expect_err("must reject");
let msg = err.to_string();
assert!(
msg.contains("buffer_memory"),
"error must cite buffer_memory, got: {msg}"
);
assert!(
!msg.contains("u32::MAX"),
"must not cite u32::MAX for a buffer_memory rejection, got: {msg}"
);
}
#[test]
fn test_check_record_admission_rejects_oversized_for_u32_max() {
let oversized = u32::MAX as usize + 1;
let err = check_record_admission(oversized, usize::MAX).expect_err("must reject");
let msg = err.to_string();
assert!(
msg.contains("u32::MAX"),
"error must cite u32::MAX, got: {msg}"
);
assert!(
!msg.contains("buffer_memory"),
"must not cite buffer_memory for a u32::MAX rejection, got: {msg}"
);
}
}