kafka_protocol/protocol/
mod.rs

1//! Most types are used internally in encoding/decoding, and are not required by typical use cases
2//! for interacting with the protocol. However, types can be used for decoding partial messages,
3//! or rewriting parts of an encoded message.
4use std::cmp;
5use std::collections::BTreeMap;
6use std::ops::RangeBounds;
7use std::{borrow::Borrow, fmt::Display};
8
9use anyhow::{bail, Result};
10use buf::{ByteBuf, ByteBufMut};
11use bytes::Bytes;
12
13pub mod buf;
14pub mod types;
15
16mod str_bytes {
17    use bytes::Bytes;
18    use std::convert::TryFrom;
19    use std::fmt::{Debug, Display, Formatter};
20    use std::ops::Deref;
21    use std::str::Utf8Error;
22
23    /// A string type backed by [Bytes].
24    #[derive(Clone, Hash, Ord, PartialOrd, PartialEq, Eq, Default)]
25    pub struct StrBytes(Bytes);
26
27    impl StrBytes {
28        /// Construct a [StrBytes] from the given [Bytes] instance,
29        /// checking that it contains valid UTF-8 data.
30        pub fn from_utf8(bytes: Bytes) -> Result<Self, Utf8Error> {
31            let _: &str = std::str::from_utf8(&bytes)?;
32            Ok(Self(bytes))
33        }
34
35        /// Construct a [StrBytes] from the provided static [str].
36        pub fn from_static_str(s: &'static str) -> Self {
37            Self(Bytes::from_static(s.as_bytes()))
38        }
39
40        /// Construct a [StrBytes] from the provided [String] without additional allocations.
41        pub fn from_string(s: String) -> Self {
42            Self(Bytes::from(s.into_bytes()))
43        }
44
45        /// View the contents of this [StrBytes] as a [str] reference.
46        pub fn as_str(&self) -> &str {
47            // SAFETY: all methods of constructing `self` check that the backing data is valid utf8,
48            // and bytes::Bytes guarantees that its contents will not change unless we mutate it,
49            // and we never mutate it.
50            unsafe { std::str::from_utf8_unchecked(&self.0) }
51        }
52
53        /// Extract the underlying [Bytes].
54        pub fn into_bytes(self) -> Bytes {
55            self.0
56        }
57    }
58
59    impl TryFrom<Bytes> for StrBytes {
60        type Error = Utf8Error;
61
62        fn try_from(value: Bytes) -> Result<Self, Self::Error> {
63            StrBytes::from_utf8(value)
64        }
65    }
66
67    impl From<StrBytes> for Bytes {
68        fn from(value: StrBytes) -> Bytes {
69            value.0
70        }
71    }
72
73    impl From<String> for StrBytes {
74        fn from(value: String) -> Self {
75            Self::from_string(value)
76        }
77    }
78
79    impl From<&'static str> for StrBytes {
80        fn from(value: &'static str) -> Self {
81            Self::from_static_str(value)
82        }
83    }
84
85    impl Deref for StrBytes {
86        type Target = str;
87
88        fn deref(&self) -> &Self::Target {
89            self.as_str()
90        }
91    }
92
93    impl Debug for StrBytes {
94        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
95            Debug::fmt(self.as_str(), f)
96        }
97    }
98
99    impl Display for StrBytes {
100        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
101            std::fmt::Display::fmt(&**self, f)
102        }
103    }
104
105    impl PartialEq<str> for StrBytes {
106        fn eq(&self, other: &str) -> bool {
107            self.as_str().eq(other)
108        }
109    }
110}
111
112pub use str_bytes::StrBytes;
113
114pub(crate) trait NewType<Inner>: From<Inner> + Into<Inner> + Borrow<Inner> {}
115
116impl<T> NewType<T> for T {}
117
118pub(crate) trait Encoder<Value> {
119    fn encode<B: ByteBufMut>(&self, buf: &mut B, value: Value) -> Result<()>;
120    fn compute_size(&self, value: Value) -> Result<usize>;
121    fn fixed_size(&self) -> Option<usize> {
122        None
123    }
124}
125
126pub(crate) trait Decoder<Value> {
127    fn decode<B: ByteBuf>(&self, buf: &mut B) -> Result<Value>;
128}
129
130/// The range of versions (min, max) allowed for agiven message.
131#[derive(Debug, Copy, Clone, PartialEq)]
132pub struct VersionRange {
133    /// The minimum version in the range.
134    pub min: i16,
135    /// The maximum version in the range.
136    pub max: i16,
137}
138
139impl VersionRange {
140    /// Checks whether the version range contains no versions.
141    pub fn is_empty(&self) -> bool {
142        self.min > self.max
143    }
144
145    /// Finds the valid intersection with a provided other version range.
146    pub fn intersect(&self, other: &VersionRange) -> VersionRange {
147        VersionRange {
148            min: cmp::max(self.min, other.min),
149            max: cmp::min(self.max, other.max),
150        }
151    }
152}
153
154impl Display for VersionRange {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(f, "{}..{}", self.min, self.max)
157    }
158}
159
160/// An API request or response.
161///
162/// All API messages must provide a set of valid versions.
163pub trait Message: Sized {
164    /// The valid versions for this message.
165    const VERSIONS: VersionRange;
166    /// The deprecated versions for this message.
167    const DEPRECATED_VERSIONS: Option<VersionRange>;
168}
169
170/// An encodable message.
171pub trait Encodable: Sized {
172    /// Encode the message into the target buffer.
173    fn encode<B: ByteBufMut>(&self, buf: &mut B, version: i16) -> Result<()>;
174    /// Compute the total size of the message when encoded.
175    fn compute_size(&self, version: i16) -> Result<usize>;
176}
177
178/// A decodable message.
179pub trait Decodable: Sized {
180    /// Decode the message from the provided buffer and version.
181    fn decode<B: ByteBuf>(buf: &mut B, version: i16) -> Result<Self>;
182}
183
184/// Every message has a set of versions valid for a given header version.
185pub trait HeaderVersion {
186    /// Maps a header version to a given version for a particular API message.
187    fn header_version(version: i16) -> i16;
188}
189
190/// An API request.
191///
192/// Every abstract request must be able to provide the following items:
193/// - An API key mapped to this request.
194/// - A version based on a provided header version.
195pub trait Request: Message + Encodable + Decodable + HeaderVersion {
196    /// The API key of this request.
197    const KEY: i16;
198    /// The response associated with this request.
199    type Response: Message + Encodable + Decodable + HeaderVersion;
200}
201
202pub(crate) fn write_unknown_tagged_fields<B: ByteBufMut, R: RangeBounds<i32>>(
203    buf: &mut B,
204    range: R,
205    unknown_tagged_fields: &BTreeMap<i32, Bytes>,
206) -> Result<()> {
207    for (&k, v) in unknown_tagged_fields.range(range) {
208        if v.len() > u32::MAX as usize {
209            bail!("Tagged field is too long to encode ({} bytes)", v.len());
210        }
211        types::UnsignedVarInt.encode(buf, k as u32)?;
212        types::UnsignedVarInt.encode(buf, v.len() as u32)?;
213        buf.put_slice(v);
214    }
215    Ok(())
216}
217
218pub(crate) fn compute_unknown_tagged_fields_size(
219    unknown_tagged_fields: &BTreeMap<i32, Bytes>,
220) -> Result<usize> {
221    let mut total_size = 0;
222    for (&k, v) in unknown_tagged_fields {
223        if v.len() > u32::MAX as usize {
224            bail!("Tagged field is too long to encode ({} bytes)", v.len());
225        }
226        total_size += types::UnsignedVarInt.compute_size(k as u32)?;
227        total_size += types::UnsignedVarInt.compute_size(v.len() as u32)?;
228        total_size += v.len();
229    }
230    Ok(total_size)
231}