1use crate::error::{ClusterError, Result};
4use crate::metadata::MetadataCommand;
5use crate::node::NodeId;
6use crate::partition::PartitionId;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9
10pub const PROTOCOL_VERSION: u16 = 1;
12
13pub const MIN_PROTOCOL_VERSION: u16 = 1;
15
16pub const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
18
19pub mod error_codes {
21 pub const UNSUPPORTED_VERSION: u16 = 1;
23 pub const MESSAGE_TOO_LARGE: u16 = 2;
25 pub const UNKNOWN_REQUEST: u16 = 3;
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct RequestHeader {
32 pub version: u16,
34 pub correlation_id: u64,
36 pub source: NodeId,
38 pub timeout_ms: u32,
40}
41
42impl RequestHeader {
43 pub fn new(correlation_id: u64, source: NodeId) -> Self {
44 Self {
45 version: PROTOCOL_VERSION,
46 correlation_id,
47 source,
48 timeout_ms: 30000,
49 }
50 }
51
52 pub fn with_timeout(mut self, timeout: Duration) -> Self {
53 self.timeout_ms = timeout.as_millis() as u32;
54 self
55 }
56
57 pub fn validate_version(&self) -> std::result::Result<(), ResponseHeader> {
60 if self.version < MIN_PROTOCOL_VERSION || self.version > PROTOCOL_VERSION {
61 Err(ResponseHeader::error(
62 self.correlation_id,
63 error_codes::UNSUPPORTED_VERSION,
64 format!(
65 "unsupported protocol version {}: supported range [{}, {}]",
66 self.version, MIN_PROTOCOL_VERSION, PROTOCOL_VERSION
67 ),
68 ))
69 } else {
70 Ok(())
71 }
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ResponseHeader {
78 pub correlation_id: u64,
80 pub error_code: u16,
82 pub error_message: Option<String>,
84}
85
86impl ResponseHeader {
87 pub fn success(correlation_id: u64) -> Self {
88 Self {
89 correlation_id,
90 error_code: 0,
91 error_message: None,
92 }
93 }
94
95 pub fn error(correlation_id: u64, code: u16, message: impl Into<String>) -> Self {
96 Self {
97 correlation_id,
98 error_code: code,
99 error_message: Some(message.into()),
100 }
101 }
102
103 pub fn is_success(&self) -> bool {
104 self.error_code == 0
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110#[allow(clippy::large_enum_variant)] pub enum ClusterRequest {
112 FetchMetadata {
115 header: RequestHeader,
116 topics: Option<Vec<String>>, },
118
119 ProposeMetadata {
121 header: RequestHeader,
122 command: MetadataCommand,
123 },
124
125 Fetch {
128 header: RequestHeader,
129 partition: PartitionId,
130 offset: u64,
131 max_bytes: u32,
132 },
133
134 Append {
136 header: RequestHeader,
137 partition: PartitionId,
138 records: Vec<u8>, required_acks: Acks,
140 },
141
142 ReplicaState {
144 header: RequestHeader,
145 partition: PartitionId,
146 log_end_offset: u64,
147 high_watermark: u64,
148 },
149
150 ElectLeader {
153 header: RequestHeader,
154 partition: PartitionId,
155 preferred_leader: Option<NodeId>,
156 },
157
158 Heartbeat {
161 header: RequestHeader,
162 partitions: Vec<HeartbeatPartition>,
163 },
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct HeartbeatPartition {
169 pub partition: PartitionId,
170 pub leader_epoch: u64,
171 pub high_watermark: u64,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub enum ClusterResponse {
177 Metadata {
180 header: ResponseHeader,
181 cluster_id: String,
182 controller_id: Option<NodeId>,
183 topics: Vec<TopicMetadata>,
184 brokers: Vec<BrokerMetadata>,
185 },
186
187 MetadataProposal { header: ResponseHeader },
189
190 Fetch {
193 header: ResponseHeader,
194 partition: PartitionId,
195 high_watermark: u64,
196 log_start_offset: u64,
197 records: Vec<u8>, },
199
200 Append {
202 header: ResponseHeader,
203 partition: PartitionId,
204 base_offset: u64,
205 log_append_time: i64,
206 },
207
208 ReplicaStateAck {
210 header: ResponseHeader,
211 partition: PartitionId,
212 in_sync: bool,
213 },
214
215 ElectLeader {
218 header: ResponseHeader,
219 partition: PartitionId,
220 leader: Option<NodeId>,
221 epoch: u64,
222 },
223
224 Heartbeat { header: ResponseHeader },
227
228 Error { header: ResponseHeader },
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct TopicMetadata {
236 pub name: String,
237 pub partitions: Vec<PartitionMetadata>,
238 pub is_internal: bool,
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct PartitionMetadata {
244 pub partition_index: u32,
245 pub leader_id: Option<NodeId>,
246 pub leader_epoch: u64,
247 pub replica_nodes: Vec<NodeId>,
248 pub isr_nodes: Vec<NodeId>,
249 pub offline_replicas: Vec<NodeId>,
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct BrokerMetadata {
255 pub node_id: NodeId,
256 pub host: String,
257 pub port: u16,
258 pub rack: Option<String>,
259}
260
261#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
263pub enum Acks {
264 None,
266 #[default]
268 Leader,
269 All,
271}
272
273impl Acks {
274 pub fn from_i8(v: i8) -> Self {
275 match v {
276 0 => Acks::None,
277 1 => Acks::Leader,
278 -1 => Acks::All,
279 _ => Acks::Leader,
280 }
281 }
282
283 pub fn to_i8(self) -> i8 {
284 match self {
285 Acks::None => 0,
286 Acks::Leader => 1,
287 Acks::All => -1,
288 }
289 }
290}
291
292#[derive(Debug, Clone, Copy, PartialEq, Eq)]
294#[repr(u16)]
295pub enum ErrorCode {
296 None = 0,
297 Unknown = 1,
298 CorruptMessage = 2,
299 UnknownTopic = 3,
300 InvalidPartition = 4,
301 LeaderNotAvailable = 5,
302 NotLeaderForPartition = 6,
303 RequestTimedOut = 7,
304 NotEnoughReplicas = 8,
305 NotEnoughReplicasAfterAppend = 9,
306 InvalidRequiredAcks = 10,
307 NotController = 11,
308 InvalidRequest = 12,
309 UnsupportedVersion = 13,
310 TopicAlreadyExists = 14,
311 InvalidReplicationFactor = 15,
312 IneligibleReplica = 16,
313 OffsetOutOfRange = 17,
314 NotReplicaForPartition = 18,
315 GroupAuthorizationFailed = 19,
316 UnknownMemberId = 20,
317}
318
319impl ErrorCode {
320 pub fn is_retriable(self) -> bool {
321 matches!(
322 self,
323 ErrorCode::LeaderNotAvailable
324 | ErrorCode::NotLeaderForPartition
325 | ErrorCode::RequestTimedOut
326 | ErrorCode::NotEnoughReplicas
327 | ErrorCode::NotController
328 )
329 }
330}
331
332pub fn encode_request(request: &ClusterRequest) -> Result<Vec<u8>> {
334 postcard::to_allocvec(request).map_err(|e| ClusterError::Serialization(e.to_string()))
335}
336
337pub fn decode_request(bytes: &[u8]) -> Result<ClusterRequest> {
339 if bytes.len() > MAX_MESSAGE_SIZE {
340 return Err(ClusterError::MessageTooLarge {
341 size: bytes.len(),
342 max: MAX_MESSAGE_SIZE,
343 });
344 }
345 postcard::from_bytes(bytes).map_err(|e| ClusterError::Deserialization(e.to_string()))
346}
347
348pub fn encode_response(response: &ClusterResponse) -> Result<Vec<u8>> {
350 postcard::to_allocvec(response).map_err(|e| ClusterError::Serialization(e.to_string()))
351}
352
353pub fn decode_response(bytes: &[u8]) -> Result<ClusterResponse> {
355 if bytes.len() > MAX_MESSAGE_SIZE {
356 return Err(ClusterError::MessageTooLarge {
357 size: bytes.len(),
358 max: MAX_MESSAGE_SIZE,
359 });
360 }
361 postcard::from_bytes(bytes).map_err(|e| ClusterError::Deserialization(e.to_string()))
362}
363
364pub fn frame_message(data: &[u8]) -> Vec<u8> {
366 let len = data.len() as u32;
367 let mut framed = Vec::with_capacity(4 + data.len());
368 framed.extend_from_slice(&len.to_be_bytes());
369 framed.extend_from_slice(data);
370 framed
371}
372
373pub fn frame_length(header: &[u8; 4]) -> usize {
375 u32::from_be_bytes(*header) as usize
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 #[test]
382 fn test_request_roundtrip() {
383 let header = RequestHeader::new(42, "node-1".to_string());
384 let request = ClusterRequest::FetchMetadata {
385 header,
386 topics: Some(vec!["test-topic".to_string()]),
387 };
388
389 let bytes = encode_request(&request).unwrap();
390 let decoded = decode_request(&bytes).unwrap();
391
392 match decoded {
393 ClusterRequest::FetchMetadata { header, topics } => {
394 assert_eq!(header.correlation_id, 42);
395 assert_eq!(topics, Some(vec!["test-topic".to_string()]));
396 }
397 _ => panic!("Wrong request type"),
398 }
399 }
400
401 #[test]
402 fn test_response_roundtrip() {
403 let header = ResponseHeader::success(42);
404 let response = ClusterResponse::Metadata {
405 header,
406 cluster_id: "test-cluster".to_string(),
407 controller_id: Some("node-1".to_string()),
408 topics: vec![],
409 brokers: vec![],
410 };
411
412 let bytes = encode_response(&response).unwrap();
413 let decoded = decode_response(&bytes).unwrap();
414
415 match decoded {
416 ClusterResponse::Metadata {
417 header, cluster_id, ..
418 } => {
419 assert!(header.is_success());
420 assert_eq!(cluster_id, "test-cluster");
421 }
422 _ => panic!("Wrong response type"),
423 }
424 }
425
426 #[test]
427 fn test_framing() {
428 let data = b"hello world";
429 let framed = frame_message(data);
430
431 assert_eq!(framed.len(), 4 + data.len());
432
433 let mut header = [0u8; 4];
434 header.copy_from_slice(&framed[..4]);
435 assert_eq!(frame_length(&header), data.len());
436 }
437
438 #[test]
439 fn test_acks_conversion() {
440 assert_eq!(Acks::from_i8(0), Acks::None);
441 assert_eq!(Acks::from_i8(1), Acks::Leader);
442 assert_eq!(Acks::from_i8(-1), Acks::All);
443
444 assert_eq!(Acks::None.to_i8(), 0);
445 assert_eq!(Acks::Leader.to_i8(), 1);
446 assert_eq!(Acks::All.to_i8(), -1);
447 }
448
449 #[test]
450 fn test_version_validation_ok() {
451 let header = RequestHeader::new(1, "node-1".to_string());
452 assert!(header.validate_version().is_ok());
453 }
454
455 #[test]
456 fn test_version_validation_too_high() {
457 let mut header = RequestHeader::new(1, "node-1".to_string());
458 header.version = PROTOCOL_VERSION + 1;
459 let err = header.validate_version().unwrap_err();
460 assert_eq!(err.error_code, error_codes::UNSUPPORTED_VERSION);
461 }
462
463 #[test]
464 fn test_version_validation_zero() {
465 let mut header = RequestHeader::new(1, "node-1".to_string());
466 header.version = 0;
467 assert!(header.validate_version().is_err());
468 }
469}