rskafka/protocol/
primitives.rs

1//! Primitive types.
2//!
3//! # References
4//! - <https://kafka.apache.org/protocol#protocol_types>
5//! - <https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields#KIP482:TheKafkaProtocolshouldSupportOptionalTaggedFields-UnsignedVarints>
6
7use std::io::{Cursor, Read, Write};
8
9use integer_encoding::{VarIntReader, VarIntWriter};
10
11#[cfg(test)]
12use proptest::prelude::*;
13
14use super::{
15    record::RecordBatch,
16    traits::{ReadError, ReadType, WriteError, WriteType},
17    vec_builder::VecBuilder,
18};
19
20/// Represents a boolean
21#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
22#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
23pub struct Boolean(pub bool);
24
25impl<R> ReadType<R> for Boolean
26where
27    R: Read,
28{
29    fn read(reader: &mut R) -> Result<Self, ReadError> {
30        let mut buf = [0u8; 1];
31        reader.read_exact(&mut buf)?;
32        match buf[0] {
33            0 => Ok(Self(false)),
34            _ => Ok(Self(true)),
35        }
36    }
37}
38
39impl<W> WriteType<W> for Boolean
40where
41    W: Write,
42{
43    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
44        match self.0 {
45            true => Ok(writer.write_all(&[1])?),
46            false => Ok(writer.write_all(&[0])?),
47        }
48    }
49}
50
51/// Represents an integer between `-2^7` and `2^7-1` inclusive.
52#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
53#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
54pub struct Int8(pub i8);
55
56impl<R> ReadType<R> for Int8
57where
58    R: Read,
59{
60    fn read(reader: &mut R) -> Result<Self, ReadError> {
61        let mut buf = [0u8; 1];
62        reader.read_exact(&mut buf)?;
63        Ok(Self(i8::from_be_bytes(buf)))
64    }
65}
66
67impl<W> WriteType<W> for Int8
68where
69    W: Write,
70{
71    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
72        let buf = self.0.to_be_bytes();
73        writer.write_all(&buf)?;
74        Ok(())
75    }
76}
77
78/// Represents an integer between `-2^15` and `2^15-1` inclusive.
79///
80/// The values are encoded using two bytes in network byte order (big-endian).
81#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
82#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
83pub struct Int16(pub i16);
84
85impl<R> ReadType<R> for Int16
86where
87    R: Read,
88{
89    fn read(reader: &mut R) -> Result<Self, ReadError> {
90        let mut buf = [0u8; 2];
91        reader.read_exact(&mut buf)?;
92        Ok(Self(i16::from_be_bytes(buf)))
93    }
94}
95
96impl<W> WriteType<W> for Int16
97where
98    W: Write,
99{
100    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
101        let buf = self.0.to_be_bytes();
102        writer.write_all(&buf)?;
103        Ok(())
104    }
105}
106
107/// Represents an integer between `-2^31` and `2^31-1` inclusive.
108///
109/// The values are encoded using four bytes in network byte order (big-endian).
110#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
111#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
112pub struct Int32(pub i32);
113
114impl<R> ReadType<R> for Int32
115where
116    R: Read,
117{
118    fn read(reader: &mut R) -> Result<Self, ReadError> {
119        let mut buf = [0u8; 4];
120        reader.read_exact(&mut buf)?;
121        Ok(Self(i32::from_be_bytes(buf)))
122    }
123}
124
125impl<W> WriteType<W> for Int32
126where
127    W: Write,
128{
129    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
130        let buf = self.0.to_be_bytes();
131        writer.write_all(&buf)?;
132        Ok(())
133    }
134}
135
136/// Represents an integer between `-2^63` and `2^63-1` inclusive.
137///
138/// The values are encoded using eight bytes in network byte order (big-endian).
139#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
140#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
141pub struct Int64(pub i64);
142
143impl<R> ReadType<R> for Int64
144where
145    R: Read,
146{
147    fn read(reader: &mut R) -> Result<Self, ReadError> {
148        let mut buf = [0u8; 8];
149        reader.read_exact(&mut buf)?;
150        Ok(Self(i64::from_be_bytes(buf)))
151    }
152}
153
154impl<W> WriteType<W> for Int64
155where
156    W: Write,
157{
158    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
159        let buf = self.0.to_be_bytes();
160        writer.write_all(&buf)?;
161        Ok(())
162    }
163}
164
165/// Represents an integer between `-2^31` and `2^31-1` inclusive.
166///
167/// Encoding follows the variable-length zig-zag encoding from Google Protocol Buffers.
168#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
169#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
170pub struct Varint(pub i32);
171
172impl<R> ReadType<R> for Varint
173where
174    R: Read,
175{
176    fn read(reader: &mut R) -> Result<Self, ReadError> {
177        // workaround for https://github.com/dermesser/integer-encoding-rs/issues/21
178        // read 64bit and use a checked downcast instead
179        let i: i64 = reader.read_varint()?;
180        Ok(Self(i32::try_from(i)?))
181    }
182}
183
184impl<W> WriteType<W> for Varint
185where
186    W: Write,
187{
188    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
189        writer.write_varint(self.0)?;
190        Ok(())
191    }
192}
193
194/// Represents an integer between `-2^63` and `2^63-1` inclusive.
195///
196/// Encoding follows the variable-length zig-zag encoding from Google Protocol Buffers.
197#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
198#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
199pub struct Varlong(pub i64);
200
201impl<R> ReadType<R> for Varlong
202where
203    R: Read,
204{
205    fn read(reader: &mut R) -> Result<Self, ReadError> {
206        Ok(Self(reader.read_varint()?))
207    }
208}
209
210impl<W> WriteType<W> for Varlong
211where
212    W: Write,
213{
214    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
215        writer.write_varint(self.0)?;
216        Ok(())
217    }
218}
219
220/// The UNSIGNED_VARINT type describes an unsigned variable length integer.
221///
222/// To serialize a number as a variable-length integer, you break it up into groups of 7 bits. The lowest 7 bits is
223/// written out first, followed by the second-lowest, and so on.  Each time a group of 7 bits is written out, the high
224/// bit (bit 8) is cleared if this group is the last one, and set if it is not.
225#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
226#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
227pub struct UnsignedVarint(pub u64);
228
229impl<R> ReadType<R> for UnsignedVarint
230where
231    R: Read,
232{
233    fn read(reader: &mut R) -> Result<Self, ReadError> {
234        let mut buf = [0u8; 1];
235        let mut res: u64 = 0;
236        let mut shift = 0;
237        loop {
238            reader.read_exact(&mut buf)?;
239            let c: u64 = buf[0].into();
240
241            res |= (c & 0x7f) << shift;
242            shift += 7;
243
244            if (c & 0x80) == 0 {
245                break;
246            }
247            if shift > 63 {
248                return Err(ReadError::Malformed(
249                    String::from("Overflow while reading unsigned varint").into(),
250                ));
251            }
252        }
253
254        Ok(Self(res))
255    }
256}
257
258impl<W> WriteType<W> for UnsignedVarint
259where
260    W: Write,
261{
262    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
263        let mut curr = self.0;
264        loop {
265            let mut c = u8::try_from(curr & 0x7f).map_err(WriteError::Overflow)?;
266            curr >>= 7;
267            if curr > 0 {
268                c |= 0x80;
269            }
270            writer.write_all(&[c])?;
271
272            if curr == 0 {
273                break;
274            }
275        }
276        Ok(())
277    }
278}
279
280/// Represents a sequence of characters or null.
281///
282/// For non-null strings, first the length N is given as an INT16. Then N bytes follow which are the UTF-8 encoding of
283/// the character sequence. A null value is encoded with length of -1 and there are no following bytes.
284#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Clone)]
285#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
286pub struct NullableString(pub Option<String>);
287
288impl<R> ReadType<R> for NullableString
289where
290    R: Read,
291{
292    fn read(reader: &mut R) -> Result<Self, ReadError> {
293        let len = Int16::read(reader)?;
294        match len.0 {
295            l if l < -1 => Err(ReadError::Malformed(
296                format!("Invalid negative length for nullable string: {}", l).into(),
297            )),
298            -1 => Ok(Self(None)),
299            l => {
300                let len = usize::try_from(l)?;
301                let mut buf = VecBuilder::new(len);
302                buf = buf.read_exact(reader)?;
303                let s =
304                    String::from_utf8(buf.into()).map_err(|e| ReadError::Malformed(Box::new(e)))?;
305                Ok(Self(Some(s)))
306            }
307        }
308    }
309}
310
311impl<W> WriteType<W> for NullableString
312where
313    W: Write,
314{
315    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
316        match &self.0 {
317            Some(s) => {
318                let l = i16::try_from(s.len()).map_err(|e| WriteError::Malformed(Box::new(e)))?;
319                Int16(l).write(writer)?;
320                writer.write_all(s.as_bytes())?;
321                Ok(())
322            }
323            None => Int16(-1).write(writer),
324        }
325    }
326}
327
328/// Represents a sequence of characters.
329///
330/// First the length N is given as an INT16. Then N bytes follow which are the UTF-8 encoding of the character
331/// sequence. Length must not be negative.
332#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
333#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
334pub struct String_(pub String);
335
336impl<R> ReadType<R> for String_
337where
338    R: Read,
339{
340    fn read(reader: &mut R) -> Result<Self, ReadError> {
341        let len = Int16::read(reader)?;
342        let len = usize::try_from(len.0).map_err(|e| ReadError::Malformed(Box::new(e)))?;
343        let mut buf = VecBuilder::new(len);
344        buf = buf.read_exact(reader)?;
345        let s = String::from_utf8(buf.into()).map_err(|e| ReadError::Malformed(Box::new(e)))?;
346        Ok(Self(s))
347    }
348}
349
350impl<W> WriteType<W> for String_
351where
352    W: Write,
353{
354    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
355        let len = i16::try_from(self.0.len()).map_err(WriteError::Overflow)?;
356        Int16(len).write(writer)?;
357        writer.write_all(self.0.as_bytes())?;
358        Ok(())
359    }
360}
361
362/// Represents a string whose length is expressed as a variable-length integer rather than a fixed 2-byte length.
363#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
364#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
365pub struct CompactString(pub String);
366
367impl<R> ReadType<R> for CompactString
368where
369    R: Read,
370{
371    fn read(reader: &mut R) -> Result<Self, ReadError> {
372        let len = UnsignedVarint::read(reader)?;
373        match len.0 {
374            0 => Err(ReadError::Malformed(
375                "CompactString must have non-zero length".into(),
376            )),
377            len => {
378                let len = usize::try_from(len)?;
379                let len = len - 1;
380
381                let mut buf = VecBuilder::new(len);
382                buf = buf.read_exact(reader)?;
383
384                let s =
385                    String::from_utf8(buf.into()).map_err(|e| ReadError::Malformed(Box::new(e)))?;
386                Ok(Self(s))
387            }
388        }
389    }
390}
391
392impl<W> WriteType<W> for CompactString
393where
394    W: Write,
395{
396    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
397        CompactStringRef(&self.0).write(writer)
398    }
399}
400
401/// Same as [`CompactString`] but contains referenced data.
402///
403/// This only supports writing.
404#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
405pub struct CompactStringRef<'a>(pub &'a str);
406
407impl<'a, W> WriteType<W> for CompactStringRef<'a>
408where
409    W: Write,
410{
411    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
412        let len = u64::try_from(self.0.len() + 1).map_err(WriteError::Overflow)?;
413        UnsignedVarint(len).write(writer)?;
414        writer.write_all(self.0.as_bytes())?;
415        Ok(())
416    }
417}
418
419/// Represents a nullable string whose length is expressed as a variable-length integer rather than a fixed 2-byte length.
420#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
421#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
422pub struct CompactNullableString(pub Option<String>);
423
424impl<R> ReadType<R> for CompactNullableString
425where
426    R: Read,
427{
428    fn read(reader: &mut R) -> Result<Self, ReadError> {
429        let len = UnsignedVarint::read(reader)?;
430        match len.0 {
431            0 => Ok(Self(None)),
432            len => {
433                let len = usize::try_from(len)?;
434                let len = len - 1;
435
436                let mut buf = VecBuilder::new(len);
437                buf = buf.read_exact(reader)?;
438
439                let s =
440                    String::from_utf8(buf.into()).map_err(|e| ReadError::Malformed(Box::new(e)))?;
441                Ok(Self(Some(s)))
442            }
443        }
444    }
445}
446
447impl<W> WriteType<W> for CompactNullableString
448where
449    W: Write,
450{
451    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
452        CompactNullableStringRef(self.0.as_deref()).write(writer)
453    }
454}
455
456/// Same as [`CompactNullableString`] but contains referenced data.
457///
458/// This only supports writing.
459#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
460pub struct CompactNullableStringRef<'a>(pub Option<&'a str>);
461
462impl<'a, W> WriteType<W> for CompactNullableStringRef<'a>
463where
464    W: Write,
465{
466    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
467        match &self.0 {
468            Some(s) => {
469                let len = u64::try_from(s.len() + 1).map_err(WriteError::Overflow)?;
470                UnsignedVarint(len).write(writer)?;
471                writer.write_all(s.as_bytes())?;
472            }
473            None => {
474                UnsignedVarint(0).write(writer)?;
475            }
476        }
477        Ok(())
478    }
479}
480
481/// Represents a raw sequence of bytes or null.
482///
483/// For non-null values, first the length N is given as an INT32. Then N bytes follow. A null value is encoded with
484/// length of -1 and there are no following bytes.
485#[derive(Debug, PartialEq, Eq)]
486#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
487pub struct NullableBytes(pub Option<Vec<u8>>);
488
489impl<R> ReadType<R> for NullableBytes
490where
491    R: Read,
492{
493    fn read(reader: &mut R) -> Result<Self, ReadError> {
494        let len = Int32::read(reader)?;
495        match len.0 {
496            l if l < -1 => Err(ReadError::Malformed(
497                format!("Invalid negative length for nullable bytes: {}", l).into(),
498            )),
499            -1 => Ok(Self(None)),
500            l => {
501                let len = usize::try_from(l)?;
502                let mut buf = VecBuilder::new(len);
503                buf = buf.read_exact(reader)?;
504                Ok(Self(Some(buf.into())))
505            }
506        }
507    }
508}
509
510impl<W> WriteType<W> for NullableBytes
511where
512    W: Write,
513{
514    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
515        match &self.0 {
516            Some(s) => {
517                let l = i32::try_from(s.len()).map_err(|e| WriteError::Malformed(Box::new(e)))?;
518                Int32(l).write(writer)?;
519                writer.write_all(s)?;
520                Ok(())
521            }
522            None => Int32(-1).write(writer),
523        }
524    }
525}
526
527/// Represents a section containing optional tagged fields.
528#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
529#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
530pub struct TaggedFields(pub Vec<(UnsignedVarint, Vec<u8>)>);
531
532impl<R> ReadType<R> for TaggedFields
533where
534    R: Read,
535{
536    fn read(reader: &mut R) -> Result<Self, ReadError> {
537        let len = UnsignedVarint::read(reader)?;
538        let len = usize::try_from(len.0).map_err(ReadError::Overflow)?;
539        let mut res = VecBuilder::new(len);
540        for _ in 0..len {
541            let tag = UnsignedVarint::read(reader)?;
542
543            let data_len = UnsignedVarint::read(reader)?;
544            let data_len = usize::try_from(data_len.0).map_err(ReadError::Overflow)?;
545            let mut data_builder = VecBuilder::new(data_len);
546            data_builder = data_builder.read_exact(reader)?;
547
548            res.push((tag, data_builder.into()));
549        }
550        Ok(Self(res.into()))
551    }
552}
553
554impl<W> WriteType<W> for TaggedFields
555where
556    W: Write,
557{
558    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
559        let len = u64::try_from(self.0.len()).map_err(WriteError::Overflow)?;
560        UnsignedVarint(len).write(writer)?;
561
562        for (tag, data) in &self.0 {
563            tag.write(writer)?;
564            let data_len = u64::try_from(data.len()).map_err(WriteError::Overflow)?;
565            UnsignedVarint(data_len).write(writer)?;
566            writer.write_all(data)?;
567        }
568
569        Ok(())
570    }
571}
572
573/// Represents a sequence of objects of a given type T.
574///
575/// Type T can be either a primitive type (e.g. STRING) or a structure. First, the length N is given as an INT32. Then
576/// N instances of type T follow. A null array is represented with a length of -1. In protocol documentation an array
577/// of T instances is referred to as `[T]`.
578#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
579#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
580pub struct Array<T>(pub Option<Vec<T>>);
581
582impl<R, T> ReadType<R> for Array<T>
583where
584    R: Read,
585    T: ReadType<R>,
586{
587    fn read(reader: &mut R) -> Result<Self, ReadError> {
588        let len = Int32::read(reader)?;
589        if len.0 == -1 {
590            Ok(Self(None))
591        } else {
592            let len = usize::try_from(len.0)?;
593            let mut res = VecBuilder::new(len);
594            for _ in 0..len {
595                res.push(T::read(reader)?);
596            }
597            Ok(Self(Some(res.into())))
598        }
599    }
600}
601
602impl<W, T> WriteType<W> for Array<T>
603where
604    W: Write,
605    T: WriteType<W>,
606{
607    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
608        ArrayRef(self.0.as_deref()).write(writer)
609    }
610}
611
612/// Same as [`Array`] but contains referenced data.
613///
614/// This only supports writing.
615#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
616pub struct ArrayRef<'a, T>(pub Option<&'a [T]>);
617
618impl<'a, W, T> WriteType<W> for ArrayRef<'a, T>
619where
620    W: Write,
621    T: WriteType<W>,
622{
623    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
624        match self.0 {
625            None => Int32(-1).write(writer),
626            Some(inner) => {
627                let len = i32::try_from(inner.len())?;
628                Int32(len).write(writer)?;
629
630                for element in inner {
631                    element.write(writer)?;
632                }
633
634                Ok(())
635            }
636        }
637    }
638}
639
640/// Represents a sequence of objects of a given type T.
641///
642/// Type T can be either a primitive type (e.g. STRING) or a structure. First, the length N + 1 is given as an
643/// UNSIGNED_VARINT. Then N instances of type T follow. A null array is represented with a length of 0. In protocol
644/// documentation an array of T instances is referred to as `[T]`.
645#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
646#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
647pub struct CompactArray<T>(pub Option<Vec<T>>);
648
649impl<R, T> ReadType<R> for CompactArray<T>
650where
651    R: Read,
652    T: ReadType<R>,
653{
654    fn read(reader: &mut R) -> Result<Self, ReadError> {
655        let len = UnsignedVarint::read(reader)?.0;
656        match len {
657            0 => Ok(Self(None)),
658            n => {
659                let len = usize::try_from(n - 1).map_err(ReadError::Overflow)?;
660                let mut builder = VecBuilder::new(len);
661                for _ in 0..len {
662                    builder.push(T::read(reader)?);
663                }
664                Ok(Self(Some(builder.into())))
665            }
666        }
667    }
668}
669
670impl<W, T> WriteType<W> for CompactArray<T>
671where
672    W: Write,
673    T: WriteType<W>,
674{
675    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
676        CompactArrayRef(self.0.as_deref()).write(writer)
677    }
678}
679
680/// Same as [`CompactArray`] but contains referenced data.
681///
682/// This only supports writing.
683#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
684pub struct CompactArrayRef<'a, T>(pub Option<&'a [T]>);
685
686impl<'a, W, T> WriteType<W> for CompactArrayRef<'a, T>
687where
688    W: Write,
689    T: WriteType<W>,
690{
691    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
692        match self.0 {
693            None => UnsignedVarint(0).write(writer),
694            Some(inner) => {
695                let len = u64::try_from(inner.len() + 1).map_err(WriteError::from)?;
696                UnsignedVarint(len).write(writer)?;
697
698                for element in inner {
699                    element.write(writer)?;
700                }
701
702                Ok(())
703            }
704        }
705    }
706}
707
708/// Represents a sequence of Kafka records as NULLABLE_BYTES.
709///
710/// This primitive actually depends on the message version and evolved twice in [KIP-32] and [KIP-98]. We only support
711/// the latest generation (message version 2).
712///
713/// It seems that during `Produce` this must contain exactly one batch, but during `Fetch` this can contain zero, one or
714/// more batches -- however I could not find any documentation stating this behavior. [KIP-74] at least documents the
715/// `Fetch` case, although it does not clearly state that record batches might be cut off half-way (this however is what
716/// we see during integration tests w/ Apache Kafka).
717///
718/// [KIP-32]: https://cwiki.apache.org/confluence/display/KAFKA/KIP-32+-+Add+timestamps+to+Kafka+message
719/// [KIP-74]: https://cwiki.apache.org/confluence/display/KAFKA/KIP-74%3A+Add+Fetch+Response+Size+Limit+in+Bytes
720/// [KIP-98]: https://cwiki.apache.org/confluence/display/KAFKA/KIP-98+-+Exactly+Once+Delivery+and+Transactional+Messaging
721#[derive(Debug, PartialEq, Eq)]
722#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
723pub struct Records(
724    // tell proptest to only generate small vectors, otherwise tests take forever
725    #[cfg_attr(
726        test,
727        proptest(strategy = "prop::collection::vec(any::<RecordBatch>(), 0..2)")
728    )]
729    pub Vec<RecordBatch>,
730);
731
732impl<R> ReadType<R> for Records
733where
734    R: Read,
735{
736    fn read(reader: &mut R) -> Result<Self, ReadError> {
737        let buf = NullableBytes::read(reader)?.0.unwrap_or_default();
738        let len = u64::try_from(buf.len())?;
739        let mut buf = Cursor::new(buf);
740
741        let mut batches = vec![];
742        while buf.position() < len {
743            let batch = match RecordBatch::read(&mut buf) {
744                Ok(batch) => batch,
745                Err(ReadError::IO(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
746                    // Record batch got cut off, likely due to `FetchRequest::max_bytes`.
747                    break;
748                }
749                Err(e) => {
750                    return Err(e);
751                }
752            };
753            batches.push(batch);
754        }
755
756        Ok(Self(batches))
757    }
758}
759
760impl<W> WriteType<W> for Records
761where
762    W: Write,
763{
764    fn write(&self, writer: &mut W) -> Result<(), WriteError> {
765        // TODO: it would be nice if we could avoid the copy here by writing the records and then seeking back.
766        let mut buf = vec![];
767        for record in &self.0 {
768            record.write(&mut buf)?;
769        }
770        NullableBytes(Some(buf)).write(writer)?;
771        Ok(())
772    }
773}
774
775#[cfg(test)]
776mod tests {
777    use std::io::Cursor;
778
779    use crate::protocol::{
780        record::{ControlBatchOrRecords, RecordBatchCompression, RecordBatchTimestampType},
781        test_utils::test_roundtrip,
782    };
783
784    use super::*;
785
786    use assert_matches::assert_matches;
787
788    test_roundtrip!(Boolean, test_bool_roundtrip);
789
790    #[test]
791    fn test_boolean_decode() {
792        assert!(!Boolean::read(&mut Cursor::new(vec![0])).unwrap().0);
793
794        // When reading a boolean value, any non-zero value is considered true.
795        for v in [1, 35, 255] {
796            assert!(Boolean::read(&mut Cursor::new(vec![v])).unwrap().0);
797        }
798    }
799
800    test_roundtrip!(Int8, test_int8_roundtrip);
801
802    test_roundtrip!(Int16, test_int16_roundtrip);
803
804    test_roundtrip!(Int32, test_int32_roundtrip);
805
806    test_roundtrip!(Int64, test_int64_roundtrip);
807
808    test_roundtrip!(Varint, test_varint_roundtrip);
809
810    #[test]
811    fn test_varint_special_values() {
812        // Taken from https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints
813        for v in [0, -1, 1, -2, 2147483647, -2147483648] {
814            let mut data = vec![];
815            Varint(v).write(&mut data).unwrap();
816
817            let restored = Varint::read(&mut Cursor::new(data)).unwrap();
818            assert_eq!(restored.0, v);
819        }
820    }
821
822    #[test]
823    fn test_varint_read_read_overflow() {
824        // this should overflow a 64bit bytes varint
825        let mut buf = Cursor::new(vec![0xffu8; 11]);
826
827        let err = Varint::read(&mut buf).unwrap_err();
828        assert_matches!(err, ReadError::IO(_));
829        assert_eq!(err.to_string(), "Cannot read data: Unterminated varint",);
830    }
831
832    #[test]
833    fn test_varint_read_downcast_overflow() {
834        // this should overflow when reading a 64bit varint and casting it down to 32bit
835        let mut data = vec![0xffu8; 9];
836        data.push(0x00);
837        let mut buf = Cursor::new(data);
838
839        let err = Varint::read(&mut buf).unwrap_err();
840        assert_matches!(err, ReadError::Overflow(_));
841        assert_eq!(
842            err.to_string(),
843            "Overflow converting integer: out of range integral type conversion attempted",
844        );
845    }
846
847    test_roundtrip!(Varlong, test_varlong_roundtrip);
848
849    #[test]
850    fn test_varlong_special_values() {
851        // Taken from https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints + min/max
852        for v in [0, -1, 1, -2, 2147483647, -2147483648, i64::MIN, i64::MAX] {
853            let mut data = vec![];
854            Varlong(v).write(&mut data).unwrap();
855
856            let restored = Varlong::read(&mut Cursor::new(data)).unwrap();
857            assert_eq!(restored.0, v);
858        }
859    }
860
861    #[test]
862    fn test_varlong_read_overflow() {
863        let mut buf = Cursor::new(vec![0xffu8; 11]);
864
865        let err = Varlong::read(&mut buf).unwrap_err();
866        assert_matches!(err, ReadError::IO(_));
867        assert_eq!(err.to_string(), "Cannot read data: Unterminated varint",);
868    }
869
870    test_roundtrip!(UnsignedVarint, test_unsigned_varint_roundtrip);
871
872    #[test]
873    fn test_unsigned_varint_read_overflow() {
874        let mut buf = Cursor::new(vec![0xffu8; 64 / 7 + 1]);
875
876        let err = UnsignedVarint::read(&mut buf).unwrap_err();
877        assert_matches!(err, ReadError::Malformed(_));
878        assert_eq!(
879            err.to_string(),
880            "Malformed data: Overflow while reading unsigned varint",
881        );
882    }
883
884    test_roundtrip!(String_, test_string_roundtrip);
885
886    #[test]
887    fn test_string_blowup_memory() {
888        let mut buf = Cursor::new(Vec::<u8>::new());
889        Int16(i16::MAX).write(&mut buf).unwrap();
890        buf.set_position(0);
891
892        let err = String_::read(&mut buf).unwrap_err();
893        assert_matches!(err, ReadError::IO(_));
894    }
895
896    test_roundtrip!(NullableString, test_nullable_string_roundtrip);
897
898    #[test]
899    fn test_nullable_string_read_negative_length() {
900        let mut buf = Cursor::new(Vec::<u8>::new());
901        Int16(-2).write(&mut buf).unwrap();
902        buf.set_position(0);
903
904        let err = NullableString::read(&mut buf).unwrap_err();
905        assert_matches!(err, ReadError::Malformed(_));
906        assert_eq!(
907            err.to_string(),
908            "Malformed data: Invalid negative length for nullable string: -2",
909        );
910    }
911
912    #[test]
913    fn test_nullable_string_blowup_memory() {
914        let mut buf = Cursor::new(Vec::<u8>::new());
915        Int16(i16::MAX).write(&mut buf).unwrap();
916        buf.set_position(0);
917
918        let err = NullableString::read(&mut buf).unwrap_err();
919        assert_matches!(err, ReadError::IO(_));
920    }
921
922    test_roundtrip!(CompactString, test_compact_string_roundtrip);
923
924    #[test]
925    fn test_compact_string_blowup_memory() {
926        let mut buf = Cursor::new(Vec::<u8>::new());
927        UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
928        buf.set_position(0);
929
930        let err = CompactString::read(&mut buf).unwrap_err();
931        assert_matches!(err, ReadError::IO(_));
932    }
933
934    test_roundtrip!(
935        CompactNullableString,
936        test_compact_nullable_string_roundtrip
937    );
938
939    #[test]
940    fn test_compact_nullable_string_blowup_memory() {
941        let mut buf = Cursor::new(Vec::<u8>::new());
942        UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
943        buf.set_position(0);
944
945        let err = CompactNullableString::read(&mut buf).unwrap_err();
946        assert_matches!(err, ReadError::IO(_));
947    }
948
949    test_roundtrip!(NullableBytes, test_nullable_bytes_roundtrip);
950
951    #[test]
952    fn test_nullable_bytes_read_negative_length() {
953        let mut buf = Cursor::new(Vec::<u8>::new());
954        Int32(-2).write(&mut buf).unwrap();
955        buf.set_position(0);
956
957        let err = NullableBytes::read(&mut buf).unwrap_err();
958        assert_matches!(err, ReadError::Malformed(_));
959        assert_eq!(
960            err.to_string(),
961            "Malformed data: Invalid negative length for nullable bytes: -2",
962        );
963    }
964
965    #[test]
966    fn test_nullable_bytes_blowup_memory() {
967        let mut buf = Cursor::new(Vec::<u8>::new());
968        Int32(i32::MAX).write(&mut buf).unwrap();
969        buf.set_position(0);
970
971        let err = NullableBytes::read(&mut buf).unwrap_err();
972        assert_matches!(err, ReadError::IO(_));
973    }
974
975    test_roundtrip!(TaggedFields, test_tagged_fields_roundtrip);
976
977    #[test]
978    fn test_tagged_fields_blowup_memory() {
979        let mut buf = Cursor::new(Vec::<u8>::new());
980
981        // number of fields
982        UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
983
984        // tag
985        UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
986
987        // data length
988        UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
989
990        buf.set_position(0);
991
992        let err = TaggedFields::read(&mut buf).unwrap_err();
993        assert_matches!(err, ReadError::IO(_));
994    }
995
996    test_roundtrip!(Array<Int32>, test_array_roundtrip);
997
998    #[test]
999    fn test_array_blowup_memory() {
1000        let mut buf = Cursor::new(Vec::<u8>::new());
1001        Int32(i32::MAX).write(&mut buf).unwrap();
1002        buf.set_position(0);
1003
1004        let err = Array::<Large>::read(&mut buf).unwrap_err();
1005        assert_matches!(err, ReadError::IO(_));
1006    }
1007
1008    test_roundtrip!(CompactArray<Int32>, test_compact_array_roundtrip);
1009
1010    #[test]
1011    fn test_compact_array_blowup_memory() {
1012        let mut buf = Cursor::new(Vec::<u8>::new());
1013        UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
1014        buf.set_position(0);
1015
1016        let err = CompactArray::<Large>::read(&mut buf).unwrap_err();
1017        assert_matches!(err, ReadError::IO(_));
1018    }
1019
1020    test_roundtrip!(Records, test_records_roundtrip);
1021
1022    #[test]
1023    fn test_records_partial() {
1024        // Records might be partially returned when fetch requests are issued w/ size limits
1025        let batch_1 = record_batch(1);
1026        let batch_2 = record_batch(2);
1027
1028        let mut buf = vec![];
1029        batch_1.write(&mut buf).unwrap();
1030        batch_2.write(&mut buf).unwrap();
1031        let inner = buf[..buf.len() - 1].to_vec();
1032
1033        let mut buf = vec![];
1034        NullableBytes(Some(inner)).write(&mut buf).unwrap();
1035
1036        let records = Records::read(&mut Cursor::new(buf)).unwrap();
1037        assert_eq!(records.0, vec![batch_1]);
1038    }
1039
1040    fn record_batch(base_offset: i64) -> RecordBatch {
1041        RecordBatch {
1042            base_offset,
1043            partition_leader_epoch: 0,
1044            last_offset_delta: 0,
1045            first_timestamp: 0,
1046            max_timestamp: 0,
1047            producer_id: 0,
1048            producer_epoch: 0,
1049            base_sequence: 0,
1050            records: ControlBatchOrRecords::Records(vec![]),
1051            compression: RecordBatchCompression::NoCompression,
1052            is_transactional: false,
1053            timestamp_type: RecordBatchTimestampType::CreateTime,
1054        }
1055    }
1056
1057    /// A rather large struct here to trigger OOM.
1058    #[derive(Debug)]
1059    struct Large {
1060        _inner: [u8; 1024],
1061    }
1062
1063    impl<R> ReadType<R> for Large
1064    where
1065        R: Read,
1066    {
1067        fn read(reader: &mut R) -> Result<Self, ReadError> {
1068            Int32::read(reader)?;
1069            unreachable!()
1070        }
1071    }
1072}