use bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::error::{KrafkaError, ProtocolErrorKind, Result};
pub const MAX_MESSAGE_SIZE: usize = 100 * 1024 * 1024;
#[derive(Debug, Default)]
pub struct Encoder {
buffer: BytesMut,
}
impl Encoder {
pub fn new() -> Self {
Self {
buffer: BytesMut::with_capacity(1024),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(capacity),
}
}
pub fn buffer_mut(&mut self) -> &mut BytesMut {
&mut self.buffer
}
pub fn buffer(&self) -> &BytesMut {
&self.buffer
}
#[inline]
pub fn len(&self) -> usize {
self.buffer.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn reset(&mut self) {
self.buffer.clear();
}
pub fn start_message(&mut self) -> usize {
let pos = self.buffer.len();
self.buffer.put_i32(0);
pos
}
pub fn finish_message(&mut self, size_pos: usize) -> Result<()> {
let message_size = i32::try_from(self.buffer.len() - size_pos - 4).map_err(|_| {
KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
"message size exceeds i32::MAX",
)
})?;
let size_bytes = message_size.to_be_bytes();
self.buffer[size_pos..size_pos + 4].copy_from_slice(&size_bytes);
Ok(())
}
pub fn take(&mut self) -> Bytes {
self.buffer.split().freeze()
}
}
#[derive(Debug, Default)]
pub struct Decoder {
buffer: BytesMut,
max_size: usize,
}
impl Decoder {
pub fn new() -> Self {
Self {
buffer: BytesMut::with_capacity(8192),
max_size: MAX_MESSAGE_SIZE,
}
}
pub fn with_max_size(max_size: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(8192),
max_size,
}
}
pub fn extend(&mut self, data: &[u8]) {
self.buffer.extend_from_slice(data);
}
pub fn decode(&mut self) -> Result<Option<Bytes>> {
if self.buffer.len() < 4 {
return Ok(None);
}
let size_i32 = i32::from_be_bytes([
self.buffer[0],
self.buffer[1],
self.buffer[2],
self.buffer[3],
]);
if size_i32 < 0 {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::Malformed,
format!("negative message size: {size_i32}"),
));
}
let size = size_i32 as usize;
if size > self.max_size {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
format!("message size {} exceeds maximum {}", size, self.max_size),
));
}
let total_size = 4 + size;
if self.buffer.len() < total_size {
return Ok(None);
}
self.buffer.advance(4); let message = self.buffer.split_to(size).freeze();
Ok(Some(message))
}
pub fn buffered(&self) -> usize {
self.buffer.len()
}
pub fn clear(&mut self) {
self.buffer.clear();
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_encoder_basic() {
let mut encoder = Encoder::new();
let pos = encoder.start_message();
encoder.buffer_mut().put_slice(b"hello");
encoder.finish_message(pos).unwrap();
let bytes = encoder.take();
assert_eq!(bytes.len(), 9);
let size = i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
assert_eq!(size, 5);
assert_eq!(&bytes[4..], b"hello");
}
#[test]
fn test_encoder_reset() {
let mut encoder = Encoder::new();
encoder.buffer_mut().put_slice(b"test");
assert!(!encoder.is_empty());
encoder.reset();
assert!(encoder.is_empty());
}
#[test]
fn test_decoder_complete_message() {
let mut decoder = Decoder::new();
let mut msg = BytesMut::new();
msg.put_i32(5);
msg.put_slice(b"hello");
decoder.extend(&msg);
let result = decoder.decode().unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().as_ref(), b"hello");
assert_eq!(decoder.buffered(), 0);
}
#[test]
fn test_decoder_incomplete_header() {
let mut decoder = Decoder::new();
decoder.extend(&[0, 0]);
let result = decoder.decode().unwrap();
assert!(result.is_none());
}
#[test]
fn test_decoder_incomplete_message() {
let mut decoder = Decoder::new();
let mut msg = BytesMut::new();
msg.put_i32(10); msg.put_slice(b"hello");
decoder.extend(&msg);
let result = decoder.decode().unwrap();
assert!(result.is_none());
}
#[test]
fn test_decoder_multiple_messages() {
let mut decoder = Decoder::new();
let mut msg = BytesMut::new();
msg.put_i32(5);
msg.put_slice(b"hello");
msg.put_i32(5);
msg.put_slice(b"world");
decoder.extend(&msg);
let result1 = decoder.decode().unwrap();
assert_eq!(result1.unwrap().as_ref(), b"hello");
let result2 = decoder.decode().unwrap();
assert_eq!(result2.unwrap().as_ref(), b"world");
let result3 = decoder.decode().unwrap();
assert!(result3.is_none());
}
#[test]
fn test_decoder_message_too_large() {
let mut decoder = Decoder::with_max_size(100);
let mut msg = BytesMut::new();
msg.put_i32(1000); msg.put_slice(b"test");
decoder.extend(&msg);
let result = decoder.decode();
assert!(result.is_err());
}
#[test]
fn test_decoder_streaming() {
let mut decoder = Decoder::new();
let mut msg = BytesMut::new();
msg.put_i32(10);
msg.put_slice(b"0123456789");
decoder.extend(&msg[..4]);
assert!(decoder.decode().unwrap().is_none());
decoder.extend(&msg[4..8]);
assert!(decoder.decode().unwrap().is_none());
decoder.extend(&msg[8..]);
let result = decoder.decode().unwrap();
assert_eq!(result.unwrap().as_ref(), b"0123456789");
}
#[test]
fn test_decoder_negative_size() {
let mut decoder = Decoder::new();
let mut msg = BytesMut::new();
msg.put_i32(-1); msg.put_slice(b"junk");
decoder.extend(&msg);
let result = decoder.decode();
assert!(result.is_err(), "negative message size should be rejected");
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("negative message size"),
"error should mention negative size: {err_msg}"
);
}
}