kafkit-client 0.1.7

Kafka 4.0+ pure Rust client.
Documentation
use std::collections::{HashMap, HashSet, VecDeque};
use std::time::{Duration, Instant};

use tokio::sync::oneshot;

use crate::config::ProducerConfig;
use crate::metadata::MetadataCache;
use crate::types::{ProduceAck, ProduceRecord, TopicPartitionKey};
use crate::{Error, Result};

#[derive(Default)]
pub struct RecordAccumulator {
    batches: HashMap<TopicPartitionKey, VecDeque<ProducerBatch>>,
}

pub struct ReadyCheckResult {
    pub ready_by_leader: HashMap<i32, Vec<TopicPartitionKey>>,
    pub unknown_leaders: Vec<TopicPartitionKey>,
    pub next_ready_check_delay: Duration,
}

pub struct ProducerBatch {
    pub key: TopicPartitionKey,
    pub records: Vec<ProduceRecord>,
    pub replies: Vec<oneshot::Sender<Result<ProduceAck>>>,
    pub created_at: Instant,
    pub estimated_bytes: usize,
    pub attempts: usize,
    pub retry_at: Option<Instant>,
    pub producer_state: Option<BatchProducerState>,
}

#[derive(Debug, Clone, Copy)]
pub struct BatchProducerState {
    pub producer_id: i64,
    pub producer_epoch: i16,
    pub base_sequence: i32,
    pub transactional: bool,
}

impl RecordAccumulator {
    pub fn append(
        &mut self,
        record: ProduceRecord,
        reply: oneshot::Sender<Result<ProduceAck>>,
        batch_size: usize,
    ) -> bool {
        let key = TopicPartitionKey::new(
            record.topic.clone(),
            record
                .partition
                .expect("producer sender must resolve the target partition before append"),
        );
        let estimated_size = estimate_record_size(&record);
        let deque = self.batches.entry(key.clone()).or_default();

        if let Some(batch) = deque.back_mut()
            && batch.is_appendable(batch_size, estimated_size)
        {
            batch.push(record, reply, estimated_size);
            return false;
        }

        deque.push_back(ProducerBatch::new(key, record, reply, estimated_size));
        true
    }

    pub fn can_append_to(
        &self,
        key: &TopicPartitionKey,
        batch_size: usize,
        estimated_size: usize,
    ) -> bool {
        self.batches
            .get(key)
            .and_then(|deque| deque.back())
            .is_some_and(|batch| batch.is_appendable(batch_size, estimated_size))
    }

    pub fn has_undrained(&self) -> bool {
        self.batches.values().any(|deque| !deque.is_empty())
    }

    pub fn ready(
        &self,
        metadata: &MetadataCache,
        config: &ProducerConfig,
        closing: bool,
    ) -> ReadyCheckResult {
        let now = Instant::now();
        let mut ready_by_leader = HashMap::<i32, Vec<TopicPartitionKey>>::new();
        let mut unknown_leaders = Vec::new();
        let mut next_ready_check_delay = Duration::from_millis(250);

        for (key, deque) in &self.batches {
            let Some(batch) = deque.front() else {
                continue;
            };

            if let Some(retry_at) = batch.retry_at
                && now < retry_at
            {
                next_ready_check_delay =
                    next_ready_check_delay.min(retry_at.saturating_duration_since(now));
                continue;
            }

            let linger_elapsed = now.saturating_duration_since(batch.created_at);
            let sendable =
                closing || batch.is_full(config.batch_size) || linger_elapsed >= config.linger;

            if !sendable {
                next_ready_check_delay =
                    next_ready_check_delay.min(config.linger.saturating_sub(linger_elapsed));
                continue;
            }

            match metadata.leader_for(&key.topic, key.partition) {
                Some(leader_id) => ready_by_leader
                    .entry(leader_id)
                    .or_default()
                    .push(key.clone()),
                None => unknown_leaders.push(key.clone()),
            }
        }

        ReadyCheckResult {
            ready_by_leader,
            unknown_leaders,
            next_ready_check_delay,
        }
    }

    pub fn drain_ready(
        &mut self,
        ready_by_leader: HashMap<i32, Vec<TopicPartitionKey>>,
    ) -> HashMap<i32, Vec<ProducerBatch>> {
        let mut drained = HashMap::<i32, Vec<ProducerBatch>>::new();

        for (leader_id, keys) in ready_by_leader {
            for key in keys {
                let Some(deque) = self.batches.get_mut(&key) else {
                    continue;
                };
                let Some(batch) = deque.pop_front() else {
                    continue;
                };
                if deque.is_empty() {
                    self.batches.remove(&key);
                }
                drained.entry(leader_id).or_default().push(batch);
            }
        }

        drained
    }

    pub fn drain_unknown_leader(
        &mut self,
        partitions: Vec<TopicPartitionKey>,
    ) -> Vec<ProducerBatch> {
        let mut drained = Vec::new();
        let mut seen = HashSet::new();

        for key in partitions {
            if !seen.insert(key.clone()) {
                continue;
            }

            let Some(deque) = self.batches.get_mut(&key) else {
                continue;
            };
            let Some(batch) = deque.pop_front() else {
                continue;
            };
            if deque.is_empty() {
                self.batches.remove(&key);
            }
            drained.push(batch);
        }

        drained
    }

    pub fn reenqueue_front(&mut self, batch: ProducerBatch) {
        self.batches
            .entry(batch.key.clone())
            .or_default()
            .push_front(batch);
    }

    pub fn fail_all(&mut self, message: &str) {
        let pending = std::mem::take(&mut self.batches);
        for deque in pending.into_values() {
            for batch in deque {
                fail_batch(batch, message);
            }
        }
    }

    pub fn expire_batches(&mut self, delivery_timeout: Duration) -> Vec<ProducerBatch> {
        let now = Instant::now();
        let mut expired = Vec::new();
        let keys = self.batches.keys().cloned().collect::<Vec<_>>();
        let mut empty_keys = Vec::new();

        for key in keys {
            let Some(deque) = self.batches.get_mut(&key) else {
                continue;
            };

            let mut retained = VecDeque::new();
            while let Some(batch) = deque.pop_front() {
                if now.saturating_duration_since(batch.created_at) >= delivery_timeout {
                    expired.push(batch);
                } else {
                    retained.push_back(batch);
                }
            }

            if retained.is_empty() {
                empty_keys.push(key);
            } else {
                *deque = retained;
            }
        }

        for key in empty_keys {
            self.batches.remove(&key);
        }

        expired
    }

    pub fn next_expiration_in(&self, delivery_timeout: Duration) -> Option<Duration> {
        let now = Instant::now();
        self.batches
            .values()
            .flat_map(|deque| deque.iter())
            .map(|batch| {
                delivery_timeout.saturating_sub(now.saturating_duration_since(batch.created_at))
            })
            .min()
    }
}

impl ProducerBatch {
    fn new(
        key: TopicPartitionKey,
        record: ProduceRecord,
        reply: oneshot::Sender<Result<ProduceAck>>,
        estimated_size: usize,
    ) -> Self {
        Self {
            key,
            records: vec![record],
            replies: vec![reply],
            created_at: Instant::now(),
            estimated_bytes: estimated_size,
            attempts: 0,
            retry_at: None,
            producer_state: None,
        }
    }

    fn is_appendable(&self, batch_size: usize, additional_bytes: usize) -> bool {
        self.attempts == 0
            && self.retry_at.is_none()
            && self.estimated_bytes.saturating_add(additional_bytes) <= batch_size
    }

    pub fn is_full(&self, batch_size: usize) -> bool {
        self.estimated_bytes >= batch_size
    }

    fn push(
        &mut self,
        record: ProduceRecord,
        reply: oneshot::Sender<Result<ProduceAck>>,
        estimated_size: usize,
    ) {
        self.records.push(record);
        self.replies.push(reply);
        self.estimated_bytes = self.estimated_bytes.saturating_add(estimated_size);
    }
}

pub fn fail_batch(mut batch: ProducerBatch, message: &str) {
    while let Some(reply) = batch.replies.pop() {
        let _ = reply.send(Err(Error::Internal(anyhow::anyhow!(message.to_owned()))));
    }
}

fn estimate_record_size(record: &ProduceRecord) -> usize {
    let header_bytes = record
        .headers
        .iter()
        .map(|header| header.key.len() + header.value.as_ref().map_or(0, |value| value.len()) + 8)
        .sum::<usize>();

    record.key.as_ref().map_or(0, |key| key.len())
        + record.value.as_ref().map_or(0, |value| value.len())
        + header_bytes
        + 64
}

#[cfg(test)]
mod tests {
    use super::*;
    use tokio::sync::oneshot;

    #[test]
    fn appends_same_partition_into_open_batch() {
        let mut accumulator = RecordAccumulator::default();
        let (reply1, _) = oneshot::channel();
        let (reply2, _) = oneshot::channel();

        accumulator.append(
            ProduceRecord::new("topic-a", 0, b"first".as_slice()),
            reply1,
            1024,
        );
        accumulator.append(
            ProduceRecord::new("topic-a", 0, b"second".as_slice()),
            reply2,
            1024,
        );

        let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
        let deque = accumulator.batches.get(&key).unwrap();
        assert_eq!(deque.len(), 1);
        assert_eq!(deque.front().unwrap().records.len(), 2);
    }

    #[test]
    fn linger_makes_batch_ready() {
        let mut accumulator = RecordAccumulator::default();
        let (reply, _) = oneshot::channel();
        accumulator.append(
            ProduceRecord::new("topic-a", 0, b"hello".as_slice()),
            reply,
            1024,
        );

        let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
        accumulator
            .batches
            .get_mut(&key)
            .unwrap()
            .front_mut()
            .unwrap()
            .created_at = Instant::now() - Duration::from_millis(20);

        let config = ProducerConfig::new("localhost:9092").with_linger(Duration::from_millis(10));
        let ready = accumulator.ready(&MetadataCache::default(), &config, false);
        assert_eq!(ready.unknown_leaders, vec![key]);
    }

    #[test]
    fn expires_batches_past_delivery_timeout() {
        let mut accumulator = RecordAccumulator::default();
        let (reply, _) = oneshot::channel();
        accumulator.append(
            ProduceRecord::new("topic-a", 0, b"hello".as_slice()),
            reply,
            1024,
        );

        let key = TopicPartitionKey::new("topic-a".to_owned(), 0);
        accumulator
            .batches
            .get_mut(&key)
            .unwrap()
            .front_mut()
            .unwrap()
            .created_at = Instant::now() - Duration::from_secs(2);

        let expired = accumulator.expire_batches(Duration::from_secs(1));
        assert_eq!(expired.len(), 1);
        assert!(!accumulator.has_undrained());
    }
}