use bytes::BytesMut;
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::{timeout, Duration};
use tracing::{error, info, warn};
const JOIN_GROUP_API_KEY: i16 = 11;
const API_VERSION: i16 = 0;
const CORRELATION_ID: i32 = 1;
const CLIENT_ID: &str = "tiny-kafka-consumer";
const RESPONSE_TIMEOUT: Duration = Duration::from_secs(5);
const CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
pub struct KafkaConsumer {
broker_address: String,
group_id: String,
topic: String,
stream: Option<TcpStream>,
partition_assignments: Vec<i32>,
current_offset: i64,
}
impl KafkaConsumer {
pub async fn new(broker_address: String, group_id: String, topic: String) -> io::Result<Self> {
let consumer = KafkaConsumer {
broker_address,
group_id,
topic,
stream: None,
partition_assignments: Vec::new(),
current_offset: 0,
};
Ok(consumer)
}
pub async fn connect(&mut self) -> io::Result<()> {
info!("Connecting to Kafka broker at {}", self.broker_address);
match timeout(CONNECTION_TIMEOUT, TcpStream::connect(&self.broker_address)).await {
Ok(result) => match result {
Ok(stream) => {
info!("Successfully connected to Kafka broker");
self.stream = Some(stream);
Ok(())
}
Err(e) => {
error!("Failed to connect to broker: {}", e);
Err(e)
}
},
Err(_) => {
error!("Connection attempt timed out");
Err(io::Error::new(
io::ErrorKind::TimedOut,
"Connection timed out",
))
}
}
}
async fn join_group(&mut self) -> io::Result<()> {
info!("Joining consumer group: {}", self.group_id);
let request = self.create_join_group_request();
info!(
"Created join group request of size: {} bytes",
request.len()
);
if let Some(ref mut stream) = self.stream {
match timeout(RESPONSE_TIMEOUT, async {
stream.write_all(&request).await?;
stream.flush().await?;
info!("Sent join group request to broker");
Ok::<(), io::Error>(())
})
.await
{
Ok(result) => result?,
Err(_) => {
error!("Timeout while sending join group request");
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"Join group request timed out",
));
}
}
match timeout(RESPONSE_TIMEOUT, self.handle_join_response()).await {
Ok(result) => match result {
Ok(_) => {
info!("Successfully joined consumer group");
Ok(())
}
Err(e) => {
error!("Failed to handle join response: {}", e);
Err(e)
}
},
Err(_) => {
error!("Timeout while waiting for join response");
Err(io::Error::new(
io::ErrorKind::TimedOut,
"Join group response timed out",
))
}
}
} else {
error!("Failed to join group: Not connected to Kafka broker");
Err(io::Error::new(
io::ErrorKind::NotConnected,
"Not connected to Kafka broker",
))
}
}
fn create_join_group_request(&self) -> BytesMut {
let mut buffer = BytesMut::new();
buffer.extend_from_slice(&[0, 0, 0, 0]);
buffer.extend_from_slice(&JOIN_GROUP_API_KEY.to_be_bytes());
buffer.extend_from_slice(&API_VERSION.to_be_bytes());
buffer.extend_from_slice(&CORRELATION_ID.to_be_bytes());
let client_id_bytes = CLIENT_ID.as_bytes();
buffer.extend_from_slice(&(client_id_bytes.len() as i16).to_be_bytes());
buffer.extend_from_slice(client_id_bytes);
buffer.extend_from_slice(&(self.group_id.len() as i16).to_be_bytes());
buffer.extend_from_slice(self.group_id.as_bytes());
buffer.extend_from_slice(&30000i32.to_be_bytes());
buffer.extend_from_slice(&0i16.to_be_bytes());
let protocol_type = "consumer";
buffer.extend_from_slice(&(protocol_type.len() as i16).to_be_bytes());
buffer.extend_from_slice(protocol_type.as_bytes());
buffer.extend_from_slice(&1i32.to_be_bytes());
let protocol_name = "range"; buffer.extend_from_slice(&(protocol_name.len() as i16).to_be_bytes());
buffer.extend_from_slice(protocol_name.as_bytes());
let metadata = format!(
"{{\"version\":0,\"topics\":[\"{}\"],\"user_data\":\"\"}}",
self.topic
);
buffer.extend_from_slice(&(metadata.len() as i32).to_be_bytes());
buffer.extend_from_slice(metadata.as_bytes());
let total_size = (buffer.len() - 4) as i32;
buffer[0..4].copy_from_slice(&total_size.to_be_bytes());
buffer
}
async fn handle_join_response(&mut self) -> io::Result<()> {
if let Some(ref mut stream) = self.stream {
let mut size_buf = [0u8; 4];
match stream.read_exact(&mut size_buf).await {
Ok(_) => {
let response_size = i32::from_be_bytes(size_buf);
info!("Join group response size: {} bytes", response_size);
if response_size <= 0 {
error!("Invalid response size: {}", response_size);
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid response size",
));
}
let mut response = vec![0u8; (response_size + 4) as usize];
response[0..4].copy_from_slice(&size_buf);
match stream.read_exact(&mut response[4..]).await {
Ok(_) => {
info!("Read complete response of {} bytes", response.len());
let mut pos = 8;
if pos + 2 > response.len() {
error!("Cannot read error code from join response");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
let error_code =
i16::from_be_bytes(response[pos..pos + 2].try_into().unwrap());
if error_code != 0 {
error!("Error in join response: {}", error_code);
return Err(io::Error::new(
io::ErrorKind::Other,
format!("Join error: {}", error_code),
));
}
pos += 2;
if pos + 4 > response.len() {
error!("Cannot read generation ID");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
pos += 4;
if pos + 2 > response.len() {
error!("Cannot read group protocol length");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
let protocol_len =
i16::from_be_bytes(response[pos..pos + 2].try_into().unwrap())
as usize;
pos += 2;
if pos + protocol_len > response.len() {
error!("Cannot read group protocol");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
pos += protocol_len;
if pos + 2 > response.len() {
error!("Cannot read leader ID length");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
let leader_len =
i16::from_be_bytes(response[pos..pos + 2].try_into().unwrap())
as usize;
pos += 2;
if pos + leader_len > response.len() {
error!("Cannot read leader ID");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
pos += leader_len;
if pos + 2 > response.len() {
error!("Cannot read member ID length");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
let member_id_len =
i16::from_be_bytes(response[pos..pos + 2].try_into().unwrap())
as usize;
pos += 2;
if pos + member_id_len > response.len() {
error!("Cannot read member ID");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
let member_id =
String::from_utf8_lossy(&response[pos..pos + member_id_len])
.to_string();
pos += member_id_len;
if pos + 4 > response.len() {
error!("Cannot read members array length");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
let members_len =
i32::from_be_bytes(response[pos..pos + 4].try_into().unwrap())
as usize;
pos += 4;
for _ in 0..members_len {
if pos + 2 > response.len() {
error!("Cannot read member ID length");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
let member_len =
i16::from_be_bytes(response[pos..pos + 2].try_into().unwrap())
as usize;
pos += 2;
if pos + member_len > response.len() {
error!("Cannot read member ID");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
pos += member_len;
if pos + 4 > response.len() {
error!("Cannot read metadata length");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
let metadata_len =
i32::from_be_bytes(response[pos..pos + 4].try_into().unwrap())
as usize;
pos += 4;
if pos + metadata_len > response.len() {
error!("Cannot read metadata");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid join response",
));
}
pos += metadata_len;
}
self.partition_assignments = vec![0];
info!(
"Assigned partition 0 to consumer with member ID: {}",
member_id
);
info!("Successfully parsed join group response");
Ok(())
}
Err(e) => {
error!("Failed to read response data: {}", e);
Err(e)
}
}
}
Err(e) => {
error!("Failed to read response size: {}", e);
Err(e)
}
}
} else {
error!("No active connection to broker");
Err(io::Error::new(
io::ErrorKind::NotConnected,
"No active connection to broker",
))
}
}
pub async fn consume(&mut self) -> io::Result<Vec<Vec<u8>>> {
info!("Consuming messages from topic: {}", self.topic);
if self.stream.is_none() {
error!("Not connected to broker");
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"Not connected to broker",
));
}
if self.partition_assignments.is_empty() {
info!("No partition assignments, using default partition 0");
self.partition_assignments = vec![0];
}
let mut messages = Vec::new();
for partition in &self.partition_assignments {
info!(
"Fetching messages from partition {} at offset {}",
partition, self.current_offset
);
let request = self.create_fetch_request(*partition);
if let Some(ref mut stream) = self.stream {
match timeout(RESPONSE_TIMEOUT, async {
stream.write_all(&request).await?;
stream.flush().await?;
info!("Sent fetch request to broker");
Ok::<(), io::Error>(())
})
.await
{
Ok(result) => result?,
Err(_) => {
error!("Timeout while sending fetch request");
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"Fetch request timed out",
));
}
}
let mut size_buf = [0u8; 4];
match timeout(RESPONSE_TIMEOUT, stream.read_exact(&mut size_buf)).await {
Ok(result) => {
match result {
Ok(_) => {
let response_size = i32::from_be_bytes(size_buf);
info!("Response size: {} bytes", response_size);
if response_size <= 0 {
info!("Empty response from partition {}", partition);
continue;
}
let mut response = vec![0u8; response_size as usize];
match timeout(RESPONSE_TIMEOUT, stream.read_exact(&mut response))
.await
{
Ok(result) => {
match result {
Ok(_) => {
info!(
"Read {} bytes from partition {}",
response.len(),
partition
);
if let Some(batch) =
self.extract_messages_from_response(&response)
{
let batch_len = batch.len();
messages.extend(batch);
self.current_offset += batch_len as i64;
info!(
"Received {} messages from partition {}",
batch_len, partition
);
} else {
error!(
"Failed to extract messages from response"
);
}
}
Err(e) => {
error!("Failed to read response data: {}", e);
return Err(e);
}
}
}
Err(_) => {
error!("Timeout while reading response data");
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"Response read timed out",
));
}
}
}
Err(e) => {
error!("Failed to read response size: {}", e);
return Err(e);
}
}
}
Err(_) => {
error!("Timeout while waiting for response size");
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"Response size read timed out",
));
}
}
}
}
info!("Total messages received: {}", messages.len());
Ok(messages)
}
fn extract_messages_from_response(&self, response: &[u8]) -> Option<Vec<Vec<u8>>> {
if response.len() < 8 {
error!("Response too short: {} bytes", response.len());
return None;
}
let mut pos = 4;
if pos + 4 > response.len() {
error!("Cannot read number of topics");
return None;
}
let num_topics = i32::from_be_bytes(response[pos..pos + 4].try_into().unwrap());
info!("Number of topics: {}", num_topics);
pos += 4;
if num_topics <= 0 || pos >= response.len() {
error!("No topics in response");
return None;
}
if pos + 2 > response.len() {
error!("Cannot read topic name length");
return None;
}
let topic_len = i16::from_be_bytes(response[pos..pos + 2].try_into().unwrap()) as usize;
info!("Topic name length: {}", topic_len);
pos += 2;
if pos + topic_len > response.len() {
error!("Cannot skip topic name");
return None;
}
pos += topic_len;
if pos + 4 > response.len() {
error!("Cannot read number of partitions");
return None;
}
let num_partitions = i32::from_be_bytes(response[pos..pos + 4].try_into().unwrap());
info!("Number of partitions: {}", num_partitions);
pos += 4;
if num_partitions <= 0 || pos >= response.len() {
error!("No partitions in response");
return None;
}
if pos + 4 > response.len() {
error!("Cannot read partition ID");
return None;
}
let partition_id = i32::from_be_bytes(response[pos..pos + 4].try_into().unwrap());
info!("Partition ID: {}", partition_id);
pos += 4;
if pos + 2 > response.len() {
error!("Cannot read error code");
return None;
}
let error_code = i16::from_be_bytes(response[pos..pos + 2].try_into().unwrap());
if error_code != 0 {
error!("Error code in response: {}", error_code);
return None;
}
pos += 2;
if pos + 8 > response.len() {
error!("Cannot read high watermark");
return None;
}
let high_watermark = i64::from_be_bytes(response[pos..pos + 8].try_into().unwrap());
info!("High watermark: {}", high_watermark);
pos += 8;
if pos + 4 > response.len() {
error!("Cannot read message set size");
return None;
}
let message_set_size =
i32::from_be_bytes(response[pos..pos + 4].try_into().unwrap()) as usize;
info!("Message set size: {}", message_set_size);
pos += 4;
if message_set_size == 0 {
info!("Empty message set");
return None;
}
if pos + message_set_size > response.len() {
error!(
"Message set size {} exceeds response length {}",
message_set_size,
response.len()
);
return None;
}
let mut messages = Vec::new();
let message_set_end = pos + message_set_size;
while pos < message_set_end {
if pos + 8 > message_set_end {
break;
}
let offset = i64::from_be_bytes(response[pos..pos + 8].try_into().unwrap());
info!("Message offset: {}", offset);
pos += 8;
if pos + 4 > message_set_end {
break;
}
let message_size =
i32::from_be_bytes(response[pos..pos + 4].try_into().unwrap()) as usize;
info!("Message size: {}", message_size);
pos += 4;
if pos + message_size > message_set_end {
error!("Message size {} exceeds message set end", message_size);
break;
}
pos += 4;
pos += 1;
pos += 1;
if pos + 4 > message_set_end {
break;
}
let key_len = i32::from_be_bytes(response[pos..pos + 4].try_into().unwrap()) as usize;
info!("Key length: {}", key_len);
pos += 4;
if key_len > 0 {
if pos + key_len > message_set_end {
error!("Key length {} exceeds message set end", key_len);
break;
}
pos += key_len;
}
if pos + 4 > message_set_end {
break;
}
let value_len = i32::from_be_bytes(response[pos..pos + 4].try_into().unwrap()) as usize;
info!("Value length: {}", value_len);
pos += 4;
if value_len == 0 {
continue;
}
if pos + value_len > message_set_end {
error!("Value length {} exceeds message set end", value_len);
break;
}
messages.push(response[pos..pos + value_len].to_vec());
info!("Extracted message of length {}", value_len);
pos += value_len;
}
if messages.is_empty() {
None
} else {
Some(messages)
}
}
fn create_fetch_request(&self, partition: i32) -> BytesMut {
let mut buffer = BytesMut::new();
buffer.extend_from_slice(&[0, 0, 0, 0]);
buffer.extend_from_slice(&1i16.to_be_bytes());
buffer.extend_from_slice(&API_VERSION.to_be_bytes());
buffer.extend_from_slice(&CORRELATION_ID.to_be_bytes());
let client_id_bytes = CLIENT_ID.as_bytes();
buffer.extend_from_slice(&(client_id_bytes.len() as i16).to_be_bytes());
buffer.extend_from_slice(client_id_bytes);
buffer.extend_from_slice(&(-1i32).to_be_bytes());
buffer.extend_from_slice(&100i32.to_be_bytes());
buffer.extend_from_slice(&1i32.to_be_bytes());
buffer.extend_from_slice(&1i32.to_be_bytes());
buffer.extend_from_slice(&(self.topic.len() as i16).to_be_bytes());
buffer.extend_from_slice(self.topic.as_bytes());
buffer.extend_from_slice(&1i32.to_be_bytes());
buffer.extend_from_slice(&partition.to_be_bytes());
buffer.extend_from_slice(&self.current_offset.to_be_bytes());
buffer.extend_from_slice(&(1024 * 1024i32).to_be_bytes());
let total_size = (buffer.len() - 4) as i32;
buffer[0..4].copy_from_slice(&total_size.to_be_bytes());
buffer
}
pub async fn commit(&mut self) -> io::Result<()> {
info!("Committing offset {}", self.current_offset);
Ok(())
}
pub async fn close(&mut self) -> io::Result<()> {
info!("Closing consumer connection");
if let Some(mut stream) = self.stream.take() {
stream.shutdown().await?;
}
Ok(())
}
}
impl Drop for KafkaConsumer {
fn drop(&mut self) {
if let Some(stream) = self.stream.take() {
let _ = stream
.into_std()
.map(|s| s.shutdown(std::net::Shutdown::Both));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, timeout};
use tracing_test::traced_test;
const TEST_TIMEOUT: Duration = Duration::from_secs(30);
const SETUP_DELAY: Duration = Duration::from_secs(2);
#[tokio::test]
#[traced_test]
async fn test_consumer_creation() {
sleep(SETUP_DELAY).await;
let consumer_result = timeout(TEST_TIMEOUT, async {
let mut consumer = KafkaConsumer::new(
"127.0.0.1:9092".to_string(),
"test_group".to_string(),
"test-topic".to_string(),
)
.await?;
for attempt in 1..=3 {
match consumer.connect().await {
Ok(_) => return Ok(consumer),
Err(e) if attempt < 3 => {
warn!("Connection attempt {} failed: {}", attempt, e);
sleep(Duration::from_secs(1)).await;
}
Err(e) => return Err(e),
}
}
Ok(consumer)
})
.await;
assert!(consumer_result.is_ok(), "Consumer creation timed out");
let consumer = consumer_result.unwrap();
assert!(
consumer.is_ok(),
"Failed to create consumer: {:?}",
consumer.err()
);
let consumer = consumer.unwrap();
assert_eq!(consumer.group_id, "test_group");
assert!(consumer.stream.is_some());
}
#[tokio::test]
#[traced_test]
async fn test_consume_messages() {
sleep(SETUP_DELAY).await;
let consumer_result = timeout(TEST_TIMEOUT, async {
let mut consumer = KafkaConsumer::new(
"127.0.0.1:9092".to_string(),
"test_group".to_string(),
"test-topic".to_string(),
)
.await?;
for attempt in 1..=3 {
match consumer.connect().await {
Ok(_) => return Ok(consumer),
Err(e) if attempt < 3 => {
warn!("Connection attempt {} failed: {}", attempt, e);
sleep(Duration::from_secs(1)).await;
}
Err(e) => return Err(e),
}
}
Ok(consumer)
})
.await;
assert!(consumer_result.is_ok(), "Consumer creation timed out");
let mut consumer = consumer_result.unwrap().expect("Failed to create consumer");
sleep(SETUP_DELAY).await;
let messages_result = timeout(TEST_TIMEOUT, consumer.consume()).await;
assert!(messages_result.is_ok(), "Consume operation timed out");
let messages = messages_result.unwrap();
assert!(
messages.is_ok(),
"Failed to consume messages: {:?}",
messages.err()
);
}
#[tokio::test]
#[traced_test]
async fn test_consumer_lifecycle() {
sleep(SETUP_DELAY).await;
let consumer_result = timeout(TEST_TIMEOUT, async {
let mut consumer = KafkaConsumer::new(
"127.0.0.1:9092".to_string(),
"test_group".to_string(),
"test-topic".to_string(),
)
.await?;
for attempt in 1..=3 {
match consumer.connect().await {
Ok(_) => return Ok(consumer),
Err(e) if attempt < 3 => {
warn!("Connection attempt {} failed: {}", attempt, e);
sleep(Duration::from_secs(1)).await;
}
Err(e) => return Err(e),
}
}
Ok(consumer)
})
.await;
assert!(consumer_result.is_ok(), "Consumer creation timed out");
let mut consumer = consumer_result.unwrap().expect("Failed to create consumer");
sleep(SETUP_DELAY).await;
let consume_result = timeout(TEST_TIMEOUT, consumer.consume()).await;
assert!(consume_result.is_ok(), "Consume operation timed out");
assert!(
consume_result.unwrap().is_ok(),
"Failed to consume messages"
);
let commit_result = timeout(TEST_TIMEOUT, consumer.commit()).await;
assert!(commit_result.is_ok(), "Commit operation timed out");
assert!(commit_result.unwrap().is_ok(), "Failed to commit offset");
let close_result = timeout(TEST_TIMEOUT, consumer.close()).await;
assert!(close_result.is_ok(), "Close operation timed out");
assert!(close_result.unwrap().is_ok(), "Failed to close consumer");
}
}