Skip to main content

mill_rpc_core/
protocol.rs

1//! Wire protocol: frame format for Mill-RPC.
2//!
3//! ```text
4//! +--------+--------+-------+--------+-----------+---------+
5//! | Magic  | Version| Flags | MsgType| PayloadLen| Payload |
6//! | 2B     | 1B     | 1B    | 1B     | 4B (LE)   | N bytes |
7//! +--------+--------+-------+--------+-----------+---------+
8//! ```
9//!
10//! Request payload:
11//! ```text
12//! +------------+-----------+-----------+---------+
13//! | RequestID  | ServiceID | MethodID  | Args    |
14//! | 8B (LE)    | 2B (LE)   | 2B (LE)   | N bytes |
15//! +------------+-----------+-----------+---------+
16//! ```
17
18use crate::error::RpcError;
19use serde::{Deserialize, Serialize};
20
21/// Magic bytes identifying Mill-RPC frames.
22pub const MAGIC: [u8; 2] = [0x4D, 0x52]; // "MR"
23
24/// Current protocol version.
25pub const VERSION: u8 = 1;
26
27/// Header size in bytes (magic:2 + version:1 + flags:1 + msg_type:1 + payload_len:4 = 9).
28pub const HEADER_SIZE: usize = 9;
29
30/// Maximum payload size (16 MB).
31pub const MAX_PAYLOAD_SIZE: u32 = 16 * 1024 * 1024;
32
33/// Message types in the wire protocol.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[repr(u8)]
36pub enum MessageType {
37    Request = 0x01,
38    Response = 0x02,
39    Error = 0x03,
40    Ping = 0x04,
41    Pong = 0x05,
42    Cancel = 0x06,
43}
44
45impl MessageType {
46    pub fn from_u8(v: u8) -> Result<Self, RpcError> {
47        match v {
48            0x01 => Ok(MessageType::Request),
49            0x02 => Ok(MessageType::Response),
50            0x03 => Ok(MessageType::Error),
51            0x04 => Ok(MessageType::Ping),
52            0x05 => Ok(MessageType::Pong),
53            0x06 => Ok(MessageType::Cancel),
54            _ => Err(RpcError::invalid_argument(format!(
55                "Unknown message type: 0x{:02X}",
56                v
57            ))),
58        }
59    }
60}
61
62/// Bit flags for frame options.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct Flags(pub u8);
65
66impl Flags {
67    pub const NONE: Flags = Flags(0);
68    pub const COMPRESSED: Flags = Flags(1 << 0);
69    pub const ONE_WAY: Flags = Flags(1 << 1);
70
71    pub fn is_one_way(self) -> bool {
72        self.0 & Self::ONE_WAY.0 != 0
73    }
74
75    pub fn is_compressed(self) -> bool {
76        self.0 & Self::COMPRESSED.0 != 0
77    }
78}
79
80/// Parsed frame header.
81#[derive(Debug, Clone)]
82pub struct FrameHeader {
83    pub version: u8,
84    pub flags: Flags,
85    pub message_type: MessageType,
86    pub payload_len: u32,
87}
88
89impl FrameHeader {
90    /// Encode the header into a 9-byte array.
91    pub fn encode(&self) -> [u8; HEADER_SIZE] {
92        let mut buf = [0u8; HEADER_SIZE];
93        buf[0] = MAGIC[0];
94        buf[1] = MAGIC[1];
95        buf[2] = self.version;
96        buf[3] = self.flags.0;
97        buf[4] = self.message_type as u8;
98        buf[5..9].copy_from_slice(&self.payload_len.to_le_bytes());
99        buf
100    }
101
102    /// Decode a header from a 9-byte slice.
103    pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result<Self, RpcError> {
104        if buf[0] != MAGIC[0] || buf[1] != MAGIC[1] {
105            return Err(RpcError::invalid_argument(format!(
106                "Invalid magic: [{:#04X}, {:#04X}]",
107                buf[0], buf[1]
108            )));
109        }
110
111        let version = buf[2];
112        if version != VERSION {
113            return Err(RpcError::invalid_argument(format!(
114                "Unsupported version: {}",
115                version
116            )));
117        }
118
119        let flags = Flags(buf[3]);
120        let message_type = MessageType::from_u8(buf[4])?;
121        let payload_len = u32::from_le_bytes([buf[5], buf[6], buf[7], buf[8]]);
122
123        if payload_len > MAX_PAYLOAD_SIZE {
124            return Err(RpcError::invalid_argument(format!(
125                "Payload too large: {} > {}",
126                payload_len, MAX_PAYLOAD_SIZE
127            )));
128        }
129
130        Ok(Self {
131            version,
132            flags,
133            message_type,
134            payload_len,
135        })
136    }
137}
138
139/// A complete frame (header + payload).
140#[derive(Debug, Clone)]
141pub struct Frame {
142    pub header: FrameHeader,
143    pub payload: Vec<u8>,
144}
145
146impl Frame {
147    /// Create a new request frame.
148    pub fn request(
149        request_id: u64,
150        service_id: u16,
151        method_id: u16,
152        args: Vec<u8>,
153        one_way: bool,
154    ) -> Self {
155        let mut payload = Vec::with_capacity(12 + args.len());
156        payload.extend_from_slice(&request_id.to_le_bytes());
157        payload.extend_from_slice(&service_id.to_le_bytes());
158        payload.extend_from_slice(&method_id.to_le_bytes());
159        payload.extend_from_slice(&args);
160
161        let flags = if one_way { Flags::ONE_WAY } else { Flags::NONE };
162
163        Frame {
164            header: FrameHeader {
165                version: VERSION,
166                flags,
167                message_type: MessageType::Request,
168                payload_len: payload.len() as u32,
169            },
170            payload,
171        }
172    }
173
174    /// Create a response frame.
175    pub fn response(request_id: u64, data: Vec<u8>) -> Self {
176        let mut payload = Vec::with_capacity(8 + data.len());
177        payload.extend_from_slice(&request_id.to_le_bytes());
178        payload.extend_from_slice(&data);
179
180        Frame {
181            header: FrameHeader {
182                version: VERSION,
183                flags: Flags::NONE,
184                message_type: MessageType::Response,
185                payload_len: payload.len() as u32,
186            },
187            payload,
188        }
189    }
190
191    /// Create an error frame.
192    pub fn error(request_id: u64, error_data: Vec<u8>) -> Self {
193        let mut payload = Vec::with_capacity(8 + error_data.len());
194        payload.extend_from_slice(&request_id.to_le_bytes());
195        payload.extend_from_slice(&error_data);
196
197        Frame {
198            header: FrameHeader {
199                version: VERSION,
200                flags: Flags::NONE,
201                message_type: MessageType::Error,
202                payload_len: payload.len() as u32,
203            },
204            payload,
205        }
206    }
207
208    /// Create a ping frame.
209    pub fn ping() -> Self {
210        Frame {
211            header: FrameHeader {
212                version: VERSION,
213                flags: Flags::NONE,
214                message_type: MessageType::Ping,
215                payload_len: 0,
216            },
217            payload: Vec::new(),
218        }
219    }
220
221    /// Create a pong frame.
222    pub fn pong() -> Self {
223        Frame {
224            header: FrameHeader {
225                version: VERSION,
226                flags: Flags::NONE,
227                message_type: MessageType::Pong,
228                payload_len: 0,
229            },
230            payload: Vec::new(),
231        }
232    }
233
234    /// Encode the full frame to bytes.
235    pub fn encode(&self) -> Vec<u8> {
236        let header_bytes = self.header.encode();
237        let mut buf = Vec::with_capacity(HEADER_SIZE + self.payload.len());
238        buf.extend_from_slice(&header_bytes);
239        buf.extend_from_slice(&self.payload);
240        buf
241    }
242
243    /// Parse request payload fields (request_id, service_id, method_id, args).
244    pub fn parse_request_payload(&self) -> Result<(u64, u16, u16, &[u8]), RpcError> {
245        if self.payload.len() < 12 {
246            return Err(RpcError::invalid_argument("Request payload too short"));
247        }
248        let request_id = u64::from_le_bytes(self.payload[0..8].try_into().unwrap());
249        let service_id = u16::from_le_bytes(self.payload[8..10].try_into().unwrap());
250        let method_id = u16::from_le_bytes(self.payload[10..12].try_into().unwrap());
251        let args = &self.payload[12..];
252        Ok((request_id, service_id, method_id, args))
253    }
254
255    /// Parse response payload fields (request_id, data).
256    pub fn parse_response_payload(&self) -> Result<(u64, &[u8]), RpcError> {
257        if self.payload.len() < 8 {
258            return Err(RpcError::invalid_argument("Response payload too short"));
259        }
260        let request_id = u64::from_le_bytes(self.payload[0..8].try_into().unwrap());
261        let data = &self.payload[8..];
262        Ok((request_id, data))
263    }
264}
265
266/// Reads frames from a byte buffer. Returns parsed frames and the number of bytes consumed.
267///
268/// This is a streaming parser: it handles partial frames by returning only
269/// complete frames and leaving remaining bytes unconsumed.
270pub fn parse_frames(buf: &[u8]) -> Result<(Vec<Frame>, usize), RpcError> {
271    let mut frames = Vec::new();
272    let mut offset = 0;
273
274    while offset + HEADER_SIZE <= buf.len() {
275        let header_bytes: &[u8; HEADER_SIZE] = buf[offset..offset + HEADER_SIZE]
276            .try_into()
277            .map_err(|_| RpcError::internal("Header slice conversion failed"))?;
278
279        let header = FrameHeader::decode(header_bytes)?;
280        let total_frame_size = HEADER_SIZE + header.payload_len as usize;
281
282        if offset + total_frame_size > buf.len() {
283            // Incomplete frame, wait for more data.
284            break;
285        }
286
287        let payload = buf[offset + HEADER_SIZE..offset + total_frame_size].to_vec();
288        frames.push(Frame { header, payload });
289        offset += total_frame_size;
290    }
291
292    Ok((frames, offset))
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_header_roundtrip() {
301        let header = FrameHeader {
302            version: VERSION,
303            flags: Flags::NONE,
304            message_type: MessageType::Request,
305            payload_len: 42,
306        };
307        let encoded = header.encode();
308        let decoded = FrameHeader::decode(&encoded).unwrap();
309        assert_eq!(decoded.version, VERSION);
310        assert_eq!(decoded.flags, Flags::NONE);
311        assert_eq!(decoded.message_type, MessageType::Request);
312        assert_eq!(decoded.payload_len, 42);
313    }
314
315    #[test]
316    fn test_request_frame_roundtrip() {
317        let frame = Frame::request(123, 1, 2, vec![10, 20, 30], false);
318        let bytes = frame.encode();
319        let (frames, consumed) = parse_frames(&bytes).unwrap();
320        assert_eq!(consumed, bytes.len());
321        assert_eq!(frames.len(), 1);
322
323        let (req_id, svc_id, method_id, args) = frames[0].parse_request_payload().unwrap();
324        assert_eq!(req_id, 123);
325        assert_eq!(svc_id, 1);
326        assert_eq!(method_id, 2);
327        assert_eq!(args, &[10, 20, 30]);
328    }
329
330    #[test]
331    fn test_response_frame_roundtrip() {
332        let frame = Frame::response(456, vec![1, 2, 3]);
333        let bytes = frame.encode();
334        let (frames, _) = parse_frames(&bytes).unwrap();
335        let (req_id, data) = frames[0].parse_response_payload().unwrap();
336        assert_eq!(req_id, 456);
337        assert_eq!(data, &[1, 2, 3]);
338    }
339
340    #[test]
341    fn test_multiple_frames() {
342        let f1 = Frame::request(1, 0, 0, vec![0xAA], false);
343        let f2 = Frame::response(1, vec![0xBB]);
344        let mut bytes = f1.encode();
345        bytes.extend_from_slice(&f2.encode());
346
347        let (frames, consumed) = parse_frames(&bytes).unwrap();
348        assert_eq!(consumed, bytes.len());
349        assert_eq!(frames.len(), 2);
350        assert_eq!(frames[0].header.message_type, MessageType::Request);
351        assert_eq!(frames[1].header.message_type, MessageType::Response);
352    }
353
354    #[test]
355    fn test_partial_frame() {
356        let frame = Frame::request(1, 0, 0, vec![0xAA; 100], false);
357        let bytes = frame.encode();
358        // Only give half the bytes
359        let partial = &bytes[..bytes.len() / 2];
360        let (frames, consumed) = parse_frames(partial).unwrap();
361        assert_eq!(frames.len(), 0);
362        assert_eq!(consumed, 0);
363    }
364
365    #[test]
366    fn test_ping_pong() {
367        let ping = Frame::ping();
368        let pong = Frame::pong();
369        assert_eq!(ping.header.message_type, MessageType::Ping);
370        assert_eq!(pong.header.message_type, MessageType::Pong);
371        assert_eq!(ping.payload.len(), 0);
372        assert_eq!(pong.payload.len(), 0);
373    }
374
375    #[test]
376    fn test_invalid_magic() {
377        let mut buf = [0u8; HEADER_SIZE];
378        buf[0] = 0xFF;
379        buf[1] = 0xFF;
380        assert!(FrameHeader::decode(&buf).is_err());
381    }
382
383    #[test]
384    fn test_one_way_flag() {
385        let frame = Frame::request(1, 0, 0, vec![], true);
386        assert!(frame.header.flags.is_one_way());
387
388        let frame = Frame::request(1, 0, 0, vec![], false);
389        assert!(!frame.header.flags.is_one_way());
390    }
391}