rskafka/protocol/messages/
header.rs

1use std::io::{Read, Write};
2
3use crate::protocol::{
4    api_key::ApiKey,
5    api_version::ApiVersion,
6    primitives::{Int16, Int32, NullableString, TaggedFields},
7    traits::{ReadType, WriteType},
8};
9
10use super::{ReadVersionedError, ReadVersionedType, WriteVersionedError, WriteVersionedType};
11
12#[derive(Debug, PartialEq, Eq)]
13#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
14pub struct RequestHeader {
15    /// The API key of this request.
16    pub request_api_key: ApiKey,
17
18    /// The API version of this request.
19    pub request_api_version: ApiVersion,
20
21    /// The correlation ID of this request.
22    pub correlation_id: Int32,
23
24    /// The client ID string.
25    ///
26    /// Added in version 1.
27    pub client_id: Option<NullableString>,
28
29    /// The tagged fields.
30    ///
31    /// Added in version 2.
32    pub tagged_fields: Option<TaggedFields>,
33}
34
35impl<R> ReadVersionedType<R> for RequestHeader
36where
37    R: Read,
38{
39    fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
40        let v = version.0 .0;
41        assert!(v <= 2);
42
43        Ok(Self {
44            request_api_key: ApiKey::from(Int16::read(reader)?),
45            request_api_version: ApiVersion(Int16::read(reader)?),
46            correlation_id: Int32::read(reader)?,
47            client_id: (v >= 1).then(|| NullableString::read(reader)).transpose()?,
48            tagged_fields: (v >= 2).then(|| TaggedFields::read(reader)).transpose()?,
49        })
50    }
51}
52
53impl<W> WriteVersionedType<W> for RequestHeader
54where
55    W: Write,
56{
57    fn write_versioned(
58        &self,
59        writer: &mut W,
60        version: ApiVersion,
61    ) -> Result<(), WriteVersionedError> {
62        let v = version.0 .0;
63        assert!(v <= 2);
64
65        Int16::from(self.request_api_key).write(writer)?;
66        self.request_api_version.0.write(writer)?;
67        self.correlation_id.write(writer)?;
68
69        if v >= 1 {
70            match self.client_id.as_ref() {
71                Some(client_id) => {
72                    client_id.write(writer)?;
73                }
74                None => {
75                    NullableString::default().write(writer)?;
76                }
77            }
78        }
79
80        if v >= 2 {
81            match self.tagged_fields.as_ref() {
82                Some(tagged_fields) => {
83                    tagged_fields.write(writer)?;
84                }
85                None => {
86                    TaggedFields::default().write(writer)?;
87                }
88            }
89        }
90
91        Ok(())
92    }
93}
94
95#[derive(Debug, PartialEq, Eq)]
96#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
97pub struct ResponseHeader {
98    /// The correlation ID of this response.
99    pub correlation_id: Int32,
100
101    /// The tagged fields.
102    ///
103    /// Added in version 1.
104    pub tagged_fields: Option<TaggedFields>,
105}
106
107impl<R> ReadVersionedType<R> for ResponseHeader
108where
109    R: Read,
110{
111    fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
112        let v = version.0 .0;
113        assert!(v <= 1);
114
115        Ok(Self {
116            correlation_id: Int32::read(reader)?,
117            tagged_fields: (v >= 1).then(|| TaggedFields::read(reader)).transpose()?,
118        })
119    }
120}
121
122// this is not technically required for production but helpful for testing
123impl<W> WriteVersionedType<W> for ResponseHeader
124where
125    W: Write,
126{
127    fn write_versioned(
128        &self,
129        writer: &mut W,
130        version: ApiVersion,
131    ) -> Result<(), WriteVersionedError> {
132        let v = version.0 .0;
133        assert!(v <= 1);
134
135        self.correlation_id.write(writer)?;
136
137        if v >= 1 {
138            match self.tagged_fields.as_ref() {
139                Some(tagged_fields) => {
140                    tagged_fields.write(writer)?;
141                }
142                None => {
143                    TaggedFields::default().write(writer)?;
144                }
145            }
146        }
147
148        Ok(())
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use crate::protocol::messages::test_utils::test_roundtrip_versioned;
155
156    use super::*;
157
158    test_roundtrip_versioned!(
159        RequestHeader,
160        ApiVersion(Int16(0)),
161        ApiVersion(Int16(2)),
162        test_roundtrip_request_header
163    );
164
165    test_roundtrip_versioned!(
166        ResponseHeader,
167        ApiVersion(Int16(0)),
168        ApiVersion(Int16(1)),
169        test_roundtrip_response_header
170    );
171}