use std::collections::{HashMap, HashSet, VecDeque};
use std::time::{Duration, Instant};
use tokio::sync::oneshot;
use crate::config::ProducerConfig;
use crate::metadata::MetadataCache;
use crate::types::{ProduceAck, ProduceRecord, TopicPartitionKey};
use crate::{Error, ProducerError, Result};
#[derive(Default)]
pub struct RecordAccumulator {
batches: HashMap<TopicPartitionKey, VecDeque<ProducerBatch>>,
}
pub struct ReadyCheckResult {
pub ready_by_leader: HashMap<i32, Vec<TopicPartitionKey>>,
pub unknown_leaders: Vec<TopicPartitionKey>,
pub next_ready_check_delay: Duration,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RecordAccumulatorStats {
pub batch_count: usize,
pub record_count: usize,
pub estimated_bytes: usize,
pub retrying_batch_count: usize,
pub oldest_batch_age_ms: u128,
}
pub struct ProducerBatch {
pub key: TopicPartitionKey,
pub records: Vec<ProduceRecord>,
pub replies: Vec<oneshot::Sender<Result<ProduceAck>>>,
pub created_at: Instant,
pub estimated_bytes: usize,
pub attempts: usize,
pub retry_at: Option<Instant>,
pub producer_state: Option<BatchProducerState>,
}
#[derive(Debug, Clone, Copy)]
pub struct BatchProducerState {
pub producer_id: i64,
pub producer_epoch: i16,
pub base_sequence: i32,
pub transactional: bool,
}
impl RecordAccumulator {
pub fn append(
&mut self,
record: ProduceRecord,
reply: oneshot::Sender<Result<ProduceAck>>,
batch_size: usize,
) -> bool {
let key = TopicPartitionKey::new(
record.topic.clone(),
record
.partition
.expect("producer sender must resolve the target partition before append"),
);
let estimated_size = estimate_record_size(&record);
let deque = self.batches.entry(key.clone()).or_default();
if let Some(batch) = deque.back_mut()
&& batch.is_appendable(batch_size, estimated_size)
{
batch.push(record, reply, estimated_size);
return false;
}
deque.push_back(ProducerBatch::new(key, record, reply, estimated_size));
true
}
pub fn can_append_to(
&self,
key: &TopicPartitionKey,
batch_size: usize,
estimated_size: usize,
) -> bool {
self.batches
.get(key)
.and_then(|deque| deque.back())
.is_some_and(|batch| batch.is_appendable(batch_size, estimated_size))
}
pub fn has_undrained(&self) -> bool {
self.batches.values().any(|deque| !deque.is_empty())
}
pub fn stats(&self) -> RecordAccumulatorStats {
let now = Instant::now();
self.batches.values().flat_map(|deque| deque.iter()).fold(
RecordAccumulatorStats::default(),
|mut stats, batch| {
stats.batch_count += 1;
stats.record_count += batch.records.len();
stats.estimated_bytes = stats.estimated_bytes.saturating_add(batch.estimated_bytes);
if batch.retry_at.is_some() {
stats.retrying_batch_count += 1;
}
stats.oldest_batch_age_ms = stats
.oldest_batch_age_ms
.max(now.saturating_duration_since(batch.created_at).as_millis());
stats
},
)
}
pub fn estimated_bytes(&self) -> usize {
self.stats().estimated_bytes
}
pub fn ready(
&self,
metadata: &MetadataCache,
config: &ProducerConfig,
closing: bool,
) -> ReadyCheckResult {
let now = Instant::now();
let mut ready_by_leader = HashMap::<i32, Vec<TopicPartitionKey>>::new();
let mut unknown_leaders = Vec::new();
let mut next_ready_check_delay = Duration::from_millis(250);
for (key, deque) in &self.batches {
let Some(batch) = deque.front() else {
continue;
};
if let Some(retry_at) = batch.retry_at
&& now < retry_at
{
next_ready_check_delay =
next_ready_check_delay.min(retry_at.saturating_duration_since(now));
continue;
}
let linger_elapsed = now.saturating_duration_since(batch.created_at);
let sendable =
closing || batch.is_full(config.batch_size) || linger_elapsed >= config.linger;
if !sendable {
next_ready_check_delay =
next_ready_check_delay.min(config.linger.saturating_sub(linger_elapsed));
continue;
}
match metadata.leader_for(&key.topic, key.partition) {
Some(leader_id) => ready_by_leader
.entry(leader_id)
.or_default()
.push(key.clone()),
None => unknown_leaders.push(key.clone()),
}
}
ReadyCheckResult {
ready_by_leader,
unknown_leaders,
next_ready_check_delay,
}
}
pub fn drain_ready(
&mut self,
ready_by_leader: HashMap<i32, Vec<TopicPartitionKey>>,
) -> HashMap<i32, Vec<ProducerBatch>> {
let mut drained = HashMap::<i32, Vec<ProducerBatch>>::new();
for (leader_id, keys) in ready_by_leader {
for key in keys {
let Some(deque) = self.batches.get_mut(&key) else {
continue;
};
let Some(batch) = deque.pop_front() else {
continue;
};
if deque.is_empty() {
self.batches.remove(&key);
}
drained.entry(leader_id).or_default().push(batch);
}
}
drained
}
pub fn drain_unknown_leader(
&mut self,
partitions: Vec<TopicPartitionKey>,
) -> Vec<ProducerBatch> {
let mut drained = Vec::new();
let mut seen = HashSet::new();
for key in partitions {
if !seen.insert(key.clone()) {
continue;
}
let Some(deque) = self.batches.get_mut(&key) else {
continue;
};
let Some(batch) = deque.pop_front() else {
continue;
};
if deque.is_empty() {
self.batches.remove(&key);
}
drained.push(batch);
}
drained
}
pub fn reenqueue_front(&mut self, batch: ProducerBatch) {
self.batches
.entry(batch.key.clone())
.or_default()
.push_front(batch);
}
pub fn split_and_reenqueue_front(&mut self, batch: ProducerBatch) -> bool {
if !batch.can_split() {
return false;
}
let key = batch.key.clone();
let split_batches = batch.split();
let deque = self.batches.entry(key).or_default();
for batch in split_batches.into_iter().rev() {
deque.push_front(batch);
}
true
}
pub fn fail_all(&mut self, message: &str) {
let pending = std::mem::take(&mut self.batches);
for deque in pending.into_values() {
for batch in deque {
fail_batch(batch, message);
}
}
}
pub fn expire_batches(&mut self, delivery_timeout: Duration) -> Vec<ProducerBatch> {
let now = Instant::now();
let mut expired = Vec::new();
let keys = self.batches.keys().cloned().collect::<Vec<_>>();
let mut empty_keys = Vec::new();
for key in keys {
let Some(deque) = self.batches.get_mut(&key) else {
continue;
};
let mut retained = VecDeque::new();
while let Some(batch) = deque.pop_front() {
if now.saturating_duration_since(batch.created_at) >= delivery_timeout {
expired.push(batch);
} else {
retained.push_back(batch);
}
}
if retained.is_empty() {
empty_keys.push(key);
} else {
*deque = retained;
}
}
for key in empty_keys {
self.batches.remove(&key);
}
expired
}
pub fn next_expiration_in(&self, delivery_timeout: Duration) -> Option<Duration> {
let now = Instant::now();
self.batches
.values()
.flat_map(|deque| deque.iter())
.map(|batch| {
delivery_timeout.saturating_sub(now.saturating_duration_since(batch.created_at))
})
.min()
}
}
impl ProducerBatch {
fn new(
key: TopicPartitionKey,
record: ProduceRecord,
reply: oneshot::Sender<Result<ProduceAck>>,
estimated_size: usize,
) -> Self {
Self {
key,
records: vec![record],
replies: vec![reply],
created_at: Instant::now(),
estimated_bytes: estimated_size,
attempts: 0,
retry_at: None,
producer_state: None,
}
}
fn is_appendable(&self, batch_size: usize, additional_bytes: usize) -> bool {
self.attempts == 0
&& self.retry_at.is_none()
&& self.estimated_bytes.saturating_add(additional_bytes) <= batch_size
}
pub fn is_full(&self, batch_size: usize) -> bool {
self.estimated_bytes >= batch_size
}
pub fn can_split(&self) -> bool {
self.records.len() > 1
}
fn split(self) -> Vec<Self> {
let Self {
key,
records,
replies,
created_at,
attempts,
producer_state,
..
} = self;
records
.into_iter()
.zip(replies)
.enumerate()
.map(|(index, (record, reply))| {
let estimated_bytes = estimate_record_size(&record);
let producer_state = producer_state.map(|state| BatchProducerState {
base_sequence: state
.base_sequence
.saturating_add(i32::try_from(index).unwrap_or(0)),
..state
});
Self {
key: key.clone(),
records: vec![record],
replies: vec![reply],
created_at,
estimated_bytes,
attempts,
retry_at: None,
producer_state,
}
})
.collect()
}
fn push(
&mut self,
record: ProduceRecord,
reply: oneshot::Sender<Result<ProduceAck>>,
estimated_size: usize,
) {
self.records.push(record);
self.replies.push(reply);
self.estimated_bytes = self.estimated_bytes.saturating_add(estimated_size);
}
}
pub fn fail_batch(mut batch: ProducerBatch, message: &str) {
while let Some(reply) = batch.replies.pop() {
let _ = reply.send(Err(Error::Producer(ProducerError::BatchFailed {
topic: batch.key.topic.clone(),
partition: batch.key.partition,
message: message.to_owned(),
})));
}
}
pub fn fail_batch_with_error(mut batch: ProducerBatch, error: &Error) {
while let Some(reply) = batch.replies.pop() {
let _ = reply.send(Err(clone_error_for_batch_reply(error, &batch)));
}
}
fn clone_error_for_batch_reply(error: &Error, batch: &ProducerBatch) -> Error {
match error {
Error::Broker(error) => Error::Broker(error.clone()),
Error::Validation(error) => Error::Validation(error.clone()),
Error::Protocol(error) => Error::Protocol(error.clone()),
Error::Producer(ProducerError::TransactionAbortRequired { operation, message }) => {
Error::Producer(ProducerError::TransactionAbortRequired {
operation,
message: message.clone(),
})
}
Error::Producer(ProducerError::TransactionFatal { operation, message }) => {
Error::Producer(ProducerError::TransactionFatal {
operation,
message: message.clone(),
})
}
Error::Producer(ProducerError::BatchFailed { .. }) => {
Error::Producer(ProducerError::BatchFailed {
topic: batch.key.topic.clone(),
partition: batch.key.partition,
message: error.to_string(),
})
}
Error::Producer(ProducerError::RecordTooLarge {
size,
limit_name,
limit,
}) => Error::Producer(ProducerError::RecordTooLarge {
size: *size,
limit_name,
limit: *limit,
}),
Error::Producer(ProducerError::RequestTooLarge { size, limit }) => {
Error::Producer(ProducerError::RequestTooLarge {
size: *size,
limit: *limit,
})
}
Error::Producer(ProducerError::BufferExhausted {
buffered,
required,
limit,
max_block_ms,
}) => Error::Producer(ProducerError::BufferExhausted {
buffered: *buffered,
required: *required,
limit: *limit,
max_block_ms: *max_block_ms,
}),
_ => Error::Producer(ProducerError::BatchFailed {
topic: batch.key.topic.clone(),
partition: batch.key.partition,
message: error.to_string(),
}),
}
}
pub fn estimate_record_size(record: &ProduceRecord) -> usize {
let header_bytes = record
.headers
.iter()
.map(|header| header.key.len() + header.value.as_ref().map_or(0, |value| value.len()) + 8)
.sum::<usize>();
record.key.as_ref().map_or(0, |key| key.len())
+ record.value.as_ref().map_or(0, |value| value.len())
+ header_bytes
+ 64
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::oneshot;
#[test]
fn appends_same_partition_into_open_batch() {
let mut accumulator = RecordAccumulator::default();
let (reply1, _) = oneshot::channel();
let (reply2, _) = oneshot::channel();
accumulator.append(
ProduceRecord::new("topic-a", 0, b"first".as_slice()),
reply1,
1024,
);
accumulator.append(
ProduceRecord::new("topic-a", 0, b"second".as_slice()),
reply2,
1024,
);
let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
let deque = accumulator.batches.get(&key).unwrap();
assert_eq!(deque.len(), 1);
assert_eq!(deque.front().unwrap().records.len(), 2);
}
#[test]
fn linger_makes_batch_ready() {
let mut accumulator = RecordAccumulator::default();
let (reply, _) = oneshot::channel();
accumulator.append(
ProduceRecord::new("topic-a", 0, b"hello".as_slice()),
reply,
1024,
);
let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
accumulator
.batches
.get_mut(&key)
.unwrap()
.front_mut()
.unwrap()
.created_at = Instant::now() - Duration::from_millis(20);
let config = ProducerConfig::new("localhost:9092").with_linger(Duration::from_millis(10));
let ready = accumulator.ready(&MetadataCache::default(), &config, false);
assert_eq!(ready.unknown_leaders, vec![key]);
}
#[test]
fn expires_batches_past_delivery_timeout() {
let mut accumulator = RecordAccumulator::default();
let (reply, _) = oneshot::channel();
accumulator.append(
ProduceRecord::new("topic-a", 0, b"hello".as_slice()),
reply,
1024,
);
let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
accumulator
.batches
.get_mut(&key)
.unwrap()
.front_mut()
.unwrap()
.created_at = Instant::now() - Duration::from_secs(2);
let expired = accumulator.expire_batches(Duration::from_secs(1));
assert_eq!(expired.len(), 1);
assert!(!accumulator.has_undrained());
}
#[test]
fn stats_reports_aggregate_queue_depth_without_partition_labels() {
let mut accumulator = RecordAccumulator::default();
let (reply1, _) = oneshot::channel();
let (reply2, _) = oneshot::channel();
let (reply3, _) = oneshot::channel();
accumulator.append(
ProduceRecord::new("topic-a", 0, b"first".as_slice()),
reply1,
1024,
);
accumulator.append(
ProduceRecord::new("topic-a", 0, b"second".as_slice()),
reply2,
1024,
);
accumulator.append(
ProduceRecord::new("topic-a", 1, b"third".as_slice()),
reply3,
1024,
);
let retrying_key = TopicPartitionKey::new("topic-a".to_owned(), 1);
accumulator
.batches
.get_mut(&retrying_key)
.unwrap()
.front_mut()
.unwrap()
.retry_at = Some(Instant::now() + Duration::from_millis(10));
let stats = accumulator.stats();
assert_eq!(stats.batch_count, 2);
assert_eq!(stats.record_count, 3);
assert!(stats.estimated_bytes > 0);
assert_eq!(stats.retrying_batch_count, 1);
}
#[test]
fn split_and_reenqueue_front_preserves_order_attempts_and_sequences() {
let mut accumulator = RecordAccumulator::default();
let (reply1, _) = oneshot::channel();
let (reply2, _) = oneshot::channel();
let (reply3, _) = oneshot::channel();
let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
accumulator.append(
ProduceRecord::new("topic-a", 0, b"first".as_slice()),
reply1,
1024,
);
accumulator.append(
ProduceRecord::new("topic-a", 0, b"second".as_slice()),
reply2,
1024,
);
accumulator.append(
ProduceRecord::new("topic-a", 0, b"third".as_slice()),
reply3,
1024,
);
let mut batch = accumulator
.drain_unknown_leader(vec![key.clone()])
.pop()
.unwrap();
batch.attempts = 2;
batch.producer_state = Some(BatchProducerState {
producer_id: 42,
producer_epoch: 7,
base_sequence: 11,
transactional: false,
});
assert!(accumulator.split_and_reenqueue_front(batch));
let stats = accumulator.stats();
assert_eq!(stats.batch_count, 3);
assert_eq!(stats.record_count, 3);
for (expected_value, expected_sequence) in [
(b"first".as_slice(), 11),
(b"second".as_slice(), 12),
(b"third".as_slice(), 13),
] {
let batch = accumulator
.drain_unknown_leader(vec![key.clone()])
.pop()
.unwrap();
assert_eq!(batch.records.len(), 1);
assert_eq!(batch.records[0].value.as_deref(), Some(expected_value));
assert_eq!(batch.attempts, 2);
assert_eq!(
batch.producer_state.map(|state| state.base_sequence),
Some(expected_sequence)
);
}
assert!(!accumulator.has_undrained());
}
}