Skip to main content

rivven_cluster/
protocol.rs

1//! Wire protocol for cluster communication
2
3use crate::error::{ClusterError, Result};
4use crate::metadata::MetadataCommand;
5use crate::node::NodeId;
6use crate::partition::PartitionId;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9
10/// Protocol version for compatibility checking
11pub const PROTOCOL_VERSION: u16 = 1;
12
13/// Minimum protocol version we can interoperate with
14pub const MIN_PROTOCOL_VERSION: u16 = 1;
15
16/// Maximum message size (16 MB)
17pub const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
18
19/// Error codes for protocol-level errors
20pub mod error_codes {
21    /// Unsupported protocol version
22    pub const UNSUPPORTED_VERSION: u16 = 1;
23    /// Message too large
24    pub const MESSAGE_TOO_LARGE: u16 = 2;
25    /// Unknown request type
26    pub const UNKNOWN_REQUEST: u16 = 3;
27}
28
29/// Request header included in all requests
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct RequestHeader {
32    /// Protocol version
33    pub version: u16,
34    /// Correlation ID for matching responses
35    pub correlation_id: u64,
36    /// Source node ID
37    pub source: NodeId,
38    /// Request timeout
39    pub timeout_ms: u32,
40}
41
42impl RequestHeader {
43    pub fn new(correlation_id: u64, source: NodeId) -> Self {
44        Self {
45            version: PROTOCOL_VERSION,
46            correlation_id,
47            source,
48            timeout_ms: 30000,
49        }
50    }
51
52    pub fn with_timeout(mut self, timeout: Duration) -> Self {
53        self.timeout_ms = timeout.as_millis() as u32;
54        self
55    }
56
57    /// Validate that the protocol version is supported.
58    /// Returns an error ResponseHeader if the version is out of range.
59    pub fn validate_version(&self) -> std::result::Result<(), ResponseHeader> {
60        if self.version < MIN_PROTOCOL_VERSION || self.version > PROTOCOL_VERSION {
61            Err(ResponseHeader::error(
62                self.correlation_id,
63                error_codes::UNSUPPORTED_VERSION,
64                format!(
65                    "unsupported protocol version {}: supported range [{}, {}]",
66                    self.version, MIN_PROTOCOL_VERSION, PROTOCOL_VERSION
67                ),
68            ))
69        } else {
70            Ok(())
71        }
72    }
73}
74
75/// Response header included in all responses
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ResponseHeader {
78    /// Correlation ID matching the request
79    pub correlation_id: u64,
80    /// Error code (0 = success)
81    pub error_code: u16,
82    /// Error message (if any)
83    pub error_message: Option<String>,
84}
85
86impl ResponseHeader {
87    pub fn success(correlation_id: u64) -> Self {
88        Self {
89            correlation_id,
90            error_code: 0,
91            error_message: None,
92        }
93    }
94
95    pub fn error(correlation_id: u64, code: u16, message: impl Into<String>) -> Self {
96        Self {
97            correlation_id,
98            error_code: code,
99            error_message: Some(message.into()),
100        }
101    }
102
103    pub fn is_success(&self) -> bool {
104        self.error_code == 0
105    }
106}
107
108/// Request types for cluster operations
109#[derive(Debug, Clone, Serialize, Deserialize)]
110#[allow(clippy::large_enum_variant)] // Acceptable for protocol enums - they're short-lived
111pub enum ClusterRequest {
112    // ==================== Metadata Requests ====================
113    /// Fetch metadata for topics
114    FetchMetadata {
115        header: RequestHeader,
116        topics: Option<Vec<String>>, // None = all topics
117    },
118
119    /// Propose a metadata change (forwarded to Raft leader)
120    ProposeMetadata {
121        header: RequestHeader,
122        command: MetadataCommand,
123    },
124
125    // ==================== Replication Requests ====================
126    /// Fetch records from a partition (follower -> leader)
127    Fetch {
128        header: RequestHeader,
129        partition: PartitionId,
130        offset: u64,
131        max_bytes: u32,
132    },
133
134    /// Append records to a partition (client -> leader -> followers)
135    Append {
136        header: RequestHeader,
137        partition: PartitionId,
138        records: Vec<u8>, // Serialized records batch
139        required_acks: Acks,
140    },
141
142    /// Report replica state to leader
143    ReplicaState {
144        header: RequestHeader,
145        partition: PartitionId,
146        log_end_offset: u64,
147        high_watermark: u64,
148    },
149
150    // ==================== Leader Election ====================
151    /// Request leader election for a partition
152    ElectLeader {
153        header: RequestHeader,
154        partition: PartitionId,
155        preferred_leader: Option<NodeId>,
156    },
157
158    // ==================== Heartbeat ====================
159    /// Heartbeat from leader to followers
160    Heartbeat {
161        header: RequestHeader,
162        partitions: Vec<HeartbeatPartition>,
163    },
164}
165
166/// Partition info in heartbeat
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct HeartbeatPartition {
169    pub partition: PartitionId,
170    pub leader_epoch: u64,
171    pub high_watermark: u64,
172}
173
174/// Response types for cluster operations
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub enum ClusterResponse {
177    // ==================== Metadata Responses ====================
178    /// Metadata response
179    Metadata {
180        header: ResponseHeader,
181        cluster_id: String,
182        controller_id: Option<NodeId>,
183        topics: Vec<TopicMetadata>,
184        brokers: Vec<BrokerMetadata>,
185    },
186
187    /// Metadata proposal response
188    MetadataProposal { header: ResponseHeader },
189
190    // ==================== Replication Responses ====================
191    /// Fetch response with records
192    Fetch {
193        header: ResponseHeader,
194        partition: PartitionId,
195        high_watermark: u64,
196        log_start_offset: u64,
197        records: Vec<u8>, // Serialized records
198    },
199
200    /// Append response
201    Append {
202        header: ResponseHeader,
203        partition: PartitionId,
204        base_offset: u64,
205        log_append_time: i64,
206    },
207
208    /// Replica state acknowledgment
209    ReplicaStateAck {
210        header: ResponseHeader,
211        partition: PartitionId,
212        in_sync: bool,
213    },
214
215    // ==================== Leader Election ====================
216    /// Leader election response
217    ElectLeader {
218        header: ResponseHeader,
219        partition: PartitionId,
220        leader: Option<NodeId>,
221        epoch: u64,
222    },
223
224    // ==================== Heartbeat ====================
225    /// Heartbeat response
226    Heartbeat { header: ResponseHeader },
227
228    // ==================== Error ====================
229    /// Generic error response
230    Error { header: ResponseHeader },
231}
232
233/// Topic metadata in response
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct TopicMetadata {
236    pub name: String,
237    pub partitions: Vec<PartitionMetadata>,
238    pub is_internal: bool,
239}
240
241/// Partition metadata in response
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct PartitionMetadata {
244    pub partition_index: u32,
245    pub leader_id: Option<NodeId>,
246    pub leader_epoch: u64,
247    pub replica_nodes: Vec<NodeId>,
248    pub isr_nodes: Vec<NodeId>,
249    pub offline_replicas: Vec<NodeId>,
250}
251
252/// Broker metadata in response
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct BrokerMetadata {
255    pub node_id: NodeId,
256    pub host: String,
257    pub port: u16,
258    pub rack: Option<String>,
259}
260
261/// Required acknowledgments for writes
262#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
263pub enum Acks {
264    /// No acknowledgment (fire and forget)
265    None,
266    /// Leader acknowledgment only
267    #[default]
268    Leader,
269    /// All ISR acknowledgment
270    All,
271}
272
273impl Acks {
274    pub fn from_i8(v: i8) -> Self {
275        match v {
276            0 => Acks::None,
277            1 => Acks::Leader,
278            -1 => Acks::All,
279            _ => Acks::Leader,
280        }
281    }
282
283    pub fn to_i8(self) -> i8 {
284        match self {
285            Acks::None => 0,
286            Acks::Leader => 1,
287            Acks::All => -1,
288        }
289    }
290}
291
292/// Error codes for cluster protocol
293#[derive(Debug, Clone, Copy, PartialEq, Eq)]
294#[repr(u16)]
295pub enum ErrorCode {
296    None = 0,
297    Unknown = 1,
298    CorruptMessage = 2,
299    UnknownTopic = 3,
300    InvalidPartition = 4,
301    LeaderNotAvailable = 5,
302    NotLeaderForPartition = 6,
303    RequestTimedOut = 7,
304    NotEnoughReplicas = 8,
305    NotEnoughReplicasAfterAppend = 9,
306    InvalidRequiredAcks = 10,
307    NotController = 11,
308    InvalidRequest = 12,
309    UnsupportedVersion = 13,
310    TopicAlreadyExists = 14,
311    InvalidReplicationFactor = 15,
312    IneligibleReplica = 16,
313    OffsetOutOfRange = 17,
314    NotReplicaForPartition = 18,
315    GroupAuthorizationFailed = 19,
316    UnknownMemberId = 20,
317}
318
319impl ErrorCode {
320    pub fn is_retriable(self) -> bool {
321        matches!(
322            self,
323            ErrorCode::LeaderNotAvailable
324                | ErrorCode::NotLeaderForPartition
325                | ErrorCode::RequestTimedOut
326                | ErrorCode::NotEnoughReplicas
327                | ErrorCode::NotController
328        )
329    }
330}
331
332/// Encode a request to bytes
333pub fn encode_request(request: &ClusterRequest) -> Result<Vec<u8>> {
334    postcard::to_allocvec(request).map_err(|e| ClusterError::Serialization(e.to_string()))
335}
336
337/// Decode a request from bytes
338pub fn decode_request(bytes: &[u8]) -> Result<ClusterRequest> {
339    if bytes.len() > MAX_MESSAGE_SIZE {
340        return Err(ClusterError::MessageTooLarge {
341            size: bytes.len(),
342            max: MAX_MESSAGE_SIZE,
343        });
344    }
345    postcard::from_bytes(bytes).map_err(|e| ClusterError::Deserialization(e.to_string()))
346}
347
348/// Encode a response to bytes
349pub fn encode_response(response: &ClusterResponse) -> Result<Vec<u8>> {
350    postcard::to_allocvec(response).map_err(|e| ClusterError::Serialization(e.to_string()))
351}
352
353/// Decode a response from bytes
354pub fn decode_response(bytes: &[u8]) -> Result<ClusterResponse> {
355    if bytes.len() > MAX_MESSAGE_SIZE {
356        return Err(ClusterError::MessageTooLarge {
357            size: bytes.len(),
358            max: MAX_MESSAGE_SIZE,
359        });
360    }
361    postcard::from_bytes(bytes).map_err(|e| ClusterError::Deserialization(e.to_string()))
362}
363
364/// Frame a message with length prefix for TCP transmission
365pub fn frame_message(data: &[u8]) -> Vec<u8> {
366    let len = data.len() as u32;
367    let mut framed = Vec::with_capacity(4 + data.len());
368    framed.extend_from_slice(&len.to_be_bytes());
369    framed.extend_from_slice(data);
370    framed
371}
372
373/// Extract message length from frame header
374pub fn frame_length(header: &[u8; 4]) -> usize {
375    u32::from_be_bytes(*header) as usize
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    #[test]
382    fn test_request_roundtrip() {
383        let header = RequestHeader::new(42, "node-1".to_string());
384        let request = ClusterRequest::FetchMetadata {
385            header,
386            topics: Some(vec!["test-topic".to_string()]),
387        };
388
389        let bytes = encode_request(&request).unwrap();
390        let decoded = decode_request(&bytes).unwrap();
391
392        match decoded {
393            ClusterRequest::FetchMetadata { header, topics } => {
394                assert_eq!(header.correlation_id, 42);
395                assert_eq!(topics, Some(vec!["test-topic".to_string()]));
396            }
397            _ => panic!("Wrong request type"),
398        }
399    }
400
401    #[test]
402    fn test_response_roundtrip() {
403        let header = ResponseHeader::success(42);
404        let response = ClusterResponse::Metadata {
405            header,
406            cluster_id: "test-cluster".to_string(),
407            controller_id: Some("node-1".to_string()),
408            topics: vec![],
409            brokers: vec![],
410        };
411
412        let bytes = encode_response(&response).unwrap();
413        let decoded = decode_response(&bytes).unwrap();
414
415        match decoded {
416            ClusterResponse::Metadata {
417                header, cluster_id, ..
418            } => {
419                assert!(header.is_success());
420                assert_eq!(cluster_id, "test-cluster");
421            }
422            _ => panic!("Wrong response type"),
423        }
424    }
425
426    #[test]
427    fn test_framing() {
428        let data = b"hello world";
429        let framed = frame_message(data);
430
431        assert_eq!(framed.len(), 4 + data.len());
432
433        let mut header = [0u8; 4];
434        header.copy_from_slice(&framed[..4]);
435        assert_eq!(frame_length(&header), data.len());
436    }
437
438    #[test]
439    fn test_acks_conversion() {
440        assert_eq!(Acks::from_i8(0), Acks::None);
441        assert_eq!(Acks::from_i8(1), Acks::Leader);
442        assert_eq!(Acks::from_i8(-1), Acks::All);
443
444        assert_eq!(Acks::None.to_i8(), 0);
445        assert_eq!(Acks::Leader.to_i8(), 1);
446        assert_eq!(Acks::All.to_i8(), -1);
447    }
448
449    #[test]
450    fn test_version_validation_ok() {
451        let header = RequestHeader::new(1, "node-1".to_string());
452        assert!(header.validate_version().is_ok());
453    }
454
455    #[test]
456    fn test_version_validation_too_high() {
457        let mut header = RequestHeader::new(1, "node-1".to_string());
458        header.version = PROTOCOL_VERSION + 1;
459        let err = header.validate_version().unwrap_err();
460        assert_eq!(err.error_code, error_codes::UNSUPPORTED_VERSION);
461    }
462
463    #[test]
464    fn test_version_validation_zero() {
465        let mut header = RequestHeader::new(1, "node-1".to_string());
466        header.version = 0;
467        assert!(header.validate_version().is_err());
468    }
469}