use std::io::{Read, Write};
use crate::protocol::{
api_key::ApiKey,
api_version::ApiVersion,
primitives::{Int16, Int32, NullableString, TaggedFields},
traits::{ReadType, WriteType},
};
use super::{ReadVersionedError, ReadVersionedType, WriteVersionedError, WriteVersionedType};
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct RequestHeader {
pub request_api_key: ApiKey,
pub request_api_version: ApiVersion,
pub correlation_id: Int32,
pub client_id: Option<NullableString>,
pub tagged_fields: Option<TaggedFields>,
}
impl<R> ReadVersionedType<R> for RequestHeader
where
R: Read,
{
fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
let v = version.0.0;
assert!(v <= 2);
Ok(Self {
request_api_key: ApiKey::from(Int16::read(reader)?),
request_api_version: ApiVersion(Int16::read(reader)?),
correlation_id: Int32::read(reader)?,
client_id: (v >= 1).then(|| NullableString::read(reader)).transpose()?,
tagged_fields: (v >= 2).then(|| TaggedFields::read(reader)).transpose()?,
})
}
}
impl<W> WriteVersionedType<W> for RequestHeader
where
W: Write,
{
fn write_versioned(
&self,
writer: &mut W,
version: ApiVersion,
) -> Result<(), WriteVersionedError> {
let v = version.0.0;
assert!(v <= 2);
Int16::from(self.request_api_key).write(writer)?;
self.request_api_version.0.write(writer)?;
self.correlation_id.write(writer)?;
if v >= 1 {
match self.client_id.as_ref() {
Some(client_id) => {
client_id.write(writer)?;
}
None => {
NullableString::default().write(writer)?;
}
}
}
if v >= 2 {
match self.tagged_fields.as_ref() {
Some(tagged_fields) => {
tagged_fields.write(writer)?;
}
None => {
TaggedFields::default().write(writer)?;
}
}
}
Ok(())
}
}
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct ResponseHeader {
pub correlation_id: Int32,
pub tagged_fields: Option<TaggedFields>,
}
impl<R> ReadVersionedType<R> for ResponseHeader
where
R: Read,
{
fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
let v = version.0.0;
assert!(v <= 1);
Ok(Self {
correlation_id: Int32::read(reader)?,
tagged_fields: (v >= 1).then(|| TaggedFields::read(reader)).transpose()?,
})
}
}
impl<W> WriteVersionedType<W> for ResponseHeader
where
W: Write,
{
fn write_versioned(
&self,
writer: &mut W,
version: ApiVersion,
) -> Result<(), WriteVersionedError> {
let v = version.0.0;
assert!(v <= 1);
self.correlation_id.write(writer)?;
if v >= 1 {
match self.tagged_fields.as_ref() {
Some(tagged_fields) => {
tagged_fields.write(writer)?;
}
None => {
TaggedFields::default().write(writer)?;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::protocol::messages::test_utils::test_roundtrip_versioned;
use super::*;
test_roundtrip_versioned!(
RequestHeader,
ApiVersion(Int16(0)),
ApiVersion(Int16(2)),
test_roundtrip_request_header
);
test_roundtrip_versioned!(
ResponseHeader,
ApiVersion(Int16(0)),
ApiVersion(Int16(1)),
test_roundtrip_response_header
);
}