kafka_protocol/protocol/
mod.rs

1//! Most types are used internally in encoding/decoding, and are not required by typical use cases
2//! for interacting with the protocol. However, types can be used for decoding partial messages,
3//! or rewriting parts of an encoded message.
4use std::cmp;
5use std::collections::BTreeMap;
6use std::ops::RangeBounds;
7use std::{borrow::Borrow, fmt::Display};
8
9use anyhow::{bail, Result};
10use buf::{ByteBuf, ByteBufMut};
11use bytes::Bytes;
12
13pub mod buf;
14pub mod types;
15
16mod str_bytes {
17    use bytes::Bytes;
18    use std::borrow::Borrow;
19    use std::convert::TryFrom;
20    use std::fmt::{Debug, Display, Formatter};
21    use std::ops::Deref;
22    use std::str::Utf8Error;
23
24    /// A string type backed by [Bytes].
25    #[derive(Clone, Hash, Ord, PartialOrd, PartialEq, Eq, Default)]
26    pub struct StrBytes(Bytes);
27
28    impl StrBytes {
29        /// Construct a [StrBytes] from the given [Bytes] instance,
30        /// checking that it contains valid UTF-8 data.
31        pub fn from_utf8(bytes: Bytes) -> Result<Self, Utf8Error> {
32            let _: &str = std::str::from_utf8(&bytes)?;
33            Ok(Self(bytes))
34        }
35
36        /// Construct a [StrBytes] from the provided static [str].
37        pub fn from_static_str(s: &'static str) -> Self {
38            Self(Bytes::from_static(s.as_bytes()))
39        }
40
41        /// Construct a [StrBytes] from the provided [String] without additional allocations.
42        pub fn from_string(s: String) -> Self {
43            Self(Bytes::from(s.into_bytes()))
44        }
45
46        /// View the contents of this [StrBytes] as a [str] reference.
47        pub fn as_str(&self) -> &str {
48            // SAFETY: all methods of constructing `self` check that the backing data is valid utf8,
49            // and bytes::Bytes guarantees that its contents will not change unless we mutate it,
50            // and we never mutate it.
51            unsafe { std::str::from_utf8_unchecked(&self.0) }
52        }
53
54        /// Extract the underlying [Bytes].
55        pub fn into_bytes(self) -> Bytes {
56            self.0
57        }
58    }
59
60    impl TryFrom<Bytes> for StrBytes {
61        type Error = Utf8Error;
62
63        fn try_from(value: Bytes) -> Result<Self, Self::Error> {
64            StrBytes::from_utf8(value)
65        }
66    }
67
68    impl From<StrBytes> for Bytes {
69        fn from(value: StrBytes) -> Bytes {
70            value.0
71        }
72    }
73
74    impl From<String> for StrBytes {
75        fn from(value: String) -> Self {
76            Self::from_string(value)
77        }
78    }
79
80    impl From<&'static str> for StrBytes {
81        fn from(value: &'static str) -> Self {
82            Self::from_static_str(value)
83        }
84    }
85
86    impl Deref for StrBytes {
87        type Target = str;
88
89        fn deref(&self) -> &Self::Target {
90            self.as_str()
91        }
92    }
93
94    impl Debug for StrBytes {
95        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
96            Debug::fmt(self.as_str(), f)
97        }
98    }
99
100    impl Display for StrBytes {
101        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102            std::fmt::Display::fmt(&**self, f)
103        }
104    }
105
106    impl PartialEq<str> for StrBytes {
107        fn eq(&self, other: &str) -> bool {
108            self.as_str().eq(other)
109        }
110    }
111
112    impl Borrow<[u8]> for StrBytes {
113        fn borrow(&self) -> &[u8] {
114            // Note that there is an equivalent Hash implementation between
115            // &[u8] and StrBytes, which makes this impl correct
116            // as described in the `std::borrow::Borrow` docs.
117            self.as_bytes()
118        }
119    }
120}
121
122pub use str_bytes::StrBytes;
123
124pub(crate) trait NewType<Inner>: From<Inner> + Into<Inner> + Borrow<Inner> {}
125
126impl<T> NewType<T> for T {}
127
128pub(crate) trait Encoder<Value> {
129    fn encode<B: ByteBufMut>(&self, buf: &mut B, value: Value) -> Result<()>;
130    fn compute_size(&self, value: Value) -> Result<usize>;
131    fn fixed_size(&self) -> Option<usize> {
132        None
133    }
134}
135
136pub(crate) trait Decoder<Value> {
137    fn decode<B: ByteBuf>(&self, buf: &mut B) -> Result<Value>;
138}
139
140/// The range of versions (min, max) allowed for agiven message.
141#[derive(Debug, Copy, Clone, PartialEq)]
142pub struct VersionRange {
143    /// The minimum version in the range.
144    pub min: i16,
145    /// The maximum version in the range.
146    pub max: i16,
147}
148
149impl VersionRange {
150    /// Checks whether the version range contains no versions.
151    pub fn is_empty(&self) -> bool {
152        self.min > self.max
153    }
154
155    /// Finds the valid intersection with a provided other version range.
156    pub fn intersect(&self, other: &VersionRange) -> VersionRange {
157        VersionRange {
158            min: cmp::max(self.min, other.min),
159            max: cmp::min(self.max, other.max),
160        }
161    }
162}
163
164impl Display for VersionRange {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        write!(f, "{}..{}", self.min, self.max)
167    }
168}
169
170/// An API request or response.
171///
172/// All API messages must provide a set of valid versions.
173pub trait Message: Sized {
174    /// The valid versions for this message.
175    const VERSIONS: VersionRange;
176    /// The deprecated versions for this message.
177    const DEPRECATED_VERSIONS: Option<VersionRange>;
178}
179
180/// An encodable message.
181pub trait Encodable: Sized {
182    /// Encode the message into the target buffer.
183    fn encode<B: ByteBufMut>(&self, buf: &mut B, version: i16) -> Result<()>;
184    /// Compute the total size of the message when encoded.
185    fn compute_size(&self, version: i16) -> Result<usize>;
186}
187
188/// A decodable message.
189pub trait Decodable: Sized {
190    /// Decode the message from the provided buffer and version.
191    fn decode<B: ByteBuf>(buf: &mut B, version: i16) -> Result<Self>;
192}
193
194/// Every message has a set of versions valid for a given header version.
195pub trait HeaderVersion {
196    /// Maps a header version to a given version for a particular API message.
197    fn header_version(version: i16) -> i16;
198}
199
200/// An API request.
201///
202/// Every abstract request must be able to provide the following items:
203/// - An API key mapped to this request.
204/// - A version based on a provided header version.
205pub trait Request: Message + Encodable + Decodable + HeaderVersion {
206    /// The API key of this request.
207    const KEY: i16;
208    /// The response associated with this request.
209    type Response: Message + Encodable + Decodable + HeaderVersion;
210}
211
212pub(crate) fn write_unknown_tagged_fields<B: ByteBufMut, R: RangeBounds<i32>>(
213    buf: &mut B,
214    range: R,
215    unknown_tagged_fields: &BTreeMap<i32, Bytes>,
216) -> Result<()> {
217    for (&k, v) in unknown_tagged_fields.range(range) {
218        if v.len() > u32::MAX as usize {
219            bail!("Tagged field is too long to encode ({} bytes)", v.len());
220        }
221        types::UnsignedVarInt.encode(buf, k as u32)?;
222        types::UnsignedVarInt.encode(buf, v.len() as u32)?;
223        buf.put_slice(v);
224    }
225    Ok(())
226}
227
228pub(crate) fn compute_unknown_tagged_fields_size(
229    unknown_tagged_fields: &BTreeMap<i32, Bytes>,
230) -> Result<usize> {
231    let mut total_size = 0;
232    for (&k, v) in unknown_tagged_fields {
233        if v.len() > u32::MAX as usize {
234            bail!("Tagged field is too long to encode ({} bytes)", v.len());
235        }
236        total_size += types::UnsignedVarInt.compute_size(k as u32)?;
237        total_size += types::UnsignedVarInt.compute_size(v.len() as u32)?;
238        total_size += v.len();
239    }
240    Ok(total_size)
241}