Skip to main content

razor_stream/
proto.rs

1//! # The protocol
2//!
3//! The packet starts with a fixed-length header and is followed by a variable-length body.
4//! An [RpcAction] represents the type of packet.
5//! The action type is either numeric or a string.
6//!
7//! The request body contains a mandatory structured message and optional blob data.
8//!
9//! The response for each request either returns successfully with an optional structured message and
10//! optional blob data (the response can be empty), or it returns with an RpcError. The error type can
11//! be numeric (like a Unix errno), text (for user-customized errors), or a statically predefined error
12//! string (for errors that occur during socket communication or encoding/decoding).
13//!
14//! ## Request
15//!
16//! Fixed length of `ReqHead` = 32B
17//!
18//! | Field     | Size | Description                               |
19//! |-----------|------|-------------------------------------------|
20//! | `magic`   | 2B   | Magic number                              |
21//! | `ver`     | 1B   | Protocol version                          |
22//! | `format`  | 1B   | Encoder-decoder format                    |
23//! | `action`  | 4B   | Action type (numeric or length if string) |
24//! | `seq`     | 8B   | Increased ID of request message           |
25//! | `client_id`| 8B   | Client identifier                         |
26//! | `msg_len` | 4B   | Structured message length                 |
27//! | `blob_len`| 4B   | Unstructured message (blob) length        |
28//!
29//! Variable length message components:
30//! - `action_len` (if `action` is a string)
31//! - `msg_len`
32//! - `blob_len`
33//!
34//! ## Response:
35//!
36//! Fixed length of `RespHead` = 20B
37//!
38//! | Field     | Size | Description                               |
39//! |-----------|------|-------------------------------------------|
40//! | `magic`   | 2B   | Magic number                              |
41//! | `ver`     | 1B   | Protocol version                          |
42//! | `has_err` | 1B   | Error flag                                |
43//! | `seq`     | 8B   | Increased ID of request message           |
44//! | `msg_len` | 4B   | Structured message length or errno        |
45//! | `blob_len`| 4B   | Unstructured message (blob) length        |
46//!
47//! Variable length message components:
48//! - `msg_len`
49//! - `blob_len`
50
51use crate::client::task::ClientTask;
52use crate::server::task::ServerTaskEncode;
53use crate::{Codec, error::*};
54use std::fmt;
55use std::io::Write;
56use std::mem::size_of;
57use std::ptr::addr_of;
58use zerocopy::byteorder::little_endian;
59use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
60
61pub const PING_ACTION: u32 = 0;
62
63pub const RPC_MAGIC: little_endian::U16 = little_endian::U16::new(19749);
64pub const U32_HIGH_MASK: u32 = 1 << 31;
65
66pub const RESP_FLAG_HAS_ERRNO: u8 = 1;
67pub const RESP_FLAG_HAS_ERR_STRING: u8 = 2;
68pub const RPC_VERSION_1: u8 = 1;
69
70#[derive(Debug, Clone, Copy, PartialEq)]
71pub enum RpcAction<'a> {
72    Str(&'a str),
73    Num(i32),
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Hash)]
77pub enum RpcActionOwned {
78    Str(String),
79    Num(i32),
80}
81
82impl<'a> From<RpcAction<'a>> for RpcActionOwned {
83    fn from(action: RpcAction<'a>) -> Self {
84        match action {
85            RpcAction::Str(s) => RpcActionOwned::Str(s.to_string()),
86            RpcAction::Num(n) => RpcActionOwned::Num(n),
87        }
88    }
89}
90
91impl RpcActionOwned {
92    pub fn to_action<'a>(&'a self) -> RpcAction<'a> {
93        match self {
94            RpcActionOwned::Str(s) => RpcAction::Str(s.as_str()),
95            RpcActionOwned::Num(n) => RpcAction::Num(*n),
96        }
97    }
98}
99
100/// Fixed-length header for request
101#[derive(FromBytes, IntoBytes, Unaligned, Immutable, KnownLayout, PartialEq, Clone)]
102#[repr(C, packed)]
103pub struct ReqHead {
104    pub magic: little_endian::U16,
105    pub ver: u8,
106    pub format: u8,
107    /// encoder-decoder format
108
109    /// If highest bit is 0, the rest will be i32 action_num.
110    ///
111    /// If highest is 1, the lower bit will be i32 action_len.
112    pub action: little_endian::U32,
113
114    /// Increased ID of request msg in the socket connection.
115    pub seq: little_endian::U64,
116
117    pub client_id: little_endian::U64,
118    /// structured msg len
119    pub msg_len: little_endian::U32,
120    /// unstructured msg
121    pub blob_len: little_endian::U32,
122}
123
124pub const RPC_REQ_HEADER_LEN: usize = size_of::<ReqHead>();
125
126impl ReqHead {
127    #[inline(always)]
128    pub fn encode_ping(buf: &mut Vec<u8>, client_id: u64, seq: u64) {
129        debug_assert!(buf.capacity() > RPC_REQ_HEADER_LEN);
130        unsafe { buf.set_len(RPC_REQ_HEADER_LEN) };
131        Self::_write_head(buf, client_id, PING_ACTION, seq, 0, 0);
132    }
133
134    #[inline(always)]
135    fn _write_head(
136        buf: &mut [u8], client_id: u64, action: u32, seq: u64, msg_len: u32, blob_len: i32,
137    ) {
138        // NOTE: We are directly init ReqHead on the buffer with unsafe, check carefully don't miss
139        // a field
140        let header: &mut Self =
141            Self::mut_from_bytes(&mut buf[0..RPC_REQ_HEADER_LEN]).expect("fill header buf");
142        header.magic = RPC_MAGIC;
143        header.ver = RPC_VERSION_1;
144        header.format = 0;
145        header.action.set(action);
146        header.seq.set(seq);
147        header.client_id.set(client_id);
148        header.msg_len.set(msg_len);
149        header.blob_len.set(blob_len as u32);
150    }
151
152    /// write header, action, msg into `buf`, return reference to blob if there's any
153    #[inline(always)]
154    pub fn encode<'a, T, C>(
155        codec: &C, buf: &mut Vec<u8>, client_id: u64, task: &'a T,
156    ) -> Result<Option<&'a [u8]>, ()>
157    where
158        T: ClientTask,
159        C: Codec,
160    {
161        debug_assert!(buf.capacity() >= RPC_REQ_HEADER_LEN);
162        // Leave a room at the beginning of buffer for ReqHead
163        unsafe { buf.set_len(RPC_REQ_HEADER_LEN) };
164        // But we have to write action str and encode the msg first to get message data len
165        let action_flag: u32;
166        match task.get_action() {
167            RpcAction::Num(num) => action_flag = num as u32,
168            RpcAction::Str(s) => {
169                action_flag = s.len() as u32 | U32_HIGH_MASK;
170                buf.write_all(s.as_bytes()).expect("fill action buffer");
171            }
172        }
173        let msg_len = task.encode_req(codec, buf)?;
174        if msg_len > u32::MAX as usize {
175            error!("ReqHead: req len {} cannot larger than u32", msg_len);
176            return Err(());
177        }
178        let blob = task.get_req_blob();
179        let blob_len = if let Some(blob) = blob { blob.len() } else { 0 };
180        if blob_len > i32::MAX as usize {
181            error!("ReqHead: blob_len {} cannot larger than i32", blob_len);
182            return Err(());
183        }
184        Self::_write_head(buf, client_id, action_flag, task.seq(), msg_len as u32, blob_len as i32);
185        Ok(blob)
186    }
187
188    #[inline(always)]
189    pub fn decode_head(head_buf: &[u8]) -> Result<&Self, RpcIntErr> {
190        let head: &Self = Self::ref_from_bytes(head_buf).expect("from header buf");
191        if head.magic != RPC_MAGIC {
192            warn!("rpc server: wrong magic receive {:?}", head.magic);
193            return Err(RpcIntErr::IO);
194        }
195        if head.ver != RPC_VERSION_1 {
196            warn!("rpc server: version {} not supported", head.ver);
197            return Err(RpcIntErr::Version);
198        }
199        return Ok(head);
200    }
201
202    #[inline]
203    pub fn get_action(&self) -> Result<i32, i32> {
204        if self.action & U32_HIGH_MASK == 0 {
205            Ok(self.action.get() as i32)
206        } else {
207            let action_len = self.action ^ U32_HIGH_MASK;
208            Err(action_len.get() as i32)
209        }
210    }
211}
212
213impl fmt::Display for ReqHead {
214    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215        let _ = unsafe {
216            write!(
217                f,
218                "[client_id:{}, seq:{}, msg:{}, blob:{}",
219                addr_of!(self.client_id).read_unaligned(),
220                addr_of!(self.seq).read_unaligned(),
221                addr_of!(self.msg_len).read_unaligned(),
222                addr_of!(self.blob_len).read_unaligned(),
223            )
224        };
225        match self.get_action() {
226            Ok(action_num) => {
227                write!(f, ", action:{:?}]", action_num)
228            }
229            Err(action_len) => {
230                write!(f, "action_len:{}]", action_len)
231            }
232        }
233    }
234}
235
236impl fmt::Debug for ReqHead {
237    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238        fmt::Display::fmt(self, f)
239    }
240}
241
242/// Fixed-length header for response
243#[derive(FromBytes, IntoBytes, Unaligned, Immutable, KnownLayout, PartialEq, Clone)]
244#[repr(C, packed)]
245pub struct RespHead {
246    pub magic: little_endian::U16,
247    pub ver: u8,
248
249    /// when flag == RESP_FLAG_HAS_ERRNO: msg_len is posix errno; blob_len = 0
250    /// when flag == RESP_FLAG_HAS_ERR_STRING: msg_len=0, blob_len > 0 and follow an error string
251    pub flag: u8,
252
253    /// structured msg_len or errno
254    pub msg_len: little_endian::U32,
255
256    /// Increased ID of request msg in the socket connection (response.seq==request.seq)
257    pub seq: little_endian::U64,
258    /// unstructured msg, only support half of 16Byte, must larger than zero
259    pub blob_len: little_endian::I32,
260}
261
262pub const RPC_RESP_HEADER_LEN: usize = size_of::<RespHead>();
263
264impl RespHead {
265    #[inline]
266    pub fn encode<'a, 'b, L, C, T>(
267        logger: &'b L, codec: &'b C, buf: &'b mut Vec<u8>, task: &'a mut T,
268    ) -> (u64, Option<&'a [u8]>)
269    where
270        L: captains_log::filter::Filter,
271        C: Codec,
272        T: ServerTaskEncode,
273    {
274        debug_assert!(buf.capacity() >= RPC_RESP_HEADER_LEN);
275        // Leave a room at the beginning of buffer for RespHead
276        unsafe { buf.set_len(RPC_RESP_HEADER_LEN) };
277        let (seq, r) = task.encode_resp(codec, buf);
278        match r {
279            Ok((msg_len, None)) => {
280                if msg_len > u32::MAX as usize {
281                    error!("write_resp: encoded msg len {} exceed u32 limit", msg_len);
282                    Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
283                } else {
284                    Self::_write_head(logger, buf, 0, seq, msg_len as u32, 0);
285                }
286                return (seq, None);
287            }
288            Ok((msg_len, Some(blob))) => {
289                if msg_len > u32::MAX as usize {
290                    error!("write_resp: encoded msg len {} exceed u32 limit", msg_len);
291                    Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
292                    return (seq, None);
293                } else if blob.len() > i32::MAX as usize {
294                    error!("write_resp: blob len {} exceed i32 limit", blob.len());
295                    Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
296                    return (seq, None);
297                }
298                Self::_write_head::<L>(logger, buf, 0, seq, msg_len as u32, blob.len() as i32);
299                return (seq, Some(blob));
300            }
301            Err(e) => {
302                Self::_encode_error::<L>(logger, buf, seq, e);
303                return (seq, None);
304            }
305        }
306    }
307
308    #[inline]
309    pub fn encode_internal<'a, L>(
310        logger: &'a L, buf: &'a mut Vec<u8>, seq: u64, err: Option<RpcIntErr>,
311    ) -> u64
312    where
313        L: captains_log::filter::Filter,
314    {
315        debug_assert!(buf.capacity() >= RPC_RESP_HEADER_LEN);
316        // Leave a room at the beginning of buffer for RespHead
317        unsafe { buf.set_len(RPC_RESP_HEADER_LEN) };
318        if let Some(e) = err {
319            Self::_encode_error::<L>(logger, buf, seq, e.into());
320            return seq;
321        } else {
322            // ping
323            Self::_write_head::<L>(logger, buf, 0, seq, 0, 0);
324            return seq;
325        }
326    }
327
328    #[inline(always)]
329    fn _encode_error<'b, L>(logger: &'b L, buf: &'b mut Vec<u8>, seq: u64, e: EncodedErr)
330    where
331        L: captains_log::filter::Filter,
332    {
333        macro_rules! write_err {
334            ($s: expr) => {
335                Self::_write_head::<L>(
336                    logger,
337                    buf,
338                    RESP_FLAG_HAS_ERR_STRING,
339                    seq,
340                    0,
341                    $s.len() as i32,
342                );
343                buf.write_all($s).expect("fill error str");
344            };
345        }
346        match e {
347            EncodedErr::Num(n) => {
348                Self::_write_head::<L>(logger, buf, RESP_FLAG_HAS_ERRNO, seq, n, 0);
349            }
350            EncodedErr::Rpc(s) => {
351                write_err!(s.as_bytes());
352            }
353            EncodedErr::Buf(s) => {
354                write_err!(&s);
355            }
356            EncodedErr::Static(s) => {
357                write_err!(s.as_bytes());
358            }
359        }
360    }
361
362    #[inline]
363    fn _write_head<L>(logger: &L, buf: &mut [u8], flag: u8, seq: u64, msg_len: u32, blob_len: i32)
364    where
365        L: captains_log::filter::Filter,
366    {
367        let header = Self::mut_from_bytes(&mut buf[0..RPC_RESP_HEADER_LEN]).expect("fill header");
368        header.magic = RPC_MAGIC;
369        header.ver = RPC_VERSION_1;
370        header.flag = flag;
371        header.msg_len.set(msg_len);
372        header.seq.set(seq);
373        header.blob_len.set(blob_len);
374        logger_trace!(logger, "resp {:?}", header);
375    }
376
377    #[inline(always)]
378    pub fn decode_head(head_buf: &[u8]) -> Result<&Self, RpcIntErr> {
379        let head: &Self = Self::ref_from_bytes(head_buf).expect("decode header");
380        if head.magic != RPC_MAGIC {
381            warn!("rpc server: wrong magic receive {:?}", head.magic);
382            return Err(RpcIntErr::IO);
383        }
384        if head.ver != RPC_VERSION_1 {
385            warn!("rpc server: version {} not supported", head.ver);
386            return Err(RpcIntErr::Version);
387        }
388        return Ok(head);
389    }
390}
391
392impl fmt::Display for RespHead {
393    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394        unsafe {
395            write!(
396                f,
397                "[seq:{}, flag:{}, msg:{}, blob:{}]",
398                addr_of!(self.seq).read_unaligned(), // format_args deals with unaligned field
399                addr_of!(self.flag).read_unaligned(),
400                addr_of!(self.msg_len).read_unaligned(),
401                addr_of!(self.blob_len).read_unaligned(),
402            )
403        }
404    }
405}
406
407impl fmt::Debug for RespHead {
408    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
409        fmt::Display::fmt(self, f)
410    }
411}
412
413#[cfg(test)]
414mod tests {
415
416    use super::*;
417
418    #[test]
419    fn test_header_len() {
420        assert_eq!(RPC_REQ_HEADER_LEN, 32);
421        assert_eq!(RPC_RESP_HEADER_LEN, 20);
422    }
423}