kafkit-client 0.1.2

Kafka 4.0+ pure Rust client.
Documentation
use std::collections::BTreeMap;
use std::time::Duration;

use anyhow::Result;
use bytes::{Bytes, BytesMut};
use kafka_protocol::indexmap::IndexMap;
use kafka_protocol::messages::produce_request::{PartitionProduceData, TopicProduceData};
use kafka_protocol::messages::{ProduceRequest, TransactionalId};
use kafka_protocol::protocol::StrBytes;
use kafka_protocol::records::{
    NO_PARTITION_LEADER_EPOCH, NO_PRODUCER_EPOCH, NO_PRODUCER_ID, Record, RecordBatchEncoder,
    RecordEncodeOptions, TimestampType,
};

use super::accumulator::{BatchProducerState, ProducerBatch};
use crate::config::ProducerCompression;
use crate::network::{duration_to_i32_ms, now_unix_ms};
use crate::types::ProduceRecord;

pub fn acks_for_request(configured_acks: i16, transactional: bool) -> i16 {
    if transactional { -1 } else { configured_acks }
}

pub fn build_produce_request(
    batches: &[ProducerBatch],
    acks: i16,
    request_timeout: Duration,
    compression: ProducerCompression,
    transactional_id: Option<&str>,
) -> Result<ProduceRequest> {
    let mut topics = BTreeMap::<String, Vec<PartitionProduceData>>::new();
    for batch in batches {
        let encoded_records =
            encode_record_batch(&batch.records, compression, batch.producer_state)?;
        topics.entry(batch.key.topic.clone()).or_default().push(
            PartitionProduceData::default()
                .with_index(batch.key.partition)
                .with_records(Some(encoded_records)),
        );
    }

    Ok(ProduceRequest::default()
        .with_transactional_id(
            transactional_id.map(|value| TransactionalId(StrBytes::from_string(value.to_owned()))),
        )
        .with_acks(acks)
        .with_timeout_ms(duration_to_i32_ms(request_timeout)?)
        .with_topic_data(
            topics
                .into_iter()
                .map(|(topic, partition_data)| {
                    TopicProduceData::default()
                        .with_name(StrBytes::from_string(topic).into())
                        .with_partition_data(partition_data)
                })
                .collect(),
        ))
}

fn encode_record_batch(
    records: &[ProduceRecord],
    compression: ProducerCompression,
    producer_state: Option<BatchProducerState>,
) -> Result<Bytes> {
    let mut buffer = BytesMut::new();
    let default_timestamp = now_unix_ms()?;
    let kafka_records = records
        .iter()
        .enumerate()
        .map(|(index, record)| {
            let offset = i64::try_from(index).unwrap_or(0);
            let base_sequence = producer_state.map(|state| state.base_sequence).unwrap_or(0);
            let sequence = base_sequence.saturating_add(i32::try_from(index).unwrap_or(0));
            Record {
                transactional: producer_state
                    .map(|state| state.transactional)
                    .unwrap_or(false),
                control: false,
                partition_leader_epoch: NO_PARTITION_LEADER_EPOCH,
                producer_id: producer_state
                    .map(|state| state.producer_id)
                    .unwrap_or(NO_PRODUCER_ID),
                producer_epoch: producer_state
                    .map(|state| state.producer_epoch)
                    .unwrap_or(NO_PRODUCER_EPOCH),
                timestamp_type: TimestampType::Creation,
                offset,
                sequence,
                timestamp: record.timestamp.unwrap_or(default_timestamp),
                key: record.key.clone(),
                value: record.value.clone(),
                headers: record_headers(&record.headers),
            }
        })
        .collect::<Vec<_>>();

    let options = RecordEncodeOptions {
        version: 2,
        compression: compression.into(),
    };
    RecordBatchEncoder::encode(&mut buffer, kafka_records.iter(), &options)?;
    Ok(buffer.freeze())
}

fn record_headers(headers: &[crate::types::RecordHeader]) -> IndexMap<StrBytes, Option<Bytes>> {
    let mut kafka_headers = IndexMap::with_capacity(headers.len());
    for header in headers {
        kafka_headers.insert(
            StrBytes::from_string(header.key.clone()),
            header.value.clone(),
        );
    }
    kafka_headers
}

#[cfg(test)]
mod tests {
    use std::time::Instant;

    use super::*;
    use kafka_protocol::records::RecordBatchDecoder;

    use crate::types::TopicPartitionKey;

    #[test]
    fn producer_compression_builds_non_empty_batch() {
        let records = vec![
            ProduceRecord::new("topic-a", 0, b"hello".as_slice()),
            ProduceRecord::new("topic-a", 0, b"world".as_slice()),
        ];
        let encoded = encode_record_batch(&records, ProducerCompression::Gzip, None).unwrap();
        assert!(!encoded.is_empty());
    }

    #[test]
    fn transactional_batches_encode_transactional_records_with_sequences() {
        let records = vec![
            ProduceRecord::new("topic-a", 0, b"hello".as_slice()),
            ProduceRecord::new("topic-a", 0, b"world".as_slice()),
        ];

        let encoded = encode_record_batch(
            &records,
            ProducerCompression::None,
            Some(BatchProducerState {
                producer_id: 42,
                producer_epoch: 7,
                base_sequence: 11,
                transactional: true,
            }),
        )
        .unwrap();

        let mut decoded = encoded.clone();
        let batches = RecordBatchDecoder::decode_all(&mut decoded).unwrap();
        let records = batches
            .into_iter()
            .flat_map(|batch| batch.records)
            .collect::<Vec<_>>();
        assert_eq!(records.len(), 2);
        assert!(records.iter().all(|record| record.transactional));
        assert!(records.iter().all(|record| record.producer_id == 42));
        assert!(records.iter().all(|record| record.producer_epoch == 7));
        assert_eq!(records[0].sequence, 11);
        assert_eq!(records[1].sequence, 12);
    }

    #[test]
    fn producer_records_encode_timestamps_and_headers() {
        let records = vec![
            ProduceRecord::new("topic-a", 0, b"hello".as_slice())
                .with_timestamp(1_700_000_000_123)
                .with_header(crate::types::RecordHeader::new(
                    "trace-id",
                    b"abc-123".as_slice(),
                ))
                .with_header(crate::types::RecordHeader::null("nullable")),
        ];

        let encoded = encode_record_batch(&records, ProducerCompression::None, None).unwrap();
        let mut decoded = encoded.clone();
        let batches = RecordBatchDecoder::decode_all(&mut decoded).unwrap();
        let records = batches
            .into_iter()
            .flat_map(|batch| batch.records)
            .collect::<Vec<_>>();

        assert_eq!(records.len(), 1);
        assert_eq!(records[0].timestamp, 1_700_000_000_123);
        assert_eq!(
            records[0]
                .headers
                .get(&StrBytes::from_static_str("trace-id"))
                .and_then(|value| value.as_deref()),
            Some(&b"abc-123"[..])
        );
        assert!(
            records[0]
                .headers
                .contains_key(&StrBytes::from_static_str("nullable"))
        );
        assert!(
            records[0]
                .headers
                .get(&StrBytes::from_static_str("nullable"))
                .unwrap()
                .is_none()
        );
    }

    #[test]
    fn producer_records_encode_null_values_for_tombstones() {
        let records = vec![ProduceRecord::tombstone("topic-a", 0).with_key("delete-me")];

        let encoded = encode_record_batch(&records, ProducerCompression::None, None).unwrap();
        let mut decoded = encoded.clone();
        let batches = RecordBatchDecoder::decode_all(&mut decoded).unwrap();
        let records = batches
            .into_iter()
            .flat_map(|batch| batch.records)
            .collect::<Vec<_>>();

        assert_eq!(records.len(), 1);
        assert_eq!(records[0].key.as_deref(), Some(&b"delete-me"[..]));
        assert!(records[0].value.is_none());
    }

    #[test]
    fn produce_request_includes_transactional_id() {
        let batch = ProducerBatch {
            key: TopicPartitionKey::new("topic-a".to_owned(), 0),
            records: vec![ProduceRecord::new("topic-a", 0, b"hello".as_slice())],
            replies: Vec::new(),
            created_at: Instant::now(),
            estimated_bytes: 0,
            attempts: 0,
            retry_at: None,
            producer_state: Some(BatchProducerState {
                producer_id: 9,
                producer_epoch: 3,
                base_sequence: 0,
                transactional: true,
            }),
        };

        let request = build_produce_request(
            &[batch],
            -1,
            Duration::from_secs(5),
            ProducerCompression::None,
            Some("tx-123"),
        )
        .unwrap();

        assert_eq!(
            request
                .transactional_id
                .as_ref()
                .map(|value| value.0.to_string()),
            Some("tx-123".to_owned())
        );
    }
}