rskafka/protocol/messages/
mod.rs1use std::io::{Read, Write};
8
9use thiserror::Error;
10
11use super::{
12 api_key::ApiKey,
13 api_version::{ApiVersion, ApiVersionRange},
14 primitives::{Int32, UnsignedVarint},
15 traits::{ReadError, ReadType, WriteError, WriteType},
16 vec_builder::VecBuilder,
17};
18
19mod api_versions;
20pub use api_versions::*;
21mod constants;
22pub use constants::*;
23mod create_topics;
24pub use create_topics::*;
25mod delete_records;
26pub use delete_records::*;
27mod delete_topics;
28pub use delete_topics::*;
29mod fetch;
30pub use fetch::*;
31mod header;
32pub use header::*;
33mod list_offsets;
34pub use list_offsets::*;
35mod metadata;
36pub use metadata::*;
37mod produce;
38pub use produce::*;
39#[cfg(test)]
40mod test_utils;
41
42#[derive(Error, Debug)]
43#[non_exhaustive]
44pub enum ReadVersionedError {
45 #[error("Read error: {0}")]
46 ReadError(#[from] ReadError),
47}
48
49pub trait ReadVersionedType<R>: Sized
50where
51 R: Read,
52{
53 fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError>;
54}
55
56#[derive(Error, Debug)]
57#[non_exhaustive]
58pub enum WriteVersionedError {
59 #[error("Write error: {0}")]
60 WriteError(#[from] WriteError),
61
62 #[error("Field {field} not available in version: {version:?}")]
63 FieldNotAvailable { field: String, version: ApiVersion },
64}
65
66pub trait WriteVersionedType<W>: Sized
67where
68 W: Write,
69{
70 fn write_versioned(
71 &self,
72 writer: &mut W,
73 version: ApiVersion,
74 ) -> Result<(), WriteVersionedError>;
75}
76
77impl<'a, W: Write, T: WriteVersionedType<W>> WriteVersionedType<W> for &'a T {
78 fn write_versioned(
79 &self,
80 writer: &mut W,
81 version: ApiVersion,
82 ) -> Result<(), WriteVersionedError> {
83 T::write_versioned(self, writer, version)
84 }
85}
86
87pub trait RequestBody {
89 type ResponseBody;
91
92 const API_KEY: ApiKey;
96
97 const API_VERSION_RANGE: ApiVersionRange;
101
102 const FIRST_TAGGED_FIELD_IN_REQUEST_VERSION: ApiVersion;
110
111 const FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION: ApiVersion =
114 Self::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION;
115}
116
117impl<T: RequestBody> RequestBody for &T {
118 type ResponseBody = T::ResponseBody;
119 const API_KEY: ApiKey = T::API_KEY;
120 const API_VERSION_RANGE: ApiVersionRange = T::API_VERSION_RANGE;
121 const FIRST_TAGGED_FIELD_IN_REQUEST_VERSION: ApiVersion =
122 T::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION;
123 const FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION: ApiVersion =
124 T::FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION;
125}
126
127fn read_versioned_array<R: Read, T: ReadVersionedType<R>>(
134 reader: &mut R,
135 version: ApiVersion,
136) -> Result<Option<Vec<T>>, ReadVersionedError> {
137 let len = Int32::read(reader)?.0;
138 match len {
139 -1 => Ok(None),
140 l if l < -1 => Err(ReadVersionedError::ReadError(ReadError::Malformed(
141 format!("Invalid negative length for array: {}", l).into(),
142 ))),
143 _ => {
144 let len = usize::try_from(len).map_err(ReadError::Overflow)?;
145 let mut builder = VecBuilder::new(len);
146 for _ in 0..len {
147 builder.push(T::read_versioned(reader, version)?);
148 }
149 Ok(Some(builder.into()))
150 }
151 }
152}
153
154fn write_versioned_array<W: Write, T: WriteVersionedType<W>>(
161 writer: &mut W,
162 version: ApiVersion,
163 data: Option<&[T]>,
164) -> Result<(), WriteVersionedError> {
165 match data {
166 None => Ok(Int32(-1).write(writer)?),
167 Some(inner) => {
168 let len = i32::try_from(inner.len()).map_err(WriteError::from)?;
169 Int32(len).write(writer)?;
170
171 for element in inner {
172 element.write_versioned(writer, version)?
173 }
174
175 Ok(())
176 }
177 }
178}
179
180fn read_compact_versioned_array<R: Read, T: ReadVersionedType<R>>(
187 reader: &mut R,
188 version: ApiVersion,
189) -> Result<Option<Vec<T>>, ReadVersionedError> {
190 let len = UnsignedVarint::read(reader)?.0;
191 match len {
192 0 => Ok(None),
193 n => {
194 let len = usize::try_from(n - 1).map_err(ReadError::Overflow)?;
195 let mut builder = VecBuilder::new(len);
196 for _ in 0..len {
197 builder.push(T::read_versioned(reader, version)?);
198 }
199 Ok(Some(builder.into()))
200 }
201 }
202}
203
204fn write_compact_versioned_array<W: Write, T: WriteVersionedType<W>>(
211 writer: &mut W,
212 version: ApiVersion,
213 data: Option<&[T]>,
214) -> Result<(), WriteVersionedError> {
215 match data {
216 None => Ok(UnsignedVarint(0).write(writer)?),
217 Some(inner) => {
218 let len = u64::try_from(inner.len() + 1).map_err(WriteError::from)?;
219 UnsignedVarint(len).write(writer)?;
220
221 for element in inner {
222 element.write_versioned(writer, version)?
223 }
224
225 Ok(())
226 }
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use std::io::Cursor;
233
234 use assert_matches::assert_matches;
235
236 use crate::protocol::primitives::Int16;
237
238 use super::*;
239
240 #[derive(Debug, Copy, Clone, PartialEq)]
241 struct VersionTest {
242 version: ApiVersion,
243 }
244
245 impl<W: Write> WriteVersionedType<W> for VersionTest {
246 fn write_versioned(
247 &self,
248 writer: &mut W,
249 version: ApiVersion,
250 ) -> Result<(), WriteVersionedError> {
251 assert_eq!(version, self.version);
252 Int32(42).write(writer)?;
253 Ok(())
254 }
255 }
256
257 impl<R: Read> ReadVersionedType<R> for VersionTest {
258 fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
259 assert_eq!(Int32::read(reader)?.0, 42);
260 Ok(Self { version })
261 }
262 }
263
264 #[test]
265 fn test_read_write_versioned() {
266 for len in [0, 6] {
267 for i in 0..3 {
268 let version = ApiVersion(Int16(i));
269 let test = VersionTest { version };
270 let input = vec![test; len];
271
272 let mut buffer = vec![];
273 write_versioned_array(&mut buffer, version, Some(&input)).unwrap();
274
275 let mut cursor = std::io::Cursor::new(buffer);
276 let output = read_versioned_array(&mut cursor, version).unwrap().unwrap();
277
278 assert_eq!(input, output);
279 }
280 }
281
282 let version = ApiVersion(Int16(0));
283 let mut buffer = vec![];
284 write_versioned_array::<_, VersionTest>(&mut buffer, version, None).unwrap();
285 let mut cursor = std::io::Cursor::new(buffer);
286 assert!(read_versioned_array::<_, VersionTest>(&mut cursor, version)
287 .unwrap()
288 .is_none())
289 }
290
291 #[test]
292 fn test_read_versioned_blowup_memory() {
293 let mut buf = Cursor::new(Vec::<u8>::new());
294 Int32(i32::MAX).write(&mut buf).unwrap();
295 buf.set_position(0);
296
297 let err =
298 read_versioned_array::<_, VersionTest>(&mut buf, ApiVersion(Int16(42))).unwrap_err();
299 assert_matches!(err, ReadVersionedError::ReadError(ReadError::IO(_)));
300 }
301
302 #[test]
303 fn test_read_write_compact_versioned() {
304 for len in [0, 6] {
305 for i in 0..3 {
306 let version = ApiVersion(Int16(i));
307 let test = VersionTest { version };
308 let input = vec![test; len];
309
310 let mut buffer = vec![];
311 write_compact_versioned_array(&mut buffer, version, Some(&input)).unwrap();
312
313 let mut cursor = std::io::Cursor::new(buffer);
314 let output = read_compact_versioned_array(&mut cursor, version)
315 .unwrap()
316 .unwrap();
317
318 assert_eq!(input, output);
319 }
320 }
321
322 let version = ApiVersion(Int16(0));
323 let mut buffer = vec![];
324 write_compact_versioned_array::<_, VersionTest>(&mut buffer, version, None).unwrap();
325 let mut cursor = std::io::Cursor::new(buffer);
326 assert!(
327 read_compact_versioned_array::<_, VersionTest>(&mut cursor, version)
328 .unwrap()
329 .is_none()
330 )
331 }
332
333 #[test]
334 fn test_read_compact_versioned_blowup_memory() {
335 let mut buf = Cursor::new(Vec::<u8>::new());
336 UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
337 buf.set_position(0);
338
339 let err = read_compact_versioned_array::<_, VersionTest>(&mut buf, ApiVersion(Int16(42)))
340 .unwrap_err();
341 assert_matches!(err, ReadVersionedError::ReadError(ReadError::IO(_)));
342 }
343}