use std::collections::{HashMap, HashSet, VecDeque};
use std::time::{Duration, Instant};
use tokio::sync::oneshot;
use crate::config::ProducerConfig;
use crate::metadata::MetadataCache;
use crate::types::{ProduceAck, ProduceRecord, TopicPartitionKey};
use crate::{Error, Result};
#[derive(Default)]
pub struct RecordAccumulator {
batches: HashMap<TopicPartitionKey, VecDeque<ProducerBatch>>,
}
pub struct ReadyCheckResult {
pub ready_by_leader: HashMap<i32, Vec<TopicPartitionKey>>,
pub unknown_leaders: Vec<TopicPartitionKey>,
pub next_ready_check_delay: Duration,
}
pub struct ProducerBatch {
pub key: TopicPartitionKey,
pub records: Vec<ProduceRecord>,
pub replies: Vec<oneshot::Sender<Result<ProduceAck>>>,
pub created_at: Instant,
pub estimated_bytes: usize,
pub attempts: usize,
pub retry_at: Option<Instant>,
pub producer_state: Option<BatchProducerState>,
}
#[derive(Debug, Clone, Copy)]
pub struct BatchProducerState {
pub producer_id: i64,
pub producer_epoch: i16,
pub base_sequence: i32,
pub transactional: bool,
}
impl RecordAccumulator {
pub fn append(
&mut self,
record: ProduceRecord,
reply: oneshot::Sender<Result<ProduceAck>>,
batch_size: usize,
) -> bool {
let key = TopicPartitionKey::new(
record.topic.clone(),
record
.partition
.expect("producer sender must resolve the target partition before append"),
);
let estimated_size = estimate_record_size(&record);
let deque = self.batches.entry(key.clone()).or_default();
if let Some(batch) = deque.back_mut()
&& batch.is_appendable(batch_size, estimated_size)
{
batch.push(record, reply, estimated_size);
return false;
}
deque.push_back(ProducerBatch::new(key, record, reply, estimated_size));
true
}
pub fn can_append_to(
&self,
key: &TopicPartitionKey,
batch_size: usize,
estimated_size: usize,
) -> bool {
self.batches
.get(key)
.and_then(|deque| deque.back())
.is_some_and(|batch| batch.is_appendable(batch_size, estimated_size))
}
pub fn has_undrained(&self) -> bool {
self.batches.values().any(|deque| !deque.is_empty())
}
pub fn ready(
&self,
metadata: &MetadataCache,
config: &ProducerConfig,
closing: bool,
) -> ReadyCheckResult {
let now = Instant::now();
let mut ready_by_leader = HashMap::<i32, Vec<TopicPartitionKey>>::new();
let mut unknown_leaders = Vec::new();
let mut next_ready_check_delay = Duration::from_millis(250);
for (key, deque) in &self.batches {
let Some(batch) = deque.front() else {
continue;
};
if let Some(retry_at) = batch.retry_at
&& now < retry_at
{
next_ready_check_delay =
next_ready_check_delay.min(retry_at.saturating_duration_since(now));
continue;
}
let linger_elapsed = now.saturating_duration_since(batch.created_at);
let sendable =
closing || batch.is_full(config.batch_size) || linger_elapsed >= config.linger;
if !sendable {
next_ready_check_delay =
next_ready_check_delay.min(config.linger.saturating_sub(linger_elapsed));
continue;
}
match metadata.leader_for(&key.topic, key.partition) {
Some(leader_id) => ready_by_leader
.entry(leader_id)
.or_default()
.push(key.clone()),
None => unknown_leaders.push(key.clone()),
}
}
ReadyCheckResult {
ready_by_leader,
unknown_leaders,
next_ready_check_delay,
}
}
pub fn drain_ready(
&mut self,
ready_by_leader: HashMap<i32, Vec<TopicPartitionKey>>,
) -> HashMap<i32, Vec<ProducerBatch>> {
let mut drained = HashMap::<i32, Vec<ProducerBatch>>::new();
for (leader_id, keys) in ready_by_leader {
for key in keys {
let Some(deque) = self.batches.get_mut(&key) else {
continue;
};
let Some(batch) = deque.pop_front() else {
continue;
};
if deque.is_empty() {
self.batches.remove(&key);
}
drained.entry(leader_id).or_default().push(batch);
}
}
drained
}
pub fn drain_unknown_leader(
&mut self,
partitions: Vec<TopicPartitionKey>,
) -> Vec<ProducerBatch> {
let mut drained = Vec::new();
let mut seen = HashSet::new();
for key in partitions {
if !seen.insert(key.clone()) {
continue;
}
let Some(deque) = self.batches.get_mut(&key) else {
continue;
};
let Some(batch) = deque.pop_front() else {
continue;
};
if deque.is_empty() {
self.batches.remove(&key);
}
drained.push(batch);
}
drained
}
pub fn reenqueue_front(&mut self, batch: ProducerBatch) {
self.batches
.entry(batch.key.clone())
.or_default()
.push_front(batch);
}
pub fn fail_all(&mut self, message: &str) {
let pending = std::mem::take(&mut self.batches);
for deque in pending.into_values() {
for batch in deque {
fail_batch(batch, message);
}
}
}
pub fn expire_batches(&mut self, delivery_timeout: Duration) -> Vec<ProducerBatch> {
let now = Instant::now();
let mut expired = Vec::new();
let keys = self.batches.keys().cloned().collect::<Vec<_>>();
let mut empty_keys = Vec::new();
for key in keys {
let Some(deque) = self.batches.get_mut(&key) else {
continue;
};
let mut retained = VecDeque::new();
while let Some(batch) = deque.pop_front() {
if now.saturating_duration_since(batch.created_at) >= delivery_timeout {
expired.push(batch);
} else {
retained.push_back(batch);
}
}
if retained.is_empty() {
empty_keys.push(key);
} else {
*deque = retained;
}
}
for key in empty_keys {
self.batches.remove(&key);
}
expired
}
pub fn next_expiration_in(&self, delivery_timeout: Duration) -> Option<Duration> {
let now = Instant::now();
self.batches
.values()
.flat_map(|deque| deque.iter())
.map(|batch| {
delivery_timeout.saturating_sub(now.saturating_duration_since(batch.created_at))
})
.min()
}
}
impl ProducerBatch {
fn new(
key: TopicPartitionKey,
record: ProduceRecord,
reply: oneshot::Sender<Result<ProduceAck>>,
estimated_size: usize,
) -> Self {
Self {
key,
records: vec![record],
replies: vec![reply],
created_at: Instant::now(),
estimated_bytes: estimated_size,
attempts: 0,
retry_at: None,
producer_state: None,
}
}
fn is_appendable(&self, batch_size: usize, additional_bytes: usize) -> bool {
self.attempts == 0
&& self.retry_at.is_none()
&& self.estimated_bytes.saturating_add(additional_bytes) <= batch_size
}
pub fn is_full(&self, batch_size: usize) -> bool {
self.estimated_bytes >= batch_size
}
fn push(
&mut self,
record: ProduceRecord,
reply: oneshot::Sender<Result<ProduceAck>>,
estimated_size: usize,
) {
self.records.push(record);
self.replies.push(reply);
self.estimated_bytes = self.estimated_bytes.saturating_add(estimated_size);
}
}
pub fn fail_batch(mut batch: ProducerBatch, message: &str) {
while let Some(reply) = batch.replies.pop() {
let _ = reply.send(Err(Error::Internal(anyhow::anyhow!(message.to_owned()))));
}
}
fn estimate_record_size(record: &ProduceRecord) -> usize {
let header_bytes = record
.headers
.iter()
.map(|header| header.key.len() + header.value.as_ref().map_or(0, |value| value.len()) + 8)
.sum::<usize>();
record.key.as_ref().map_or(0, |key| key.len())
+ record.value.as_ref().map_or(0, |value| value.len())
+ header_bytes
+ 64
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::oneshot;
#[test]
fn appends_same_partition_into_open_batch() {
let mut accumulator = RecordAccumulator::default();
let (reply1, _) = oneshot::channel();
let (reply2, _) = oneshot::channel();
accumulator.append(
ProduceRecord::new("topic-a", 0, b"first".as_slice()),
reply1,
1024,
);
accumulator.append(
ProduceRecord::new("topic-a", 0, b"second".as_slice()),
reply2,
1024,
);
let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
let deque = accumulator.batches.get(&key).unwrap();
assert_eq!(deque.len(), 1);
assert_eq!(deque.front().unwrap().records.len(), 2);
}
#[test]
fn linger_makes_batch_ready() {
let mut accumulator = RecordAccumulator::default();
let (reply, _) = oneshot::channel();
accumulator.append(
ProduceRecord::new("topic-a", 0, b"hello".as_slice()),
reply,
1024,
);
let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
accumulator
.batches
.get_mut(&key)
.unwrap()
.front_mut()
.unwrap()
.created_at = Instant::now() - Duration::from_millis(20);
let config = ProducerConfig::new("localhost:9092").with_linger(Duration::from_millis(10));
let ready = accumulator.ready(&MetadataCache::default(), &config, false);
assert_eq!(ready.unknown_leaders, vec![key]);
}
#[test]
fn expires_batches_past_delivery_timeout() {
let mut accumulator = RecordAccumulator::default();
let (reply, _) = oneshot::channel();
accumulator.append(
ProduceRecord::new("topic-a", 0, b"hello".as_slice()),
reply,
1024,
);
let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
accumulator
.batches
.get_mut(&key)
.unwrap()
.front_mut()
.unwrap()
.created_at = Instant::now() - Duration::from_secs(2);
let expired = accumulator.expire_batches(Duration::from_secs(1));
assert_eq!(expired.len(), 1);
assert!(!accumulator.has_undrained());
}
}