rskafka/protocol/messages/
mod.rs

1//! Individual API messages.
2//!
3//! # References
4//! - <https://kafka.apache.org/protocol#protocol_messages>
5//! - <https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields>
6
7use std::io::{Read, Write};
8
9use thiserror::Error;
10
11use super::{
12    api_key::ApiKey,
13    api_version::{ApiVersion, ApiVersionRange},
14    primitives::{Int32, UnsignedVarint},
15    traits::{ReadError, ReadType, WriteError, WriteType},
16    vec_builder::VecBuilder,
17};
18
19mod api_versions;
20pub use api_versions::*;
21mod constants;
22pub use constants::*;
23mod create_topics;
24pub use create_topics::*;
25mod delete_records;
26pub use delete_records::*;
27mod delete_topics;
28pub use delete_topics::*;
29mod fetch;
30pub use fetch::*;
31mod header;
32pub use header::*;
33mod list_offsets;
34pub use list_offsets::*;
35mod metadata;
36pub use metadata::*;
37mod produce;
38pub use produce::*;
39#[cfg(test)]
40mod test_utils;
41
42#[derive(Error, Debug)]
43#[non_exhaustive]
44pub enum ReadVersionedError {
45    #[error("Read error: {0}")]
46    ReadError(#[from] ReadError),
47}
48
49pub trait ReadVersionedType<R>: Sized
50where
51    R: Read,
52{
53    fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError>;
54}
55
56#[derive(Error, Debug)]
57#[non_exhaustive]
58pub enum WriteVersionedError {
59    #[error("Write error: {0}")]
60    WriteError(#[from] WriteError),
61
62    #[error("Field {field} not available in version: {version:?}")]
63    FieldNotAvailable { field: String, version: ApiVersion },
64}
65
66pub trait WriteVersionedType<W>: Sized
67where
68    W: Write,
69{
70    fn write_versioned(
71        &self,
72        writer: &mut W,
73        version: ApiVersion,
74    ) -> Result<(), WriteVersionedError>;
75}
76
77impl<'a, W: Write, T: WriteVersionedType<W>> WriteVersionedType<W> for &'a T {
78    fn write_versioned(
79        &self,
80        writer: &mut W,
81        version: ApiVersion,
82    ) -> Result<(), WriteVersionedError> {
83        T::write_versioned(self, writer, version)
84    }
85}
86
87/// Specifies a request body.
88pub trait RequestBody {
89    /// The response type that will follow when issuing this request.
90    type ResponseBody;
91
92    /// Kafka API key.
93    ///
94    /// This will be added to the request header.
95    const API_KEY: ApiKey;
96
97    /// Supported version range.
98    ///
99    /// From this range and the range that the broker reports, we will pick the highest version that both support.
100    const API_VERSION_RANGE: ApiVersionRange;
101
102    /// The first version of the messages (not of the header) that uses tagged fields, if any.
103    ///
104    /// To determine the version just look for the `_tagged_fields` or `TAG_BUFFER` in the protocol description.
105    ///
106    /// This will be used to control which request and response header versions will be used.
107    ///
108    /// It's OK to specify a version here that is larger then the highest supported version.
109    const FIRST_TAGGED_FIELD_IN_REQUEST_VERSION: ApiVersion;
110
111    /// Normally the same as [`FIRST_TAGGED_FIELD_IN_REQUEST_VERSION`](Self::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION) but
112    /// there are some special snowflakes.
113    const FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION: ApiVersion =
114        Self::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION;
115}
116
117impl<T: RequestBody> RequestBody for &T {
118    type ResponseBody = T::ResponseBody;
119    const API_KEY: ApiKey = T::API_KEY;
120    const API_VERSION_RANGE: ApiVersionRange = T::API_VERSION_RANGE;
121    const FIRST_TAGGED_FIELD_IN_REQUEST_VERSION: ApiVersion =
122        T::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION;
123    const FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION: ApiVersion =
124        T::FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION;
125}
126
127/// Read an array of versioned objects.
128///
129/// Note that this is normally only used for messages that DO NOT contain tagged fields. All messages with tagged fields
130/// normally use [`read_compact_versioned_array`] to comply with [KIP-482].
131///
132/// [KIP-482]: https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
133fn read_versioned_array<R: Read, T: ReadVersionedType<R>>(
134    reader: &mut R,
135    version: ApiVersion,
136) -> Result<Option<Vec<T>>, ReadVersionedError> {
137    let len = Int32::read(reader)?.0;
138    match len {
139        -1 => Ok(None),
140        l if l < -1 => Err(ReadVersionedError::ReadError(ReadError::Malformed(
141            format!("Invalid negative length for array: {}", l).into(),
142        ))),
143        _ => {
144            let len = usize::try_from(len).map_err(ReadError::Overflow)?;
145            let mut builder = VecBuilder::new(len);
146            for _ in 0..len {
147                builder.push(T::read_versioned(reader, version)?);
148            }
149            Ok(Some(builder.into()))
150        }
151    }
152}
153
154/// Write an array of versioned objects.
155///
156/// Note that this is normally only used for messages that DO NOT contain tagged fields. All messages with tagged fields
157/// normally use [`write_compact_versioned_array`] to comply with [KIP-482].
158///
159/// [KIP-482]: https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
160fn write_versioned_array<W: Write, T: WriteVersionedType<W>>(
161    writer: &mut W,
162    version: ApiVersion,
163    data: Option<&[T]>,
164) -> Result<(), WriteVersionedError> {
165    match data {
166        None => Ok(Int32(-1).write(writer)?),
167        Some(inner) => {
168            let len = i32::try_from(inner.len()).map_err(WriteError::from)?;
169            Int32(len).write(writer)?;
170
171            for element in inner {
172                element.write_versioned(writer, version)?
173            }
174
175            Ok(())
176        }
177    }
178}
179
180/// Read a compact array of versioned objects.
181///
182/// Note that this is normally only used for messages that DO contain tagged fields. All messages without tagged fields
183/// normally use [`read_versioned_array`] to comply with [KIP-482].
184///
185/// [KIP-482]: https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
186fn read_compact_versioned_array<R: Read, T: ReadVersionedType<R>>(
187    reader: &mut R,
188    version: ApiVersion,
189) -> Result<Option<Vec<T>>, ReadVersionedError> {
190    let len = UnsignedVarint::read(reader)?.0;
191    match len {
192        0 => Ok(None),
193        n => {
194            let len = usize::try_from(n - 1).map_err(ReadError::Overflow)?;
195            let mut builder = VecBuilder::new(len);
196            for _ in 0..len {
197                builder.push(T::read_versioned(reader, version)?);
198            }
199            Ok(Some(builder.into()))
200        }
201    }
202}
203
204/// Write a compact array of versioned objects.
205///
206/// Note that this is normally only used for messages that DO contain tagged fields. All messages without tagged fields
207/// normally use [`write_versioned_array`] to comply with [KIP-482].
208///
209/// [KIP-482]: https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
210fn write_compact_versioned_array<W: Write, T: WriteVersionedType<W>>(
211    writer: &mut W,
212    version: ApiVersion,
213    data: Option<&[T]>,
214) -> Result<(), WriteVersionedError> {
215    match data {
216        None => Ok(UnsignedVarint(0).write(writer)?),
217        Some(inner) => {
218            let len = u64::try_from(inner.len() + 1).map_err(WriteError::from)?;
219            UnsignedVarint(len).write(writer)?;
220
221            for element in inner {
222                element.write_versioned(writer, version)?
223            }
224
225            Ok(())
226        }
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use std::io::Cursor;
233
234    use assert_matches::assert_matches;
235
236    use crate::protocol::primitives::Int16;
237
238    use super::*;
239
240    #[derive(Debug, Copy, Clone, PartialEq)]
241    struct VersionTest {
242        version: ApiVersion,
243    }
244
245    impl<W: Write> WriteVersionedType<W> for VersionTest {
246        fn write_versioned(
247            &self,
248            writer: &mut W,
249            version: ApiVersion,
250        ) -> Result<(), WriteVersionedError> {
251            assert_eq!(version, self.version);
252            Int32(42).write(writer)?;
253            Ok(())
254        }
255    }
256
257    impl<R: Read> ReadVersionedType<R> for VersionTest {
258        fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
259            assert_eq!(Int32::read(reader)?.0, 42);
260            Ok(Self { version })
261        }
262    }
263
264    #[test]
265    fn test_read_write_versioned() {
266        for len in [0, 6] {
267            for i in 0..3 {
268                let version = ApiVersion(Int16(i));
269                let test = VersionTest { version };
270                let input = vec![test; len];
271
272                let mut buffer = vec![];
273                write_versioned_array(&mut buffer, version, Some(&input)).unwrap();
274
275                let mut cursor = std::io::Cursor::new(buffer);
276                let output = read_versioned_array(&mut cursor, version).unwrap().unwrap();
277
278                assert_eq!(input, output);
279            }
280        }
281
282        let version = ApiVersion(Int16(0));
283        let mut buffer = vec![];
284        write_versioned_array::<_, VersionTest>(&mut buffer, version, None).unwrap();
285        let mut cursor = std::io::Cursor::new(buffer);
286        assert!(read_versioned_array::<_, VersionTest>(&mut cursor, version)
287            .unwrap()
288            .is_none())
289    }
290
291    #[test]
292    fn test_read_versioned_blowup_memory() {
293        let mut buf = Cursor::new(Vec::<u8>::new());
294        Int32(i32::MAX).write(&mut buf).unwrap();
295        buf.set_position(0);
296
297        let err =
298            read_versioned_array::<_, VersionTest>(&mut buf, ApiVersion(Int16(42))).unwrap_err();
299        assert_matches!(err, ReadVersionedError::ReadError(ReadError::IO(_)));
300    }
301
302    #[test]
303    fn test_read_write_compact_versioned() {
304        for len in [0, 6] {
305            for i in 0..3 {
306                let version = ApiVersion(Int16(i));
307                let test = VersionTest { version };
308                let input = vec![test; len];
309
310                let mut buffer = vec![];
311                write_compact_versioned_array(&mut buffer, version, Some(&input)).unwrap();
312
313                let mut cursor = std::io::Cursor::new(buffer);
314                let output = read_compact_versioned_array(&mut cursor, version)
315                    .unwrap()
316                    .unwrap();
317
318                assert_eq!(input, output);
319            }
320        }
321
322        let version = ApiVersion(Int16(0));
323        let mut buffer = vec![];
324        write_compact_versioned_array::<_, VersionTest>(&mut buffer, version, None).unwrap();
325        let mut cursor = std::io::Cursor::new(buffer);
326        assert!(
327            read_compact_versioned_array::<_, VersionTest>(&mut cursor, version)
328                .unwrap()
329                .is_none()
330        )
331    }
332
333    #[test]
334    fn test_read_compact_versioned_blowup_memory() {
335        let mut buf = Cursor::new(Vec::<u8>::new());
336        UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
337        buf.set_position(0);
338
339        let err = read_compact_versioned_array::<_, VersionTest>(&mut buf, ApiVersion(Int16(42)))
340            .unwrap_err();
341        assert_matches!(err, ReadVersionedError::ReadError(ReadError::IO(_)));
342    }
343}