rocketmq_controller/rpc/
codec.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use bytes::Buf;
19use bytes::BufMut;
20use bytes::BytesMut;
21use serde::Deserialize;
22use serde::Serialize;
23use tokio_util::codec::Decoder;
24use tokio_util::codec::Encoder;
25use tracing::debug;
26use tracing::trace;
27
28use crate::error::ControllerError;
29use crate::error::Result;
30use crate::processor::RequestType;
31
32/// RPC request message
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RpcRequest {
35    /// Request ID for correlation
36    pub request_id: u64,
37
38    /// Request type
39    pub request_type: RequestType,
40
41    /// Request payload (JSON-encoded)
42    pub payload: Vec<u8>,
43}
44
45/// RPC response message
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct RpcResponse {
48    /// Request ID for correlation
49    pub request_id: u64,
50
51    /// Success flag
52    pub success: bool,
53
54    /// Error message if failed
55    pub error: Option<String>,
56
57    /// Response payload (JSON-encoded)
58    pub payload: Vec<u8>,
59}
60
61/// RPC message codec
62///
63/// Protocol format:
64/// ```text
65/// +--------+--------+--------+--------+
66/// | Length (4 bytes, big-endian)      |
67/// +--------+--------+--------+--------+
68/// | JSON-encoded message              |
69/// | ...                               |
70/// +-----------------------------------+
71/// ```
72pub struct RpcCodec;
73
74impl RpcCodec {
75    /// Maximum frame size (16MB)
76    const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
77
78    /// Create a new RPC codec
79    pub fn new() -> Self {
80        Self
81    }
82}
83
84impl Default for RpcCodec {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl Decoder for RpcCodec {
91    type Item = RpcRequest;
92    type Error = ControllerError;
93
94    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
95        // Check if we have enough bytes for the length prefix
96        if src.len() < 4 {
97            trace!("Not enough bytes for length prefix: {}", src.len());
98            return Ok(None);
99        }
100
101        // Read the length prefix
102        let mut length_bytes = [0u8; 4];
103        length_bytes.copy_from_slice(&src[..4]);
104        let length = u32::from_be_bytes(length_bytes) as usize;
105
106        trace!("RPC request length: {}", length);
107
108        // Validate length
109        if length > Self::MAX_FRAME_SIZE {
110            return Err(ControllerError::InvalidRequest(format!(
111                "Frame size {} exceeds maximum {}",
112                length,
113                Self::MAX_FRAME_SIZE
114            )));
115        }
116
117        // Check if we have the complete frame
118        if src.len() < 4 + length {
119            trace!("Incomplete frame: have {}, need {}", src.len(), 4 + length);
120            // Reserve space for the rest of the frame
121            src.reserve(4 + length - src.len());
122            return Ok(None);
123        }
124
125        // Skip the length prefix
126        src.advance(4);
127
128        // Read the frame data
129        let data = src.split_to(length);
130
131        // Deserialize the request
132        let request: RpcRequest = serde_json::from_slice(&data)
133            .map_err(|e| ControllerError::InvalidRequest(e.to_string()))?;
134
135        debug!(
136            "Decoded RPC request: id={}, type={:?}",
137            request.request_id, request.request_type
138        );
139
140        Ok(Some(request))
141    }
142}
143
144impl Encoder<RpcResponse> for RpcCodec {
145    type Error = ControllerError;
146
147    fn encode(&mut self, item: RpcResponse, dst: &mut BytesMut) -> Result<()> {
148        debug!(
149            "Encoding RPC response: id={}, success={}",
150            item.request_id, item.success
151        );
152
153        // Serialize the response
154        let data = serde_json::to_vec(&item)
155            .map_err(|e| ControllerError::SerializationError(e.to_string()))?;
156
157        // Check size
158        if data.len() > Self::MAX_FRAME_SIZE {
159            return Err(ControllerError::SerializationError(format!(
160                "Response size {} exceeds maximum {}",
161                data.len(),
162                Self::MAX_FRAME_SIZE
163            )));
164        }
165
166        // Write length prefix
167        let length = data.len() as u32;
168        dst.reserve(4 + data.len());
169        dst.put_u32(length);
170
171        // Write data
172        dst.put_slice(&data);
173
174        trace!("Encoded RPC response: {} bytes", 4 + data.len());
175
176        Ok(())
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_rpc_request_serialization() {
186        let request = RpcRequest {
187            request_id: 123,
188            request_type: RequestType::RegisterBroker,
189            payload: b"test payload".to_vec(),
190        };
191
192        let serialized = serde_json::to_vec(&request).unwrap();
193        let deserialized: RpcRequest = serde_json::from_slice(&serialized).unwrap();
194
195        assert_eq!(deserialized.request_id, request.request_id);
196        assert_eq!(deserialized.request_type, request.request_type);
197        assert_eq!(deserialized.payload, request.payload);
198    }
199
200    #[test]
201    fn test_rpc_response_serialization() {
202        let response = RpcResponse {
203            request_id: 456,
204            success: true,
205            error: None,
206            payload: b"response payload".to_vec(),
207        };
208
209        let serialized = serde_json::to_vec(&response).unwrap();
210        let deserialized: RpcResponse = serde_json::from_slice(&serialized).unwrap();
211
212        assert_eq!(deserialized.request_id, response.request_id);
213        assert_eq!(deserialized.success, response.success);
214        assert_eq!(deserialized.error, response.error);
215        assert_eq!(deserialized.payload, response.payload);
216    }
217
218    #[test]
219    fn test_codec_decode_incomplete() {
220        let mut codec = RpcCodec::new();
221        let mut buf = BytesMut::new();
222
223        // Write only 2 bytes of length prefix
224        buf.put_u16(0x00);
225
226        let result = codec.decode(&mut buf);
227        assert!(result.is_ok());
228        assert!(result.unwrap().is_none());
229    }
230
231    #[test]
232    fn test_codec_encode_decode() {
233        let mut codec = RpcCodec::new();
234
235        // Create a request
236        let request = RpcRequest {
237            request_id: 789,
238            request_type: RequestType::BrokerHeartbeat,
239            payload: b"heartbeat data".to_vec(),
240        };
241
242        // Serialize manually to get the data
243        let request_data = serde_json::to_vec(&request).unwrap();
244
245        // Create a buffer with length prefix + data
246        let mut encode_buf = BytesMut::new();
247        encode_buf.put_u32(request_data.len() as u32);
248        encode_buf.put_slice(&request_data);
249
250        // Decode
251        let decoded = codec.decode(&mut encode_buf).unwrap();
252        assert!(decoded.is_some());
253
254        let decoded_request = decoded.unwrap();
255        assert_eq!(decoded_request.request_id, request.request_id);
256        assert_eq!(decoded_request.request_type, request.request_type);
257        assert_eq!(decoded_request.payload, request.payload);
258    }
259
260    #[test]
261    fn test_codec_encode_response() {
262        let mut codec = RpcCodec::new();
263        let mut buf = BytesMut::new();
264
265        let response = RpcResponse {
266            request_id: 999,
267            success: true,
268            error: None,
269            payload: b"success response".to_vec(),
270        };
271
272        let result = codec.encode(response.clone(), &mut buf);
273        assert!(result.is_ok());
274
275        // Check that length prefix is present
276        assert!(buf.len() >= 4);
277
278        // Read length
279        let length = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
280        assert_eq!(buf.len(), 4 + length);
281    }
282}