use std::io::{Cursor, Read, Write};
use integer_encoding::{VarIntReader, VarIntWriter};
#[cfg(test)]
use proptest::prelude::*;
use super::{
record::RecordBatch,
traits::{ReadError, ReadType, WriteError, WriteType},
vec_builder::VecBuilder,
};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Boolean(pub bool);
impl<R> ReadType<R> for Boolean
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let mut buf = [0u8; 1];
reader.read_exact(&mut buf)?;
match buf[0] {
0 => Ok(Self(false)),
_ => Ok(Self(true)),
}
}
}
impl<W> WriteType<W> for Boolean
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
match self.0 {
true => Ok(writer.write_all(&[1])?),
false => Ok(writer.write_all(&[0])?),
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Int8(pub i8);
impl<R> ReadType<R> for Int8
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let mut buf = [0u8; 1];
reader.read_exact(&mut buf)?;
Ok(Self(i8::from_be_bytes(buf)))
}
}
impl<W> WriteType<W> for Int8
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
let buf = self.0.to_be_bytes();
writer.write_all(&buf)?;
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Int16(pub i16);
impl<R> ReadType<R> for Int16
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let mut buf = [0u8; 2];
reader.read_exact(&mut buf)?;
Ok(Self(i16::from_be_bytes(buf)))
}
}
impl<W> WriteType<W> for Int16
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
let buf = self.0.to_be_bytes();
writer.write_all(&buf)?;
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Int32(pub i32);
impl<R> ReadType<R> for Int32
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(Self(i32::from_be_bytes(buf)))
}
}
impl<W> WriteType<W> for Int32
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
let buf = self.0.to_be_bytes();
writer.write_all(&buf)?;
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Int64(pub i64);
impl<R> ReadType<R> for Int64
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;
Ok(Self(i64::from_be_bytes(buf)))
}
}
impl<W> WriteType<W> for Int64
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
let buf = self.0.to_be_bytes();
writer.write_all(&buf)?;
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Varint(pub i32);
impl<R> ReadType<R> for Varint
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let i: i64 = reader.read_varint()?;
Ok(Self(i32::try_from(i)?))
}
}
impl<W> WriteType<W> for Varint
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
writer.write_varint(self.0)?;
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Varlong(pub i64);
impl<R> ReadType<R> for Varlong
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
Ok(Self(reader.read_varint()?))
}
}
impl<W> WriteType<W> for Varlong
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
writer.write_varint(self.0)?;
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct UnsignedVarint(pub u64);
impl<R> ReadType<R> for UnsignedVarint
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let mut buf = [0u8; 1];
let mut res: u64 = 0;
let mut shift = 0;
loop {
reader.read_exact(&mut buf)?;
let c: u64 = buf[0].into();
res |= (c & 0x7f) << shift;
shift += 7;
if (c & 0x80) == 0 {
break;
}
if shift > 63 {
return Err(ReadError::Malformed(
String::from("Overflow while reading unsigned varint").into(),
));
}
}
Ok(Self(res))
}
}
impl<W> WriteType<W> for UnsignedVarint
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
let mut curr = self.0;
loop {
let mut c = u8::try_from(curr & 0x7f).map_err(WriteError::Overflow)?;
curr >>= 7;
if curr > 0 {
c |= 0x80;
}
writer.write_all(&[c])?;
if curr == 0 {
break;
}
}
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Clone)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct NullableString(pub Option<String>);
impl<R> ReadType<R> for NullableString
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let len = Int16::read(reader)?;
match len.0 {
l if l < -1 => Err(ReadError::Malformed(
format!("Invalid negative length for nullable string: {}", l).into(),
)),
-1 => Ok(Self(None)),
l => {
let len = usize::try_from(l)?;
let mut buf = VecBuilder::new(len);
buf = buf.read_exact(reader)?;
let s =
String::from_utf8(buf.into()).map_err(|e| ReadError::Malformed(Box::new(e)))?;
Ok(Self(Some(s)))
}
}
}
}
impl<W> WriteType<W> for NullableString
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
match &self.0 {
Some(s) => {
let l = i16::try_from(s.len()).map_err(|e| WriteError::Malformed(Box::new(e)))?;
Int16(l).write(writer)?;
writer.write_all(s.as_bytes())?;
Ok(())
}
None => Int16(-1).write(writer),
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct String_(pub String);
impl<R> ReadType<R> for String_
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let len = Int16::read(reader)?;
let len = usize::try_from(len.0).map_err(|e| ReadError::Malformed(Box::new(e)))?;
let mut buf = VecBuilder::new(len);
buf = buf.read_exact(reader)?;
let s = String::from_utf8(buf.into()).map_err(|e| ReadError::Malformed(Box::new(e)))?;
Ok(Self(s))
}
}
impl<W> WriteType<W> for String_
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
let len = i16::try_from(self.0.len()).map_err(WriteError::Overflow)?;
Int16(len).write(writer)?;
writer.write_all(self.0.as_bytes())?;
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct CompactString(pub String);
impl<R> ReadType<R> for CompactString
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let len = UnsignedVarint::read(reader)?;
match len.0 {
0 => Err(ReadError::Malformed(
"CompactString must have non-zero length".into(),
)),
len => {
let len = usize::try_from(len)?;
let len = len - 1;
let mut buf = VecBuilder::new(len);
buf = buf.read_exact(reader)?;
let s =
String::from_utf8(buf.into()).map_err(|e| ReadError::Malformed(Box::new(e)))?;
Ok(Self(s))
}
}
}
}
impl<W> WriteType<W> for CompactString
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
CompactStringRef(&self.0).write(writer)
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CompactStringRef<'a>(pub &'a str);
impl<'a, W> WriteType<W> for CompactStringRef<'a>
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
let len = u64::try_from(self.0.len() + 1).map_err(WriteError::Overflow)?;
UnsignedVarint(len).write(writer)?;
writer.write_all(self.0.as_bytes())?;
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct CompactNullableString(pub Option<String>);
impl<R> ReadType<R> for CompactNullableString
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let len = UnsignedVarint::read(reader)?;
match len.0 {
0 => Ok(Self(None)),
len => {
let len = usize::try_from(len)?;
let len = len - 1;
let mut buf = VecBuilder::new(len);
buf = buf.read_exact(reader)?;
let s =
String::from_utf8(buf.into()).map_err(|e| ReadError::Malformed(Box::new(e)))?;
Ok(Self(Some(s)))
}
}
}
}
impl<W> WriteType<W> for CompactNullableString
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
CompactNullableStringRef(self.0.as_deref()).write(writer)
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CompactNullableStringRef<'a>(pub Option<&'a str>);
impl<'a, W> WriteType<W> for CompactNullableStringRef<'a>
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
match &self.0 {
Some(s) => {
let len = u64::try_from(s.len() + 1).map_err(WriteError::Overflow)?;
UnsignedVarint(len).write(writer)?;
writer.write_all(s.as_bytes())?;
}
None => {
UnsignedVarint(0).write(writer)?;
}
}
Ok(())
}
}
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct NullableBytes(pub Option<Vec<u8>>);
impl<R> ReadType<R> for NullableBytes
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let len = Int32::read(reader)?;
match len.0 {
l if l < -1 => Err(ReadError::Malformed(
format!("Invalid negative length for nullable bytes: {}", l).into(),
)),
-1 => Ok(Self(None)),
l => {
let len = usize::try_from(l)?;
let mut buf = VecBuilder::new(len);
buf = buf.read_exact(reader)?;
Ok(Self(Some(buf.into())))
}
}
}
}
impl<W> WriteType<W> for NullableBytes
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
match &self.0 {
Some(s) => {
let l = i32::try_from(s.len()).map_err(|e| WriteError::Malformed(Box::new(e)))?;
Int32(l).write(writer)?;
writer.write_all(s)?;
Ok(())
}
None => Int32(-1).write(writer),
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct TaggedFields(pub Vec<(UnsignedVarint, Vec<u8>)>);
impl<R> ReadType<R> for TaggedFields
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let len = UnsignedVarint::read(reader)?;
let len = usize::try_from(len.0).map_err(ReadError::Overflow)?;
let mut res = VecBuilder::new(len);
for _ in 0..len {
let tag = UnsignedVarint::read(reader)?;
let data_len = UnsignedVarint::read(reader)?;
let data_len = usize::try_from(data_len.0).map_err(ReadError::Overflow)?;
let mut data_builder = VecBuilder::new(data_len);
data_builder = data_builder.read_exact(reader)?;
res.push((tag, data_builder.into()));
}
Ok(Self(res.into()))
}
}
impl<W> WriteType<W> for TaggedFields
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
let len = u64::try_from(self.0.len()).map_err(WriteError::Overflow)?;
UnsignedVarint(len).write(writer)?;
for (tag, data) in &self.0 {
tag.write(writer)?;
let data_len = u64::try_from(data.len()).map_err(WriteError::Overflow)?;
UnsignedVarint(data_len).write(writer)?;
writer.write_all(data)?;
}
Ok(())
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Array<T>(pub Option<Vec<T>>);
impl<R, T> ReadType<R> for Array<T>
where
R: Read,
T: ReadType<R>,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let len = Int32::read(reader)?;
if len.0 == -1 {
Ok(Self(None))
} else {
let len = usize::try_from(len.0)?;
let mut res = VecBuilder::new(len);
for _ in 0..len {
res.push(T::read(reader)?);
}
Ok(Self(Some(res.into())))
}
}
}
impl<W, T> WriteType<W> for Array<T>
where
W: Write,
T: WriteType<W>,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
ArrayRef(self.0.as_deref()).write(writer)
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ArrayRef<'a, T>(pub Option<&'a [T]>);
impl<'a, W, T> WriteType<W> for ArrayRef<'a, T>
where
W: Write,
T: WriteType<W>,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
match self.0 {
None => Int32(-1).write(writer),
Some(inner) => {
let len = i32::try_from(inner.len())?;
Int32(len).write(writer)?;
for element in inner {
element.write(writer)?;
}
Ok(())
}
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct CompactArray<T>(pub Option<Vec<T>>);
impl<R, T> ReadType<R> for CompactArray<T>
where
R: Read,
T: ReadType<R>,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let len = UnsignedVarint::read(reader)?.0;
match len {
0 => Ok(Self(None)),
n => {
let len = usize::try_from(n - 1).map_err(ReadError::Overflow)?;
let mut builder = VecBuilder::new(len);
for _ in 0..len {
builder.push(T::read(reader)?);
}
Ok(Self(Some(builder.into())))
}
}
}
}
impl<W, T> WriteType<W> for CompactArray<T>
where
W: Write,
T: WriteType<W>,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
CompactArrayRef(self.0.as_deref()).write(writer)
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CompactArrayRef<'a, T>(pub Option<&'a [T]>);
impl<'a, W, T> WriteType<W> for CompactArrayRef<'a, T>
where
W: Write,
T: WriteType<W>,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
match self.0 {
None => UnsignedVarint(0).write(writer),
Some(inner) => {
let len = u64::try_from(inner.len() + 1).map_err(WriteError::from)?;
UnsignedVarint(len).write(writer)?;
for element in inner {
element.write(writer)?;
}
Ok(())
}
}
}
}
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Records(
#[cfg_attr(
test,
proptest(strategy = "prop::collection::vec(any::<RecordBatch>(), 0..2)")
)]
pub Vec<RecordBatch>,
);
impl<R> ReadType<R> for Records
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let buf = NullableBytes::read(reader)?.0.unwrap_or_default();
let len = u64::try_from(buf.len())?;
let mut buf = Cursor::new(buf);
let mut batches = vec![];
while buf.position() < len {
let batch = match RecordBatch::read(&mut buf) {
Ok(batch) => batch,
Err(ReadError::IO(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
// Record batch got cut off, likely due to `FetchRequest::max_bytes`.
break;
}
Err(e) => {
return Err(e);
}
};
batches.push(batch);
}
Ok(Self(batches))
}
}
impl<W> WriteType<W> for Records
where
W: Write,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
// TODO: it would be nice if we could avoid the copy here by writing the records and then seeking back.
let mut buf = vec![];
for record in &self.0 {
record.write(&mut buf)?;
}
NullableBytes(Some(buf)).write(writer)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use crate::protocol::{
record::{ControlBatchOrRecords, RecordBatchCompression, RecordBatchTimestampType},
test_utils::test_roundtrip,
};
use super::*;
use assert_matches::assert_matches;
test_roundtrip!(Boolean, test_bool_roundtrip);
#[test]
fn test_boolean_decode() {
assert!(!Boolean::read(&mut Cursor::new(vec![0])).unwrap().0);
// When reading a boolean value, any non-zero value is considered true.
for v in [1, 35, 255] {
assert!(Boolean::read(&mut Cursor::new(vec![v])).unwrap().0);
}
}
test_roundtrip!(Int8, test_int8_roundtrip);
test_roundtrip!(Int16, test_int16_roundtrip);
test_roundtrip!(Int32, test_int32_roundtrip);
test_roundtrip!(Int64, test_int64_roundtrip);
test_roundtrip!(Varint, test_varint_roundtrip);
#[test]
fn test_varint_special_values() {
// Taken from https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints
for v in [0, -1, 1, -2, 2147483647, -2147483648] {
let mut data = vec![];
Varint(v).write(&mut data).unwrap();
let restored = Varint::read(&mut Cursor::new(data)).unwrap();
assert_eq!(restored.0, v);
}
}
#[test]
fn test_varint_read_read_overflow() {
// this should overflow a 64bit bytes varint
let mut buf = Cursor::new(vec![0xffu8; 11]);
let err = Varint::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
assert_eq!(err.to_string(), "Cannot read data: Unterminated varint",);
}
#[test]
fn test_varint_read_downcast_overflow() {
// this should overflow when reading a 64bit varint and casting it down to 32bit
let mut data = vec![0xffu8; 9];
data.push(0x00);
let mut buf = Cursor::new(data);
let err = Varint::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::Overflow(_));
assert_eq!(
err.to_string(),
"Overflow converting integer: out of range integral type conversion attempted",
);
}
test_roundtrip!(Varlong, test_varlong_roundtrip);
#[test]
fn test_varlong_special_values() {
// Taken from https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints + min/max
for v in [0, -1, 1, -2, 2147483647, -2147483648, i64::MIN, i64::MAX] {
let mut data = vec![];
Varlong(v).write(&mut data).unwrap();
let restored = Varlong::read(&mut Cursor::new(data)).unwrap();
assert_eq!(restored.0, v);
}
}
#[test]
fn test_varlong_read_overflow() {
let mut buf = Cursor::new(vec![0xffu8; 11]);
let err = Varlong::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
assert_eq!(err.to_string(), "Cannot read data: Unterminated varint",);
}
test_roundtrip!(UnsignedVarint, test_unsigned_varint_roundtrip);
#[test]
fn test_unsigned_varint_read_overflow() {
let mut buf = Cursor::new(vec![0xffu8; 64 / 7 + 1]);
let err = UnsignedVarint::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::Malformed(_));
assert_eq!(
err.to_string(),
"Malformed data: Overflow while reading unsigned varint",
);
}
test_roundtrip!(String_, test_string_roundtrip);
#[test]
fn test_string_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
Int16(i16::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err = String_::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}
test_roundtrip!(NullableString, test_nullable_string_roundtrip);
#[test]
fn test_nullable_string_read_negative_length() {
let mut buf = Cursor::new(Vec::<u8>::new());
Int16(-2).write(&mut buf).unwrap();
buf.set_position(0);
let err = NullableString::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::Malformed(_));
assert_eq!(
err.to_string(),
"Malformed data: Invalid negative length for nullable string: -2",
);
}
#[test]
fn test_nullable_string_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
Int16(i16::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err = NullableString::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}
test_roundtrip!(CompactString, test_compact_string_roundtrip);
#[test]
fn test_compact_string_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err = CompactString::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}
test_roundtrip!(
CompactNullableString,
test_compact_nullable_string_roundtrip
);
#[test]
fn test_compact_nullable_string_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err = CompactNullableString::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}
test_roundtrip!(NullableBytes, test_nullable_bytes_roundtrip);
#[test]
fn test_nullable_bytes_read_negative_length() {
let mut buf = Cursor::new(Vec::<u8>::new());
Int32(-2).write(&mut buf).unwrap();
buf.set_position(0);
let err = NullableBytes::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::Malformed(_));
assert_eq!(
err.to_string(),
"Malformed data: Invalid negative length for nullable bytes: -2",
);
}
#[test]
fn test_nullable_bytes_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
Int32(i32::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err = NullableBytes::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}
test_roundtrip!(TaggedFields, test_tagged_fields_roundtrip);
#[test]
fn test_tagged_fields_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
// number of fields
UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
// tag
UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
// data length
UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err = TaggedFields::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}
test_roundtrip!(Array<Int32>, test_array_roundtrip);
#[test]
fn test_array_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
Int32(i32::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err = Array::<Large>::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}
test_roundtrip!(CompactArray<Int32>, test_compact_array_roundtrip);
#[test]
fn test_compact_array_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err = CompactArray::<Large>::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}
test_roundtrip!(Records, test_records_roundtrip);
#[test]
fn test_records_partial() {
// Records might be partially returned when fetch requests are issued w/ size limits
let batch_1 = record_batch(1);
let batch_2 = record_batch(2);
let mut buf = vec![];
batch_1.write(&mut buf).unwrap();
batch_2.write(&mut buf).unwrap();
let inner = buf[..buf.len() - 1].to_vec();
let mut buf = vec![];
NullableBytes(Some(inner)).write(&mut buf).unwrap();
let records = Records::read(&mut Cursor::new(buf)).unwrap();
assert_eq!(records.0, vec![batch_1]);
}
fn record_batch(base_offset: i64) -> RecordBatch {
RecordBatch {
base_offset,
partition_leader_epoch: 0,
last_offset_delta: 0,
first_timestamp: 0,
max_timestamp: 0,
producer_id: 0,
producer_epoch: 0,
base_sequence: 0,
records: ControlBatchOrRecords::Records(vec![]),
compression: RecordBatchCompression::NoCompression,
is_transactional: false,
timestamp_type: RecordBatchTimestampType::CreateTime,
}
}
/// A rather large struct here to trigger OOM.
#[derive(Debug)]
struct Large {
_inner: [u8; 1024],
}
impl<R> ReadType<R> for Large
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
Int32::read(reader)?;
unreachable!()
}
}
}