mod accumulator;
mod barrier;
mod batch;
mod config;
mod idempotent;
mod partitioner;
mod record;
mod retry;
mod transaction;
pub use accumulator::{AccumulatorConfig, RecordAccumulator, RecordAccumulatorHandle};
pub use batch::ProducerBatch;
pub use config::{Acks, ProducerConfig, ProducerConfigBuilder};
pub use idempotent::{PartitionSequenceSnapshot, ProducerIdentity, ProducerIdentitySnapshot};
pub use partitioner::{
DefaultPartitioner, HashPartitioner, Partitioner, RoundRobinPartitioner, StickyPartitioner,
UniformStickyPartitioner, murmur2,
};
pub use record::{ProducerRecord, RecordMetadata};
pub use retry::{RetryContext, RetryPolicy};
pub use transaction::{
TopicPartitionOffset, TransactionState, TransactionalProducer, TransactionalProducerBuilder,
TransactionalProducerConfig,
};
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::time::{Duration, Instant};
use bytes::Bytes;
use tokio::sync::Semaphore;
use tracing::{debug, info, warn};
use crate::PartitionId;
use crate::auth::AuthConfig;
use crate::error::{ErrorCode, KrafkaError, ProtocolErrorKind, Result};
use crate::metadata::ClusterMetadata;
use crate::metrics::{ConnectionMetrics, ProducerMetrics as ProducerMetricsInner};
use crate::network::{ConnectionConfig, ConnectionPool};
use crate::protocol::{
ApiKey, Compression, InitProducerIdRequest, InitProducerIdResponse, ProducePartitionData,
ProduceRequest, ProduceResponse, ProduceTopicData, RecordBatchBuilder, VersionedDecode,
VersionedEncode, versions,
};
use crate::schema_registry::SchemaEncoder;
use self::barrier::{InFlightBarrier, InFlightOpGuard};
use self::record::{RoutedRecord, TopicHandle};
struct SendMemoryReservation {
bytes: usize,
memory_permits: Arc<Semaphore>,
_buffered_record_guard: accumulator::BufferedRecordGuard,
}
impl Drop for SendMemoryReservation {
fn drop(&mut self) {
self.memory_permits.add_permits(self.bytes);
}
}
pub struct Producer {
config: ProducerConfig,
metadata: Arc<ClusterMetadata>,
pool: Arc<ConnectionPool>,
partitioner: Arc<dyn Partitioner>,
accumulator: Option<RecordAccumulatorHandle>,
in_flight_barrier: Arc<InFlightBarrier>,
retry_policy: RetryPolicy,
metrics: Arc<ProducerMetricsInner>,
memory_permits: Arc<Semaphore>,
memory_capacity: usize,
max_request_size: usize,
buffered_records: Arc<AtomicUsize>,
in_flight_semaphore: Arc<Semaphore>,
interceptor: Arc<dyn crate::interceptor::ProducerInterceptor>,
identity: Option<Arc<ProducerIdentity>>,
key_encoder: Option<Arc<dyn SchemaEncoder>>,
value_encoder: Option<Arc<dyn SchemaEncoder>>,
}
fn is_unknown_producer_id_error(error: &KrafkaError) -> bool {
matches!(
error,
KrafkaError::Broker {
code: ErrorCode::UnknownProducerId,
..
}
)
}
async fn init_idempotent_producer_id(
identity: &ProducerIdentity,
metadata: &ClusterMetadata,
retry_policy: &RetryPolicy,
) -> Result<()> {
let started_at = Instant::now();
for attempt in 0..=retry_policy.max_retries {
if let Some(deadline) = retry_policy.delivery_timeout
&& started_at.elapsed() >= deadline
{
return Err(KrafkaError::timeout("InitProducerId"));
}
if attempt > 0 {
let mut backoff = retry_policy.calculate_backoff(attempt);
if let Some(deadline) = retry_policy.delivery_timeout {
let elapsed = started_at.elapsed();
if elapsed >= deadline {
return Err(KrafkaError::timeout("InitProducerId"));
}
backoff = backoff.min(deadline.saturating_sub(elapsed));
}
if !backoff.is_zero() {
tokio::time::sleep(backoff).await;
}
}
if let Some(deadline) = retry_policy.delivery_timeout
&& started_at.elapsed() >= deadline
{
return Err(KrafkaError::timeout("InitProducerId"));
}
let brokers = metadata.brokers();
if brokers.is_empty() {
if attempt < retry_policy.max_retries {
warn!(attempt, "No brokers available for InitProducerId, retrying");
continue;
}
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::Malformed,
"no brokers available for InitProducerId",
));
}
let broker = &brokers[attempt as usize % brokers.len()];
let conn = match metadata.get_broker_connection(broker.id()).await {
Ok(connection) => connection,
Err(error) if error.is_retriable() && attempt < retry_policy.max_retries => {
warn!(
attempt,
error = %error,
"Connection failed for InitProducerId, retrying"
);
continue;
}
Err(error) => return Err(error),
};
let ip_version = match conn
.negotiate_api_version(
ApiKey::InitProducerId,
versions::INIT_PRODUCER_ID_MAX,
versions::INIT_PRODUCER_ID_MIN,
)
.await
{
Some(version) => version,
None => {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::UnknownApiVersion,
"no mutually supported InitProducerId API version",
));
}
};
let request = InitProducerIdRequest::idempotent();
let response_bytes = match conn
.send_request(ApiKey::InitProducerId, ip_version, |buf| {
request.encode_versioned(ip_version, buf)
})
.await
{
Ok(bytes) => bytes,
Err(error) if error.is_retriable() && attempt < retry_policy.max_retries => {
warn!(
attempt,
error = %error,
"InitProducerId request failed, retrying"
);
continue;
}
Err(error) => return Err(error),
};
let mut buf = response_bytes;
let response = InitProducerIdResponse::decode_versioned(ip_version, &mut buf)?;
if response.is_ok() {
identity.initialize(response.producer_id, response.producer_epoch);
info!(
"Idempotent producer initialized: PID={}, epoch={}",
response.producer_id, response.producer_epoch
);
return Ok(());
}
if response.error_code.is_retriable() && attempt < retry_policy.max_retries {
warn!(
error_code = ?response.error_code,
attempt,
"InitProducerId returned retriable error, retrying"
);
} else {
return Err(KrafkaError::broker(
response.error_code,
"failed to initialize producer ID",
));
}
}
Err(KrafkaError::protocol_kind(
ProtocolErrorKind::Malformed,
format!(
"InitProducerId retry loop exhausted after {} retries",
retry_policy.max_retries
),
))
}
async fn ensure_idempotent_producer_id_initialized(
identity: &ProducerIdentity,
metadata: &ClusterMetadata,
retry_policy: &RetryPolicy,
) -> Result<()> {
if identity.is_poisoned() {
return Err(KrafkaError::invalid_state(
"producer identity is poisoned after an unrecoverable UnknownProducerId; recreate the producer",
));
}
if identity.is_initialized() {
return Ok(());
}
init_idempotent_producer_id(identity, metadata, retry_policy).await
}
async fn recover_unknown_producer_id(
identity: &ProducerIdentity,
metadata: &ClusterMetadata,
retry_policy: &RetryPolicy,
topic: &str,
partition: PartitionId,
base_sequence: i32,
record_count: i32,
) -> Result<i32> {
if identity.is_poisoned() {
return Err(KrafkaError::invalid_state(
"producer identity is poisoned after an unrecoverable UnknownProducerId; recreate the producer",
));
}
if !identity.check_and_reset_if_retryable(topic, partition, base_sequence, record_count)? {
identity.poison();
return Err(KrafkaError::invalid_state(format!(
"UnknownProducerId for {topic}-{partition} cannot be retried safely while newer batches are still in flight; producer identity poisoned, recreate the producer after in-flight work drains"
)));
}
init_idempotent_producer_id(identity, metadata, retry_policy).await?;
identity.allocate_sequence(topic, partition, record_count)
}
fn unsigned_varint_size(mut value: u32) -> usize {
let mut size = 1;
while value >= 0x80 {
size += 1;
value >>= 7;
}
size
}
fn kafka_string_size(value: Option<&str>) -> Result<usize> {
match value {
Some(value) => {
i16::try_from(value.len()).map_err(|_| {
KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
format!(
"KafkaString length {} exceeds protocol limit of {}",
value.len(),
i16::MAX
),
)
})?;
Ok(2 + value.len())
}
None => Ok(2),
}
}
fn compact_kafka_string_size(value: Option<&str>) -> Result<usize> {
match value {
Some(value) => {
let len_plus_one = u32::try_from(value.len().saturating_add(1)).map_err(|_| {
KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
format!(
"compact KafkaString length {} exceeds u32 limit",
value.len()
),
)
})?;
Ok(unsigned_varint_size(len_plus_one) + value.len())
}
None => Ok(unsigned_varint_size(0)),
}
}
fn kafka_bytes_size(len: usize) -> Result<usize> {
i32::try_from(len).map_err(|_| {
KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
format!(
"KafkaBytes length {} exceeds protocol limit of {}",
len,
i32::MAX
),
)
})?;
Ok(4 + len)
}
fn compact_kafka_bytes_size(len: usize) -> Result<usize> {
let len_plus_one = u32::try_from(len.saturating_add(1)).map_err(|_| {
KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
format!("compact KafkaBytes length {} exceeds u32 limit", len),
)
})?;
Ok(unsigned_varint_size(len_plus_one) + len)
}
fn array_len_size(len: usize) -> Result<usize> {
i32::try_from(len).map_err(|_| {
KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
format!("Kafka array length {len} exceeds protocol limit"),
)
})?;
Ok(4)
}
fn compact_array_len_size(len: usize) -> Result<usize> {
let len_plus_one = u32::try_from(len.saturating_add(1)).map_err(|_| {
KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
format!("compact Kafka array length {len} exceeds u32 limit"),
)
})?;
Ok(unsigned_varint_size(len_plus_one))
}
fn request_header_size(api_key: ApiKey, api_version: i16, client_id: &str) -> Result<usize> {
let base = 2 + 2 + 4 + kafka_string_size(Some(client_id))?;
match crate::protocol::RequestHeader::header_version(api_key, api_version) {
1 => Ok(base),
2 => Ok(base + 1),
version => Err(KrafkaError::protocol_kind(
ProtocolErrorKind::UnknownApiVersion,
format!("unsupported request header version {version}"),
)),
}
}
fn produce_request_body_size(api_version: i16, request: &ProduceRequest) -> Result<usize> {
let mut size = match api_version {
3..=8 => kafka_string_size(request.transactional_id.as_deref())? + 2 + 4,
9..=13 | 14.. => compact_kafka_string_size(request.transactional_id.as_deref())? + 2 + 4,
_ => {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::UnknownApiVersion,
format!("unsupported ProduceRequest version {api_version}"),
));
}
};
match api_version {
3..=8 => {
size += array_len_size(request.topic_data.len())?;
for topic in &request.topic_data {
size += kafka_string_size(Some(&topic.name))?;
size += array_len_size(topic.partition_data.len())?;
for partition in &topic.partition_data {
size += 4 + kafka_bytes_size(partition.records.len())?;
}
}
}
9..=12 => {
size += compact_array_len_size(request.topic_data.len())?;
for topic in &request.topic_data {
size += compact_kafka_string_size(Some(&topic.name))?;
size += compact_array_len_size(topic.partition_data.len())?;
for partition in &topic.partition_data {
size += 4 + compact_kafka_bytes_size(partition.records.len())? + 1;
}
size += 1;
}
size += 1;
}
13 => {
size += compact_array_len_size(request.topic_data.len())?;
for topic in &request.topic_data {
if topic.topic_id.is_none() {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidValue,
"topic_id is required for Produce v13+ (KIP-516)",
));
}
size += 16;
size += compact_array_len_size(topic.partition_data.len())?;
for partition in &topic.partition_data {
size += 4 + compact_kafka_bytes_size(partition.records.len())? + 1;
}
size += 1;
}
size += 1;
}
v @ 14.. => {
use std::sync::atomic::{AtomicBool, Ordering};
static WARNED: AtomicBool = AtomicBool::new(false);
if !WARNED.swap(true, Ordering::AcqRel) {
tracing::warn!(
api_version = v,
"ProduceRequest version {v} exceeds max known version 13; \
using v13 compact encoding as a best-effort fallback. \
Update produce_request_body_size() to add explicit v{v}+ support."
);
}
size += compact_array_len_size(request.topic_data.len())?;
for topic in &request.topic_data {
if topic.topic_id.is_none() {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidValue,
"topic_id is required for Produce v14+ (KIP-516)",
));
}
size += 16;
size += compact_array_len_size(topic.partition_data.len())?;
for partition in &topic.partition_data {
size += 4 + compact_kafka_bytes_size(partition.records.len())? + 1;
}
size += 1;
}
size += 1;
}
_ => unreachable!("validated above"),
}
Ok(size)
}
fn produce_request_frame_size(
client_id: &str,
api_version: i16,
request: &ProduceRequest,
) -> Result<usize> {
Ok(
4 + request_header_size(ApiKey::Produce, api_version, client_id)?
+ produce_request_body_size(api_version, request)?,
)
}
pub(crate) fn fill_produce_topic_ids(
request: &mut ProduceRequest,
metadata: &ClusterMetadata,
) -> bool {
let mut all_resolved = true;
for topic_data in &mut request.topic_data {
if topic_data.topic_id.is_none() {
if let Some(id) = metadata.topic_id_for_name(&topic_data.name) {
topic_data.topic_id = Some(id);
} else {
all_resolved = false;
}
}
}
all_resolved
}
fn validate_produce_request_size(
client_id: &str,
max_request_size: usize,
api_version: i16,
request: &ProduceRequest,
) -> Result<()> {
let frame_size = produce_request_frame_size(client_id, api_version, request)?;
if frame_size > max_request_size {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
format!(
"produce request size {frame_size} exceeds max_request_size {max_request_size}"
),
));
}
Ok(())
}
impl Producer {
pub fn builder() -> ProducerBuilder {
ProducerBuilder::default()
}
async fn reserve_send_memory(&self, record_size: usize) -> Result<SendMemoryReservation> {
accumulator::check_record_admission(
record_size,
self.memory_capacity,
self.max_request_size,
)?;
let permit = match tokio::time::timeout(
self.config.max_block,
self.memory_permits.acquire_many(record_size as u32),
)
.await
{
Ok(Ok(permit)) => permit,
Ok(Err(_)) => return Err(KrafkaError::invalid_state("producer memory gate closed")),
Err(_) => {
return Err(KrafkaError::timeout(
"producer send: max_block exceeded while waiting for buffer memory \
(ProducerConfig::max_block)",
));
}
};
permit.forget();
Ok(SendMemoryReservation {
bytes: record_size,
memory_permits: self.memory_permits.clone(),
_buffered_record_guard: accumulator::BufferedRecordGuard::new(
self.buffered_records.clone(),
self.metrics.clone(),
),
})
}
async fn new(
config: ProducerConfig,
interceptor: Arc<dyn crate::interceptor::ProducerInterceptor>,
partitioner: Option<Arc<dyn Partitioner>>,
key_encoder: Option<Arc<dyn SchemaEncoder>>,
value_encoder: Option<Arc<dyn SchemaEncoder>>,
shared: Option<(Arc<ConnectionPool>, Arc<crate::metadata::ClusterMetadata>)>,
) -> Result<Self> {
let (pool, metadata) = if let Some((pool, metadata)) = shared {
(pool, metadata)
} else {
let mut pool_config_builder = ConnectionConfig::builder()
.client_id(&config.client_id)
.request_timeout(config.request_timeout);
if let Some(ref auth) = config.auth {
pool_config_builder = pool_config_builder.auth(auth.clone());
}
#[cfg(feature = "socks5")]
if let Some(ref proxy) = config.proxy {
pool_config_builder = pool_config_builder.proxy(proxy.clone());
}
let mut pool_config = pool_config_builder.build()?;
pool_config.init_tls().await?;
let pool = Arc::new(ConnectionPool::new(pool_config));
pool.start_idle_evictor();
let bootstrap_servers =
crate::util::parse_bootstrap_servers(&config.bootstrap_servers)?;
let metadata = Arc::new({
let mut meta =
ClusterMetadata::new(bootstrap_servers, pool.clone(), config.metadata_max_age)
.with_recovery_strategy(config.metadata_recovery_strategy)
.with_rebootstrap_trigger(config.metadata_recovery_rebootstrap_trigger);
if let Some(ttl) = config.metadata_topic_cache_ttl {
meta = meta.with_topic_cache_ttl(ttl);
} else {
meta = meta.with_topic_cache_ttl_disabled();
}
meta
});
metadata.refresh().await?;
info!(
"Producer initialized with {} brokers",
metadata.brokers().len()
);
(pool, metadata)
};
let init_retry_policy = RetryPolicy::new()
.with_max_retries(config.retries)
.with_initial_backoff(config.retry_backoff)
.with_max_backoff(Duration::from_secs(10))
.with_delivery_timeout(Some(config.delivery_timeout));
let identity = if config.idempotent {
let identity = Arc::new(ProducerIdentity::new());
init_idempotent_producer_id(&identity, &metadata, &init_retry_policy).await?;
Some(identity)
} else {
None
};
let partitioner: Arc<dyn Partitioner> = partitioner.unwrap_or_else(|| {
if config.linger > Duration::ZERO {
Arc::new(UniformStickyPartitioner::new())
} else {
Arc::new(DefaultPartitioner::new())
}
});
let retry_policy = RetryPolicy::new()
.with_max_retries(config.retries)
.with_initial_backoff(config.retry_backoff)
.with_max_backoff(Duration::from_secs(30))
.with_delivery_timeout(Some(config.delivery_timeout));
let metrics = Arc::new(ProducerMetricsInner::default());
let memory_capacity = accumulator::effective_memory_capacity(config.buffer_memory);
let memory_permits = Arc::new(Semaphore::new(memory_capacity));
let buffered_records = Arc::new(AtomicUsize::new(0));
if config.buffer_memory == 0 {
warn!(
"buffer_memory=0 disables producer backpressure; \
memory usage is unbounded. Not recommended for production."
);
}
let in_flight_semaphore = Arc::new(Semaphore::new(config.max_in_flight));
let in_flight_barrier = Arc::new(InFlightBarrier::new());
let accumulator = if !config.linger.is_zero() {
let acc_config = accumulator::AccumulatorConfig {
batch_size: config.batch_size,
linger: config.linger,
compression: config.compression,
acks: config.acks.to_i16(),
client_id: config.client_id.clone(),
request_timeout: config.request_timeout,
max_request_size: config.max_request_size,
buffer_memory: config.buffer_memory,
max_block_ms: config.max_block,
in_flight_semaphore: in_flight_semaphore.clone(),
interceptor: interceptor.clone(),
identity: identity.clone(),
partitioner: partitioner.clone(),
};
Some(accumulator::RecordAccumulator::spawn(
acc_config,
metadata.clone(),
retry_policy.clone(),
metrics.clone(),
in_flight_barrier.clone(),
))
} else {
None
};
Ok(Self {
config: config.clone(),
metadata,
pool,
partitioner,
accumulator,
in_flight_barrier,
retry_policy,
metrics,
memory_permits,
memory_capacity,
max_request_size: config.max_request_size,
buffered_records,
in_flight_semaphore,
interceptor,
identity,
key_encoder,
value_encoder,
})
}
pub async fn send(
&self,
topic: &str,
key: Option<&[u8]>,
value: &[u8],
) -> Result<RecordMetadata> {
let mut record = ProducerRecord::new(topic, Bytes::copy_from_slice(value));
if let Some(k) = key {
record = record.with_key(Bytes::copy_from_slice(k));
}
self.send_record(record).await
}
pub async fn send_with_headers(
&self,
topic: &str,
key: Option<&[u8]>,
value: &[u8],
headers: Vec<(String, Bytes)>,
) -> Result<RecordMetadata> {
let mut record = ProducerRecord::new(topic, Bytes::copy_from_slice(value));
if let Some(k) = key {
record = record.with_key(Bytes::copy_from_slice(k));
}
record.headers = headers;
self.send_record(record).await
}
pub async fn send_record(&self, record: ProducerRecord) -> Result<RecordMetadata> {
let operation_guard = self.in_flight_barrier.start("producer")?;
let mut record = record;
crate::interceptor::safe_on_send(&*self.interceptor, &mut record);
if let Some(enc) = &self.value_encoder {
record.value = enc
.encode(
record.value.clone(),
&record.topic,
record.record_name.as_deref(),
false,
)
.await?;
}
if let Some(enc) = &self.key_encoder {
let key = record.key.clone().unwrap_or_default();
record.key = Some(
enc.encode(key, &record.topic, record.record_name.as_deref(), true)
.await?,
);
}
record.validate()?;
let record_size = record.estimated_size();
let routed = record.into_routed_parts();
let topic = routed.topic;
let record = routed.record;
let partition = match routed.partition {
Some(p) => p,
None => {
let partition_count = self
.metadata
.partition_count(topic.as_ref())
.ok_or_else(|| KrafkaError::invalid_state(format!("unknown topic: {topic}")))?;
self.partitioner
.partition(topic.as_ref(), record.key_bytes(), partition_count)
}
};
if let Some(ref accumulator) = self.accumulator {
return accumulator
.append_routed_with_guard(topic, record, record_size, partition, operation_guard)
.await;
}
self.send_to_partition(topic, partition, record, record_size, operation_guard)
.await
}
async fn send_to_partition(
&self,
topic: TopicHandle,
partition: PartitionId,
record: RoutedRecord,
record_size: usize,
operation_guard: InFlightOpGuard,
) -> Result<RecordMetadata> {
let _operation_guard = operation_guard;
let _memory_reservation = self.reserve_send_memory(record_size).await?;
let topic_owned = topic.to_string();
if let Some(identity) = self.identity.as_ref() {
ensure_idempotent_producer_id_initialized(identity, &self.metadata, &self.retry_policy)
.await?;
}
let _permit = self
.in_flight_semaphore
.acquire()
.await
.map_err(|_| KrafkaError::invalid_state("in-flight semaphore closed"))?;
let mut sequence: Option<i32> = if let Some(ref identity) = self.identity {
match identity.checked_allocate_sequence(topic.as_ref(), partition, 1)? {
Some(seq) => Some(seq),
None => {
ensure_idempotent_producer_id_initialized(
identity,
&self.metadata,
&self.retry_policy,
)
.await?;
Some(identity.checked_allocate_sequence(topic.as_ref(), partition, 1)?.ok_or_else(|| {
KrafkaError::invalid_state(
"producer identity reset during sequence allocation; retry the send",
)
})?)
}
}
} else {
None
};
let mut request =
match self.build_produce_request(topic.as_ref(), partition, &record, sequence) {
Ok(r) => r,
Err(e) => {
if let Some(ref identity) = self.identity {
let _ = identity.rollback_sequence(topic.as_ref(), partition);
}
return Err(e);
}
};
let mut retry_ctx = RetryContext::new(
self.retry_policy.clone(),
format!("produce({topic}-{partition})"),
);
loop {
let result = self.do_send(topic.as_ref(), partition, &request).await;
let result = if let Err(KrafkaError::Broker { code, .. }) = &result
&& *code == ErrorCode::DuplicateSequenceNumber
&& self.identity.is_some()
{
debug!(
topic = %topic,
partition = partition,
"DuplicateSequenceNumber — dedup confirmed"
);
Ok(RecordMetadata {
topic: topic_owned.clone(),
partition,
offset: -1,
timestamp: -1,
})
} else {
result
};
match result {
Ok(metadata) => {
retry_ctx.record_success();
if let (Some(identity), Some(seq)) = (&self.identity, sequence) {
identity.acknowledge(topic.as_ref(), partition, seq);
}
self.metrics.record_send(record.payload_size_bytes());
self.metrics.connections.set(self.pool.len() as u64);
crate::interceptor::safe_on_acknowledgement(
&*self.interceptor,
&metadata,
None,
);
return Ok(metadata);
}
Err(e) => {
if is_unknown_producer_id_error(&e)
&& let (Some(identity), Some(current_sequence)) =
(self.identity.as_ref(), sequence)
{
warn!(
topic = %topic,
partition = partition,
"UnknownProducerId, reinitializing idempotent producer state"
);
let new_sequence = match recover_unknown_producer_id(
identity,
&self.metadata,
&self.retry_policy,
topic.as_ref(),
partition,
current_sequence,
1,
)
.await
{
Ok(new_sequence) => new_sequence,
Err(recovery_error) => {
self.metrics.record_error();
let dummy_metadata = RecordMetadata {
topic: topic_owned.clone(),
partition,
offset: -1,
timestamp: 0,
};
crate::interceptor::safe_on_acknowledgement(
&*self.interceptor,
&dummy_metadata,
Some(&recovery_error),
);
return Err(recovery_error);
}
};
sequence = Some(new_sequence);
match self.build_produce_request(
topic.as_ref(),
partition,
&record,
sequence,
) {
Ok(new_request) => request = new_request,
Err(build_error) => {
let _ = identity.rollback_sequence(topic.as_ref(), partition);
self.metrics.record_error();
let dummy_metadata = RecordMetadata {
topic: topic_owned.clone(),
partition,
offset: -1,
timestamp: 0,
};
crate::interceptor::safe_on_acknowledgement(
&*self.interceptor,
&dummy_metadata,
Some(&build_error),
);
return Err(build_error);
}
}
} else if let KrafkaError::Broker { code, .. } = &e
&& *code == ErrorCode::OutOfOrderSequenceNumber
&& let Some(ref identity) = self.identity
{
warn!(
topic = %topic,
partition = partition,
"OutOfOrderSequenceNumber, resetting sequence and rebuilding batch"
);
let new_seq =
match identity.reset_and_allocate(topic.as_ref(), partition, 1) {
Ok(s) => s,
Err(e) => {
self.metrics.record_error();
return Err(e);
}
};
sequence = Some(new_seq);
match self.build_produce_request(
topic.as_ref(),
partition,
&record,
sequence,
) {
Ok(r) => request = r,
Err(build_err) => {
let _ = identity.rollback_sequence(topic.as_ref(), partition);
self.metrics.record_error();
let dummy_metadata = RecordMetadata {
topic: topic_owned.clone(),
partition,
offset: -1,
timestamp: 0,
};
crate::interceptor::safe_on_acknowledgement(
&*self.interceptor,
&dummy_metadata,
Some(&build_err),
);
return Err(build_err);
}
}
} else if e.is_retriable() {
debug!(
topic = %topic,
partition = partition,
error = %e,
"Transient error, refreshing metadata"
);
if let Err(refresh_err) = self
.metadata
.refresh_for_topics(Some(&[topic.as_ref()]))
.await
{
debug!(error = %refresh_err, "Metadata refresh failed during retry");
}
}
if let Some(backoff) = retry_ctx.record_failure(&e) {
self.metrics.retries.inc();
retry_ctx.wait(backoff).await;
continue;
}
if let Some(ref identity) = self.identity {
let _ = identity.rollback_sequence(topic.as_ref(), partition);
}
self.metrics.record_error();
let dummy_metadata = RecordMetadata {
topic: topic_owned.clone(),
partition,
offset: -1,
timestamp: 0,
};
crate::interceptor::safe_on_acknowledgement(
&*self.interceptor,
&dummy_metadata,
Some(&e),
);
return Err(e);
}
}
}
}
fn build_produce_request(
&self,
topic: &str,
partition: PartitionId,
record: &RoutedRecord,
sequence: Option<i32>,
) -> Result<ProduceRequest> {
let mut batch_builder = RecordBatchBuilder::new().compression(self.config.compression);
if let Some(ts) = record.timestamp {
batch_builder = batch_builder.base_timestamp(ts);
}
if let (Some(identity), Some(seq)) = (&self.identity, sequence) {
batch_builder =
batch_builder.producer(identity.producer_id(), identity.producer_epoch(), seq);
}
batch_builder = record.append_to_batch_builder(batch_builder);
let batch = batch_builder.build();
let batch_bytes = batch.encode()?;
Ok(ProduceRequest {
transactional_id: None,
acks: self.config.acks.to_i16(),
timeout_ms: crate::util::duration_to_millis_i32(self.config.request_timeout),
topic_data: vec![ProduceTopicData {
name: topic.to_string(),
topic_id: None,
partition_data: vec![ProducePartitionData {
index: partition,
records: batch_bytes,
}],
}],
})
}
async fn do_send(
&self,
topic: &str,
partition: PartitionId,
request: &ProduceRequest,
) -> Result<RecordMetadata> {
let _timer = self.metrics.send_latency.start();
let conn = self
.metadata
.get_leader_connection(topic, partition)
.await?;
let mut version = conn
.negotiate_api_version(
ApiKey::Produce,
versions::PRODUCE_MAX,
versions::PRODUCE_MIN,
)
.await
.ok_or_else(|| {
KrafkaError::protocol_kind(
ProtocolErrorKind::UnknownApiVersion,
"no mutually supported Produce API version",
)
})?;
let mut owned_request;
let effective_request: &ProduceRequest = if version >= 13 {
owned_request = request.clone();
if !fill_produce_topic_ids(&mut owned_request, &self.metadata) {
version = 12;
request
} else {
&owned_request
}
} else {
request
};
validate_produce_request_size(
&self.config.client_id,
self.config.max_request_size,
version,
effective_request,
)?;
if self.config.acks == Acks::None {
conn.send_fire_and_forget(ApiKey::Produce, version, |buf| {
effective_request.encode_versioned(version, buf)
})
.await?;
return Ok(RecordMetadata {
topic: topic.to_string(),
partition,
offset: -1, timestamp: -1,
});
}
let response = conn
.send_request(ApiKey::Produce, version, |buf| {
effective_request.encode_versioned(version, buf)
})
.await?;
let mut buf = response;
let produce_response = ProduceResponse::decode_versioned(version, &mut buf)?;
for topic_response in &produce_response.responses {
for partition_response in &topic_response.partition_responses {
if partition_response.index == partition {
if !partition_response.error_code.is_ok() {
return Err(KrafkaError::broker(
partition_response.error_code,
format!("produce failed for {topic}-{partition}"),
));
}
return Ok(RecordMetadata {
topic: topic.to_string(),
partition,
offset: partition_response.base_offset,
timestamp: partition_response.log_append_time_ms,
});
}
}
}
Err(KrafkaError::protocol_kind(
ProtocolErrorKind::Malformed,
"partition not found in response",
))
}
pub async fn flush(&self) -> Result<()> {
let target = self.in_flight_barrier.snapshot();
if let Some(ref accumulator) = self.accumulator {
accumulator.flush().await?;
}
self.in_flight_barrier.wait_for(target).await;
Ok(())
}
pub fn update_seed_brokers(&self, servers: Vec<String>) -> Result<()> {
self.metadata.update_seed_brokers(servers)
}
pub async fn rebootstrap(&self) {
self.metadata.rebootstrap().await;
}
pub async fn close(&self) {
let _ = self.close_inner(None).await;
}
pub async fn close_with_timeout(&self, timeout: Duration) -> Result<()> {
self.close_inner(Some(timeout)).await
}
async fn close_inner(&self, timeout: Option<Duration>) -> Result<()> {
let Some(target) = self.in_flight_barrier.begin_close() else {
return Ok(());
};
let graceful_close = async {
if let Some(ref accumulator) = self.accumulator
&& let Err(e) = accumulator.shutdown().await
{
warn!("Accumulator shutdown error during close: {e}");
}
self.in_flight_barrier.wait_for(target).await;
};
let close_result = if let Some(timeout) = timeout {
tokio::time::timeout(timeout, graceful_close)
.await
.map_err(|_| {
warn!("Producer close timed out; batches in retry backoff may be lost");
KrafkaError::timeout("producer close")
})
} else {
graceful_close.await;
Ok(())
};
crate::interceptor::safe_producer_close(&*self.interceptor);
self.pool.close_all().await;
info!("Producer closed");
close_result
}
#[inline]
pub fn is_closed(&self) -> bool {
self.in_flight_barrier.is_closing()
}
pub async fn metrics(&self) -> ProducerMetricsSnapshot {
ProducerMetricsSnapshot {
connections: self.pool.len(),
records_sent: self.metrics.records_sent.get(),
bytes_sent: self.metrics.bytes_sent.get(),
errors: self.metrics.errors.get(),
retries: self.metrics.retries.get(),
buffered_records: self.metrics.buffered_records.get(),
}
}
#[inline]
pub fn metrics_handle(&self) -> Arc<ProducerMetricsInner> {
self.metrics.clone()
}
#[inline]
pub fn connection_metrics(&self) -> Arc<ConnectionMetrics> {
self.pool.metrics()
}
}
impl Drop for Producer {
fn drop(&mut self) {
if !self.in_flight_barrier.is_closing() && !std::thread::panicking() {
warn!(
"Producer dropped without close(); in-flight batches may be lost. \
Call `Producer::close()` (or `close_with_timeout`) before drop to flush."
);
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ProducerMetricsSnapshot {
pub connections: usize,
pub records_sent: u64,
pub bytes_sent: u64,
pub errors: u64,
pub retries: u64,
pub buffered_records: u64,
}
#[derive(Default)]
#[must_use = "builders do nothing until .build() is called"]
pub struct ProducerBuilder {
config: ProducerConfig,
interceptors: Vec<Arc<dyn crate::interceptor::ProducerInterceptor>>,
partitioner: Option<Arc<dyn Partitioner>>,
key_encoder: Option<Arc<dyn SchemaEncoder>>,
value_encoder: Option<Arc<dyn SchemaEncoder>>,
shared: Option<(Arc<ConnectionPool>, Arc<crate::metadata::ClusterMetadata>)>,
}
impl ProducerBuilder {
pub fn bootstrap_servers(mut self, servers: impl Into<String>) -> Self {
self.config.bootstrap_servers = servers.into();
self
}
pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
self.config.client_id = client_id.into();
self
}
pub fn acks(mut self, acks: Acks) -> Self {
self.config.acks = acks;
self
}
pub fn compression(mut self, compression: Compression) -> Self {
self.config.compression = compression;
self
}
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_size = size;
self
}
pub fn linger(mut self, duration: Duration) -> Self {
self.config.linger = duration;
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.config.request_timeout = timeout;
self
}
pub fn delivery_timeout(mut self, timeout: Duration) -> Self {
self.config.delivery_timeout = timeout;
self
}
pub fn retries(mut self, retries: u32) -> Self {
self.config.retries = retries;
self
}
pub fn retry_backoff(mut self, backoff: Duration) -> Self {
self.config.retry_backoff = backoff;
self
}
pub fn max_in_flight(mut self, max: usize) -> Self {
self.config.max_in_flight = max;
self
}
pub fn max_request_size(mut self, bytes: usize) -> Self {
self.config.max_request_size = bytes;
self
}
pub fn metadata_max_age(mut self, duration: Duration) -> Self {
self.config.metadata_max_age = duration;
self
}
pub fn metadata_topic_cache_ttl(mut self, ttl: Duration) -> Self {
self.config.metadata_topic_cache_ttl = Some(ttl);
self
}
pub fn disable_metadata_topic_cache_ttl(mut self) -> Self {
self.config.metadata_topic_cache_ttl = None;
self
}
pub fn idempotent(mut self, enable: bool) -> Self {
self.config.idempotent = enable;
self
}
pub fn auth(mut self, auth: AuthConfig) -> Self {
self.config.auth = Some(auth);
self
}
pub fn sasl_plain(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> crate::Result<Self> {
self.config.auth = Some(AuthConfig::sasl_plain(username, password)?);
Ok(self)
}
pub fn sasl_scram_sha256(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.config.auth = Some(AuthConfig::sasl_scram_sha256(username, password));
self
}
pub fn sasl_scram_sha512(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.config.auth = Some(AuthConfig::sasl_scram_sha512(username, password));
self
}
pub fn sasl_oauthbearer(mut self, token: impl Into<String>) -> Self {
self.config.auth = Some(AuthConfig::sasl_oauthbearer(token));
self
}
pub fn sasl_oauthbearer_provider(
mut self,
provider: impl crate::auth::OAuthBearerTokenProvider + 'static,
) -> Self {
self.config.auth = Some(AuthConfig::sasl_oauthbearer_provider(provider));
self
}
pub fn partitioner(mut self, partitioner: impl Partitioner + 'static) -> Self {
self.partitioner = Some(Arc::new(partitioner));
self
}
pub fn interceptor(
mut self,
interceptor: Arc<dyn crate::interceptor::ProducerInterceptor>,
) -> Self {
self.interceptors = vec![interceptor];
self
}
pub fn add_interceptor(
mut self,
interceptor: Arc<dyn crate::interceptor::ProducerInterceptor>,
) -> Self {
self.interceptors.push(interceptor);
self
}
pub fn key_encoder(mut self, encoder: Arc<dyn SchemaEncoder>) -> Self {
self.key_encoder = Some(encoder);
self
}
pub fn value_encoder(mut self, encoder: Arc<dyn SchemaEncoder>) -> Self {
self.value_encoder = Some(encoder);
self
}
pub fn with_client(mut self, client: &crate::client::KrafkaClient) -> Self {
self.shared = Some((client.pool().clone(), client.metadata().clone()));
self
}
pub async fn build(mut self) -> Result<Producer> {
if self.shared.is_none() && self.config.bootstrap_servers.is_empty() {
return Err(KrafkaError::config("bootstrap.servers is required"));
}
if self.config.max_in_flight == 0 {
return Err(KrafkaError::config(format!(
"max_in_flight must be >= 1 (got {})",
self.config.max_in_flight
)));
}
if self.config.max_request_size == 0 {
return Err(KrafkaError::config("max_request_size must be >= 1"));
}
if self.config.batch_size == 0 {
return Err(KrafkaError::config(format!(
"batch_size must be >= 1 (got {})",
self.config.batch_size
)));
}
if self.config.delivery_timeout.is_zero() {
return Err(KrafkaError::config(
"delivery_timeout must be greater than zero",
));
}
if self.config.idempotent {
if self.config.retries == 0 {
return Err(KrafkaError::config(
"idempotent producer requires retries > 0",
));
}
if self.config.acks != Acks::All {
return Err(KrafkaError::config(format!(
"idempotent producer requires acks = All (got {:?})",
self.config.acks
)));
}
if self.config.max_in_flight > 5 {
tracing::info!(
configured = self.config.max_in_flight,
effective = 5,
"idempotent producer requires max_in_flight ≤ 5; capping automatically"
);
self.config.max_in_flight = 5;
}
}
if self.config.buffer_memory > 0 && self.config.batch_size > self.config.buffer_memory {
return Err(KrafkaError::config(format!(
"batch_size must not exceed buffer_memory (got batch_size={}, buffer_memory={})",
self.config.batch_size, self.config.buffer_memory
)));
}
if self.config.batch_size > self.config.max_request_size {
return Err(KrafkaError::config(format!(
"batch_size must not exceed max_request_size (got batch_size={}, max_request_size={})",
self.config.batch_size, self.config.max_request_size
)));
}
let interceptor: Arc<dyn crate::interceptor::ProducerInterceptor> =
if self.interceptors.is_empty() {
Arc::new(crate::interceptor::NoOpProducerInterceptor)
} else if self.interceptors.len() == 1 {
let Some(single) = self.interceptors.into_iter().next() else {
unreachable!("len == 1 verified above");
};
single
} else {
Arc::new(crate::interceptor::ProducerInterceptorChain::new(
self.interceptors,
))
};
let producer = Producer::new(
self.config,
interceptor,
self.partitioner,
self.key_encoder,
self.value_encoder,
self.shared,
)
.await?;
Ok(producer)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use crate::metadata::ClusterMetadata;
use crate::network::{ConnectionConfig, ConnectionPool};
#[test]
fn test_producer_builder() {
let builder = Producer::builder()
.bootstrap_servers("localhost:9092")
.client_id("test")
.acks(Acks::All)
.compression(Compression::Gzip)
.batch_size(32768)
.max_request_size(65536)
.linger(Duration::from_millis(10));
assert_eq!(builder.config.bootstrap_servers, "localhost:9092");
assert_eq!(builder.config.client_id, "test");
assert_eq!(builder.config.acks, Acks::All);
assert_eq!(builder.config.compression, Compression::Gzip);
assert_eq!(builder.config.batch_size, 32768);
assert_eq!(builder.config.max_request_size, 65536);
assert_eq!(builder.config.linger, Duration::from_millis(10));
assert!(builder.config.auth.is_none());
}
#[test]
fn test_producer_builder_with_auth() {
let builder = Producer::builder()
.bootstrap_servers("broker:9093")
.auth(AuthConfig::sasl_plain("user", "pass").unwrap());
let auth = builder.config.auth.as_ref().unwrap();
assert!(auth.requires_sasl());
assert!(!auth.requires_tls());
assert_eq!(
auth.security_protocol,
crate::auth::SecurityProtocol::SaslPlaintext
);
assert_eq!(auth.sasl_mechanism, Some(crate::auth::SaslMechanism::Plain));
}
#[test]
fn test_producer_builder_aws_msk_iam() {
let auth = AuthConfig::aws_msk_iam("AKID", "secret", "us-east-1");
let builder = Producer::builder()
.bootstrap_servers("broker:9094")
.auth(auth);
let auth = builder.config.auth.as_ref().unwrap();
assert!(auth.requires_tls());
assert!(auth.requires_sasl());
assert_eq!(
auth.sasl_mechanism,
Some(crate::auth::SaslMechanism::AwsMskIam)
);
assert!(auth.aws_msk_iam_credentials.is_some());
assert!(auth.tls_config.is_some());
}
#[test]
fn test_producer_builder_no_auth_by_default() {
let builder = Producer::builder().bootstrap_servers("broker:9092");
assert!(builder.config.auth.is_none());
}
#[test]
fn test_producer_builder_sasl_plain() {
let builder = Producer::builder()
.bootstrap_servers("broker:9093")
.sasl_plain("user", "pass")
.unwrap();
let auth = builder.config.auth.as_ref().unwrap();
assert!(auth.requires_sasl());
assert!(auth.plain_credentials.is_some());
}
#[test]
fn test_producer_builder_sasl_scram() {
let builder = Producer::builder()
.bootstrap_servers("broker:9093")
.sasl_scram_sha256("user", "pass");
let auth = builder.config.auth.as_ref().unwrap();
assert!(auth.requires_sasl());
assert!(auth.scram_credentials.is_some());
let builder = Producer::builder()
.bootstrap_servers("broker:9093")
.sasl_scram_sha512("user", "pass");
let auth = builder.config.auth.as_ref().unwrap();
assert!(auth.requires_sasl());
assert!(auth.scram_credentials.is_some());
}
#[tokio::test]
async fn test_producer_builder_no_servers() {
let result = Producer::builder().build().await;
assert!(result.is_err());
}
#[test]
fn test_producer_builder_retry_config() {
let builder = Producer::builder()
.bootstrap_servers("localhost:9092")
.retries(5)
.retry_backoff(Duration::from_millis(200));
assert_eq!(builder.config.retries, 5);
assert_eq!(builder.config.retry_backoff, Duration::from_millis(200));
}
#[test]
fn test_validate_produce_request_size_rejects_oversized_frame() {
let request = ProduceRequest {
transactional_id: None,
acks: Acks::All.to_i16(),
timeout_ms: 30_000,
topic_data: vec![ProduceTopicData {
name: "topic".to_string(),
topic_id: None,
partition_data: vec![ProducePartitionData {
index: 0,
records: Bytes::from(vec![0; 512]),
}],
}],
};
let error = validate_produce_request_size("client", 128, versions::PRODUCE_MIN, &request)
.expect_err("oversized frame should be rejected");
assert!(error.to_string().contains("max_request_size"));
}
#[test]
fn test_validate_produce_request_size_uses_exact_flexible_encoding_size() {
let request = ProduceRequest {
transactional_id: Some("txn-123".to_string()),
acks: Acks::All.to_i16(),
timeout_ms: 30_000,
topic_data: vec![ProduceTopicData {
name: "topic".to_string(),
topic_id: Some([0u8; 16]),
partition_data: vec![ProducePartitionData {
index: 0,
records: Bytes::from(vec![1; 32]),
}],
}],
};
let exact_size =
produce_request_frame_size("client", versions::PRODUCE_MAX, &request).unwrap();
validate_produce_request_size("client", exact_size, versions::PRODUCE_MAX, &request)
.unwrap();
let error = validate_produce_request_size(
"client",
exact_size.saturating_sub(1),
versions::PRODUCE_MAX,
&request,
)
.unwrap_err();
assert!(error.to_string().contains("max_request_size"));
}
#[test]
fn test_validate_produce_request_size_v13_requires_topic_id() {
let request = ProduceRequest {
transactional_id: None,
acks: Acks::All.to_i16(),
timeout_ms: 30_000,
topic_data: vec![ProduceTopicData {
name: "topic".to_string(),
topic_id: None,
partition_data: vec![ProducePartitionData {
index: 0,
records: Bytes::from_static(b"payload"),
}],
}],
};
let error = validate_produce_request_size("client", 1024, 13, &request).unwrap_err();
assert!(error.to_string().contains("topic_id is required"));
}
#[tokio::test]
async fn test_recover_unknown_producer_id_poisoned_when_newer_batches_in_flight() {
let identity = ProducerIdentity::new();
identity.initialize(7, 1);
assert_eq!(identity.allocate_sequence("topic", 0, 2).unwrap(), 0);
assert_eq!(identity.allocate_sequence("topic", 0, 1).unwrap(), 2);
let pool = Arc::new(ConnectionPool::new(ConnectionConfig::default()));
let metadata = ClusterMetadata::new(
vec!["localhost:9092".to_string()],
pool,
Duration::from_secs(300),
);
let retry_policy = RetryPolicy::default();
let error =
recover_unknown_producer_id(&identity, &metadata, &retry_policy, "topic", 0, 0, 2)
.await
.unwrap_err();
assert!(error.to_string().contains("poisoned"));
assert!(identity.is_initialized());
assert!(identity.is_poisoned());
assert_eq!(identity.producer_id(), 7);
assert_eq!(identity.peek_sequence("topic", 0), 3);
let ensure_error =
ensure_idempotent_producer_id_initialized(&identity, &metadata, &retry_policy)
.await
.unwrap_err();
assert!(ensure_error.to_string().contains("poisoned"));
}
#[test]
fn test_producer_metrics_snapshot() {
let snapshot = ProducerMetricsSnapshot {
connections: 3,
records_sent: 100,
bytes_sent: 50000,
errors: 2,
retries: 5,
buffered_records: 7,
};
assert_eq!(snapshot.connections, 3);
assert_eq!(snapshot.records_sent, 100);
assert_eq!(snapshot.bytes_sent, 50000);
assert_eq!(snapshot.errors, 2);
assert_eq!(snapshot.retries, 5);
assert_eq!(snapshot.buffered_records, 7);
}
#[tokio::test]
async fn test_direct_send_rejects_record_larger_than_buffer_memory() {
let pool = Arc::new(ConnectionPool::new(ConnectionConfig::default()));
let metadata = Arc::new(ClusterMetadata::new(
vec!["localhost:9092".to_string()],
pool.clone(),
Duration::from_secs(300),
));
let metrics = Arc::new(ProducerMetricsInner::default());
let producer = Producer {
config: ProducerConfig {
buffer_memory: 16,
..ProducerConfig::default()
},
metadata,
pool,
partitioner: Arc::new(DefaultPartitioner::new()),
accumulator: None,
in_flight_barrier: Arc::new(InFlightBarrier::new()),
retry_policy: RetryPolicy::default(),
metrics: metrics.clone(),
memory_permits: Arc::new(Semaphore::new(16)),
memory_capacity: 16,
max_request_size: 0,
buffered_records: Arc::new(AtomicUsize::new(0)),
in_flight_semaphore: Arc::new(Semaphore::new(1)),
interceptor: Arc::new(crate::interceptor::NoOpProducerInterceptor),
identity: None,
key_encoder: None,
value_encoder: None,
};
let record = RoutedRecord {
key: None,
value: Bytes::from(vec![0u8; 1024]),
timestamp: None,
headers: Vec::new(),
};
let err = producer
.send_to_partition(
Arc::<str>::from("topic"),
0,
record,
1024,
producer.in_flight_barrier.start("producer").unwrap(),
)
.await
.expect_err("direct send must reject records larger than buffer_memory");
assert!(err.to_string().contains("buffer_memory"));
assert_eq!(metrics.buffered_records.get(), 0);
}
#[test]
fn test_retry_policy_from_config() {
let policy = RetryPolicy::new()
.with_max_retries(10)
.with_initial_backoff(Duration::from_millis(50))
.with_max_backoff(Duration::from_secs(30));
assert_eq!(policy.max_retries, 10);
assert_eq!(policy.initial_backoff(), Duration::from_millis(50));
assert_eq!(policy.max_backoff(), Duration::from_secs(30));
}
#[test]
fn test_producer_config_max_in_flight_default() {
let config = ProducerConfig::default();
assert!(config.max_in_flight > 0);
}
#[test]
fn test_acks_none_returns_fire_and_forget_metadata() {
let builder = Producer::builder()
.bootstrap_servers("localhost:9092")
.acks(Acks::None);
assert_eq!(builder.config.acks, Acks::None);
assert_eq!(builder.config.acks.to_i16(), 0);
}
#[test]
fn test_idempotent_builder() {
let builder = Producer::builder().bootstrap_servers("broker:9092");
assert!(builder.config.idempotent);
let builder = Producer::builder()
.bootstrap_servers("broker:9092")
.idempotent(false);
assert!(!builder.config.idempotent);
}
#[tokio::test]
async fn test_idempotent_requires_acks_all() {
let builder = Producer::builder()
.bootstrap_servers("localhost:9092")
.acks(Acks::Leader)
.idempotent(true);
let result = builder.build().await;
match result {
Err(e) => assert!(e.to_string().contains("acks")),
Ok(_) => panic!("expected config error for idempotent with acks != All"),
}
}
#[test]
fn test_idempotent_autocaps_max_in_flight() {
let cfg = ProducerConfig::builder()
.bootstrap_servers("localhost:9092")
.idempotent(true)
.max_in_flight(10)
.build()
.expect("idempotent config should auto-cap max_in_flight to 5");
assert_eq!(cfg.max_in_flight(), 5);
}
#[tokio::test]
async fn test_producer_builder_rejects_zero_max_in_flight() {
let mut builder = Producer::builder().bootstrap_servers("localhost:9092");
builder.config.max_in_flight = 0;
let result = builder.build().await;
match result {
Err(e) => assert!(e.to_string().contains("max_in_flight")),
Ok(_) => panic!("expected error for max_in_flight=0"),
}
}
#[tokio::test]
async fn test_producer_builder_rejects_zero_batch_size() {
let mut builder = Producer::builder().bootstrap_servers("localhost:9092");
builder.config.batch_size = 0;
let result = builder.build().await;
match result {
Err(e) => assert!(e.to_string().contains("batch_size")),
Ok(_) => panic!("expected error for batch_size=0"),
}
}
#[test]
fn test_producer_builder_interceptor() {
use crate::interceptor::{InterceptorResult, ProducerInterceptor};
#[derive(Debug)]
struct TestInterceptor;
impl ProducerInterceptor for TestInterceptor {
fn on_send(&self, _record: &mut ProducerRecord) -> InterceptorResult {
Ok(())
}
}
let builder = Producer::builder()
.bootstrap_servers("localhost:9092")
.interceptor(Arc::new(TestInterceptor));
assert_eq!(builder.interceptors.len(), 1);
}
#[test]
fn test_producer_builder_add_interceptor() {
use crate::interceptor::ProducerInterceptor;
#[derive(Debug)]
struct A;
impl ProducerInterceptor for A {}
#[derive(Debug)]
struct B;
impl ProducerInterceptor for B {}
let builder = Producer::builder()
.bootstrap_servers("localhost:9092")
.add_interceptor(Arc::new(A))
.add_interceptor(Arc::new(B));
assert_eq!(builder.interceptors.len(), 2);
}
#[test]
fn test_producer_builder_interceptor_replaces_chain() {
use crate::interceptor::ProducerInterceptor;
#[derive(Debug)]
struct A;
impl ProducerInterceptor for A {}
#[derive(Debug)]
struct B;
impl ProducerInterceptor for B {}
let builder = Producer::builder()
.bootstrap_servers("localhost:9092")
.add_interceptor(Arc::new(A))
.add_interceptor(Arc::new(A))
.interceptor(Arc::new(B));
assert_eq!(builder.interceptors.len(), 1);
}
#[test]
fn test_producer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Producer>();
}
}