Skip to main content

razor_stream/
proto.rs

1//! # The protocol
2//!
3//! ## Request
4//!
5//! Fixed length of `ReqHead` = 32B
6//!
7//! | Field     | Size | Description                               |
8//! |-----------|------|-------------------------------------------|
9//! | `magic`   | 2B   | Magic number                              |
10//! | `ver`     | 1B   | Protocol version                          |
11//! | `format`  | 1B   | Encoder-decoder format                    |
12//! | `action`  | 4B   | Action type (numeric or length if string) |
13//! | `seq`     | 8B   | Increased ID of request message           |
14//! | `client_id`| 8B   | Client identifier                         |
15//! | `msg_len` | 4B   | Structured message length                 |
16//! | `blob_len`| 4B   | Unstructured message (blob) length        |
17//!
18//! Variable length message components:
19//! - `action_len` (if `action` is a string)
20//! - `msg_len`
21//! - `blob_len`
22//!
23//! ## Response:
24//!
25//! Fixed length of `RespHead` = 20B
26//!
27//! | Field     | Size | Description                               |
28//! |-----------|------|-------------------------------------------|
29//! | `magic`   | 2B   | Magic number                              |
30//! | `ver`     | 1B   | Protocol version                          |
31//! | `has_err` | 1B   | Error flag                                |
32//! | `seq`     | 8B   | Increased ID of request message           |
33//! | `msg_len` | 4B   | Structured message length or errno        |
34//! | `blob_len`| 4B   | Unstructured message (blob) length        |
35//!
36//! Variable length message components:
37//! - `msg_len`
38//! - `blob_len`
39///
40use crate::client::task::ClientTask;
41use crate::server::task::ServerTaskEncode;
42use crate::{Codec, error::*};
43use std::fmt;
44use std::io::Write;
45use std::mem::size_of;
46use std::ptr::addr_of;
47use zerocopy::byteorder::little_endian;
48use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
49
50pub const PING_ACTION: u32 = 0;
51
52pub const RPC_MAGIC: little_endian::U16 = little_endian::U16::new(19749);
53pub const U32_HIGH_MASK: u32 = 1 << 31;
54
55pub const RESP_FLAG_HAS_ERRNO: u8 = 1;
56pub const RESP_FLAG_HAS_ERR_STRING: u8 = 2;
57pub const RPC_VERSION_1: u8 = 1;
58
59#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum RpcAction<'a> {
61    Str(&'a str),
62    Num(i32),
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
66pub enum RpcActionOwned {
67    Str(String),
68    Num(i32),
69}
70
71impl<'a> From<RpcAction<'a>> for RpcActionOwned {
72    fn from(action: RpcAction<'a>) -> Self {
73        match action {
74            RpcAction::Str(s) => RpcActionOwned::Str(s.to_string()),
75            RpcAction::Num(n) => RpcActionOwned::Num(n),
76        }
77    }
78}
79
80impl RpcActionOwned {
81    pub fn to_action<'a>(&'a self) -> RpcAction<'a> {
82        match self {
83            RpcActionOwned::Str(s) => RpcAction::Str(s.as_str()),
84            RpcActionOwned::Num(n) => RpcAction::Num(*n),
85        }
86    }
87}
88
89/// Fixed-length header for request
90#[derive(FromBytes, IntoBytes, Unaligned, Immutable, KnownLayout, PartialEq, Clone)]
91#[repr(packed)]
92pub struct ReqHead {
93    pub magic: little_endian::U16,
94    pub ver: u8,
95    pub format: u8,
96    /// encoder-decoder format
97
98    /// If highest bit is 0, the rest will be i32 action_num.
99    ///
100    /// If highest is 1, the lower bit will be i32 action_len.
101    pub action: little_endian::U32,
102
103    /// Increased ID of request msg in the socket connection.
104    pub seq: little_endian::U64,
105
106    pub client_id: little_endian::U64,
107    /// structured msg len
108    pub msg_len: little_endian::U32,
109    /// unstructured msg
110    pub blob_len: little_endian::U32,
111}
112
113pub const RPC_REQ_HEADER_LEN: usize = size_of::<ReqHead>();
114
115impl ReqHead {
116    #[inline(always)]
117    pub fn encode_ping(buf: &mut Vec<u8>, client_id: u64, seq: u64) {
118        debug_assert!(buf.capacity() > RPC_REQ_HEADER_LEN);
119        unsafe { buf.set_len(RPC_REQ_HEADER_LEN) };
120        Self::_write_head(buf, client_id, PING_ACTION, seq, 0, 0);
121    }
122
123    #[inline(always)]
124    fn _write_head(
125        buf: &mut Vec<u8>, client_id: u64, action: u32, seq: u64, msg_len: u32, blob_len: i32,
126    ) {
127        // NOTE: We are directly init ReqHead on the buffer with unsafe, check carefully don't miss
128        // a field
129        let header: &mut Self =
130            Self::mut_from_bytes(&mut buf[0..RPC_REQ_HEADER_LEN]).expect("fill header buf");
131        header.magic = RPC_MAGIC;
132        header.ver = RPC_VERSION_1;
133        header.format = 0;
134        header.action.set(action);
135        header.seq.set(seq);
136        header.client_id.set(client_id);
137        header.msg_len.set(msg_len);
138        header.blob_len.set(blob_len as u32);
139    }
140
141    /// write header, action, msg into `buf`, return reference to blob if there's any
142    #[inline(always)]
143    pub fn encode<'a, T, C>(
144        codec: &C, buf: &mut Vec<u8>, client_id: u64, task: &'a T,
145    ) -> Result<Option<&'a [u8]>, ()>
146    where
147        T: ClientTask,
148        C: Codec,
149    {
150        debug_assert!(buf.capacity() >= RPC_REQ_HEADER_LEN);
151        // Leave a room at the beginning of buffer for ReqHead
152        unsafe { buf.set_len(RPC_REQ_HEADER_LEN) };
153        // But we have to write action str and encode the msg first to get message data len
154        let action_flag: u32;
155        match task.get_action() {
156            RpcAction::Num(num) => action_flag = num as u32,
157            RpcAction::Str(s) => {
158                action_flag = s.len() as u32 | U32_HIGH_MASK;
159                buf.write_all(s.as_bytes()).expect("fill action buffer");
160            }
161        }
162        let msg_len = task.encode_req(codec, buf)?;
163        if msg_len > u32::MAX as usize {
164            error!("ReqHead: req len {} cannot larger than u32", msg_len);
165            return Err(());
166        }
167        let blob = task.get_req_blob();
168        let blob_len = if let Some(blob) = blob { blob.len() } else { 0 };
169        if blob_len > i32::MAX as usize {
170            error!("ReqHead: blob_len {} cannot larger than i32", blob_len);
171            return Err(());
172        }
173        Self::_write_head(buf, client_id, action_flag, task.seq(), msg_len as u32, blob_len as i32);
174        Ok(blob)
175    }
176
177    #[inline(always)]
178    pub fn decode_head(head_buf: &[u8]) -> Result<&Self, RpcIntErr> {
179        let head: &Self = Self::ref_from_bytes(head_buf).expect("from header buf");
180        if head.magic != RPC_MAGIC {
181            warn!("rpc server: wrong magic receive {:?}", head.magic);
182            return Err(RpcIntErr::IO);
183        }
184        if head.ver != RPC_VERSION_1 {
185            warn!("rpc server: version {} not supported", head.ver);
186            return Err(RpcIntErr::Version);
187        }
188        return Ok(head);
189    }
190
191    #[inline]
192    pub fn get_action(&self) -> Result<i32, i32> {
193        if self.action & U32_HIGH_MASK == 0 {
194            Ok(self.action.get() as i32)
195        } else {
196            let action_len = self.action ^ U32_HIGH_MASK;
197            Err(action_len.get() as i32)
198        }
199    }
200}
201
202impl fmt::Display for ReqHead {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        let _ = unsafe {
205            write!(
206                f,
207                "[client_id:{}, seq:{}, msg:{}, blob:{}",
208                addr_of!(self.client_id).read_unaligned(),
209                addr_of!(self.seq).read_unaligned(),
210                addr_of!(self.msg_len).read_unaligned(),
211                addr_of!(self.blob_len).read_unaligned(),
212            )
213        };
214        match self.get_action() {
215            Ok(action_num) => {
216                write!(f, ", action:{:?}]", action_num)
217            }
218            Err(action_len) => {
219                write!(f, "action_len:{}]", action_len)
220            }
221        }
222    }
223}
224
225impl fmt::Debug for ReqHead {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        fmt::Display::fmt(self, f)
228    }
229}
230
231/// Fixed-length header for response
232#[derive(FromBytes, IntoBytes, Unaligned, Immutable, KnownLayout, PartialEq, Clone)]
233#[repr(packed)]
234pub struct RespHead {
235    pub magic: little_endian::U16,
236    pub ver: u8,
237
238    /// when flag == RESP_FLAG_HAS_ERRNO: msg_len is posix errno; blob_len = 0
239    /// when flag == RESP_FLAG_HAS_ERR_STRING: msg_len=0, blob_len > 0 and follow an error string
240    pub flag: u8,
241
242    /// structured msg_len or errno
243    pub msg_len: little_endian::U32,
244
245    /// Increased ID of request msg in the socket connection (response.seq==request.seq)
246    pub seq: little_endian::U64,
247    /// unstructured msg, only support half of 16Byte, must larger than zero
248    pub blob_len: little_endian::I32,
249}
250
251pub const RPC_RESP_HEADER_LEN: usize = size_of::<RespHead>();
252
253impl RespHead {
254    #[inline]
255    pub fn encode<'a, 'b, L, C, T>(
256        logger: &'b L, codec: &'b C, buf: &'b mut Vec<u8>, task: &'a mut T,
257    ) -> (u64, Option<&'a [u8]>)
258    where
259        L: captains_log::filter::Filter,
260        C: Codec,
261        T: ServerTaskEncode,
262    {
263        debug_assert!(buf.capacity() >= RPC_RESP_HEADER_LEN);
264        // Leave a room at the beginning of buffer for RespHead
265        unsafe { buf.set_len(RPC_RESP_HEADER_LEN) };
266        let (seq, r) = task.encode_resp(codec, buf);
267        match r {
268            Ok((msg_len, None)) => {
269                if msg_len > u32::MAX as usize {
270                    error!("write_resp: encoded msg len {} exceed u32 limit", msg_len);
271                    Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
272                } else {
273                    Self::_write_head(logger, buf, 0, seq, msg_len as u32, 0);
274                }
275                return (seq, None);
276            }
277            Ok((msg_len, Some(blob))) => {
278                if msg_len > u32::MAX as usize {
279                    error!("write_resp: encoded msg len {} exceed u32 limit", msg_len);
280                    Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
281                    return (seq, None);
282                } else if blob.len() > i32::MAX as usize {
283                    error!("write_resp: blob len {} exceed i32 limit", blob.len());
284                    Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
285                    return (seq, None);
286                }
287                Self::_write_head::<L>(logger, buf, 0, seq, msg_len as u32, blob.len() as i32);
288                return (seq, Some(blob));
289            }
290            Err(e) => {
291                Self::_encode_error::<L>(logger, buf, seq, e);
292                return (seq, None);
293            }
294        }
295    }
296
297    #[inline]
298    pub fn encode_internal<'a, L>(
299        logger: &'a L, buf: &'a mut Vec<u8>, seq: u64, err: Option<RpcIntErr>,
300    ) -> u64
301    where
302        L: captains_log::filter::Filter,
303    {
304        debug_assert!(buf.capacity() >= RPC_RESP_HEADER_LEN);
305        // Leave a room at the beginning of buffer for RespHead
306        unsafe { buf.set_len(RPC_RESP_HEADER_LEN) };
307        if let Some(e) = err {
308            Self::_encode_error::<L>(logger, buf, seq, e.into());
309            return seq;
310        } else {
311            // ping
312            Self::_write_head::<L>(logger, buf, 0, seq, 0, 0);
313            return seq;
314        }
315    }
316
317    #[inline(always)]
318    fn _encode_error<'b, L>(logger: &'b L, buf: &'b mut Vec<u8>, seq: u64, e: EncodedErr)
319    where
320        L: captains_log::filter::Filter,
321    {
322        macro_rules! write_err {
323            ($s: expr) => {
324                Self::_write_head::<L>(
325                    logger,
326                    buf,
327                    RESP_FLAG_HAS_ERR_STRING,
328                    seq,
329                    0,
330                    $s.len() as i32,
331                );
332                buf.write_all($s).expect("fill error str");
333            };
334        }
335        match e {
336            EncodedErr::Num(n) => {
337                Self::_write_head::<L>(logger, buf, RESP_FLAG_HAS_ERRNO, seq, n as u32, 0);
338            }
339            EncodedErr::Rpc(s) => {
340                write_err!(s.as_bytes());
341            }
342            EncodedErr::Buf(s) => {
343                write_err!(&s);
344            }
345            EncodedErr::Static(s) => {
346                write_err!(s.as_bytes());
347            }
348        }
349    }
350
351    #[inline]
352    fn _write_head<L>(
353        logger: &L, buf: &mut Vec<u8>, flag: u8, seq: u64, msg_len: u32, blob_len: i32,
354    ) where
355        L: captains_log::filter::Filter,
356    {
357        let header = Self::mut_from_bytes(&mut buf[0..RPC_RESP_HEADER_LEN]).expect("fill header");
358        header.magic = RPC_MAGIC;
359        header.ver = RPC_VERSION_1;
360        header.flag = flag;
361        header.msg_len.set(msg_len);
362        header.seq.set(seq);
363        header.blob_len.set(blob_len);
364        logger_trace!(logger, "resp {:?}", header);
365    }
366
367    #[inline(always)]
368    pub fn decode_head(head_buf: &[u8]) -> Result<&Self, RpcIntErr> {
369        let head: &Self = Self::ref_from_bytes(head_buf).expect("decode header");
370        if head.magic != RPC_MAGIC {
371            warn!("rpc server: wrong magic receive {:?}", head.magic);
372            return Err(RpcIntErr::IO);
373        }
374        if head.ver != RPC_VERSION_1 {
375            warn!("rpc server: version {} not supported", head.ver);
376            return Err(RpcIntErr::Version);
377        }
378        return Ok(head);
379    }
380}
381
382impl fmt::Display for RespHead {
383    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384        unsafe {
385            write!(
386                f,
387                "[seq:{}, flag:{}, msg:{}, blob:{}]",
388                addr_of!(self.seq).read_unaligned(), // format_args deals with unaligned field
389                addr_of!(self.flag).read_unaligned(),
390                addr_of!(self.msg_len).read_unaligned(),
391                addr_of!(self.blob_len).read_unaligned(),
392            )
393        }
394    }
395}
396
397impl fmt::Debug for RespHead {
398    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399        fmt::Display::fmt(self, f)
400    }
401}
402
403#[cfg(test)]
404mod tests {
405
406    use super::*;
407
408    #[test]
409    fn test_header_len() {
410        assert_eq!(RPC_REQ_HEADER_LEN, 32);
411        assert_eq!(RPC_RESP_HEADER_LEN, 20);
412    }
413}