kafka_api/records/
record_batch.rs

1// Copyright 2024 tison <wander4096@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::Ref;
16use std::cell::RefCell;
17use std::cell::RefMut;
18use std::fmt::Debug;
19use std::fmt::Formatter;
20
21use byteorder::BigEndian;
22use byteorder::ReadBytesExt;
23use byteorder::WriteBytesExt;
24
25use crate::codec::err_codec_message;
26use crate::codec::Decoder;
27use crate::codec::RecordList;
28use crate::records::*;
29use crate::IoResult;
30
31fn check_and_fetch_batch_size(bytes: &[u8], remaining: usize) -> IoResult<usize> {
32    if remaining < RECORD_BATCH_OVERHEAD {
33        return Err(err_codec_message(format!(
34            "no enough bytes when decode records (remaining: {}, required header: {})",
35            remaining, RECORD_BATCH_OVERHEAD
36        )));
37    }
38
39    let record_size = (&bytes[LENGTH_OFFSET..])
40        .read_i32::<BigEndian>()
41        .map_err(|err| err_codec_message(format!("failed to read record size: {err}")))?;
42    let batch_size = record_size as usize + LOG_OVERHEAD;
43    if remaining < batch_size {
44        return Err(err_codec_message(format!(
45            "no enough bytes when decode records (remaining: {}, required batch: {})",
46            remaining, batch_size
47        )));
48    }
49
50    let magic = (&bytes[MAGIC_OFFSET..])
51        .read_i8()
52        .map_err(|err| err_codec_message(format!("failed to read version: {err}")))?;
53    if magic != 2 {
54        return Err(err_codec_message(format!(
55            "unsupported record batch version: {}",
56            magic
57        )));
58    }
59
60    Ok(batch_size)
61}
62
63#[derive(Default)]
64pub struct RecordBatches {
65    bytes: RefCell<Vec<u8>>,
66}
67
68impl Debug for RecordBatches {
69    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
70        let mut de = f.debug_struct("RecordBatches");
71        de.field("batches", &self.batches());
72        de.finish()
73    }
74}
75
76impl RecordBatches {
77    pub fn new(bytes: Vec<u8>) -> Self {
78        RecordBatches {
79            bytes: RefCell::new(bytes),
80        }
81    }
82
83    pub fn into_bytes(self) -> Vec<u8> {
84        self.bytes.into_inner()
85    }
86
87    pub fn mut_batches(&mut self) -> IoResult<Vec<MutableRecordBatch>> {
88        let mut batches = vec![];
89        let mut bytes = RefMut::map(self.bytes.borrow_mut(), |bs| bs.as_mut_slice());
90        let mut remaining = bytes.len();
91        while remaining > 0 {
92            let batch_size = check_and_fetch_batch_size(&bytes, remaining)?;
93            let (left, right) = RefMut::map_split(bytes, |b| b.split_at_mut(batch_size));
94            batches.push(MutableRecordBatch { bytes: left });
95            bytes = right;
96            remaining -= batch_size;
97        }
98        Ok(batches)
99    }
100
101    pub fn batches(&self) -> IoResult<Vec<RecordBatch>> {
102        let mut batches = vec![];
103        let mut bytes = Ref::map(self.bytes.borrow(), |bs| bs.as_slice());
104        let mut remaining = bytes.len();
105        while remaining > 0 {
106            let batch_size = check_and_fetch_batch_size(&bytes, remaining)?;
107            let (left, right) = Ref::map_split(bytes, |b| b.split_at(batch_size));
108            batches.push(RecordBatch { bytes: left });
109            bytes = right;
110            remaining -= batch_size;
111        }
112        Ok(batches)
113    }
114}
115
116pub struct MutableRecordBatch<'a> {
117    bytes: RefMut<'a, [u8]>,
118}
119
120// SAFETY: record's length are validated on construction; so all slices are valid.
121impl MutableRecordBatch<'_> {
122    pub fn view(&self) -> RecordBatchView {
123        RecordBatchView { bytes: &self.bytes }
124    }
125
126    pub fn set_last_offset(&mut self, offset: i64) {
127        let base_offset = offset - self.view().last_offset_delta() as i64;
128        (&mut self.bytes[BASE_OFFSET_OFFSET..])
129            .write_i64::<BigEndian>(base_offset)
130            .expect("write base offset");
131    }
132
133    pub fn set_partition_leader_epoch(&mut self, epoch: i32) {
134        (&mut self.bytes[PARTITION_LEADER_EPOCH_OFFSET..])
135            .write_i32::<BigEndian>(epoch)
136            .expect("write partition leader epoch");
137    }
138}
139
140impl Debug for MutableRecordBatch<'_> {
141    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
142        RecordBatchView::fmt(&self.view(), f)
143    }
144}
145
146pub struct RecordBatch<'a> {
147    bytes: Ref<'a, [u8]>,
148}
149
150impl RecordBatch<'_> {
151    pub fn view(&self) -> RecordBatchView {
152        RecordBatchView { bytes: &self.bytes }
153    }
154}
155
156impl Debug for RecordBatch<'_> {
157    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
158        RecordBatchView::fmt(&self.view(), f)
159    }
160}
161
162pub struct RecordBatchView<'a> {
163    bytes: &'a [u8],
164}
165
166impl Debug for RecordBatchView<'_> {
167    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
168        let mut de = f.debug_struct("RecordBatch");
169        de.field("magic", &self.magic());
170        de.field("offset", &(self.base_offset()..=self.last_offset()));
171        de.field("sequence", &(self.base_sequence()..=self.last_sequence()));
172        de.field("is_transactional", &self.is_transactional());
173        de.field("is_control_batch", &self.is_control_batch());
174        de.field("compression_type", &self.compression_type());
175        de.field("timestamp_type", &self.timestamp_type());
176        de.field("crc", &self.checksum());
177        de.field("records_count", &self.records_count());
178        de.field("records", &self.records());
179        de.finish()
180    }
181}
182
183/// Similar to [i32::wrapping_add], but wrap to `0` instead of [i32::MIN].
184pub fn increment_sequence(sequence: i32, increment: i32) -> i32 {
185    if sequence > i32::MAX - increment {
186        increment - (i32::MAX - sequence) - 1
187    } else {
188        sequence + increment
189    }
190}
191
192/// Similar to [i32::wrapping_add], but wrap at `0` instead of [i32::MIN].
193pub fn decrement_sequence(sequence: i32, decrement: i32) -> i32 {
194    if sequence < decrement {
195        i32::MAX - (decrement - sequence) + 1
196    } else {
197        sequence - decrement
198    }
199}
200
201// SAFETY: record's length are validated on construction; so all slices are valid.
202impl RecordBatchView<'_> {
203    pub fn magic(&self) -> i8 {
204        (&self.bytes[MAGIC_OFFSET..]).read_i8().expect("read magic")
205    }
206
207    pub fn base_offset(&self) -> i64 {
208        (&self.bytes[BASE_OFFSET_OFFSET..])
209            .read_i64::<BigEndian>()
210            .expect("read base offset")
211    }
212
213    pub fn last_offset(&self) -> i64 {
214        self.base_offset() + self.last_offset_delta() as i64
215    }
216
217    pub fn base_sequence(&self) -> i32 {
218        (&self.bytes[BASE_SEQUENCE_OFFSET..])
219            .read_i32::<BigEndian>()
220            .expect("read base sequence")
221    }
222
223    pub fn last_sequence(&self) -> i32 {
224        match self.base_sequence() {
225            NO_SEQUENCE => NO_SEQUENCE,
226            seq => increment_sequence(seq, self.last_offset_delta()),
227        }
228    }
229
230    fn last_offset_delta(&self) -> i32 {
231        (&self.bytes[LAST_OFFSET_DELTA_OFFSET..])
232            .read_i32::<BigEndian>()
233            .expect("read last offset delta")
234    }
235
236    pub fn max_timestamp(&self) -> i64 {
237        (&self.bytes[MAX_TIMESTAMP_OFFSET..])
238            .read_i64::<BigEndian>()
239            .expect("read max timestamp")
240    }
241
242    pub fn records_count(&self) -> i32 {
243        (&self.bytes[RECORDS_COUNT_OFFSET..])
244            .read_i32::<BigEndian>()
245            .expect("read records count")
246    }
247
248    pub fn records(&self) -> Vec<Record> {
249        let mut records = &self.bytes[RECORDS_COUNT_OFFSET..];
250        RecordList.decode(&mut records).expect("malformed records")
251    }
252
253    pub fn checksum(&self) -> u32 {
254        (&self.bytes[CRC_OFFSET..])
255            .read_u32::<BigEndian>()
256            .expect("read checksum")
257    }
258
259    pub fn is_transactional(&self) -> bool {
260        self.attributes() & TRANSACTIONAL_FLAG_MASK > 0
261    }
262
263    pub fn is_control_batch(&self) -> bool {
264        self.attributes() & CONTROL_FLAG_MASK > 0
265    }
266
267    pub fn timestamp_type(&self) -> TimestampType {
268        if self.attributes() & TIMESTAMP_TYPE_MASK != 0 {
269            TimestampType::LogAppendTime
270        } else {
271            TimestampType::CreateTime
272        }
273    }
274
275    pub fn compression_type(&self) -> CompressionType {
276        (self.attributes() & COMPRESSION_CODEC_MASK).into()
277    }
278
279    pub fn delete_horizon_ms(&self) -> Option<i64> {
280        if self.has_delete_horizon_ms() {
281            Some(
282                (&self.bytes[BASE_TIMESTAMP_OFFSET..])
283                    .read_i64::<BigEndian>()
284                    .expect("read base timestamp offset"),
285            )
286        } else {
287            None
288        }
289    }
290
291    fn has_delete_horizon_ms(&self) -> bool {
292        self.attributes() & DELETE_HORIZON_FLAG_MASK > 0
293    }
294
295    // note we're not using the second byte of attributes
296    fn attributes(&self) -> u8 {
297        (&self.bytes[ATTRIBUTES_OFFSET..])
298            .read_u16::<BigEndian>()
299            .expect("read attributes") as u8
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use std::io;
306
307    use crate::records::record_batch::RecordBatches;
308
309    const RECORD: &[u8] = &[
310        // batch 1
311        0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // first offset
312        0x0, 0x0, 0x0, 0x52, // record batch size
313        0xFF, 0xFF, 0xFF, 0xFF, // partition leader epoch
314        0x2,  // magic byte
315        0xE2, 0x3F, 0xC9, 0x74, // crc
316        0x0, 0x0, // attributes
317        0x0, 0x0, 0x0, 0x0, // last offset delta
318        0x0, 0x0, 0x1, 0x89, 0xAF, 0x78, 0x40, 0x72, // base timestamp
319        0x0, 0x0, 0x1, 0x89, 0xAF, 0x78, 0x40, 0x72, // max timestamp
320        0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, // producer ID
321        0x0, 0x0, // producer epoch
322        0x0, 0x0, 0x0, 0x0, // base sequence
323        0x0, 0x0, 0x0, 0x1,  // record counts
324        0x40, // first record size
325        0x0,  // attribute
326        0x0,  // timestamp delta
327        0x0,  // offset delta
328        0x1,  // key length (zigzag : -1)
329        // empty key payload
330        0x34, // value length (zigzag : 26)
331        0x54, 0x68, 0x69, 0x73, 0x20, 0x69, 0x73, 0x20, 0x74, 0x68, 0x65, 0x20, 0x66, 0x69, 0x72,
332        0x73, 0x74, 0x20, 0x6D, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2E, // value payload
333        0x0,  // header counts
334        // batch 2
335        0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // first offset
336        0x0, 0x0, 0x0, 0x52, // record batch size
337        0xFF, 0xFF, 0xFF, 0xFF, // partition leader epoch
338        0x2,  // magic byte
339        0xE2, 0x3F, 0xC9, 0x74, // crc
340        0x0, 0x0, // attributes
341        0x0, 0x0, 0x0, 0x0, // last offset delta
342        0x0, 0x0, 0x1, 0x89, 0xAF, 0x78, 0x40, 0x72, // base timestamp
343        0x0, 0x0, 0x1, 0x89, 0xAF, 0x78, 0x40, 0x72, // max timestamp
344        0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, // producer ID
345        0x0, 0x0, // producer epoch
346        0x0, 0x0, 0x0, 0x0, // base sequence
347        0x0, 0x0, 0x0, 0x1,  // record counts
348        0x40, // first record size
349        0x0,  // attribute
350        0x0,  // timestamp delta
351        0x0,  // offset delta
352        0x1,  // key length (zigzag : -1)
353        // empty key payload
354        0x34, // value length (zigzag : 26)
355        0x54, 0x68, 0x69, 0x73, 0x20, 0x69, 0x73, 0x20, 0x74, 0x68, 0x65, 0x20, 0x66, 0x69, 0x72,
356        0x73, 0x74, 0x20, 0x6D, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2E, // value payload
357        0x0,  // header counts
358    ];
359
360    #[test]
361    fn test_codec_records() -> io::Result<()> {
362        let records = RecordBatches::new(RECORD.to_vec());
363        let record_batches = records.batches().unwrap();
364        assert_eq!(record_batches.len(), 2);
365        let record_batch = record_batches[0].view();
366        assert_eq!(record_batch.records_count(), 1);
367        let record_vec = record_batch.records();
368        assert_eq!(record_vec.len(), 1);
369        let record = &record_vec[0];
370        assert_eq!(record.key_len, -1);
371        assert_eq!(record.key, None);
372        assert_eq!(record.value_len, 26);
373        assert_eq!(
374            record.value.as_deref().map(String::from_utf8_lossy),
375            Some("This is the first message.".into())
376        );
377        Ok(())
378    }
379}