use std::io::{Read, Write};
use thiserror::Error;
use super::{
api_key::ApiKey,
api_version::{ApiVersion, ApiVersionRange},
primitives::{Int32, UnsignedVarint},
traits::{ReadError, ReadType, WriteError, WriteType},
vec_builder::VecBuilder,
};
mod api_versions;
pub use api_versions::*;
mod constants;
pub use constants::*;
mod create_topics;
pub use create_topics::*;
mod delete_records;
pub use delete_records::*;
mod delete_topics;
pub use delete_topics::*;
mod fetch;
pub use fetch::*;
mod header;
pub use header::*;
mod list_offsets;
pub use list_offsets::*;
mod metadata;
pub use metadata::*;
mod produce;
pub use produce::*;
mod sasl_msg;
pub use sasl_msg::*;
#[cfg(test)]
mod test_utils;
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ReadVersionedError {
#[error("Read error: {0}")]
ReadError(#[from] ReadError),
}
pub trait ReadVersionedType<R>: Sized
where
R: Read,
{
fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError>;
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum WriteVersionedError {
#[error("Write error: {0}")]
WriteError(#[from] WriteError),
#[error("Field {field} not available in version: {version:?}")]
FieldNotAvailable { field: String, version: ApiVersion },
}
pub trait WriteVersionedType<W>: Sized
where
W: Write,
{
fn write_versioned(
&self,
writer: &mut W,
version: ApiVersion,
) -> Result<(), WriteVersionedError>;
}
impl<W: Write, T: WriteVersionedType<W>> WriteVersionedType<W> for &T {
fn write_versioned(
&self,
writer: &mut W,
version: ApiVersion,
) -> Result<(), WriteVersionedError> {
T::write_versioned(self, writer, version)
}
}
pub trait RequestBody {
type ResponseBody;
const API_KEY: ApiKey;
const API_VERSION_RANGE: ApiVersionRange;
const FIRST_TAGGED_FIELD_IN_REQUEST_VERSION: ApiVersion;
const FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION: ApiVersion =
Self::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION;
}
impl<T: RequestBody> RequestBody for &T {
type ResponseBody = T::ResponseBody;
const API_KEY: ApiKey = T::API_KEY;
const API_VERSION_RANGE: ApiVersionRange = T::API_VERSION_RANGE;
const FIRST_TAGGED_FIELD_IN_REQUEST_VERSION: ApiVersion =
T::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION;
const FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION: ApiVersion =
T::FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION;
}
fn read_versioned_array<R: Read, T: ReadVersionedType<R>>(
reader: &mut R,
version: ApiVersion,
) -> Result<Option<Vec<T>>, ReadVersionedError> {
let len = Int32::read(reader)?.0;
match len {
-1 => Ok(None),
l if l < -1 => Err(ReadVersionedError::ReadError(ReadError::Malformed(
format!("Invalid negative length for array: {}", l).into(),
))),
_ => {
let len = usize::try_from(len).map_err(ReadError::Overflow)?;
let mut builder = VecBuilder::new(len);
for _ in 0..len {
builder.push(T::read_versioned(reader, version)?);
}
Ok(Some(builder.into()))
}
}
}
fn write_versioned_array<W: Write, T: WriteVersionedType<W>>(
writer: &mut W,
version: ApiVersion,
data: Option<&[T]>,
) -> Result<(), WriteVersionedError> {
match data {
None => Ok(Int32(-1).write(writer)?),
Some(inner) => {
let len = i32::try_from(inner.len()).map_err(WriteError::from)?;
Int32(len).write(writer)?;
for element in inner {
element.write_versioned(writer, version)?
}
Ok(())
}
}
}
fn read_compact_versioned_array<R: Read, T: ReadVersionedType<R>>(
reader: &mut R,
version: ApiVersion,
) -> Result<Option<Vec<T>>, ReadVersionedError> {
let len = UnsignedVarint::read(reader)?.0;
match len {
0 => Ok(None),
n => {
let len = usize::try_from(n - 1).map_err(ReadError::Overflow)?;
let mut builder = VecBuilder::new(len);
for _ in 0..len {
builder.push(T::read_versioned(reader, version)?);
}
Ok(Some(builder.into()))
}
}
}
fn write_compact_versioned_array<W: Write, T: WriteVersionedType<W>>(
writer: &mut W,
version: ApiVersion,
data: Option<&[T]>,
) -> Result<(), WriteVersionedError> {
match data {
None => Ok(UnsignedVarint(0).write(writer)?),
Some(inner) => {
let len = u64::try_from(inner.len() + 1).map_err(WriteError::from)?;
UnsignedVarint(len).write(writer)?;
for element in inner {
element.write_versioned(writer, version)?
}
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use assert_matches::assert_matches;
use crate::protocol::primitives::Int16;
use super::*;
#[derive(Debug, Copy, Clone, PartialEq)]
struct VersionTest {
version: ApiVersion,
}
impl<W: Write> WriteVersionedType<W> for VersionTest {
fn write_versioned(
&self,
writer: &mut W,
version: ApiVersion,
) -> Result<(), WriteVersionedError> {
assert_eq!(version, self.version);
Int32(42).write(writer)?;
Ok(())
}
}
impl<R: Read> ReadVersionedType<R> for VersionTest {
fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
assert_eq!(Int32::read(reader)?.0, 42);
Ok(Self { version })
}
}
#[test]
fn test_read_write_versioned() {
for len in [0, 6] {
for i in 0..3 {
let version = ApiVersion(Int16(i));
let test = VersionTest { version };
let input = vec![test; len];
let mut buffer = vec![];
write_versioned_array(&mut buffer, version, Some(&input)).unwrap();
let mut cursor = std::io::Cursor::new(buffer);
let output = read_versioned_array(&mut cursor, version).unwrap().unwrap();
assert_eq!(input, output);
}
}
let version = ApiVersion(Int16(0));
let mut buffer = vec![];
write_versioned_array::<_, VersionTest>(&mut buffer, version, None).unwrap();
let mut cursor = std::io::Cursor::new(buffer);
assert!(
read_versioned_array::<_, VersionTest>(&mut cursor, version)
.unwrap()
.is_none()
)
}
#[test]
fn test_read_versioned_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
Int32(i32::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err =
read_versioned_array::<_, VersionTest>(&mut buf, ApiVersion(Int16(42))).unwrap_err();
assert_matches!(err, ReadVersionedError::ReadError(ReadError::IO(_)));
}
#[test]
fn test_read_write_compact_versioned() {
for len in [0, 6] {
for i in 0..3 {
let version = ApiVersion(Int16(i));
let test = VersionTest { version };
let input = vec![test; len];
let mut buffer = vec![];
write_compact_versioned_array(&mut buffer, version, Some(&input)).unwrap();
let mut cursor = std::io::Cursor::new(buffer);
let output = read_compact_versioned_array(&mut cursor, version)
.unwrap()
.unwrap();
assert_eq!(input, output);
}
}
let version = ApiVersion(Int16(0));
let mut buffer = vec![];
write_compact_versioned_array::<_, VersionTest>(&mut buffer, version, None).unwrap();
let mut cursor = std::io::Cursor::new(buffer);
assert!(
read_compact_versioned_array::<_, VersionTest>(&mut cursor, version)
.unwrap()
.is_none()
)
}
#[test]
fn test_read_compact_versioned_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
buf.set_position(0);
let err = read_compact_versioned_array::<_, VersionTest>(&mut buf, ApiVersion(Int16(42)))
.unwrap_err();
assert_matches!(err, ReadVersionedError::ReadError(ReadError::IO(_)));
}
}