use std::collections::HashMap;
#[derive(Debug)]
pub struct KafkaProtocolHandler {
api_versions: HashMap<i16, ApiVersion>,
}
impl KafkaProtocolHandler {
pub fn new() -> Self {
let mut api_versions = HashMap::new();
api_versions.insert(
0,
ApiVersion {
min_version: 0,
max_version: 12,
},
); api_versions.insert(
1,
ApiVersion {
min_version: 0,
max_version: 16,
},
); api_versions.insert(
3,
ApiVersion {
min_version: 0,
max_version: 12,
},
); api_versions.insert(
9,
ApiVersion {
min_version: 0,
max_version: 5,
},
); api_versions.insert(
15,
ApiVersion {
min_version: 0,
max_version: 9,
},
); api_versions.insert(
16,
ApiVersion {
min_version: 0,
max_version: 9,
},
); api_versions.insert(
18,
ApiVersion {
min_version: 0,
max_version: 4,
},
); api_versions.insert(
19,
ApiVersion {
min_version: 0,
max_version: 7,
},
); api_versions.insert(
20,
ApiVersion {
min_version: 0,
max_version: 6,
},
); api_versions.insert(
32,
ApiVersion {
min_version: 0,
max_version: 4,
},
); api_versions.insert(
49,
ApiVersion {
min_version: 0,
max_version: 4,
},
);
Self { api_versions }
}
}
impl Default for KafkaProtocolHandler {
fn default() -> Self {
Self::new()
}
}
impl KafkaProtocolHandler {
pub fn parse_request(&self, data: &[u8]) -> Result<KafkaRequest> {
if data.len() < 12 {
return Err(anyhow::anyhow!("Message too short for header"));
}
let api_key = ((data[4] as i16) << 8) | (data[5] as i16);
let api_version = ((data[6] as i16) << 8) | (data[7] as i16);
let correlation_id = ((data[8] as i32) << 24)
| ((data[9] as i32) << 16)
| ((data[10] as i32) << 8)
| (data[11] as i32);
if data.len() < 14 {
return Err(anyhow::anyhow!("Message too short for client ID length"));
}
let client_id_len = ((data[12] as i16) << 8) | (data[13] as i16);
let client_id_start = 14;
let client_id_end = client_id_start + (client_id_len as usize);
if data.len() < client_id_end {
return Err(anyhow::anyhow!("Message too short for client ID"));
}
let client_id = if client_id_len > 0 {
String::from_utf8(data[client_id_start..client_id_end].to_vec())
.map_err(|e| anyhow::anyhow!("Invalid client ID encoding: {}", e))?
} else {
String::new()
};
let request_type = match api_key {
0 => KafkaRequestType::Produce,
1 => KafkaRequestType::Fetch,
3 => KafkaRequestType::Metadata,
9 => KafkaRequestType::ListGroups,
15 => KafkaRequestType::DescribeGroups,
18 => KafkaRequestType::ApiVersions,
19 => KafkaRequestType::CreateTopics,
20 => KafkaRequestType::DeleteTopics,
32 => KafkaRequestType::DescribeConfigs,
_ => KafkaRequestType::ApiVersions, };
Ok(KafkaRequest {
api_key,
api_version,
correlation_id,
client_id,
request_type,
})
}
pub fn serialize_response(
&self,
response: &KafkaResponse,
correlation_id: i32,
) -> Result<Vec<u8>> {
fn push_kafka_string(buf: &mut Vec<u8>, value: &str) {
buf.extend_from_slice(&(value.len() as i16).to_be_bytes());
buf.extend_from_slice(value.as_bytes());
}
match response {
KafkaResponse::ApiVersions => {
let mut api_versions = self.api_versions.iter().collect::<Vec<_>>();
api_versions.sort_by_key(|(api_key, _)| **api_key);
let mut data = Vec::new();
data.extend_from_slice(&correlation_id.to_be_bytes());
data.extend_from_slice(&0i16.to_be_bytes());
data.extend_from_slice(&(api_versions.len() as i32).to_be_bytes());
for (api_key, version) in api_versions {
data.extend_from_slice(&api_key.to_be_bytes());
data.extend_from_slice(&version.min_version.to_be_bytes());
data.extend_from_slice(&version.max_version.to_be_bytes());
}
data.extend_from_slice(&0i32.to_be_bytes()); Ok(data)
}
KafkaResponse::CreateTopics => {
let mut data = Vec::new();
data.extend_from_slice(&correlation_id.to_be_bytes());
data.extend_from_slice(&0i32.to_be_bytes()); data.extend_from_slice(&1i32.to_be_bytes()); push_kafka_string(&mut data, "default-topic");
data.extend_from_slice(&0i16.to_be_bytes()); data.extend_from_slice(&(-1i16).to_be_bytes()); Ok(data)
}
_ => {
let mut data = Vec::new();
data.extend_from_slice(&correlation_id.to_be_bytes());
data.extend_from_slice(&0i16.to_be_bytes());
Ok(data)
}
}
}
pub fn is_api_version_supported(&self, api_key: i16, version: i16) -> bool {
if let Some(api_version) = self.api_versions.get(&api_key) {
version >= api_version.min_version && version <= api_version.max_version
} else {
false
}
}
}
#[derive(Debug)]
pub struct KafkaRequest {
pub api_key: i16,
pub api_version: i16,
pub correlation_id: i32,
pub client_id: String,
pub request_type: KafkaRequestType,
}
#[derive(Debug)]
pub enum KafkaRequestType {
Metadata,
Produce,
Fetch,
ListGroups,
DescribeGroups,
ApiVersions,
CreateTopics,
DeleteTopics,
DescribeConfigs,
}
#[derive(Debug)]
pub enum KafkaResponse {
Metadata,
Produce,
Fetch,
ListGroups,
DescribeGroups,
ApiVersions,
CreateTopics,
DeleteTopics,
DescribeConfigs,
}
#[derive(Debug)]
struct ApiVersion {
min_version: i16,
max_version: i16,
}
type Result<T> = std::result::Result<T, anyhow::Error>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_protocol_handler_new() {
let handler = KafkaProtocolHandler::new();
assert!(handler.api_versions.len() > 0);
assert!(handler.api_versions.contains_key(&0)); assert!(handler.api_versions.contains_key(&1)); assert!(handler.api_versions.contains_key(&18)); }
#[test]
fn test_protocol_handler_default() {
let handler = KafkaProtocolHandler::default();
assert!(handler.api_versions.len() > 0);
}
#[test]
fn test_is_api_version_supported_produce() {
let handler = KafkaProtocolHandler::new();
assert!(handler.is_api_version_supported(0, 0));
assert!(handler.is_api_version_supported(0, 12));
assert!(!handler.is_api_version_supported(0, 13));
assert!(!handler.is_api_version_supported(0, -1));
}
#[test]
fn test_is_api_version_supported_fetch() {
let handler = KafkaProtocolHandler::new();
assert!(handler.is_api_version_supported(1, 0));
assert!(handler.is_api_version_supported(1, 16));
assert!(!handler.is_api_version_supported(1, 17));
}
#[test]
fn test_is_api_version_supported_metadata() {
let handler = KafkaProtocolHandler::new();
assert!(handler.is_api_version_supported(3, 0));
assert!(handler.is_api_version_supported(3, 12));
assert!(!handler.is_api_version_supported(3, 13));
}
#[test]
fn test_is_api_version_supported_api_versions() {
let handler = KafkaProtocolHandler::new();
assert!(handler.is_api_version_supported(18, 0));
assert!(handler.is_api_version_supported(18, 4));
assert!(!handler.is_api_version_supported(18, 5));
}
#[test]
fn test_is_api_version_unsupported_api_key() {
let handler = KafkaProtocolHandler::new();
assert!(!handler.is_api_version_supported(999, 0));
assert!(!handler.is_api_version_supported(-1, 0));
}
#[test]
fn test_parse_request_too_short() {
let handler = KafkaProtocolHandler::new();
let data = vec![0u8; 5]; let result = handler.parse_request(&data);
assert!(result.is_err());
}
#[test]
fn test_parse_request_minimal_header() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0;
data[5] = 18;
data[6] = 0;
data[7] = 0;
data[8] = 0;
data[9] = 0;
data[10] = 0;
data[11] = 1;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data);
assert!(result.is_ok());
let request = result.unwrap();
assert_eq!(request.api_key, 18);
assert_eq!(request.api_version, 0);
assert_eq!(request.correlation_id, 1);
assert_eq!(request.client_id, "");
}
#[test]
fn test_parse_request_with_client_id() {
let handler = KafkaProtocolHandler::new();
let client_id = b"test-client";
let client_id_len = client_id.len() as i16;
let mut data = vec![0u8; 14 + client_id.len()];
data[4] = 0;
data[5] = 0;
data[6] = 0;
data[7] = 7;
data[8] = 0;
data[9] = 0;
data[10] = 0;
data[11] = 42;
data[12] = (client_id_len >> 8) as u8;
data[13] = (client_id_len & 0xFF) as u8;
data[14..].copy_from_slice(client_id);
let result = handler.parse_request(&data);
assert!(result.is_ok());
let request = result.unwrap();
assert_eq!(request.api_key, 0);
assert_eq!(request.api_version, 7);
assert_eq!(request.correlation_id, 42);
assert_eq!(request.client_id, "test-client");
assert!(matches!(request.request_type, KafkaRequestType::Produce));
}
#[test]
fn test_parse_request_produce() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 0;
data[12] = 0; data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::Produce));
}
#[test]
fn test_parse_request_fetch() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 1;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::Fetch));
}
#[test]
fn test_parse_request_metadata() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 3;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::Metadata));
}
#[test]
fn test_parse_request_list_groups() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 9;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::ListGroups));
}
#[test]
fn test_parse_request_describe_groups() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 15;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::DescribeGroups));
}
#[test]
fn test_parse_request_api_versions() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 18;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::ApiVersions));
}
#[test]
fn test_parse_request_create_topics() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 19;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::CreateTopics));
}
#[test]
fn test_parse_request_delete_topics() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 20;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::DeleteTopics));
}
#[test]
fn test_parse_request_describe_configs() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 32;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::DescribeConfigs));
}
#[test]
fn test_parse_request_unsupported_api_defaults_to_api_versions() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0; data[5] = 99;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data).unwrap();
assert!(matches!(result.request_type, KafkaRequestType::ApiVersions));
}
#[test]
fn test_parse_request_invalid_client_id_length() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0;
data[5] = 18;
data[12] = 0;
data[13] = 100;
let result = handler.parse_request(&data);
assert!(result.is_err());
}
#[test]
fn test_parse_request_missing_client_id_length() {
let handler = KafkaProtocolHandler::new();
let data = vec![0u8; 12];
let result = handler.parse_request(&data);
assert!(result.is_err());
}
#[test]
fn test_parse_request_max_values() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 14];
data[4] = 0x7F;
data[5] = 0xFF;
data[6] = 0x7F;
data[7] = 0xFF;
data[8] = 0x7F;
data[9] = 0xFF;
data[10] = 0xFF;
data[11] = 0xFF;
data[12] = 0;
data[13] = 0;
let result = handler.parse_request(&data);
assert!(result.is_ok());
let request = result.unwrap();
assert_eq!(request.api_key, 0x7FFF);
assert_eq!(request.api_version, 0x7FFF);
assert_eq!(request.correlation_id, 0x7FFFFFFF);
}
#[test]
fn test_serialize_response_api_versions() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::ApiVersions;
let correlation_id = 12345;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
let data = result.unwrap();
assert!(data.len() > 0);
let corr_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
assert_eq!(corr_id, correlation_id);
let error_code = i16::from_be_bytes([data[4], data[5]]);
assert_eq!(error_code, 0); }
#[test]
fn test_serialize_response_create_topics() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::CreateTopics;
let correlation_id = 999;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
let data = result.unwrap();
assert!(data.len() > 0);
let corr_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
assert_eq!(corr_id, correlation_id);
}
#[test]
fn test_serialize_response_metadata() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::Metadata;
let correlation_id = 1;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
let data = result.unwrap();
assert!(data.len() >= 6); }
#[test]
fn test_serialize_response_produce() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::Produce;
let correlation_id = 42;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
}
#[test]
fn test_serialize_response_fetch() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::Fetch;
let correlation_id = 100;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
}
#[test]
fn test_serialize_response_list_groups() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::ListGroups;
let correlation_id = 200;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
}
#[test]
fn test_serialize_response_describe_groups() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::DescribeGroups;
let correlation_id = 300;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
}
#[test]
fn test_serialize_response_delete_topics() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::DeleteTopics;
let correlation_id = 400;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
}
#[test]
fn test_serialize_response_describe_configs() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::DescribeConfigs;
let correlation_id = 500;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
}
#[test]
fn test_serialize_response_negative_correlation_id() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::ApiVersions;
let correlation_id = -1;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
let data = result.unwrap();
let corr_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
assert_eq!(corr_id, -1);
}
#[test]
fn test_serialize_response_zero_correlation_id() {
let handler = KafkaProtocolHandler::new();
let response = KafkaResponse::ApiVersions;
let correlation_id = 0;
let result = handler.serialize_response(&response, correlation_id);
assert!(result.is_ok());
let data = result.unwrap();
let corr_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
assert_eq!(corr_id, 0);
}
#[test]
fn test_kafka_request_debug() {
let request = KafkaRequest {
api_key: 0,
api_version: 7,
correlation_id: 123,
client_id: "test".to_string(),
request_type: KafkaRequestType::Produce,
};
let debug_str = format!("{:?}", request);
assert!(debug_str.contains("KafkaRequest"));
assert!(debug_str.contains("api_key"));
}
#[test]
fn test_kafka_request_type_debug() {
let metadata = KafkaRequestType::Metadata;
let debug_str = format!("{:?}", metadata);
assert!(debug_str.contains("Metadata"));
}
#[test]
fn test_kafka_response_debug() {
let response = KafkaResponse::Produce;
let debug_str = format!("{:?}", response);
assert!(debug_str.contains("Produce"));
}
#[test]
fn test_api_version_ranges_complete() {
let handler = KafkaProtocolHandler::new();
let api_configs = vec![
(0, 0, 12), (1, 0, 16), (3, 0, 12), (9, 0, 5), (15, 0, 9), (16, 0, 9), (18, 0, 4), (19, 0, 7), (20, 0, 6), (32, 0, 4), (49, 0, 4), ];
for (api_key, min_ver, max_ver) in api_configs {
assert!(handler.is_api_version_supported(api_key, min_ver));
assert!(handler.is_api_version_supported(api_key, max_ver));
assert!(!handler.is_api_version_supported(api_key, max_ver + 1));
if min_ver > 0 {
assert!(!handler.is_api_version_supported(api_key, min_ver - 1));
}
}
}
#[test]
fn test_parse_request_large_client_id() {
let handler = KafkaProtocolHandler::new();
let client_id = "a".repeat(1000); let client_id_len = client_id.len() as i16;
let mut data = vec![0u8; 14 + client_id.len()];
data[4] = 0;
data[5] = 18;
data[12] = (client_id_len >> 8) as u8;
data[13] = (client_id_len & 0xFF) as u8;
data[14..].copy_from_slice(client_id.as_bytes());
let result = handler.parse_request(&data);
assert!(result.is_ok());
assert_eq!(result.unwrap().client_id, client_id);
}
#[test]
fn test_parse_request_invalid_utf8_client_id() {
let handler = KafkaProtocolHandler::new();
let mut data = vec![0u8; 17];
data[4] = 0;
data[5] = 18;
data[12] = 0;
data[13] = 3; data[14] = 0xFF;
data[15] = 0xFF;
data[16] = 0xFF;
let result = handler.parse_request(&data);
assert!(result.is_err());
}
#[test]
fn test_api_version_struct() {
let api_version = ApiVersion {
min_version: 0,
max_version: 10,
};
assert_eq!(api_version.min_version, 0);
assert_eq!(api_version.max_version, 10);
}
}