use std::time::Instant;
use super::record::ProducerRecord;
use crate::PartitionId;
use crate::protocol::{Compression, RecordBatch, RecordBatchBuilder};
#[derive(Debug)]
pub struct ProducerBatch {
pub topic: String,
pub partition: PartitionId,
records: Vec<ProducerRecord>,
tracked_count: usize,
size: usize,
max_size: usize,
compression: Compression,
created_at: Instant,
}
impl ProducerBatch {
pub fn new(
topic: String,
partition: PartitionId,
max_size: usize,
compression: Compression,
) -> Self {
Self {
topic,
partition,
records: Vec::new(),
tracked_count: 0,
size: 0,
max_size,
compression,
created_at: Instant::now(),
}
}
#[inline]
#[allow(clippy::result_large_err)]
pub fn try_add(&mut self, record: ProducerRecord) -> Result<(), ProducerRecord> {
let record_size = record.estimated_size();
if !self.is_empty() && self.size + record_size > self.max_size {
return Err(record);
}
self.size += record_size;
self.tracked_count += 1;
self.records.push(record);
Ok(())
}
#[inline]
pub(crate) fn would_fit(&self, record_size: usize) -> bool {
self.is_empty() || self.size + record_size <= self.max_size
}
#[inline]
pub(crate) fn track(&mut self, record_size: usize) {
self.size += record_size;
self.tracked_count += 1;
}
#[inline]
pub fn is_empty(&self) -> bool {
self.tracked_count == 0
}
#[inline]
pub fn len(&self) -> usize {
self.tracked_count
}
#[inline]
pub fn size(&self) -> usize {
self.size
}
#[inline]
pub fn is_full(&self) -> bool {
self.size >= self.max_size
}
#[inline]
pub fn age(&self) -> std::time::Duration {
self.created_at.elapsed()
}
pub fn build(&self) -> RecordBatch {
let mut builder = RecordBatchBuilder::new().compression(self.compression);
for record in &self.records {
if record.headers.is_empty() {
builder = builder.add_record(record.key.clone(), Some(record.value.clone()));
} else {
let hdrs: Vec<(String, Vec<u8>)> = record
.headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
builder = builder.add_record_with_headers(
record.key.clone(),
Some(record.value.clone()),
hdrs,
);
}
}
builder.build()
}
pub fn drain(&mut self) -> Vec<ProducerRecord> {
self.size = 0;
self.tracked_count = 0;
self.records.drain(..).collect()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_batch_new() {
let batch = ProducerBatch::new("test".to_string(), 0, 1024, Compression::None);
assert!(batch.is_empty());
assert_eq!(batch.len(), 0);
assert_eq!(batch.size(), 0);
}
#[test]
fn test_batch_try_add() {
let mut batch = ProducerBatch::new("test".to_string(), 0, 1024, Compression::None);
let record = ProducerRecord::new("test", b"hello".to_vec());
assert!(batch.try_add(record).is_ok());
assert!(!batch.is_empty());
assert_eq!(batch.len(), 1);
assert!(batch.size() > 0);
}
#[test]
fn test_batch_full() {
let mut batch = ProducerBatch::new("test".to_string(), 0, 200, Compression::None);
let record1 = ProducerRecord::new("test", vec![0u8; 20]);
assert!(batch.try_add(record1).is_ok());
let record2 = ProducerRecord::new("test", vec![0u8; 20]);
assert!(batch.try_add(record2).is_ok());
let record3 = ProducerRecord::new("test", vec![0u8; 20]);
assert!(batch.try_add(record3).is_err());
}
#[test]
fn test_batch_drain() {
let mut batch = ProducerBatch::new("test".to_string(), 0, 1024, Compression::None);
let _ = batch.try_add(ProducerRecord::new("test", b"hello".to_vec()));
let _ = batch.try_add(ProducerRecord::new("test", b"world".to_vec()));
let records = batch.drain();
assert_eq!(records.len(), 2);
assert!(batch.is_empty());
}
#[test]
fn test_batch_build() {
let mut batch = ProducerBatch::new("test".to_string(), 0, 1024, Compression::None);
let _ =
batch.try_add(ProducerRecord::new("test", b"value".to_vec()).with_key(b"key".to_vec()));
let record_batch = batch.build();
assert_eq!(record_batch.records.len(), 1);
}
#[test]
fn test_batch_build_preserves_headers() {
let mut batch = ProducerBatch::new("test".to_string(), 0, 4096, Compression::None);
let record = ProducerRecord::new("test", b"value".to_vec())
.with_key(b"key".to_vec())
.with_header("trace-id", b"abc123")
.with_header("content-type", b"application/json");
let _ = batch.try_add(record);
let record_batch = batch.build();
assert_eq!(record_batch.records.len(), 1);
assert_eq!(
record_batch.records[0].headers.len(),
2,
"Headers should be preserved in built batch"
);
assert_eq!(record_batch.records[0].headers[0].key, "trace-id");
assert_eq!(record_batch.records[0].headers[1].key, "content-type");
}
#[test]
fn test_would_fit_and_track() {
let mut batch = ProducerBatch::new("test".to_string(), 0, 200, Compression::None);
let record = ProducerRecord::new("test", vec![0u8; 20]);
let size = record.estimated_size();
assert!(batch.would_fit(size));
batch.track(size);
assert_eq!(batch.len(), 1);
assert_eq!(batch.size(), size);
assert!(!batch.is_empty());
assert!(batch.would_fit(size));
batch.track(size);
assert_eq!(batch.len(), 2);
assert!(!batch.would_fit(size));
}
#[test]
fn test_would_fit_first_record_always_fits() {
let mut batch = ProducerBatch::new("test".to_string(), 0, 10, Compression::None);
let large_size = 100;
assert!(batch.would_fit(large_size));
batch.track(large_size);
assert!(batch.is_full());
assert!(!batch.would_fit(1));
}
}