1use bytes::{Buf, BufMut, Bytes, BytesMut};
4use zerocopy::FromBytes as _;
5
6use crate::primitives::varint::{
7 get_varint, get_varlong, put_varint, put_varlong, varint_len, varlong_len,
8};
9use crate::records::RecordsError;
10use crate::records::crc::{crc32c, crc32c_append};
11use crate::records::header::{Attributes, HEADER_LEN};
12
13#[derive(Debug, Clone, PartialEq, Eq, Default)]
14pub struct RecordHeader {
15 pub key: String,
16 pub value: Option<Bytes>,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Default)]
20pub struct Record {
21 pub attributes: i8,
22 pub timestamp_delta: i64,
23 pub offset_delta: i32,
24 pub key: Option<Bytes>,
25 pub value: Option<Bytes>,
26 pub headers: Vec<RecordHeader>,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct RecordBatch {
31 pub base_offset: i64,
32 pub partition_leader_epoch: i32,
33 pub attributes: Attributes,
34 pub last_offset_delta: i32,
35 pub base_timestamp: i64,
36 pub max_timestamp: i64,
37 pub producer_id: i64,
38 pub producer_epoch: i16,
39 pub base_sequence: i32,
40 pub records: Vec<Record>,
41}
42
43impl Default for RecordBatch {
44 fn default() -> Self {
45 Self {
46 base_offset: 0,
47 partition_leader_epoch: 0,
48 attributes: Attributes::default(),
49 last_offset_delta: 0,
50 base_timestamp: 0,
51 max_timestamp: 0,
52 producer_id: -1, producer_epoch: -1,
54 base_sequence: -1,
55 records: Vec::new(),
56 }
57 }
58}
59
60impl Record {
61 pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<(), RecordsError> {
63 let body_len = self.body_len();
64 put_varlong(
65 buf,
66 i64::try_from(body_len)
67 .map_err(|_| RecordsError::RecordParse("record body length overflow".into()))?,
68 );
69 self.encode_body(buf)
70 }
71
72 pub fn encoded_len(&self) -> usize {
74 let body = self.body_len();
75 #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
76 let body_i64 = body as i64;
77 varlong_len(body_i64) + body
78 }
79
80 fn body_len(&self) -> usize {
81 let mut n = 1; n += varlong_len(self.timestamp_delta);
83 n += varint_len(self.offset_delta);
84 n += match &self.key {
85 None => varint_len(-1),
86 Some(k) => varint_len(i32::try_from(k.len()).unwrap_or(i32::MAX)) + k.len(),
87 };
88 n += match &self.value {
89 None => varint_len(-1),
90 Some(v) => varint_len(i32::try_from(v.len()).unwrap_or(i32::MAX)) + v.len(),
91 };
92 n += varint_len(i32::try_from(self.headers.len()).unwrap_or(i32::MAX));
93 for h in &self.headers {
94 let key_bytes = h.key.as_bytes();
95 n += varint_len(i32::try_from(key_bytes.len()).unwrap_or(i32::MAX)) + key_bytes.len();
96 n += match &h.value {
97 None => varint_len(-1),
98 Some(v) => varint_len(i32::try_from(v.len()).unwrap_or(i32::MAX)) + v.len(),
99 };
100 }
101 n
102 }
103
104 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<(), RecordsError> {
105 buf.put_i8(self.attributes);
106 put_varlong(buf, self.timestamp_delta);
107 put_varint(buf, self.offset_delta);
108 match &self.key {
109 None => put_varint(buf, -1),
110 Some(k) => {
111 put_varint(
112 buf,
113 i32::try_from(k.len()).map_err(|_| {
114 RecordsError::RecordParse("record key length overflow".into())
115 })?,
116 );
117 buf.put_slice(k);
118 }
119 }
120 match &self.value {
121 None => put_varint(buf, -1),
122 Some(v) => {
123 put_varint(
124 buf,
125 i32::try_from(v.len()).map_err(|_| {
126 RecordsError::RecordParse("record value length overflow".into())
127 })?,
128 );
129 buf.put_slice(v);
130 }
131 }
132 put_varint(
133 buf,
134 i32::try_from(self.headers.len())
135 .map_err(|_| RecordsError::RecordParse("record header count overflow".into()))?,
136 );
137 for h in &self.headers {
138 let key_bytes = h.key.as_bytes();
139 put_varint(
140 buf,
141 i32::try_from(key_bytes.len())
142 .map_err(|_| RecordsError::RecordParse("header key length overflow".into()))?,
143 );
144 buf.put_slice(key_bytes);
145 match &h.value {
146 None => put_varint(buf, -1),
147 Some(v) => {
148 put_varint(
149 buf,
150 i32::try_from(v.len()).map_err(|_| {
151 RecordsError::RecordParse("header value length overflow".into())
152 })?,
153 );
154 buf.put_slice(v);
155 }
156 }
157 }
158 Ok(())
159 }
160
161 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, RecordsError> {
164 let body_len = get_varlong(buf)
165 .map_err(|e| RecordsError::RecordParse(format!("record length: {e}")))?;
166 let body_len = usize::try_from(body_len).map_err(|_| {
167 RecordsError::RecordParse(format!("record length negative or too large: {body_len}"))
168 })?;
169 if buf.remaining() < body_len {
170 return Err(RecordsError::BodyTooShort {
171 needed: body_len - buf.remaining(),
172 });
173 }
174 let mut body = buf.take(body_len);
177 let r = Self::decode_body(&mut body)?;
178 if body.has_remaining() {
180 return Err(RecordsError::RecordParse(format!(
181 "trailing bytes inside record (left={})",
182 body.remaining()
183 )));
184 }
185 Ok(r)
186 }
187
188 fn decode_body<B: Buf>(buf: &mut B) -> Result<Self, RecordsError> {
189 if buf.remaining() == 0 {
190 return Err(RecordsError::RecordParse("record body empty".into()));
191 }
192 let attributes = buf.get_i8();
193 let timestamp_delta = get_varlong(buf)
194 .map_err(|e| RecordsError::RecordParse(format!("timestamp_delta: {e}")))?;
195 let offset_delta =
196 get_varint(buf).map_err(|e| RecordsError::RecordParse(format!("offset_delta: {e}")))?;
197
198 let key = decode_nullable_bytes(buf, "key")?;
199 let value = decode_nullable_bytes(buf, "value")?;
200
201 let header_count =
202 get_varint(buf).map_err(|e| RecordsError::RecordParse(format!("header_count: {e}")))?;
203 if header_count < 0 {
204 return Err(RecordsError::RecordParse(format!(
205 "negative header count {header_count}"
206 )));
207 }
208 #[allow(clippy::cast_sign_loss)] let header_count_usize = header_count as usize;
210 let mut headers = Vec::with_capacity(header_count_usize);
211 for i in 0..header_count {
212 headers.push(
213 decode_record_header(buf)
214 .map_err(|e| RecordsError::RecordParse(format!("header[{i}]: {e}")))?,
215 );
216 }
217
218 Ok(Self {
219 attributes,
220 timestamp_delta,
221 offset_delta,
222 key,
223 value,
224 headers,
225 })
226 }
227}
228
229fn decode_nullable_bytes<B: Buf>(buf: &mut B, label: &str) -> Result<Option<Bytes>, RecordsError> {
230 let len =
231 get_varint(buf).map_err(|e| RecordsError::RecordParse(format!("{label} length: {e}")))?;
232 if len < 0 {
233 Ok(None)
234 } else {
235 #[allow(clippy::cast_sign_loss)] let n = len as usize;
237 if buf.remaining() < n {
238 return Err(RecordsError::BodyTooShort {
239 needed: n - buf.remaining(),
240 });
241 }
242 let mut v = vec![0u8; n];
243 buf.copy_to_slice(&mut v);
244 Ok(Some(Bytes::from(v)))
245 }
246}
247
248fn decode_record_header<B: Buf>(buf: &mut B) -> Result<RecordHeader, String> {
249 let key_len = get_varint(buf).map_err(|e| format!("key length: {e}"))?;
250 if key_len < 0 {
251 return Err(format!("non-nullable key has negative length {key_len}"));
252 }
253 #[allow(clippy::cast_sign_loss)] let n = key_len as usize;
255 if buf.remaining() < n {
256 return Err(format!("key truncated (need {} more)", n - buf.remaining()));
257 }
258 let mut kv = vec![0u8; n];
259 buf.copy_to_slice(&mut kv);
260 let key = String::from_utf8(kv).map_err(|e| format!("key utf-8: {e}"))?;
261
262 let value_len = get_varint(buf).map_err(|e| format!("value length: {e}"))?;
263 let value = if value_len < 0 {
264 None
265 } else {
266 #[allow(clippy::cast_sign_loss)] let n = value_len as usize;
268 if buf.remaining() < n {
269 return Err(format!(
270 "value truncated (need {} more)",
271 n - buf.remaining()
272 ));
273 }
274 let mut vv = vec![0u8; n];
275 buf.copy_to_slice(&mut vv);
276 Some(Bytes::from(vv))
277 };
278
279 Ok(RecordHeader { key, value })
280}
281
282#[cfg(test)]
283mod record_tests {
284 use super::*;
285 use assert2::assert;
286 use bytes::BytesMut;
287
288 fn fixture_minimal_record() -> Record {
289 Record {
290 attributes: 0,
291 timestamp_delta: 0,
292 offset_delta: 0,
293 key: None,
294 value: None,
295 headers: vec![],
296 }
297 }
298
299 fn fixture_keyed_record() -> Record {
300 Record {
301 attributes: 0,
302 timestamp_delta: 17,
303 offset_delta: 2,
304 key: Some(Bytes::from_static(b"the-key")),
305 value: Some(Bytes::from_static(b"hello kafka")),
306 headers: vec![
307 RecordHeader {
308 key: "trace-id".to_string(),
309 value: Some(Bytes::from_static(b"abc")),
310 },
311 RecordHeader {
312 key: "null-val".to_string(),
313 value: None,
314 },
315 ],
316 }
317 }
318
319 fn fixture_large_payload_record() -> Record {
320 Record {
321 attributes: 0,
322 timestamp_delta: 1_000_000,
323 offset_delta: 999,
324 key: Some(Bytes::from(vec![b'k'; 128])),
325 value: Some(Bytes::from(vec![b'v'; 4096])),
326 headers: vec![],
327 }
328 }
329
330 macro_rules! roundtrip {
331 ($name:ident, $fixture:ident) => {
332 #[test]
333 fn $name() {
334 let r = $fixture();
335 let mut buf = BytesMut::new();
336 r.encode(&mut buf).unwrap();
337 assert!(buf.len() == r.encoded_len(), "predicted len mismatch");
338
339 let mut cur: &[u8] = &buf[..];
340 let decoded = Record::decode(&mut cur).unwrap();
341 assert!(decoded == r);
342 assert!(cur.is_empty(), "trailing bytes after decode");
343 }
344 };
345 }
346
347 roundtrip!(minimal, fixture_minimal_record);
348 roundtrip!(keyed_with_headers, fixture_keyed_record);
349 roundtrip!(large_payload, fixture_large_payload_record);
350
351 #[test]
352 fn decode_rejects_negative_header_count() {
353 let mut buf = BytesMut::new();
354 put_varlong(&mut buf, 6); buf.put_i8(0); put_varlong(&mut buf, 0); put_varint(&mut buf, 0); put_varint(&mut buf, -1); put_varint(&mut buf, -1); put_varint(&mut buf, -1); let mut cur: &[u8] = &buf[..];
365 match Record::decode(&mut cur) {
366 Err(RecordsError::RecordParse(msg)) => {
367 assert!(msg.contains("negative header count"), "got: {msg}");
368 }
369 other => panic!("expected RecordParse, got {other:?}"),
370 }
371 }
372}
373
374impl RecordBatch {
375 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, RecordsError> {
378 const HEADER_TAIL_LEN: i32 = 49;
384
385 if buf.remaining() < HEADER_LEN {
387 return Err(RecordsError::HeaderTooShort {
388 needed: HEADER_LEN - buf.remaining(),
389 });
390 }
391 let mut hdr_bytes = [0u8; HEADER_LEN];
393 buf.copy_to_slice(&mut hdr_bytes);
394
395 let hdr = crate::records::header::RecordBatchHeader::ref_from_bytes(&hdr_bytes[..])
396 .map_err(|_| RecordsError::ZerocopyFailure)?;
397
398 if hdr.magic != 2 {
399 return Err(RecordsError::UnsupportedMagic { found: hdr.magic });
400 }
401
402 let body_len = i32::checked_sub(hdr.batch_length.get(), HEADER_TAIL_LEN)
404 .and_then(|n| usize::try_from(n).ok())
405 .ok_or_else(|| {
406 RecordsError::RecordParse("negative or oversized batch_length".into())
407 })?;
408
409 if buf.remaining() < body_len {
410 return Err(RecordsError::BodyTooShort {
411 needed: body_len - buf.remaining(),
412 });
413 }
414
415 let mut body = vec![0u8; body_len];
417 buf.copy_to_slice(&mut body);
418
419 let expected_crc = hdr.crc.get();
422 let mut computed = crc32c(&hdr_bytes[21..HEADER_LEN]);
423 computed = crc32c_append(computed, &body);
424 if computed != expected_crc {
425 return Err(RecordsError::CrcMismatch {
426 expected: expected_crc,
427 computed,
428 });
429 }
430
431 let attributes = Attributes(hdr.attributes.get());
432 let codec = attributes.compression();
433
434 let body_for_records: Bytes = if codec == crabka_compression::CompressionType::None {
436 Bytes::from(body)
437 } else {
438 crabka_compression::decompress(codec, &body)?
439 };
440
441 let count = hdr.records_count.get();
443 if count < 0 {
444 return Err(RecordsError::RecordParse(format!(
445 "negative records_count {count}"
446 )));
447 }
448 let mut body_cur: &[u8] = &body_for_records[..];
449 #[allow(clippy::cast_sign_loss)] let mut records = Vec::with_capacity(count as usize);
451 for i in 0..count {
452 records.push(
453 Record::decode(&mut body_cur)
454 .map_err(|e| RecordsError::RecordParse(format!("record[{i}]: {e}")))?,
455 );
456 }
457 if !body_cur.is_empty() {
458 return Err(RecordsError::RecordParse(format!(
459 "trailing bytes after records (left={})",
460 body_cur.len()
461 )));
462 }
463
464 Ok(Self {
465 base_offset: hdr.base_offset.get(),
466 partition_leader_epoch: hdr.partition_leader_epoch.get(),
467 attributes,
468 last_offset_delta: hdr.last_offset_delta.get(),
469 base_timestamp: hdr.base_timestamp.get(),
470 max_timestamp: hdr.max_timestamp.get(),
471 producer_id: hdr.producer_id.get(),
472 producer_epoch: hdr.producer_epoch.get(),
473 base_sequence: hdr.base_sequence.get(),
474 records,
475 })
476 }
477
478 pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<(), RecordsError> {
480 const HEADER_TAIL_LEN: i32 = 49;
481
482 let mut raw_body =
484 BytesMut::with_capacity(self.records.iter().map(Record::encoded_len).sum());
485 for r in &self.records {
486 r.encode(&mut raw_body)?;
487 }
488 let raw_body = raw_body.freeze();
489
490 let codec = self.attributes.compression();
492 let body: Bytes = if codec == crabka_compression::CompressionType::None {
493 raw_body
494 } else {
495 crabka_compression::compress(codec, &raw_body)?
496 };
497
498 let batch_length = HEADER_TAIL_LEN
500 + i32::try_from(body.len())
501 .map_err(|_| RecordsError::RecordParse("body length exceeds i32".into()))?;
502
503 let mut covered = BytesMut::with_capacity(40);
505 covered.put_i16(self.attributes.0);
506 covered.put_i32(self.last_offset_delta);
507 covered.put_i64(self.base_timestamp);
508 covered.put_i64(self.max_timestamp);
509 covered.put_i64(self.producer_id);
510 covered.put_i16(self.producer_epoch);
511 covered.put_i32(self.base_sequence);
512 covered.put_i32(
513 i32::try_from(self.records.len())
514 .map_err(|_| RecordsError::RecordParse("records_count exceeds i32".into()))?,
515 );
516 let covered_head = covered.freeze();
517
518 let mut crc = crc32c(&covered_head);
520 crc = crc32c_append(crc, &body);
521
522 buf.put_i64(self.base_offset);
524 buf.put_i32(batch_length);
525 buf.put_i32(self.partition_leader_epoch);
526 buf.put_i8(2); buf.put_u32(crc);
528 buf.put_slice(&covered_head);
529 buf.put_slice(&body);
530 Ok(())
531 }
532
533 pub fn encoded_len(&self) -> usize {
536 let body: usize = self.records.iter().map(Record::encoded_len).sum();
537 HEADER_LEN + body
538 }
539}
540
541#[cfg(test)]
542mod batch_tests {
543 use super::*;
544 use assert2::assert;
545 use crabka_compression::CompressionType;
546
547 fn fixture_empty_batch() -> RecordBatch {
548 RecordBatch::default()
549 }
550
551 fn fixture_single_record_batch() -> RecordBatch {
552 RecordBatch {
553 records: vec![Record {
554 key: Some(Bytes::from_static(b"k1")),
555 value: Some(Bytes::from_static(b"v1")),
556 ..Default::default()
557 }],
558 ..RecordBatch::default()
559 }
560 }
561
562 fn fixture_multi_record_batch() -> RecordBatch {
563 RecordBatch {
564 base_offset: 42,
565 partition_leader_epoch: 5,
566 last_offset_delta: 2,
567 base_timestamp: 1_700_000_000,
568 max_timestamp: 1_700_000_500,
569 producer_id: 100,
570 producer_epoch: 3,
571 base_sequence: 7,
572 records: vec![
573 Record {
574 offset_delta: 0,
575 timestamp_delta: 0,
576 key: Some(Bytes::from_static(b"a")),
577 value: Some(Bytes::from_static(b"1")),
578 ..Default::default()
579 },
580 Record {
581 offset_delta: 1,
582 timestamp_delta: 100,
583 key: Some(Bytes::from_static(b"b")),
584 value: Some(Bytes::from_static(b"2")),
585 ..Default::default()
586 },
587 Record {
588 offset_delta: 2,
589 timestamp_delta: 500,
590 key: None,
591 value: Some(Bytes::from_static(b"3")),
592 headers: vec![RecordHeader {
593 key: "h".to_string(),
594 value: Some(Bytes::from_static(b"hv")),
595 }],
596 ..Default::default()
597 },
598 ],
599 ..RecordBatch::default()
600 }
601 }
602
603 macro_rules! roundtrip_uncompressed {
604 ($name:ident, $fixture:ident) => {
605 #[test]
606 fn $name() {
607 let mut b = $fixture();
608 b.attributes = b.attributes.with_compression(CompressionType::None);
609
610 let mut buf = BytesMut::new();
611 b.encode(&mut buf).unwrap();
612 assert!(buf.len() == b.encoded_len());
613
614 let mut cur: &[u8] = &buf[..];
615 let decoded = RecordBatch::decode(&mut cur).unwrap();
616 assert!(decoded == b);
617 assert!(cur.is_empty());
618 }
619 };
620 }
621
622 roundtrip_uncompressed!(uncompressed_empty, fixture_empty_batch);
623 roundtrip_uncompressed!(uncompressed_single, fixture_single_record_batch);
624 roundtrip_uncompressed!(uncompressed_multi, fixture_multi_record_batch);
625
626 #[test]
627 fn rejects_pre_v2_magic() {
628 let mut buf = BytesMut::new();
629 buf.put_i64(0); buf.put_i32(49); buf.put_i32(0); buf.put_i8(1); buf.put_u32(0); for _ in 21..HEADER_LEN {
635 buf.put_u8(0);
636 }
637 let mut cur: &[u8] = &buf[..];
638 assert!(matches!(
639 RecordBatch::decode(&mut cur),
640 Err(RecordsError::UnsupportedMagic { found: 1 })
641 ));
642 }
643
644 #[test]
645 fn rejects_bad_crc() {
646 let b = fixture_single_record_batch();
647 let mut buf = BytesMut::new();
648 b.encode(&mut buf).unwrap();
649 buf[17] ^= 0xFF;
651 let mut cur: &[u8] = &buf[..];
652 assert!(matches!(
653 RecordBatch::decode(&mut cur),
654 Err(RecordsError::CrcMismatch { .. })
655 ));
656 }
657
658 macro_rules! roundtrip_compressed {
659 ($name:ident, $codec:expr) => {
660 #[test]
661 fn $name() {
662 let mut b = fixture_multi_record_batch();
663 b.attributes = b.attributes.with_compression($codec);
664
665 let mut buf = BytesMut::new();
666 b.encode(&mut buf).unwrap();
667 let mut cur: &[u8] = &buf[..];
668 let decoded = RecordBatch::decode(&mut cur).unwrap();
669 assert!(decoded == b);
670 assert!(cur.is_empty());
671 }
672 };
673 }
674
675 roundtrip_compressed!(compressed_gzip, CompressionType::Gzip);
676 roundtrip_compressed!(compressed_snappy, CompressionType::Snappy);
677 roundtrip_compressed!(compressed_lz4, CompressionType::Lz4);
678 roundtrip_compressed!(compressed_zstd, CompressionType::Zstd);
679}
680
681impl crate::Encode for RecordBatch {
682 fn encode<B: BufMut>(&self, buf: &mut B, _version: i16) -> Result<(), crate::ProtocolError> {
683 RecordBatch::encode(self, buf).map_err(Into::into)
684 }
685
686 fn encoded_len(&self, _version: i16) -> usize {
687 RecordBatch::encoded_len(self)
688 }
689}
690
691impl crate::Decode<'_> for RecordBatch {
692 fn decode<B: Buf>(buf: &mut B, _version: i16) -> Result<Self, crate::ProtocolError> {
693 RecordBatch::decode(buf).map_err(Into::into)
694 }
695}