rskafka/protocol/messages/
header.rs1use std::io::{Read, Write};
2
3use crate::protocol::{
4 api_key::ApiKey,
5 api_version::ApiVersion,
6 primitives::{Int16, Int32, NullableString, TaggedFields},
7 traits::{ReadType, WriteType},
8};
9
10use super::{ReadVersionedError, ReadVersionedType, WriteVersionedError, WriteVersionedType};
11
12#[derive(Debug, PartialEq, Eq)]
13#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
14pub struct RequestHeader {
15 pub request_api_key: ApiKey,
17
18 pub request_api_version: ApiVersion,
20
21 pub correlation_id: Int32,
23
24 pub client_id: Option<NullableString>,
28
29 pub tagged_fields: Option<TaggedFields>,
33}
34
35impl<R> ReadVersionedType<R> for RequestHeader
36where
37 R: Read,
38{
39 fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
40 let v = version.0 .0;
41 assert!(v <= 2);
42
43 Ok(Self {
44 request_api_key: ApiKey::from(Int16::read(reader)?),
45 request_api_version: ApiVersion(Int16::read(reader)?),
46 correlation_id: Int32::read(reader)?,
47 client_id: (v >= 1).then(|| NullableString::read(reader)).transpose()?,
48 tagged_fields: (v >= 2).then(|| TaggedFields::read(reader)).transpose()?,
49 })
50 }
51}
52
53impl<W> WriteVersionedType<W> for RequestHeader
54where
55 W: Write,
56{
57 fn write_versioned(
58 &self,
59 writer: &mut W,
60 version: ApiVersion,
61 ) -> Result<(), WriteVersionedError> {
62 let v = version.0 .0;
63 assert!(v <= 2);
64
65 Int16::from(self.request_api_key).write(writer)?;
66 self.request_api_version.0.write(writer)?;
67 self.correlation_id.write(writer)?;
68
69 if v >= 1 {
70 match self.client_id.as_ref() {
71 Some(client_id) => {
72 client_id.write(writer)?;
73 }
74 None => {
75 NullableString::default().write(writer)?;
76 }
77 }
78 }
79
80 if v >= 2 {
81 match self.tagged_fields.as_ref() {
82 Some(tagged_fields) => {
83 tagged_fields.write(writer)?;
84 }
85 None => {
86 TaggedFields::default().write(writer)?;
87 }
88 }
89 }
90
91 Ok(())
92 }
93}
94
95#[derive(Debug, PartialEq, Eq)]
96#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
97pub struct ResponseHeader {
98 pub correlation_id: Int32,
100
101 pub tagged_fields: Option<TaggedFields>,
105}
106
107impl<R> ReadVersionedType<R> for ResponseHeader
108where
109 R: Read,
110{
111 fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
112 let v = version.0 .0;
113 assert!(v <= 1);
114
115 Ok(Self {
116 correlation_id: Int32::read(reader)?,
117 tagged_fields: (v >= 1).then(|| TaggedFields::read(reader)).transpose()?,
118 })
119 }
120}
121
122impl<W> WriteVersionedType<W> for ResponseHeader
124where
125 W: Write,
126{
127 fn write_versioned(
128 &self,
129 writer: &mut W,
130 version: ApiVersion,
131 ) -> Result<(), WriteVersionedError> {
132 let v = version.0 .0;
133 assert!(v <= 1);
134
135 self.correlation_id.write(writer)?;
136
137 if v >= 1 {
138 match self.tagged_fields.as_ref() {
139 Some(tagged_fields) => {
140 tagged_fields.write(writer)?;
141 }
142 None => {
143 TaggedFields::default().write(writer)?;
144 }
145 }
146 }
147
148 Ok(())
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use crate::protocol::messages::test_utils::test_roundtrip_versioned;
155
156 use super::*;
157
158 test_roundtrip_versioned!(
159 RequestHeader,
160 ApiVersion(Int16(0)),
161 ApiVersion(Int16(2)),
162 test_roundtrip_request_header
163 );
164
165 test_roundtrip_versioned!(
166 ResponseHeader,
167 ApiVersion(Int16(0)),
168 ApiVersion(Int16(1)),
169 test_roundtrip_response_header
170 );
171}