1use crate::error::RpcError;
19use serde::{Deserialize, Serialize};
20
21pub const MAGIC: [u8; 2] = [0x4D, 0x52]; pub const VERSION: u8 = 1;
26
27pub const HEADER_SIZE: usize = 9;
29
30pub const MAX_PAYLOAD_SIZE: u32 = 16 * 1024 * 1024;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[repr(u8)]
36pub enum MessageType {
37 Request = 0x01,
38 Response = 0x02,
39 Error = 0x03,
40 Ping = 0x04,
41 Pong = 0x05,
42 Cancel = 0x06,
43}
44
45impl MessageType {
46 pub fn from_u8(v: u8) -> Result<Self, RpcError> {
47 match v {
48 0x01 => Ok(MessageType::Request),
49 0x02 => Ok(MessageType::Response),
50 0x03 => Ok(MessageType::Error),
51 0x04 => Ok(MessageType::Ping),
52 0x05 => Ok(MessageType::Pong),
53 0x06 => Ok(MessageType::Cancel),
54 _ => Err(RpcError::invalid_argument(format!(
55 "Unknown message type: 0x{:02X}",
56 v
57 ))),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct Flags(pub u8);
65
66impl Flags {
67 pub const NONE: Flags = Flags(0);
68 pub const COMPRESSED: Flags = Flags(1 << 0);
69 pub const ONE_WAY: Flags = Flags(1 << 1);
70
71 pub fn is_one_way(self) -> bool {
72 self.0 & Self::ONE_WAY.0 != 0
73 }
74
75 pub fn is_compressed(self) -> bool {
76 self.0 & Self::COMPRESSED.0 != 0
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct FrameHeader {
83 pub version: u8,
84 pub flags: Flags,
85 pub message_type: MessageType,
86 pub payload_len: u32,
87}
88
89impl FrameHeader {
90 pub fn encode(&self) -> [u8; HEADER_SIZE] {
92 let mut buf = [0u8; HEADER_SIZE];
93 buf[0] = MAGIC[0];
94 buf[1] = MAGIC[1];
95 buf[2] = self.version;
96 buf[3] = self.flags.0;
97 buf[4] = self.message_type as u8;
98 buf[5..9].copy_from_slice(&self.payload_len.to_le_bytes());
99 buf
100 }
101
102 pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result<Self, RpcError> {
104 if buf[0] != MAGIC[0] || buf[1] != MAGIC[1] {
105 return Err(RpcError::invalid_argument(format!(
106 "Invalid magic: [{:#04X}, {:#04X}]",
107 buf[0], buf[1]
108 )));
109 }
110
111 let version = buf[2];
112 if version != VERSION {
113 return Err(RpcError::invalid_argument(format!(
114 "Unsupported version: {}",
115 version
116 )));
117 }
118
119 let flags = Flags(buf[3]);
120 let message_type = MessageType::from_u8(buf[4])?;
121 let payload_len = u32::from_le_bytes([buf[5], buf[6], buf[7], buf[8]]);
122
123 if payload_len > MAX_PAYLOAD_SIZE {
124 return Err(RpcError::invalid_argument(format!(
125 "Payload too large: {} > {}",
126 payload_len, MAX_PAYLOAD_SIZE
127 )));
128 }
129
130 Ok(Self {
131 version,
132 flags,
133 message_type,
134 payload_len,
135 })
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct Frame {
142 pub header: FrameHeader,
143 pub payload: Vec<u8>,
144}
145
146impl Frame {
147 pub fn request(
149 request_id: u64,
150 service_id: u16,
151 method_id: u16,
152 args: Vec<u8>,
153 one_way: bool,
154 ) -> Self {
155 let mut payload = Vec::with_capacity(12 + args.len());
156 payload.extend_from_slice(&request_id.to_le_bytes());
157 payload.extend_from_slice(&service_id.to_le_bytes());
158 payload.extend_from_slice(&method_id.to_le_bytes());
159 payload.extend_from_slice(&args);
160
161 let flags = if one_way { Flags::ONE_WAY } else { Flags::NONE };
162
163 Frame {
164 header: FrameHeader {
165 version: VERSION,
166 flags,
167 message_type: MessageType::Request,
168 payload_len: payload.len() as u32,
169 },
170 payload,
171 }
172 }
173
174 pub fn response(request_id: u64, data: Vec<u8>) -> Self {
176 let mut payload = Vec::with_capacity(8 + data.len());
177 payload.extend_from_slice(&request_id.to_le_bytes());
178 payload.extend_from_slice(&data);
179
180 Frame {
181 header: FrameHeader {
182 version: VERSION,
183 flags: Flags::NONE,
184 message_type: MessageType::Response,
185 payload_len: payload.len() as u32,
186 },
187 payload,
188 }
189 }
190
191 pub fn error(request_id: u64, error_data: Vec<u8>) -> Self {
193 let mut payload = Vec::with_capacity(8 + error_data.len());
194 payload.extend_from_slice(&request_id.to_le_bytes());
195 payload.extend_from_slice(&error_data);
196
197 Frame {
198 header: FrameHeader {
199 version: VERSION,
200 flags: Flags::NONE,
201 message_type: MessageType::Error,
202 payload_len: payload.len() as u32,
203 },
204 payload,
205 }
206 }
207
208 pub fn ping() -> Self {
210 Frame {
211 header: FrameHeader {
212 version: VERSION,
213 flags: Flags::NONE,
214 message_type: MessageType::Ping,
215 payload_len: 0,
216 },
217 payload: Vec::new(),
218 }
219 }
220
221 pub fn pong() -> Self {
223 Frame {
224 header: FrameHeader {
225 version: VERSION,
226 flags: Flags::NONE,
227 message_type: MessageType::Pong,
228 payload_len: 0,
229 },
230 payload: Vec::new(),
231 }
232 }
233
234 pub fn encode(&self) -> Vec<u8> {
236 let header_bytes = self.header.encode();
237 let mut buf = Vec::with_capacity(HEADER_SIZE + self.payload.len());
238 buf.extend_from_slice(&header_bytes);
239 buf.extend_from_slice(&self.payload);
240 buf
241 }
242
243 pub fn parse_request_payload(&self) -> Result<(u64, u16, u16, &[u8]), RpcError> {
245 if self.payload.len() < 12 {
246 return Err(RpcError::invalid_argument("Request payload too short"));
247 }
248 let request_id = u64::from_le_bytes(self.payload[0..8].try_into().unwrap());
249 let service_id = u16::from_le_bytes(self.payload[8..10].try_into().unwrap());
250 let method_id = u16::from_le_bytes(self.payload[10..12].try_into().unwrap());
251 let args = &self.payload[12..];
252 Ok((request_id, service_id, method_id, args))
253 }
254
255 pub fn parse_response_payload(&self) -> Result<(u64, &[u8]), RpcError> {
257 if self.payload.len() < 8 {
258 return Err(RpcError::invalid_argument("Response payload too short"));
259 }
260 let request_id = u64::from_le_bytes(self.payload[0..8].try_into().unwrap());
261 let data = &self.payload[8..];
262 Ok((request_id, data))
263 }
264}
265
266pub fn parse_frames(buf: &[u8]) -> Result<(Vec<Frame>, usize), RpcError> {
271 let mut frames = Vec::new();
272 let mut offset = 0;
273
274 while offset + HEADER_SIZE <= buf.len() {
275 let header_bytes: &[u8; HEADER_SIZE] = buf[offset..offset + HEADER_SIZE]
276 .try_into()
277 .map_err(|_| RpcError::internal("Header slice conversion failed"))?;
278
279 let header = FrameHeader::decode(header_bytes)?;
280 let total_frame_size = HEADER_SIZE + header.payload_len as usize;
281
282 if offset + total_frame_size > buf.len() {
283 break;
285 }
286
287 let payload = buf[offset + HEADER_SIZE..offset + total_frame_size].to_vec();
288 frames.push(Frame { header, payload });
289 offset += total_frame_size;
290 }
291
292 Ok((frames, offset))
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_header_roundtrip() {
301 let header = FrameHeader {
302 version: VERSION,
303 flags: Flags::NONE,
304 message_type: MessageType::Request,
305 payload_len: 42,
306 };
307 let encoded = header.encode();
308 let decoded = FrameHeader::decode(&encoded).unwrap();
309 assert_eq!(decoded.version, VERSION);
310 assert_eq!(decoded.flags, Flags::NONE);
311 assert_eq!(decoded.message_type, MessageType::Request);
312 assert_eq!(decoded.payload_len, 42);
313 }
314
315 #[test]
316 fn test_request_frame_roundtrip() {
317 let frame = Frame::request(123, 1, 2, vec![10, 20, 30], false);
318 let bytes = frame.encode();
319 let (frames, consumed) = parse_frames(&bytes).unwrap();
320 assert_eq!(consumed, bytes.len());
321 assert_eq!(frames.len(), 1);
322
323 let (req_id, svc_id, method_id, args) = frames[0].parse_request_payload().unwrap();
324 assert_eq!(req_id, 123);
325 assert_eq!(svc_id, 1);
326 assert_eq!(method_id, 2);
327 assert_eq!(args, &[10, 20, 30]);
328 }
329
330 #[test]
331 fn test_response_frame_roundtrip() {
332 let frame = Frame::response(456, vec![1, 2, 3]);
333 let bytes = frame.encode();
334 let (frames, _) = parse_frames(&bytes).unwrap();
335 let (req_id, data) = frames[0].parse_response_payload().unwrap();
336 assert_eq!(req_id, 456);
337 assert_eq!(data, &[1, 2, 3]);
338 }
339
340 #[test]
341 fn test_multiple_frames() {
342 let f1 = Frame::request(1, 0, 0, vec![0xAA], false);
343 let f2 = Frame::response(1, vec![0xBB]);
344 let mut bytes = f1.encode();
345 bytes.extend_from_slice(&f2.encode());
346
347 let (frames, consumed) = parse_frames(&bytes).unwrap();
348 assert_eq!(consumed, bytes.len());
349 assert_eq!(frames.len(), 2);
350 assert_eq!(frames[0].header.message_type, MessageType::Request);
351 assert_eq!(frames[1].header.message_type, MessageType::Response);
352 }
353
354 #[test]
355 fn test_partial_frame() {
356 let frame = Frame::request(1, 0, 0, vec![0xAA; 100], false);
357 let bytes = frame.encode();
358 let partial = &bytes[..bytes.len() / 2];
360 let (frames, consumed) = parse_frames(partial).unwrap();
361 assert_eq!(frames.len(), 0);
362 assert_eq!(consumed, 0);
363 }
364
365 #[test]
366 fn test_ping_pong() {
367 let ping = Frame::ping();
368 let pong = Frame::pong();
369 assert_eq!(ping.header.message_type, MessageType::Ping);
370 assert_eq!(pong.header.message_type, MessageType::Pong);
371 assert_eq!(ping.payload.len(), 0);
372 assert_eq!(pong.payload.len(), 0);
373 }
374
375 #[test]
376 fn test_invalid_magic() {
377 let mut buf = [0u8; HEADER_SIZE];
378 buf[0] = 0xFF;
379 buf[1] = 0xFF;
380 assert!(FrameHeader::decode(&buf).is_err());
381 }
382
383 #[test]
384 fn test_one_way_flag() {
385 let frame = Frame::request(1, 0, 0, vec![], true);
386 assert!(frame.header.flags.is_one_way());
387
388 let frame = Frame::request(1, 0, 0, vec![], false);
389 assert!(!frame.header.flags.is_one_way());
390 }
391}