rocketmq_controller/rpc/
codec.rs1use bytes::Buf;
19use bytes::BufMut;
20use bytes::BytesMut;
21use serde::Deserialize;
22use serde::Serialize;
23use tokio_util::codec::Decoder;
24use tokio_util::codec::Encoder;
25use tracing::debug;
26use tracing::trace;
27
28use crate::error::ControllerError;
29use crate::error::Result;
30use crate::processor::RequestType;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RpcRequest {
35 pub request_id: u64,
37
38 pub request_type: RequestType,
40
41 pub payload: Vec<u8>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct RpcResponse {
48 pub request_id: u64,
50
51 pub success: bool,
53
54 pub error: Option<String>,
56
57 pub payload: Vec<u8>,
59}
60
61pub struct RpcCodec;
73
74impl RpcCodec {
75 const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
77
78 pub fn new() -> Self {
80 Self
81 }
82}
83
84impl Default for RpcCodec {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl Decoder for RpcCodec {
91 type Item = RpcRequest;
92 type Error = ControllerError;
93
94 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
95 if src.len() < 4 {
97 trace!("Not enough bytes for length prefix: {}", src.len());
98 return Ok(None);
99 }
100
101 let mut length_bytes = [0u8; 4];
103 length_bytes.copy_from_slice(&src[..4]);
104 let length = u32::from_be_bytes(length_bytes) as usize;
105
106 trace!("RPC request length: {}", length);
107
108 if length > Self::MAX_FRAME_SIZE {
110 return Err(ControllerError::InvalidRequest(format!(
111 "Frame size {} exceeds maximum {}",
112 length,
113 Self::MAX_FRAME_SIZE
114 )));
115 }
116
117 if src.len() < 4 + length {
119 trace!("Incomplete frame: have {}, need {}", src.len(), 4 + length);
120 src.reserve(4 + length - src.len());
122 return Ok(None);
123 }
124
125 src.advance(4);
127
128 let data = src.split_to(length);
130
131 let request: RpcRequest = serde_json::from_slice(&data)
133 .map_err(|e| ControllerError::InvalidRequest(e.to_string()))?;
134
135 debug!(
136 "Decoded RPC request: id={}, type={:?}",
137 request.request_id, request.request_type
138 );
139
140 Ok(Some(request))
141 }
142}
143
144impl Encoder<RpcResponse> for RpcCodec {
145 type Error = ControllerError;
146
147 fn encode(&mut self, item: RpcResponse, dst: &mut BytesMut) -> Result<()> {
148 debug!(
149 "Encoding RPC response: id={}, success={}",
150 item.request_id, item.success
151 );
152
153 let data = serde_json::to_vec(&item)
155 .map_err(|e| ControllerError::SerializationError(e.to_string()))?;
156
157 if data.len() > Self::MAX_FRAME_SIZE {
159 return Err(ControllerError::SerializationError(format!(
160 "Response size {} exceeds maximum {}",
161 data.len(),
162 Self::MAX_FRAME_SIZE
163 )));
164 }
165
166 let length = data.len() as u32;
168 dst.reserve(4 + data.len());
169 dst.put_u32(length);
170
171 dst.put_slice(&data);
173
174 trace!("Encoded RPC response: {} bytes", 4 + data.len());
175
176 Ok(())
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_rpc_request_serialization() {
186 let request = RpcRequest {
187 request_id: 123,
188 request_type: RequestType::RegisterBroker,
189 payload: b"test payload".to_vec(),
190 };
191
192 let serialized = serde_json::to_vec(&request).unwrap();
193 let deserialized: RpcRequest = serde_json::from_slice(&serialized).unwrap();
194
195 assert_eq!(deserialized.request_id, request.request_id);
196 assert_eq!(deserialized.request_type, request.request_type);
197 assert_eq!(deserialized.payload, request.payload);
198 }
199
200 #[test]
201 fn test_rpc_response_serialization() {
202 let response = RpcResponse {
203 request_id: 456,
204 success: true,
205 error: None,
206 payload: b"response payload".to_vec(),
207 };
208
209 let serialized = serde_json::to_vec(&response).unwrap();
210 let deserialized: RpcResponse = serde_json::from_slice(&serialized).unwrap();
211
212 assert_eq!(deserialized.request_id, response.request_id);
213 assert_eq!(deserialized.success, response.success);
214 assert_eq!(deserialized.error, response.error);
215 assert_eq!(deserialized.payload, response.payload);
216 }
217
218 #[test]
219 fn test_codec_decode_incomplete() {
220 let mut codec = RpcCodec::new();
221 let mut buf = BytesMut::new();
222
223 buf.put_u16(0x00);
225
226 let result = codec.decode(&mut buf);
227 assert!(result.is_ok());
228 assert!(result.unwrap().is_none());
229 }
230
231 #[test]
232 fn test_codec_encode_decode() {
233 let mut codec = RpcCodec::new();
234
235 let request = RpcRequest {
237 request_id: 789,
238 request_type: RequestType::BrokerHeartbeat,
239 payload: b"heartbeat data".to_vec(),
240 };
241
242 let request_data = serde_json::to_vec(&request).unwrap();
244
245 let mut encode_buf = BytesMut::new();
247 encode_buf.put_u32(request_data.len() as u32);
248 encode_buf.put_slice(&request_data);
249
250 let decoded = codec.decode(&mut encode_buf).unwrap();
252 assert!(decoded.is_some());
253
254 let decoded_request = decoded.unwrap();
255 assert_eq!(decoded_request.request_id, request.request_id);
256 assert_eq!(decoded_request.request_type, request.request_type);
257 assert_eq!(decoded_request.payload, request.payload);
258 }
259
260 #[test]
261 fn test_codec_encode_response() {
262 let mut codec = RpcCodec::new();
263 let mut buf = BytesMut::new();
264
265 let response = RpcResponse {
266 request_id: 999,
267 success: true,
268 error: None,
269 payload: b"success response".to_vec(),
270 };
271
272 let result = codec.encode(response.clone(), &mut buf);
273 assert!(result.is_ok());
274
275 assert!(buf.len() >= 4);
277
278 let length = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
280 assert_eq!(buf.len(), 4 + length);
281 }
282}