use bytes::{Buf, BufMut};
use super::api::ApiKey;
use super::primitives::{Decode, Encode, KafkaString, TaggedFields, TryEncode};
use crate::error::{KrafkaError, ProtocolErrorKind, Result};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct RequestHeader {
pub api_key: ApiKey,
pub api_version: i16,
pub correlation_id: i32,
pub client_id: Option<KafkaString>,
}
impl RequestHeader {
pub fn new(api_key: ApiKey, api_version: i16, correlation_id: i32) -> Self {
Self {
api_key,
api_version,
correlation_id,
client_id: None,
}
}
pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
self.client_id = Some(KafkaString::new(client_id));
self
}
#[inline]
pub fn encode_v1(&self, buf: &mut impl BufMut) -> Result<()> {
self.api_key.encode(buf);
self.api_version.encode(buf);
self.correlation_id.encode(buf);
match &self.client_id {
Some(client_id) => client_id.try_encode(buf)?,
None => KafkaString::null().try_encode(buf)?,
}
Ok(())
}
#[inline]
pub fn encode_v2(&self, buf: &mut impl BufMut) -> Result<()> {
self.api_key.encode(buf);
self.api_version.encode(buf);
self.correlation_id.encode(buf);
match &self.client_id {
Some(client_id) => client_id.try_encode(buf)?,
None => KafkaString::null().try_encode(buf)?,
}
TaggedFields::default().try_encode(buf)?;
Ok(())
}
pub fn header_version(api_key: ApiKey, api_version: i16) -> i16 {
if api_version >= api_key.flexible_version() {
2
} else {
1
}
}
pub fn encode(&self, buf: &mut impl BufMut) -> Result<()> {
let header_version = Self::header_version(self.api_key, self.api_version);
match header_version {
1 => self.encode_v1(buf)?,
2 => self.encode_v2(buf)?,
v => {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::UnknownApiVersion,
format!("unsupported request header version {v}"),
));
}
}
Ok(())
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ResponseHeader {
pub correlation_id: i32,
}
impl ResponseHeader {
pub fn new(correlation_id: i32) -> Self {
Self { correlation_id }
}
#[inline]
pub fn decode_v0(buf: &mut impl Buf) -> Result<Self> {
Ok(Self {
correlation_id: i32::decode(buf)?,
})
}
#[inline]
pub fn decode_v1(buf: &mut impl Buf) -> Result<Self> {
let correlation_id = i32::decode(buf)?;
let _ = TaggedFields::decode(buf)?;
Ok(Self { correlation_id })
}
pub fn header_version(api_key: ApiKey, api_version: i16) -> i16 {
if api_key == ApiKey::ApiVersions {
return 0;
}
if api_version >= api_key.flexible_version() {
1
} else {
0
}
}
pub fn decode(buf: &mut impl Buf, api_key: ApiKey, api_version: i16) -> Result<Self> {
let header_version = Self::header_version(api_key, api_version);
match header_version {
0 => Self::decode_v0(buf),
1 => Self::decode_v1(buf),
v => Err(KrafkaError::protocol_kind(
ProtocolErrorKind::UnknownApiVersion,
format!("unsupported response header version {v}"),
)),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use bytes::BytesMut;
use super::*;
#[test]
fn test_request_header_v1_without_client_id() {
let header = RequestHeader::new(ApiKey::ApiVersions, 0, 1);
let mut buf = BytesMut::new();
header.encode_v1(&mut buf).unwrap();
let mut buf = buf.freeze();
assert_eq!(i16::decode(&mut buf).unwrap(), 18); assert_eq!(i16::decode(&mut buf).unwrap(), 0); assert_eq!(i32::decode(&mut buf).unwrap(), 1); let client_id = KafkaString::decode(&mut buf).unwrap();
assert!(client_id.is_null());
}
#[test]
fn test_request_header_v1() {
let header = RequestHeader::new(ApiKey::Metadata, 0, 42).with_client_id("test-client");
let mut buf = BytesMut::new();
header.encode_v1(&mut buf).unwrap();
let mut buf = buf.freeze();
assert_eq!(i16::decode(&mut buf).unwrap(), 3); assert_eq!(i16::decode(&mut buf).unwrap(), 0); assert_eq!(i32::decode(&mut buf).unwrap(), 42); let client_id = KafkaString::decode(&mut buf).unwrap();
assert_eq!(client_id.as_str(), Some("test-client"));
}
#[test]
fn test_request_header_v2_client_id_uses_standard_encoding() {
let header = RequestHeader::new(ApiKey::Metadata, 12, 99).with_client_id("krafka-client");
let mut buf = BytesMut::new();
header.encode_v2(&mut buf).unwrap();
let mut buf = buf.freeze();
assert_eq!(i16::decode(&mut buf).unwrap(), 3); assert_eq!(i16::decode(&mut buf).unwrap(), 12); assert_eq!(i32::decode(&mut buf).unwrap(), 99); let client_id = KafkaString::decode(&mut buf).unwrap();
assert_eq!(client_id.as_str(), Some("krafka-client"));
let tf = TaggedFields::decode(&mut buf).unwrap();
assert!(tf.0.is_empty());
assert!(!buf.has_remaining(), "no trailing bytes");
}
#[test]
fn test_request_header_v2_without_client_id() {
let header = RequestHeader::new(ApiKey::Metadata, 12, 1);
let mut buf = BytesMut::new();
header.encode_v2(&mut buf).unwrap();
let mut buf = buf.freeze();
assert_eq!(i16::decode(&mut buf).unwrap(), 3);
assert_eq!(i16::decode(&mut buf).unwrap(), 12);
assert_eq!(i32::decode(&mut buf).unwrap(), 1);
let client_id = KafkaString::decode(&mut buf).unwrap();
assert!(client_id.is_null());
let _ = TaggedFields::decode(&mut buf).unwrap();
assert!(!buf.has_remaining());
}
#[test]
fn test_response_header_v0() {
let mut buf = BytesMut::new();
42i32.encode(&mut buf);
let header = ResponseHeader::decode_v0(&mut buf.freeze()).unwrap();
assert_eq!(header.correlation_id, 42);
}
#[test]
fn test_response_header_v1() {
let mut buf = BytesMut::new();
42i32.encode(&mut buf);
TaggedFields::default().try_encode(&mut buf).unwrap();
let header = ResponseHeader::decode_v1(&mut buf.freeze()).unwrap();
assert_eq!(header.correlation_id, 42);
}
#[test]
fn test_header_version_api_versions() {
assert_eq!(RequestHeader::header_version(ApiKey::ApiVersions, 0), 1);
assert_eq!(RequestHeader::header_version(ApiKey::ApiVersions, 2), 1);
assert_eq!(RequestHeader::header_version(ApiKey::ApiVersions, 3), 2);
assert_eq!(ResponseHeader::header_version(ApiKey::ApiVersions, 0), 0);
assert_eq!(ResponseHeader::header_version(ApiKey::ApiVersions, 3), 0);
}
#[test]
fn test_header_version_fetch() {
for v in 0..12 {
assert_eq!(
RequestHeader::header_version(ApiKey::Fetch, v),
1,
"Fetch v{v} request header should be v1 (non-flexible)"
);
assert_eq!(
ResponseHeader::header_version(ApiKey::Fetch, v),
0,
"Fetch v{v} response header should be v0 (non-flexible)"
);
}
assert_eq!(RequestHeader::header_version(ApiKey::Fetch, 12), 2);
assert_eq!(ResponseHeader::header_version(ApiKey::Fetch, 12), 1);
}
#[test]
fn test_header_version_flexible_boundaries() {
let apis: &[ApiKey] = &[
ApiKey::Produce,
ApiKey::Fetch,
ApiKey::ListOffsets,
ApiKey::Metadata,
ApiKey::OffsetCommit,
ApiKey::OffsetFetch,
ApiKey::FindCoordinator,
ApiKey::JoinGroup,
ApiKey::Heartbeat,
ApiKey::LeaveGroup,
ApiKey::SyncGroup,
ApiKey::DescribeGroups,
ApiKey::ListGroups,
ApiKey::CreateTopics,
ApiKey::DeleteTopics,
ApiKey::DeleteRecords,
ApiKey::InitProducerId,
ApiKey::OffsetForLeaderEpoch,
ApiKey::AddPartitionsToTxn,
ApiKey::AddOffsetsToTxn,
ApiKey::EndTxn,
ApiKey::TxnOffsetCommit,
ApiKey::DescribeAcls,
ApiKey::CreateAcls,
ApiKey::DeleteAcls,
ApiKey::DescribeConfigs,
ApiKey::AlterConfigs,
ApiKey::CreatePartitions,
ApiKey::ApiVersions,
];
for &api in apis {
let flex = api.flexible_version();
if flex > 0 {
let before = flex - 1;
assert_eq!(
RequestHeader::header_version(api, before),
1,
"{api:?} v{before} request header should be v1"
);
assert_eq!(
ResponseHeader::header_version(api, before),
0,
"{api:?} v{before} response header should be v0"
);
}
assert_eq!(
RequestHeader::header_version(api, flex),
2,
"{api:?} v{flex} request header should be v2"
);
let expected_resp = if api == ApiKey::ApiVersions { 0 } else { 1 };
assert_eq!(
ResponseHeader::header_version(api, flex),
expected_resp,
"{api:?} v{flex} response header mismatch"
);
}
}
}