use bytes::{Buf, BufMut, Bytes, BytesMut};
use zerocopy::FromBytes as _;
use crate::primitives::varint::{
get_varint, get_varlong, put_varint, put_varlong, varint_len, varlong_len,
};
use crate::records::RecordsError;
use crate::records::crc::{crc32c, crc32c_append};
use crate::records::header::{Attributes, HEADER_LEN};
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct RecordHeader {
pub key: String,
pub value: Option<Bytes>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Record {
pub attributes: i8,
pub timestamp_delta: i64,
pub offset_delta: i32,
pub key: Option<Bytes>,
pub value: Option<Bytes>,
pub headers: Vec<RecordHeader>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RecordBatch {
pub base_offset: i64,
pub partition_leader_epoch: i32,
pub attributes: Attributes,
pub last_offset_delta: i32,
pub base_timestamp: i64,
pub max_timestamp: i64,
pub producer_id: i64,
pub producer_epoch: i16,
pub base_sequence: i32,
pub records: Vec<Record>,
}
impl Default for RecordBatch {
fn default() -> Self {
Self {
base_offset: 0,
partition_leader_epoch: 0,
attributes: Attributes::default(),
last_offset_delta: 0,
base_timestamp: 0,
max_timestamp: 0,
producer_id: -1, producer_epoch: -1,
base_sequence: -1,
records: Vec::new(),
}
}
}
impl Record {
pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<(), RecordsError> {
let body_len = self.body_len();
put_varlong(
buf,
i64::try_from(body_len)
.map_err(|_| RecordsError::RecordParse("record body length overflow".into()))?,
);
self.encode_body(buf)
}
pub fn encoded_len(&self) -> usize {
let body = self.body_len();
#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
let body_i64 = body as i64;
varlong_len(body_i64) + body
}
fn body_len(&self) -> usize {
let mut n = 1; n += varlong_len(self.timestamp_delta);
n += varint_len(self.offset_delta);
n += match &self.key {
None => varint_len(-1),
Some(k) => varint_len(i32::try_from(k.len()).unwrap_or(i32::MAX)) + k.len(),
};
n += match &self.value {
None => varint_len(-1),
Some(v) => varint_len(i32::try_from(v.len()).unwrap_or(i32::MAX)) + v.len(),
};
n += varint_len(i32::try_from(self.headers.len()).unwrap_or(i32::MAX));
for h in &self.headers {
let key_bytes = h.key.as_bytes();
n += varint_len(i32::try_from(key_bytes.len()).unwrap_or(i32::MAX)) + key_bytes.len();
n += match &h.value {
None => varint_len(-1),
Some(v) => varint_len(i32::try_from(v.len()).unwrap_or(i32::MAX)) + v.len(),
};
}
n
}
fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<(), RecordsError> {
buf.put_i8(self.attributes);
put_varlong(buf, self.timestamp_delta);
put_varint(buf, self.offset_delta);
match &self.key {
None => put_varint(buf, -1),
Some(k) => {
put_varint(
buf,
i32::try_from(k.len()).map_err(|_| {
RecordsError::RecordParse("record key length overflow".into())
})?,
);
buf.put_slice(k);
}
}
match &self.value {
None => put_varint(buf, -1),
Some(v) => {
put_varint(
buf,
i32::try_from(v.len()).map_err(|_| {
RecordsError::RecordParse("record value length overflow".into())
})?,
);
buf.put_slice(v);
}
}
put_varint(
buf,
i32::try_from(self.headers.len())
.map_err(|_| RecordsError::RecordParse("record header count overflow".into()))?,
);
for h in &self.headers {
let key_bytes = h.key.as_bytes();
put_varint(
buf,
i32::try_from(key_bytes.len())
.map_err(|_| RecordsError::RecordParse("header key length overflow".into()))?,
);
buf.put_slice(key_bytes);
match &h.value {
None => put_varint(buf, -1),
Some(v) => {
put_varint(
buf,
i32::try_from(v.len()).map_err(|_| {
RecordsError::RecordParse("header value length overflow".into())
})?,
);
buf.put_slice(v);
}
}
}
Ok(())
}
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, RecordsError> {
let body_len = get_varlong(buf)
.map_err(|e| RecordsError::RecordParse(format!("record length: {e}")))?;
let body_len = usize::try_from(body_len).map_err(|_| {
RecordsError::RecordParse(format!("record length negative or too large: {body_len}"))
})?;
if buf.remaining() < body_len {
return Err(RecordsError::BodyTooShort {
needed: body_len - buf.remaining(),
});
}
let mut body = buf.take(body_len);
let r = Self::decode_body(&mut body)?;
if body.has_remaining() {
return Err(RecordsError::RecordParse(format!(
"trailing bytes inside record (left={})",
body.remaining()
)));
}
Ok(r)
}
fn decode_body<B: Buf>(buf: &mut B) -> Result<Self, RecordsError> {
if buf.remaining() == 0 {
return Err(RecordsError::RecordParse("record body empty".into()));
}
let attributes = buf.get_i8();
let timestamp_delta = get_varlong(buf)
.map_err(|e| RecordsError::RecordParse(format!("timestamp_delta: {e}")))?;
let offset_delta =
get_varint(buf).map_err(|e| RecordsError::RecordParse(format!("offset_delta: {e}")))?;
let key = decode_nullable_bytes(buf, "key")?;
let value = decode_nullable_bytes(buf, "value")?;
let header_count =
get_varint(buf).map_err(|e| RecordsError::RecordParse(format!("header_count: {e}")))?;
if header_count < 0 {
return Err(RecordsError::RecordParse(format!(
"negative header count {header_count}"
)));
}
#[allow(clippy::cast_sign_loss)] let header_count_usize = header_count as usize;
let mut headers = Vec::with_capacity(header_count_usize.min(buf.remaining()));
for i in 0..header_count {
headers.push(
decode_record_header(buf)
.map_err(|e| RecordsError::RecordParse(format!("header[{i}]: {e}")))?,
);
}
Ok(Self {
attributes,
timestamp_delta,
offset_delta,
key,
value,
headers,
})
}
}
fn decode_nullable_bytes<B: Buf>(buf: &mut B, label: &str) -> Result<Option<Bytes>, RecordsError> {
let len =
get_varint(buf).map_err(|e| RecordsError::RecordParse(format!("{label} length: {e}")))?;
if len < 0 {
Ok(None)
} else {
#[allow(clippy::cast_sign_loss)] let n = len as usize;
if buf.remaining() < n {
return Err(RecordsError::BodyTooShort {
needed: n - buf.remaining(),
});
}
let mut v = vec![0u8; n];
buf.copy_to_slice(&mut v);
Ok(Some(Bytes::from(v)))
}
}
fn decode_record_header<B: Buf>(buf: &mut B) -> Result<RecordHeader, String> {
let key_len = get_varint(buf).map_err(|e| format!("key length: {e}"))?;
if key_len < 0 {
return Err(format!("non-nullable key has negative length {key_len}"));
}
#[allow(clippy::cast_sign_loss)] let n = key_len as usize;
if buf.remaining() < n {
return Err(format!("key truncated (need {} more)", n - buf.remaining()));
}
let mut kv = vec![0u8; n];
buf.copy_to_slice(&mut kv);
let key = String::from_utf8(kv).map_err(|e| format!("key utf-8: {e}"))?;
let value_len = get_varint(buf).map_err(|e| format!("value length: {e}"))?;
let value = if value_len < 0 {
None
} else {
#[allow(clippy::cast_sign_loss)] let n = value_len as usize;
if buf.remaining() < n {
return Err(format!(
"value truncated (need {} more)",
n - buf.remaining()
));
}
let mut vv = vec![0u8; n];
buf.copy_to_slice(&mut vv);
Some(Bytes::from(vv))
};
Ok(RecordHeader { key, value })
}
#[cfg(test)]
mod record_tests {
use super::*;
use assert2::assert;
use bytes::BytesMut;
fn fixture_minimal_record() -> Record {
Record {
attributes: 0,
timestamp_delta: 0,
offset_delta: 0,
key: None,
value: None,
headers: vec![],
}
}
fn fixture_keyed_record() -> Record {
Record {
attributes: 0,
timestamp_delta: 17,
offset_delta: 2,
key: Some(Bytes::from_static(b"the-key")),
value: Some(Bytes::from_static(b"hello kafka")),
headers: vec![
RecordHeader {
key: "trace-id".to_string(),
value: Some(Bytes::from_static(b"abc")),
},
RecordHeader {
key: "null-val".to_string(),
value: None,
},
],
}
}
fn fixture_large_payload_record() -> Record {
Record {
attributes: 0,
timestamp_delta: 1_000_000,
offset_delta: 999,
key: Some(Bytes::from(vec![b'k'; 128])),
value: Some(Bytes::from(vec![b'v'; 4096])),
headers: vec![],
}
}
macro_rules! roundtrip {
($name:ident, $fixture:ident) => {
#[test]
fn $name() {
let r = $fixture();
let mut buf = BytesMut::new();
r.encode(&mut buf).unwrap();
assert!(buf.len() == r.encoded_len(), "predicted len mismatch");
let mut cur: &[u8] = &buf[..];
let decoded = Record::decode(&mut cur).unwrap();
assert!(decoded == r);
assert!(cur.is_empty(), "trailing bytes after decode");
}
};
}
roundtrip!(minimal, fixture_minimal_record);
roundtrip!(keyed_with_headers, fixture_keyed_record);
roundtrip!(large_payload, fixture_large_payload_record);
#[test]
fn decode_rejects_negative_header_count() {
let mut buf = BytesMut::new();
put_varlong(&mut buf, 6); buf.put_i8(0); put_varlong(&mut buf, 0); put_varint(&mut buf, 0); put_varint(&mut buf, -1); put_varint(&mut buf, -1); put_varint(&mut buf, -1);
let mut cur: &[u8] = &buf[..];
match Record::decode(&mut cur) {
Err(RecordsError::RecordParse(msg)) => {
assert!(msg.contains("negative header count"), "got: {msg}");
}
other => panic!("expected RecordParse, got {other:?}"),
}
}
#[test]
fn decode_huge_header_count_does_not_overallocate() {
let mut inner = BytesMut::new();
inner.put_i8(0); put_varlong(&mut inner, 0); put_varint(&mut inner, 0); put_varint(&mut inner, -1); put_varint(&mut inner, -1); put_varint(&mut inner, 1_000_000_000);
let mut buf = BytesMut::new();
put_varlong(&mut buf, i64::try_from(inner.len()).unwrap());
buf.extend_from_slice(&inner);
let mut cur: &[u8] = &buf[..];
assert!(Record::decode(&mut cur).is_err());
}
}
impl RecordBatch {
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, RecordsError> {
const HEADER_TAIL_LEN: i32 = 49;
if buf.remaining() < HEADER_LEN {
return Err(RecordsError::HeaderTooShort {
needed: HEADER_LEN - buf.remaining(),
});
}
let mut hdr_bytes = [0u8; HEADER_LEN];
buf.copy_to_slice(&mut hdr_bytes);
let hdr = crate::records::header::RecordBatchHeader::ref_from_bytes(&hdr_bytes[..])
.map_err(|_| RecordsError::ZerocopyFailure)?;
if hdr.magic != 2 {
return Err(RecordsError::UnsupportedMagic { found: hdr.magic });
}
let body_len = i32::checked_sub(hdr.batch_length.get(), HEADER_TAIL_LEN)
.and_then(|n| usize::try_from(n).ok())
.ok_or_else(|| {
RecordsError::RecordParse("negative or oversized batch_length".into())
})?;
if buf.remaining() < body_len {
return Err(RecordsError::BodyTooShort {
needed: body_len - buf.remaining(),
});
}
let mut body = vec![0u8; body_len];
buf.copy_to_slice(&mut body);
let expected_crc = hdr.crc.get();
let mut computed = crc32c(&hdr_bytes[21..HEADER_LEN]);
computed = crc32c_append(computed, &body);
if computed != expected_crc {
return Err(RecordsError::CrcMismatch {
expected: expected_crc,
computed,
});
}
let attributes = Attributes(hdr.attributes.get());
let codec = attributes.compression();
let body_for_records: Bytes = if codec == crabka_compression::CompressionType::None {
Bytes::from(body)
} else {
const DECOMPRESS_MIN_CAP: usize = 16 * 1024 * 1024; const DECOMPRESS_MAX_RATIO: usize = 100; const DECOMPRESS_ABSOLUTE_CEILING: usize = 1024 * 1024 * 1024; let max_output = body
.len()
.saturating_mul(DECOMPRESS_MAX_RATIO)
.clamp(DECOMPRESS_MIN_CAP, DECOMPRESS_ABSOLUTE_CEILING);
crabka_compression::decompress(codec, &body, max_output)?
};
let count = hdr.records_count.get();
if count < 0 {
return Err(RecordsError::RecordParse(format!(
"negative records_count {count}"
)));
}
let mut body_cur: &[u8] = &body_for_records[..];
#[allow(clippy::cast_sign_loss)] let mut records = Vec::with_capacity((count as usize).min(body_for_records.len()));
for i in 0..count {
records.push(
Record::decode(&mut body_cur)
.map_err(|e| RecordsError::RecordParse(format!("record[{i}]: {e}")))?,
);
}
if !body_cur.is_empty() {
return Err(RecordsError::RecordParse(format!(
"trailing bytes after records (left={})",
body_cur.len()
)));
}
Ok(Self {
base_offset: hdr.base_offset.get(),
partition_leader_epoch: hdr.partition_leader_epoch.get(),
attributes,
last_offset_delta: hdr.last_offset_delta.get(),
base_timestamp: hdr.base_timestamp.get(),
max_timestamp: hdr.max_timestamp.get(),
producer_id: hdr.producer_id.get(),
producer_epoch: hdr.producer_epoch.get(),
base_sequence: hdr.base_sequence.get(),
records,
})
}
pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<(), RecordsError> {
const HEADER_TAIL_LEN: i32 = 49;
let mut raw_body =
BytesMut::with_capacity(self.records.iter().map(Record::encoded_len).sum());
for r in &self.records {
r.encode(&mut raw_body)?;
}
let raw_body = raw_body.freeze();
let codec = self.attributes.compression();
let body: Bytes = if codec == crabka_compression::CompressionType::None {
raw_body
} else {
crabka_compression::compress(codec, &raw_body)?
};
let batch_length = HEADER_TAIL_LEN
+ i32::try_from(body.len())
.map_err(|_| RecordsError::RecordParse("body length exceeds i32".into()))?;
let mut covered = BytesMut::with_capacity(40);
covered.put_i16(self.attributes.0);
covered.put_i32(self.last_offset_delta);
covered.put_i64(self.base_timestamp);
covered.put_i64(self.max_timestamp);
covered.put_i64(self.producer_id);
covered.put_i16(self.producer_epoch);
covered.put_i32(self.base_sequence);
covered.put_i32(
i32::try_from(self.records.len())
.map_err(|_| RecordsError::RecordParse("records_count exceeds i32".into()))?,
);
let covered_head = covered.freeze();
let mut crc = crc32c(&covered_head);
crc = crc32c_append(crc, &body);
buf.put_i64(self.base_offset);
buf.put_i32(batch_length);
buf.put_i32(self.partition_leader_epoch);
buf.put_i8(2); buf.put_u32(crc);
buf.put_slice(&covered_head);
buf.put_slice(&body);
Ok(())
}
pub fn encoded_len(&self) -> usize {
let body: usize = self.records.iter().map(Record::encoded_len).sum();
HEADER_LEN + body
}
}
#[cfg(test)]
mod batch_tests {
use super::*;
use assert2::assert;
use crabka_compression::CompressionType;
fn fixture_empty_batch() -> RecordBatch {
RecordBatch::default()
}
fn fixture_single_record_batch() -> RecordBatch {
RecordBatch {
records: vec![Record {
key: Some(Bytes::from_static(b"k1")),
value: Some(Bytes::from_static(b"v1")),
..Default::default()
}],
..RecordBatch::default()
}
}
fn fixture_multi_record_batch() -> RecordBatch {
RecordBatch {
base_offset: 42,
partition_leader_epoch: 5,
last_offset_delta: 2,
base_timestamp: 1_700_000_000,
max_timestamp: 1_700_000_500,
producer_id: 100,
producer_epoch: 3,
base_sequence: 7,
records: vec![
Record {
offset_delta: 0,
timestamp_delta: 0,
key: Some(Bytes::from_static(b"a")),
value: Some(Bytes::from_static(b"1")),
..Default::default()
},
Record {
offset_delta: 1,
timestamp_delta: 100,
key: Some(Bytes::from_static(b"b")),
value: Some(Bytes::from_static(b"2")),
..Default::default()
},
Record {
offset_delta: 2,
timestamp_delta: 500,
key: None,
value: Some(Bytes::from_static(b"3")),
headers: vec![RecordHeader {
key: "h".to_string(),
value: Some(Bytes::from_static(b"hv")),
}],
..Default::default()
},
],
..RecordBatch::default()
}
}
macro_rules! roundtrip_uncompressed {
($name:ident, $fixture:ident) => {
#[test]
fn $name() {
let mut b = $fixture();
b.attributes = b.attributes.with_compression(CompressionType::None);
let mut buf = BytesMut::new();
b.encode(&mut buf).unwrap();
assert!(buf.len() == b.encoded_len());
let mut cur: &[u8] = &buf[..];
let decoded = RecordBatch::decode(&mut cur).unwrap();
assert!(decoded == b);
assert!(cur.is_empty());
}
};
}
roundtrip_uncompressed!(uncompressed_empty, fixture_empty_batch);
roundtrip_uncompressed!(uncompressed_single, fixture_single_record_batch);
roundtrip_uncompressed!(uncompressed_multi, fixture_multi_record_batch);
#[test]
fn rejects_pre_v2_magic() {
let mut buf = BytesMut::new();
buf.put_i64(0); buf.put_i32(49); buf.put_i32(0); buf.put_i8(1); buf.put_u32(0); for _ in 21..HEADER_LEN {
buf.put_u8(0);
}
let mut cur: &[u8] = &buf[..];
assert!(matches!(
RecordBatch::decode(&mut cur),
Err(RecordsError::UnsupportedMagic { found: 1 })
));
}
#[test]
fn rejects_bad_crc() {
let b = fixture_single_record_batch();
let mut buf = BytesMut::new();
b.encode(&mut buf).unwrap();
buf[17] ^= 0xFF;
let mut cur: &[u8] = &buf[..];
assert!(matches!(
RecordBatch::decode(&mut cur),
Err(RecordsError::CrcMismatch { .. })
));
}
macro_rules! roundtrip_compressed {
($name:ident, $codec:expr) => {
#[test]
fn $name() {
let mut b = fixture_multi_record_batch();
b.attributes = b.attributes.with_compression($codec);
let mut buf = BytesMut::new();
b.encode(&mut buf).unwrap();
let mut cur: &[u8] = &buf[..];
let decoded = RecordBatch::decode(&mut cur).unwrap();
assert!(decoded == b);
assert!(cur.is_empty());
}
};
}
roundtrip_compressed!(compressed_gzip, CompressionType::Gzip);
roundtrip_compressed!(compressed_snappy, CompressionType::Snappy);
roundtrip_compressed!(compressed_lz4, CompressionType::Lz4);
roundtrip_compressed!(compressed_zstd, CompressionType::Zstd);
#[test]
fn decode_huge_records_count_does_not_overallocate() {
let mut b = fixture_empty_batch();
b.attributes = b.attributes.with_compression(CompressionType::None);
let mut buf = BytesMut::new();
b.encode(&mut buf).unwrap();
let rc_off = HEADER_LEN - 4;
buf[rc_off..HEADER_LEN].copy_from_slice(&1_000_000_000i32.to_be_bytes());
let body = &buf[HEADER_LEN..];
let mut computed = crc32c(&buf[21..HEADER_LEN]);
computed = crc32c_append(computed, body);
buf[17..21].copy_from_slice(&computed.to_be_bytes());
let mut cur: &[u8] = &buf[..];
assert!(RecordBatch::decode(&mut cur).is_err());
}
}
impl crate::Encode for RecordBatch {
fn encode<B: BufMut>(&self, buf: &mut B, _version: i16) -> Result<(), crate::ProtocolError> {
RecordBatch::encode(self, buf).map_err(Into::into)
}
fn encoded_len(&self, _version: i16) -> usize {
RecordBatch::encoded_len(self)
}
}
impl crate::Decode<'_> for RecordBatch {
fn decode<B: Buf>(buf: &mut B, _version: i16) -> Result<Self, crate::ProtocolError> {
RecordBatch::decode(buf).map_err(Into::into)
}
}