rskafka/protocol/messages/
api_versions.rs

1use std::io::{Read, Write};
2
3use crate::protocol::{
4    api_key::ApiKey,
5    api_version::{ApiVersion, ApiVersionRange},
6    error::Error as ApiError,
7    messages::{
8        read_compact_versioned_array, write_compact_versioned_array, write_versioned_array,
9    },
10    primitives::{CompactString, Int16, Int32, TaggedFields},
11    traits::{ReadType, WriteType},
12};
13
14use super::{
15    read_versioned_array, ReadVersionedError, ReadVersionedType, RequestBody, WriteVersionedError,
16    WriteVersionedType,
17};
18
19#[cfg(test)]
20use proptest::prelude::*;
21
22#[derive(Debug, PartialEq, Eq)]
23#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
24pub struct ApiVersionsRequest {
25    /// The name of the client.
26    ///
27    /// Added in version 3.
28    pub client_software_name: Option<CompactString>,
29
30    /// The version of the client.
31    ///
32    /// Added in version 3.
33    pub client_software_version: Option<CompactString>,
34
35    /// The tagged fields.
36    ///
37    /// Added in version 3.
38    pub tagged_fields: Option<TaggedFields>,
39}
40
41impl<R> ReadVersionedType<R> for ApiVersionsRequest
42where
43    R: Read,
44{
45    fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
46        let v = version.0 .0;
47        assert!(v <= 3);
48
49        Ok(Self {
50            client_software_name: (v >= 3).then(|| CompactString::read(reader)).transpose()?,
51            client_software_version: (v >= 3).then(|| CompactString::read(reader)).transpose()?,
52            tagged_fields: (v >= 3).then(|| TaggedFields::read(reader)).transpose()?,
53        })
54    }
55}
56
57impl<W> WriteVersionedType<W> for ApiVersionsRequest
58where
59    W: Write,
60{
61    fn write_versioned(
62        &self,
63        writer: &mut W,
64        version: ApiVersion,
65    ) -> Result<(), WriteVersionedError> {
66        let v = version.0 .0;
67        assert!(v <= 3);
68
69        if v >= 3 {
70            match self.client_software_name.as_ref() {
71                Some(client_software_name) => {
72                    client_software_name.write(writer)?;
73                }
74                None => {
75                    CompactString::default().write(writer)?;
76                }
77            }
78
79            match self.client_software_version.as_ref() {
80                Some(client_software_version) => {
81                    client_software_version.write(writer)?;
82                }
83                None => {
84                    CompactString::default().write(writer)?;
85                }
86            }
87
88            match self.tagged_fields.as_ref() {
89                Some(tagged_fields) => {
90                    tagged_fields.write(writer)?;
91                }
92                None => {
93                    TaggedFields::default().write(writer)?;
94                }
95            }
96        }
97
98        Ok(())
99    }
100}
101
102impl RequestBody for ApiVersionsRequest {
103    type ResponseBody = ApiVersionsResponse;
104    const API_KEY: ApiKey = ApiKey::ApiVersions;
105    const API_VERSION_RANGE: ApiVersionRange =
106        ApiVersionRange::new(ApiVersion(Int16(0)), ApiVersion(Int16(3)));
107    const FIRST_TAGGED_FIELD_IN_REQUEST_VERSION: ApiVersion = ApiVersion(Int16(3));
108
109    // It seems version 3 actually doesn't use tagged fields during response, at least not for Kafka 3.
110    //
111    // rdkafka also does this, see
112    // https://github.com/edenhill/librdkafka/blob/2b76b65212e5efda213961d5f84e565038036270/src/rdkafka_broker.c#L1781-L1785
113    const FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION: ApiVersion = ApiVersion(Int16(i16::MAX));
114}
115
116#[derive(Debug, PartialEq, Eq)]
117#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
118pub struct ApiVersionsResponseApiKey {
119    /// The API index.
120    pub api_key: ApiKey,
121
122    /// The minimum supported version, inclusive.
123    pub min_version: ApiVersion,
124
125    /// The maximum supported version, inclusive.
126    pub max_version: ApiVersion,
127
128    /// The tagged fields.
129    ///
130    /// Added in version 3
131    pub tagged_fields: Option<TaggedFields>,
132}
133
134impl<R> ReadVersionedType<R> for ApiVersionsResponseApiKey
135where
136    R: Read,
137{
138    fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
139        let v = version.0 .0;
140        assert!(v <= 3);
141
142        Ok(Self {
143            api_key: Int16::read(reader)?.into(),
144            min_version: ApiVersion(Int16::read(reader)?),
145            max_version: ApiVersion(Int16::read(reader)?),
146            tagged_fields: (v >= 3).then(|| TaggedFields::read(reader)).transpose()?,
147        })
148    }
149}
150
151// this is not technically required for production but helpful for testing
152impl<W> WriteVersionedType<W> for ApiVersionsResponseApiKey
153where
154    W: Write,
155{
156    fn write_versioned(
157        &self,
158        writer: &mut W,
159        version: ApiVersion,
160    ) -> Result<(), WriteVersionedError> {
161        let v = version.0 .0;
162        assert!(v <= 3);
163
164        let api_key: Int16 = self.api_key.into();
165        api_key.write(writer)?;
166
167        self.min_version.0.write(writer)?;
168        self.max_version.0.write(writer)?;
169
170        if v >= 3 {
171            match self.tagged_fields.as_ref() {
172                Some(tagged_fields) => {
173                    tagged_fields.write(writer)?;
174                }
175                None => {
176                    TaggedFields::default().write(writer)?;
177                }
178            }
179        }
180
181        Ok(())
182    }
183}
184
185#[derive(Debug, PartialEq, Eq)]
186#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
187pub struct ApiVersionsResponse {
188    /// The top-level error code.
189    #[cfg_attr(test, proptest(strategy = "any::<i16>().prop_map(ApiError::new)"))]
190    pub error_code: Option<ApiError>,
191
192    /// The APIs supported by the broker.
193    // tell proptest to only generate small vectors, otherwise tests take forever
194    #[cfg_attr(
195        test,
196        proptest(strategy = "prop::collection::vec(any::<ApiVersionsResponseApiKey>(), 0..2)")
197    )]
198    pub api_keys: Vec<ApiVersionsResponseApiKey>,
199
200    /// The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota.
201    ///
202    /// Added in version 1
203    pub throttle_time_ms: Option<Int32>,
204
205    /// The tagged fields.
206    ///
207    /// Added in version 3
208    pub tagged_fields: Option<TaggedFields>,
209}
210
211impl<R> ReadVersionedType<R> for ApiVersionsResponse
212where
213    R: Read,
214{
215    fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
216        let v = version.0 .0;
217        assert!(v <= 3);
218
219        let error_code = ApiError::new(Int16::read(reader)?.0);
220        let api_keys = if v >= 3 {
221            read_compact_versioned_array(reader, version)?.unwrap_or_default()
222        } else {
223            read_versioned_array(reader, version)?.unwrap_or_default()
224        };
225        let throttle_time_ms = (v >= 1).then(|| Int32::read(reader)).transpose()?;
226        let tagged_fields = (v >= 3).then(|| TaggedFields::read(reader)).transpose()?;
227
228        Ok(Self {
229            error_code,
230            api_keys,
231            throttle_time_ms,
232            tagged_fields,
233        })
234    }
235}
236
237// this is not technically required for production but helpful for testing
238impl<W> WriteVersionedType<W> for ApiVersionsResponse
239where
240    W: Write,
241{
242    fn write_versioned(
243        &self,
244        writer: &mut W,
245        version: ApiVersion,
246    ) -> Result<(), WriteVersionedError> {
247        let v = version.0 .0;
248        assert!(v <= 3);
249
250        let error_code: Int16 = self.error_code.into();
251        error_code.write(writer)?;
252
253        if v >= 3 {
254            write_compact_versioned_array(writer, version, Some(&self.api_keys))?;
255        } else {
256            write_versioned_array(writer, version, Some(&self.api_keys))?;
257        }
258
259        if v >= 1 {
260            // defaults to "no throttle"
261            self.throttle_time_ms.unwrap_or(Int32(0)).write(writer)?;
262        }
263
264        if v >= 3 {
265            match self.tagged_fields.as_ref() {
266                Some(tagged_fields) => {
267                    tagged_fields.write(writer)?;
268                }
269                None => {
270                    TaggedFields::default().write(writer)?;
271                }
272            }
273        }
274
275        Ok(())
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use crate::protocol::messages::test_utils::test_roundtrip_versioned;
282
283    use super::*;
284
285    test_roundtrip_versioned!(
286        ApiVersionsRequest,
287        ApiVersionsRequest::API_VERSION_RANGE.min(),
288        ApiVersionsRequest::API_VERSION_RANGE.max(),
289        test_roundtrip_api_versions_request
290    );
291
292    test_roundtrip_versioned!(
293        ApiVersionsResponse,
294        ApiVersionsRequest::API_VERSION_RANGE.min(),
295        ApiVersionsRequest::API_VERSION_RANGE.max(),
296        test_roundtrip_api_versions_response
297    );
298}