kafka_protocol/protocol/
mod.rs1use std::cmp;
5use std::collections::BTreeMap;
6use std::ops::RangeBounds;
7use std::{borrow::Borrow, fmt::Display};
8
9use anyhow::{bail, Result};
10use buf::{ByteBuf, ByteBufMut};
11use bytes::Bytes;
12
13pub mod buf;
14pub mod types;
15
16mod str_bytes {
17 use bytes::Bytes;
18 use std::borrow::Borrow;
19 use std::convert::TryFrom;
20 use std::fmt::{Debug, Display, Formatter};
21 use std::ops::Deref;
22 use std::str::Utf8Error;
23
24 #[derive(Clone, Hash, Ord, PartialOrd, PartialEq, Eq, Default)]
26 pub struct StrBytes(Bytes);
27
28 impl StrBytes {
29 pub fn from_utf8(bytes: Bytes) -> Result<Self, Utf8Error> {
32 let _: &str = std::str::from_utf8(&bytes)?;
33 Ok(Self(bytes))
34 }
35
36 pub fn from_static_str(s: &'static str) -> Self {
38 Self(Bytes::from_static(s.as_bytes()))
39 }
40
41 pub fn from_string(s: String) -> Self {
43 Self(Bytes::from(s.into_bytes()))
44 }
45
46 pub fn as_str(&self) -> &str {
48 unsafe { std::str::from_utf8_unchecked(&self.0) }
52 }
53
54 pub fn into_bytes(self) -> Bytes {
56 self.0
57 }
58 }
59
60 impl TryFrom<Bytes> for StrBytes {
61 type Error = Utf8Error;
62
63 fn try_from(value: Bytes) -> Result<Self, Self::Error> {
64 StrBytes::from_utf8(value)
65 }
66 }
67
68 impl From<StrBytes> for Bytes {
69 fn from(value: StrBytes) -> Bytes {
70 value.0
71 }
72 }
73
74 impl From<String> for StrBytes {
75 fn from(value: String) -> Self {
76 Self::from_string(value)
77 }
78 }
79
80 impl From<&'static str> for StrBytes {
81 fn from(value: &'static str) -> Self {
82 Self::from_static_str(value)
83 }
84 }
85
86 impl Deref for StrBytes {
87 type Target = str;
88
89 fn deref(&self) -> &Self::Target {
90 self.as_str()
91 }
92 }
93
94 impl Debug for StrBytes {
95 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
96 Debug::fmt(self.as_str(), f)
97 }
98 }
99
100 impl Display for StrBytes {
101 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102 std::fmt::Display::fmt(&**self, f)
103 }
104 }
105
106 impl PartialEq<str> for StrBytes {
107 fn eq(&self, other: &str) -> bool {
108 self.as_str().eq(other)
109 }
110 }
111
112 impl Borrow<[u8]> for StrBytes {
113 fn borrow(&self) -> &[u8] {
114 self.as_bytes()
118 }
119 }
120}
121
122pub use str_bytes::StrBytes;
123
124use crate::messages::{ApiKey, RequestHeader};
125
126pub(crate) trait NewType<Inner>: From<Inner> + Into<Inner> + Borrow<Inner> {}
127
128impl<T> NewType<T> for T {}
129
130pub(crate) trait Encoder<Value> {
131 fn encode<B: ByteBufMut>(&self, buf: &mut B, value: Value) -> Result<()>;
132 fn compute_size(&self, value: Value) -> Result<usize>;
133 fn fixed_size(&self) -> Option<usize> {
134 None
135 }
136}
137
138pub(crate) trait Decoder<Value> {
139 fn decode<B: ByteBuf>(&self, buf: &mut B) -> Result<Value>;
140}
141
142#[derive(Debug, Copy, Clone, PartialEq)]
144pub struct VersionRange {
145 pub min: i16,
147 pub max: i16,
149}
150
151impl VersionRange {
152 pub fn is_empty(&self) -> bool {
154 self.min > self.max
155 }
156
157 pub fn intersect(&self, other: &VersionRange) -> VersionRange {
159 VersionRange {
160 min: cmp::max(self.min, other.min),
161 max: cmp::min(self.max, other.max),
162 }
163 }
164}
165
166impl Display for VersionRange {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 write!(f, "{}..{}", self.min, self.max)
169 }
170}
171
172pub trait Message: Sized {
176 const VERSIONS: VersionRange;
178 const DEPRECATED_VERSIONS: Option<VersionRange>;
180}
181
182pub trait Encodable: Sized {
184 fn encode<B: ByteBufMut>(&self, buf: &mut B, version: i16) -> Result<()>;
186 fn compute_size(&self, version: i16) -> Result<usize>;
188}
189
190pub trait Decodable: Sized {
192 fn decode<B: ByteBuf>(buf: &mut B, version: i16) -> Result<Self>;
194}
195
196pub trait HeaderVersion {
198 fn header_version(version: i16) -> i16;
200}
201
202pub trait Request: Message + Encodable + Decodable + HeaderVersion {
208 const KEY: i16;
210 type Response: Message + Encodable + Decodable + HeaderVersion;
212}
213
214pub fn decode_request_header_from_buffer<B: ByteBuf>(buf: &mut B) -> Result<RequestHeader> {
216 let api_key = ApiKey::try_from(bytes::Buf::get_i16(&mut buf.peek_bytes(0..2)))
217 .map_err(|_| anyhow::Error::msg("Unknown API key"))?;
218 let api_version = bytes::Buf::get_i16(&mut buf.peek_bytes(2..4));
219 let header_version = api_key.request_header_version(api_version);
220 RequestHeader::decode(buf, header_version)
221}
222
223pub fn encode_request_header_into_buffer<B: ByteBufMut>(
225 buf: &mut B,
226 header: &RequestHeader,
227) -> Result<()> {
228 let api_key = ApiKey::try_from(header.request_api_key)
229 .map_err(|_| anyhow::Error::msg("Unknown API key"))?;
230 let version = api_key.request_header_version(header.request_api_version);
231 header.encode(buf, version)
232}
233
234pub(crate) fn write_unknown_tagged_fields<B: ByteBufMut, R: RangeBounds<i32>>(
235 buf: &mut B,
236 range: R,
237 unknown_tagged_fields: &BTreeMap<i32, Bytes>,
238) -> Result<()> {
239 for (&k, v) in unknown_tagged_fields.range(range) {
240 if v.len() > u32::MAX as usize {
241 bail!("Tagged field is too long to encode ({} bytes)", v.len());
242 }
243 types::UnsignedVarInt.encode(buf, k as u32)?;
244 types::UnsignedVarInt.encode(buf, v.len() as u32)?;
245 buf.put_slice(v);
246 }
247 Ok(())
248}
249
250pub(crate) fn compute_unknown_tagged_fields_size(
251 unknown_tagged_fields: &BTreeMap<i32, Bytes>,
252) -> Result<usize> {
253 let mut total_size = 0;
254 for (&k, v) in unknown_tagged_fields {
255 if v.len() > u32::MAX as usize {
256 bail!("Tagged field is too long to encode ({} bytes)", v.len());
257 }
258 total_size += types::UnsignedVarInt.compute_size(k as u32)?;
259 total_size += types::UnsignedVarInt.compute_size(v.len() as u32)?;
260 total_size += v.len();
261 }
262 Ok(total_size)
263}