use crate::error::{ClusterError, Result};
use crate::metadata::MetadataCommand;
use crate::node::NodeId;
use crate::partition::PartitionId;
use serde::{Deserialize, Serialize};
use std::time::Duration;
pub const PROTOCOL_VERSION: u16 = 1;
pub const MIN_PROTOCOL_VERSION: u16 = 1;
pub const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
pub mod error_codes {
pub const UNSUPPORTED_VERSION: u16 = 1;
pub const MESSAGE_TOO_LARGE: u16 = 2;
pub const UNKNOWN_REQUEST: u16 = 3;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestHeader {
pub version: u16,
pub correlation_id: u64,
pub source: NodeId,
pub timeout_ms: u32,
}
impl RequestHeader {
pub fn new(correlation_id: u64, source: NodeId) -> Self {
Self {
version: PROTOCOL_VERSION,
correlation_id,
source,
timeout_ms: 30000,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout_ms = timeout.as_millis() as u32;
self
}
pub fn validate_version(&self) -> std::result::Result<(), ResponseHeader> {
if self.version < MIN_PROTOCOL_VERSION || self.version > PROTOCOL_VERSION {
Err(ResponseHeader::error(
self.correlation_id,
error_codes::UNSUPPORTED_VERSION,
format!(
"unsupported protocol version {}: supported range [{}, {}]",
self.version, MIN_PROTOCOL_VERSION, PROTOCOL_VERSION
),
))
} else {
Ok(())
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseHeader {
pub correlation_id: u64,
pub error_code: u16,
pub error_message: Option<String>,
}
impl ResponseHeader {
pub fn success(correlation_id: u64) -> Self {
Self {
correlation_id,
error_code: 0,
error_message: None,
}
}
pub fn error(correlation_id: u64, code: u16, message: impl Into<String>) -> Self {
Self {
correlation_id,
error_code: code,
error_message: Some(message.into()),
}
}
pub fn is_success(&self) -> bool {
self.error_code == 0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)] pub enum ClusterRequest {
FetchMetadata {
header: RequestHeader,
topics: Option<Vec<String>>, },
ProposeMetadata {
header: RequestHeader,
command: MetadataCommand,
},
Fetch {
header: RequestHeader,
partition: PartitionId,
offset: u64,
max_bytes: u32,
},
Append {
header: RequestHeader,
partition: PartitionId,
records: Vec<u8>, required_acks: Acks,
},
ReplicaState {
header: RequestHeader,
partition: PartitionId,
log_end_offset: u64,
high_watermark: u64,
},
ElectLeader {
header: RequestHeader,
partition: PartitionId,
preferred_leader: Option<NodeId>,
},
Heartbeat {
header: RequestHeader,
partitions: Vec<HeartbeatPartition>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeartbeatPartition {
pub partition: PartitionId,
pub leader_epoch: u64,
pub high_watermark: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ClusterResponse {
Metadata {
header: ResponseHeader,
cluster_id: String,
controller_id: Option<NodeId>,
topics: Vec<TopicMetadata>,
brokers: Vec<BrokerMetadata>,
},
MetadataProposal { header: ResponseHeader },
Fetch {
header: ResponseHeader,
partition: PartitionId,
high_watermark: u64,
log_start_offset: u64,
records: Vec<u8>, },
Append {
header: ResponseHeader,
partition: PartitionId,
base_offset: u64,
log_append_time: i64,
},
ReplicaStateAck {
header: ResponseHeader,
partition: PartitionId,
in_sync: bool,
},
ElectLeader {
header: ResponseHeader,
partition: PartitionId,
leader: Option<NodeId>,
epoch: u64,
},
Heartbeat { header: ResponseHeader },
Error { header: ResponseHeader },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopicMetadata {
pub name: String,
pub partitions: Vec<PartitionMetadata>,
pub is_internal: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PartitionMetadata {
pub partition_index: u32,
pub leader_id: Option<NodeId>,
pub leader_epoch: u64,
pub replica_nodes: Vec<NodeId>,
pub isr_nodes: Vec<NodeId>,
pub offline_replicas: Vec<NodeId>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BrokerMetadata {
pub node_id: NodeId,
pub host: String,
pub port: u16,
pub rack: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum Acks {
None,
#[default]
Leader,
All,
}
impl Acks {
pub fn from_i8(v: i8) -> Self {
match v {
0 => Acks::None,
1 => Acks::Leader,
-1 => Acks::All,
_ => Acks::Leader,
}
}
pub fn to_i8(self) -> i8 {
match self {
Acks::None => 0,
Acks::Leader => 1,
Acks::All => -1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum ErrorCode {
None = 0,
Unknown = 1,
CorruptMessage = 2,
UnknownTopic = 3,
InvalidPartition = 4,
LeaderNotAvailable = 5,
NotLeaderForPartition = 6,
RequestTimedOut = 7,
NotEnoughReplicas = 8,
NotEnoughReplicasAfterAppend = 9,
InvalidRequiredAcks = 10,
NotController = 11,
InvalidRequest = 12,
UnsupportedVersion = 13,
TopicAlreadyExists = 14,
InvalidReplicationFactor = 15,
IneligibleReplica = 16,
OffsetOutOfRange = 17,
NotReplicaForPartition = 18,
GroupAuthorizationFailed = 19,
UnknownMemberId = 20,
}
impl ErrorCode {
pub fn is_retriable(self) -> bool {
matches!(
self,
ErrorCode::LeaderNotAvailable
| ErrorCode::NotLeaderForPartition
| ErrorCode::RequestTimedOut
| ErrorCode::NotEnoughReplicas
| ErrorCode::NotController
)
}
}
pub fn encode_request(request: &ClusterRequest) -> Result<Vec<u8>> {
postcard::to_allocvec(request).map_err(|e| ClusterError::Serialization(e.to_string()))
}
pub fn decode_request(bytes: &[u8]) -> Result<ClusterRequest> {
if bytes.len() > MAX_MESSAGE_SIZE {
return Err(ClusterError::MessageTooLarge {
size: bytes.len(),
max: MAX_MESSAGE_SIZE,
});
}
postcard::from_bytes(bytes).map_err(|e| ClusterError::Deserialization(e.to_string()))
}
pub fn encode_response(response: &ClusterResponse) -> Result<Vec<u8>> {
postcard::to_allocvec(response).map_err(|e| ClusterError::Serialization(e.to_string()))
}
pub fn decode_response(bytes: &[u8]) -> Result<ClusterResponse> {
if bytes.len() > MAX_MESSAGE_SIZE {
return Err(ClusterError::MessageTooLarge {
size: bytes.len(),
max: MAX_MESSAGE_SIZE,
});
}
postcard::from_bytes(bytes).map_err(|e| ClusterError::Deserialization(e.to_string()))
}
pub fn frame_message(data: &[u8]) -> Vec<u8> {
let len = data.len() as u32;
let mut framed = Vec::with_capacity(4 + data.len());
framed.extend_from_slice(&len.to_be_bytes());
framed.extend_from_slice(data);
framed
}
pub fn frame_length(header: &[u8; 4]) -> usize {
u32::from_be_bytes(*header) as usize
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_roundtrip() {
let header = RequestHeader::new(42, "node-1".to_string());
let request = ClusterRequest::FetchMetadata {
header,
topics: Some(vec!["test-topic".to_string()]),
};
let bytes = encode_request(&request).unwrap();
let decoded = decode_request(&bytes).unwrap();
match decoded {
ClusterRequest::FetchMetadata { header, topics } => {
assert_eq!(header.correlation_id, 42);
assert_eq!(topics, Some(vec!["test-topic".to_string()]));
}
_ => panic!("Wrong request type"),
}
}
#[test]
fn test_response_roundtrip() {
let header = ResponseHeader::success(42);
let response = ClusterResponse::Metadata {
header,
cluster_id: "test-cluster".to_string(),
controller_id: Some("node-1".to_string()),
topics: vec![],
brokers: vec![],
};
let bytes = encode_response(&response).unwrap();
let decoded = decode_response(&bytes).unwrap();
match decoded {
ClusterResponse::Metadata {
header, cluster_id, ..
} => {
assert!(header.is_success());
assert_eq!(cluster_id, "test-cluster");
}
_ => panic!("Wrong response type"),
}
}
#[test]
fn test_framing() {
let data = b"hello world";
let framed = frame_message(data);
assert_eq!(framed.len(), 4 + data.len());
let mut header = [0u8; 4];
header.copy_from_slice(&framed[..4]);
assert_eq!(frame_length(&header), data.len());
}
#[test]
fn test_acks_conversion() {
assert_eq!(Acks::from_i8(0), Acks::None);
assert_eq!(Acks::from_i8(1), Acks::Leader);
assert_eq!(Acks::from_i8(-1), Acks::All);
assert_eq!(Acks::None.to_i8(), 0);
assert_eq!(Acks::Leader.to_i8(), 1);
assert_eq!(Acks::All.to_i8(), -1);
}
#[test]
fn test_version_validation_ok() {
let header = RequestHeader::new(1, "node-1".to_string());
assert!(header.validate_version().is_ok());
}
#[test]
fn test_version_validation_too_high() {
let mut header = RequestHeader::new(1, "node-1".to_string());
header.version = PROTOCOL_VERSION + 1;
let err = header.validate_version().unwrap_err();
assert_eq!(err.error_code, error_codes::UNSUPPORTED_VERSION);
}
#[test]
fn test_version_validation_zero() {
let mut header = RequestHeader::new(1, "node-1".to_string());
header.version = 0;
assert!(header.validate_version().is_err());
}
}