kafka_protocol/protocol/
mod.rs

1//! Most types are used internally in encoding/decoding, and are not required by typical use cases
2//! for interacting with the protocol. However, types can be used for decoding partial messages,
3//! or rewriting parts of an encoded message.
4use std::cmp;
5use std::collections::BTreeMap;
6use std::ops::RangeBounds;
7use std::{borrow::Borrow, fmt::Display};
8
9use anyhow::{bail, Result};
10use buf::{ByteBuf, ByteBufMut};
11use bytes::Bytes;
12
13pub mod buf;
14pub mod types;
15
16mod str_bytes {
17    use bytes::Bytes;
18    use std::borrow::Borrow;
19    use std::convert::TryFrom;
20    use std::fmt::{Debug, Display, Formatter};
21    use std::ops::Deref;
22    use std::str::Utf8Error;
23
24    /// A string type backed by [Bytes].
25    #[derive(Clone, Hash, Ord, PartialOrd, PartialEq, Eq, Default)]
26    pub struct StrBytes(Bytes);
27
28    impl StrBytes {
29        /// Construct a [StrBytes] from the given [Bytes] instance,
30        /// checking that it contains valid UTF-8 data.
31        pub fn from_utf8(bytes: Bytes) -> Result<Self, Utf8Error> {
32            let _: &str = std::str::from_utf8(&bytes)?;
33            Ok(Self(bytes))
34        }
35
36        /// Construct a [StrBytes] from the provided static [str].
37        pub fn from_static_str(s: &'static str) -> Self {
38            Self(Bytes::from_static(s.as_bytes()))
39        }
40
41        /// Construct a [StrBytes] from the provided [String] without additional allocations.
42        pub fn from_string(s: String) -> Self {
43            Self(Bytes::from(s.into_bytes()))
44        }
45
46        /// View the contents of this [StrBytes] as a [str] reference.
47        pub fn as_str(&self) -> &str {
48            // SAFETY: all methods of constructing `self` check that the backing data is valid utf8,
49            // and bytes::Bytes guarantees that its contents will not change unless we mutate it,
50            // and we never mutate it.
51            unsafe { std::str::from_utf8_unchecked(&self.0) }
52        }
53
54        /// Extract the underlying [Bytes].
55        pub fn into_bytes(self) -> Bytes {
56            self.0
57        }
58    }
59
60    impl TryFrom<Bytes> for StrBytes {
61        type Error = Utf8Error;
62
63        fn try_from(value: Bytes) -> Result<Self, Self::Error> {
64            StrBytes::from_utf8(value)
65        }
66    }
67
68    impl From<StrBytes> for Bytes {
69        fn from(value: StrBytes) -> Bytes {
70            value.0
71        }
72    }
73
74    impl From<String> for StrBytes {
75        fn from(value: String) -> Self {
76            Self::from_string(value)
77        }
78    }
79
80    impl From<&'static str> for StrBytes {
81        fn from(value: &'static str) -> Self {
82            Self::from_static_str(value)
83        }
84    }
85
86    impl Deref for StrBytes {
87        type Target = str;
88
89        fn deref(&self) -> &Self::Target {
90            self.as_str()
91        }
92    }
93
94    impl Debug for StrBytes {
95        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
96            Debug::fmt(self.as_str(), f)
97        }
98    }
99
100    impl Display for StrBytes {
101        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102            std::fmt::Display::fmt(&**self, f)
103        }
104    }
105
106    impl PartialEq<str> for StrBytes {
107        fn eq(&self, other: &str) -> bool {
108            self.as_str().eq(other)
109        }
110    }
111
112    impl Borrow<[u8]> for StrBytes {
113        fn borrow(&self) -> &[u8] {
114            // Note that there is an equivalent Hash implementation between
115            // &[u8] and StrBytes, which makes this impl correct
116            // as described in the `std::borrow::Borrow` docs.
117            self.as_bytes()
118        }
119    }
120}
121
122pub use str_bytes::StrBytes;
123
124use crate::messages::{ApiKey, RequestHeader};
125
126pub(crate) trait NewType<Inner>: From<Inner> + Into<Inner> + Borrow<Inner> {}
127
128impl<T> NewType<T> for T {}
129
130pub(crate) trait Encoder<Value> {
131    fn encode<B: ByteBufMut>(&self, buf: &mut B, value: Value) -> Result<()>;
132    fn compute_size(&self, value: Value) -> Result<usize>;
133    fn fixed_size(&self) -> Option<usize> {
134        None
135    }
136}
137
138pub(crate) trait Decoder<Value> {
139    fn decode<B: ByteBuf>(&self, buf: &mut B) -> Result<Value>;
140}
141
142/// The range of versions (min, max) allowed for agiven message.
143#[derive(Debug, Copy, Clone, PartialEq)]
144pub struct VersionRange {
145    /// The minimum version in the range.
146    pub min: i16,
147    /// The maximum version in the range.
148    pub max: i16,
149}
150
151impl VersionRange {
152    /// Checks whether the version range contains no versions.
153    pub fn is_empty(&self) -> bool {
154        self.min > self.max
155    }
156
157    /// Finds the valid intersection with a provided other version range.
158    pub fn intersect(&self, other: &VersionRange) -> VersionRange {
159        VersionRange {
160            min: cmp::max(self.min, other.min),
161            max: cmp::min(self.max, other.max),
162        }
163    }
164}
165
166impl Display for VersionRange {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        write!(f, "{}..{}", self.min, self.max)
169    }
170}
171
172/// An API request or response.
173///
174/// All API messages must provide a set of valid versions.
175pub trait Message: Sized {
176    /// The valid versions for this message.
177    const VERSIONS: VersionRange;
178    /// The deprecated versions for this message.
179    const DEPRECATED_VERSIONS: Option<VersionRange>;
180}
181
182/// An encodable message.
183pub trait Encodable: Sized {
184    /// Encode the message into the target buffer.
185    fn encode<B: ByteBufMut>(&self, buf: &mut B, version: i16) -> Result<()>;
186    /// Compute the total size of the message when encoded.
187    fn compute_size(&self, version: i16) -> Result<usize>;
188}
189
190/// A decodable message.
191pub trait Decodable: Sized {
192    /// Decode the message from the provided buffer and version.
193    fn decode<B: ByteBuf>(buf: &mut B, version: i16) -> Result<Self>;
194}
195
196/// Every message has a set of versions valid for a given header version.
197pub trait HeaderVersion {
198    /// Maps a header version to a given version for a particular API message.
199    fn header_version(version: i16) -> i16;
200}
201
202/// An API request.
203///
204/// Every abstract request must be able to provide the following items:
205/// - An API key mapped to this request.
206/// - A version based on a provided header version.
207pub trait Request: Message + Encodable + Decodable + HeaderVersion {
208    /// The API key of this request.
209    const KEY: i16;
210    /// The response associated with this request.
211    type Response: Message + Encodable + Decodable + HeaderVersion;
212}
213
214/// Decode the request header from the provided buffer.
215pub fn decode_request_header_from_buffer<B: ByteBuf>(buf: &mut B) -> Result<RequestHeader> {
216    let api_key = ApiKey::try_from(bytes::Buf::get_i16(&mut buf.peek_bytes(0..2)))
217        .map_err(|_| anyhow::Error::msg("Unknown API key"))?;
218    let api_version = bytes::Buf::get_i16(&mut buf.peek_bytes(2..4));
219    let header_version = api_key.request_header_version(api_version);
220    RequestHeader::decode(buf, header_version)
221}
222
223/// Encode the request header into the provided buffer.
224pub fn encode_request_header_into_buffer<B: ByteBufMut>(
225    buf: &mut B,
226    header: &RequestHeader,
227) -> Result<()> {
228    let api_key = ApiKey::try_from(header.request_api_key)
229        .map_err(|_| anyhow::Error::msg("Unknown API key"))?;
230    let version = api_key.request_header_version(header.request_api_version);
231    header.encode(buf, version)
232}
233
234pub(crate) fn write_unknown_tagged_fields<B: ByteBufMut, R: RangeBounds<i32>>(
235    buf: &mut B,
236    range: R,
237    unknown_tagged_fields: &BTreeMap<i32, Bytes>,
238) -> Result<()> {
239    for (&k, v) in unknown_tagged_fields.range(range) {
240        if v.len() > u32::MAX as usize {
241            bail!("Tagged field is too long to encode ({} bytes)", v.len());
242        }
243        types::UnsignedVarInt.encode(buf, k as u32)?;
244        types::UnsignedVarInt.encode(buf, v.len() as u32)?;
245        buf.put_slice(v);
246    }
247    Ok(())
248}
249
250pub(crate) fn compute_unknown_tagged_fields_size(
251    unknown_tagged_fields: &BTreeMap<i32, Bytes>,
252) -> Result<usize> {
253    let mut total_size = 0;
254    for (&k, v) in unknown_tagged_fields {
255        if v.len() > u32::MAX as usize {
256            bail!("Tagged field is too long to encode ({} bytes)", v.len());
257        }
258        total_size += types::UnsignedVarInt.compute_size(k as u32)?;
259        total_size += types::UnsignedVarInt.compute_size(v.len() as u32)?;
260        total_size += v.len();
261    }
262    Ok(total_size)
263}