kafka_protocol/protocol/
mod.rsuse std::cmp;
use std::collections::BTreeMap;
use std::ops::RangeBounds;
use std::{borrow::Borrow, fmt::Display};
use anyhow::{bail, Result};
use buf::{ByteBuf, ByteBufMut};
use bytes::Bytes;
pub mod buf;
pub mod types;
mod str_bytes {
use bytes::Bytes;
use std::convert::TryFrom;
use std::fmt::{Debug, Display, Formatter};
use std::ops::Deref;
use std::str::Utf8Error;
#[derive(Clone, Hash, Ord, PartialOrd, PartialEq, Eq, Default)]
pub struct StrBytes(Bytes);
impl StrBytes {
pub fn from_utf8(bytes: Bytes) -> Result<Self, Utf8Error> {
let _: &str = std::str::from_utf8(&bytes)?;
Ok(Self(bytes))
}
pub fn from_static_str(s: &'static str) -> Self {
Self(Bytes::from_static(s.as_bytes()))
}
pub fn from_string(s: String) -> Self {
Self(Bytes::from(s.into_bytes()))
}
pub fn as_str(&self) -> &str {
unsafe { std::str::from_utf8_unchecked(&self.0) }
}
pub fn into_bytes(self) -> Bytes {
self.0
}
}
impl TryFrom<Bytes> for StrBytes {
type Error = Utf8Error;
fn try_from(value: Bytes) -> Result<Self, Self::Error> {
StrBytes::from_utf8(value)
}
}
impl From<StrBytes> for Bytes {
fn from(value: StrBytes) -> Bytes {
value.0
}
}
impl From<String> for StrBytes {
fn from(value: String) -> Self {
Self::from_string(value)
}
}
impl From<&'static str> for StrBytes {
fn from(value: &'static str) -> Self {
Self::from_static_str(value)
}
}
impl Deref for StrBytes {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl Debug for StrBytes {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(self.as_str(), f)
}
}
impl Display for StrBytes {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&**self, f)
}
}
impl PartialEq<str> for StrBytes {
fn eq(&self, other: &str) -> bool {
self.as_str().eq(other)
}
}
}
pub use str_bytes::StrBytes;
pub(crate) trait NewType<Inner>: From<Inner> + Into<Inner> + Borrow<Inner> {}
impl<T> NewType<T> for T {}
pub(crate) trait Encoder<Value> {
fn encode<B: ByteBufMut>(&self, buf: &mut B, value: Value) -> Result<()>;
fn compute_size(&self, value: Value) -> Result<usize>;
fn fixed_size(&self) -> Option<usize> {
None
}
}
pub(crate) trait Decoder<Value> {
fn decode<B: ByteBuf>(&self, buf: &mut B) -> Result<Value>;
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct VersionRange {
pub min: i16,
pub max: i16,
}
impl VersionRange {
pub fn is_empty(&self) -> bool {
self.min > self.max
}
pub fn intersect(&self, other: &VersionRange) -> VersionRange {
VersionRange {
min: cmp::max(self.min, other.min),
max: cmp::min(self.max, other.max),
}
}
}
impl Display for VersionRange {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}..{}", self.min, self.max)
}
}
pub trait Message: Sized {
const VERSIONS: VersionRange;
const DEPRECATED_VERSIONS: Option<VersionRange>;
}
pub trait Encodable: Sized {
fn encode<B: ByteBufMut>(&self, buf: &mut B, version: i16) -> Result<()>;
fn compute_size(&self, version: i16) -> Result<usize>;
}
pub trait Decodable: Sized {
fn decode<B: ByteBuf>(buf: &mut B, version: i16) -> Result<Self>;
}
pub trait HeaderVersion {
fn header_version(version: i16) -> i16;
}
pub trait Request: Message + Encodable + Decodable + HeaderVersion {
const KEY: i16;
type Response: Message + Encodable + Decodable + HeaderVersion;
}
pub(crate) fn write_unknown_tagged_fields<B: ByteBufMut, R: RangeBounds<i32>>(
buf: &mut B,
range: R,
unknown_tagged_fields: &BTreeMap<i32, Bytes>,
) -> Result<()> {
for (&k, v) in unknown_tagged_fields.range(range) {
if v.len() > u32::MAX as usize {
bail!("Tagged field is too long to encode ({} bytes)", v.len());
}
types::UnsignedVarInt.encode(buf, k as u32)?;
types::UnsignedVarInt.encode(buf, v.len() as u32)?;
buf.put_slice(v);
}
Ok(())
}
pub(crate) fn compute_unknown_tagged_fields_size(
unknown_tagged_fields: &BTreeMap<i32, Bytes>,
) -> Result<usize> {
let mut total_size = 0;
for (&k, v) in unknown_tagged_fields {
if v.len() > u32::MAX as usize {
bail!("Tagged field is too long to encode ({} bytes)", v.len());
}
total_size += types::UnsignedVarInt.compute_size(k as u32)?;
total_size += types::UnsignedVarInt.compute_size(v.len() as u32)?;
total_size += v.len();
}
Ok(total_size)
}