rpcx_protocol/
message.rs

1use byteorder::{BigEndian, ByteOrder};
2use enum_primitive_derive::Primitive;
3use flate2::{read::GzDecoder, write::GzEncoder, Compression};
4use num_traits::{FromPrimitive, ToPrimitive};
5use strum_macros::{Display, EnumIter, EnumString};
6
7use std::{
8    cell::RefCell,
9    collections::hash_map::HashMap,
10    io::{Read, Write},
11};
12
13use crate::{Error, Result};
14
15const MAGIC_NUMBER: u8 = 0x08;
16pub const SERVICE_ERROR: &str = "__rpcx_error__";
17
18#[derive(Debug, Copy, Clone, Display, PartialEq, EnumIter, EnumString, Primitive)]
19pub enum MessageType {
20    Request = 0,
21    Response = 1,
22}
23
24#[derive(Debug, Copy, Clone, Display, PartialEq, EnumIter, EnumString, Primitive)]
25pub enum MessageStatusType {
26    Normal = 0,
27    Error = 1,
28}
29
30#[derive(Debug, Copy, Clone, Display, PartialEq, EnumIter, EnumString, Primitive)]
31pub enum CompressType {
32    CompressNone = 0,
33    Gzip = 1,
34}
35
36#[derive(Debug, Copy, Clone, Display, PartialEq, EnumIter, EnumString, Primitive)]
37pub enum SerializeType {
38    SerializeNone = 0,
39    JSON = 1,
40    Protobuf = 2,
41    MsgPack = 3,
42    Thrift = 4,
43}
44
45/// define the rpcx message interface.
46pub trait RpcxMessage {
47    fn check_magic_number(&self) -> bool;
48    fn get_version(&self) -> u8;
49    fn set_version(&mut self, v: u8);
50    fn get_message_type(&self) -> Option<MessageType>;
51    fn set_message_type(&mut self, mt: MessageType);
52    fn is_heartbeat(&self) -> bool;
53    fn set_heartbeat(&mut self, b: bool);
54    fn is_oneway(&self) -> bool;
55    fn set_oneway(&mut self, b: bool);
56    fn get_compress_type(&self) -> Option<CompressType>;
57    fn set_compress_type(&mut self, ct: CompressType);
58    fn get_message_status_type(&self) -> Option<MessageStatusType>;
59    fn set_message_status_type(&mut self, mst: MessageStatusType);
60    fn get_serialize_type(&self) -> Option<SerializeType>;
61    fn set_serialize_type(&mut self, st: SerializeType);
62    fn get_seq(&self) -> u64;
63    fn set_seq(&mut self, seq: u64);
64    fn decode<R: ?Sized>(&mut self, r: &mut R) -> Result<()>
65    where
66        R: Read;
67    fn encode(&self) -> Vec<u8>;
68
69    fn get_error(&self) -> Option<String>;
70}
71
72pub type Metadata = HashMap<String, String>;
73
74/// a commmon struct for request and response.
75#[derive(Debug, Default)]
76pub struct Message {
77    pub header: [u8; 12],
78    pub service_path: String,
79    pub service_method: String,
80    pub metadata: RefCell<Metadata>,
81    pub payload: Vec<u8>,
82}
83impl Message {
84    /// Creates a new `Message`
85    pub fn new() -> Self {
86        let mut msg: Message = Default::default();
87        msg.header = [0u8; 12];
88        msg.header[0] = MAGIC_NUMBER;
89        msg.metadata = RefCell::new(HashMap::new());
90        msg
91    }
92
93    pub fn get_reply(&self) -> Result<Self> {
94        let mut reply = Message::new();
95        reply.set_version(self.get_version());
96        reply.set_compress_type(self.get_compress_type().unwrap());
97        reply.set_message_status_type(MessageStatusType::Normal);
98        reply.set_message_type(MessageType::Response);
99        reply.set_serialize_type(self.get_serialize_type().unwrap());
100        reply.set_seq(self.get_seq());
101        reply.service_path = self.service_path.clone();
102        reply.service_method = self.service_method.clone();
103
104        Ok(reply)
105    }
106}
107
108impl RpcxMessage for Message {
109    fn check_magic_number(&self) -> bool {
110        self.header[0] == MAGIC_NUMBER
111    }
112
113    fn get_version(&self) -> u8 {
114        self.header[1]
115    }
116    fn set_version(&mut self, v: u8) {
117        self.header[1] = v;
118    }
119
120    fn get_message_type(&self) -> Option<MessageType> {
121        MessageType::from_u8((self.header[2] & 0x80) >> 7 as u8)
122    }
123    fn set_message_type(&mut self, mt: MessageType) {
124        self.header[2] |= mt.to_u8().unwrap() << 7;
125    }
126    fn is_heartbeat(&self) -> bool {
127        self.header[2] & 0x40 == 0x40
128    }
129    fn set_heartbeat(&mut self, b: bool) {
130        if b {
131            self.header[2] |= 0x40;
132        } else {
133            self.header[2] &= !0x40;
134        }
135    }
136    fn is_oneway(&self) -> bool {
137        self.header[2] & 0x20 == 0x20
138    }
139    fn set_oneway(&mut self, b: bool) {
140        if b {
141            self.header[2] |= 0x20;
142        } else {
143            self.header[2] &= !0x20;
144        }
145    }
146    fn get_compress_type(&self) -> Option<CompressType> {
147        CompressType::from_u8((self.header[2] & 0x1C) >> 2)
148    }
149    fn set_compress_type(&mut self, ct: CompressType) {
150        self.header[2] = (self.header[2] & !0x1C) | (ct.to_u8().unwrap() << 2 & 0x1C);
151    }
152    fn get_message_status_type(&self) -> Option<MessageStatusType> {
153        MessageStatusType::from_u8(self.header[2] & 0x03)
154    }
155    fn set_message_status_type(&mut self, mst: MessageStatusType) {
156        self.header[2] = (self.header[2] & !0x03) | (mst.to_u8().unwrap() & 0x03);
157    }
158    fn get_serialize_type(&self) -> Option<SerializeType> {
159        SerializeType::from_u8((self.header[3] & 0xF0) >> 4)
160    }
161    fn set_serialize_type(&mut self, st: SerializeType) {
162        self.header[3] = (self.header[3] & !0xF0) | (st.to_u8().unwrap() << 4)
163    }
164    fn get_seq(&self) -> u64 {
165        u64_from_slice(&(self.header[4..]))
166    }
167    fn set_seq(&mut self, seq: u64) {
168        u64_to_slice(seq, &mut self.header[4..]);
169    }
170
171    fn decode<R: ?Sized>(&mut self, r: &mut R) -> Result<()>
172    where
173        R: Read,
174    {
175        r.read_exact(&mut self.header)?;
176
177        let mut buf = [0u8; 4];
178        r.read(&mut buf[..]).map(|_| {})?;
179        let len = BigEndian::read_u32(&buf); //length of all expect header
180        let mut buf = vec![0u8; len as usize];
181        r.read(&mut buf[..]).map(|_| ())?;
182
183        let mut start = 0;
184        // read service_path
185        let len = read_len(&buf[start..(start + 4)]) as usize;
186        let service_path = read_str(&buf[(start + 4)..(start + 4 + len)])?;
187        self.service_path = service_path;
188        start = start + 4 + len;
189        // read service_method
190        let len = read_len(&buf[start..(start + 4)]) as usize;
191        let service_method = read_str(&buf[(start + 4)..(start + 4 + len)])?;
192        self.service_method = service_method;
193
194        start = start + 4 + len;
195        //metadata
196        let len = read_len(&buf[start..(start + 4)]) as usize;
197        let metadata_bytes = &buf[(start + 4)..(start + 4 + len)];
198        let mut meta_start = 0;
199        while meta_start < len {
200            let sl = read_len(&metadata_bytes[meta_start..(meta_start + 4)]) as usize;
201            let key = read_str(&metadata_bytes[(meta_start + 4)..(meta_start + 4 + sl)])?;
202            meta_start = meta_start + 4 + sl;
203            if meta_start < len {
204                let value_len = read_len(&metadata_bytes[meta_start..(meta_start + 4)]) as usize;
205                let value =
206                    read_str(&metadata_bytes[(meta_start + 4)..(meta_start + 4 + value_len)])?;
207                self.metadata.borrow_mut().insert(key, value);
208                meta_start = meta_start + 4 + value_len;
209            } else {
210                self.metadata.borrow_mut().insert(key, String::new());
211                break;
212            }
213        }
214        start = start + 4 + len;
215        // payload
216        let len = read_len(&buf[start..start + 4]) as usize;
217        let payload = &buf[start + 4..];
218        if len != payload.len() {
219            return Err(Error::from("invalid payload length"));
220        }
221
222        let mut vp = Vec::with_capacity(payload.len());
223        match self.get_compress_type().unwrap() {
224            CompressType::Gzip => {
225                let mut deflater = GzDecoder::new(payload);
226                deflater.read_to_end(&mut vp)?;
227            }
228            CompressType::CompressNone => {
229                vp.extend_from_slice(&payload);
230            }
231        }
232        self.payload = vp;
233
234        Ok(())
235    }
236
237    fn encode(&self) -> Vec<u8> {
238        // encode all except header
239        let mut buf = Vec::<u8>::with_capacity(20);
240        buf.extend_from_slice(&self.header);
241
242        // push fake length
243        let len_bytes = write_len(0);
244        buf.extend_from_slice(&len_bytes);
245
246        // service_path
247        let len = self.service_path.len();
248        let len_bytes = write_len(len as u32);
249        buf.extend_from_slice(&len_bytes);
250        buf.extend_from_slice(self.service_path.as_bytes());
251
252        // service_method
253        let len = self.service_method.len();
254        let len_bytes = write_len(len as u32);
255        buf.extend_from_slice(&len_bytes);
256        buf.extend_from_slice(self.service_method.as_bytes());
257
258        // metadata
259        let mut metadata_bytes = Vec::<u8>::new();
260        let metadata = self.metadata.borrow_mut();
261        for meta in metadata.iter() {
262            let key = meta.0;
263            let len_bytes = write_len(key.len() as u32);
264            metadata_bytes.extend_from_slice(&len_bytes);
265            metadata_bytes.extend_from_slice(key.as_bytes());
266
267            let value = meta.1;
268            let len_bytes = write_len(value.len() as u32);
269            metadata_bytes.extend_from_slice(&len_bytes);
270            metadata_bytes.extend_from_slice(value.as_bytes());
271        }
272        let len = metadata_bytes.len();
273        let len_bytes = write_len(len as u32);
274        buf.extend_from_slice(&len_bytes);
275        buf.append(&mut metadata_bytes);
276
277        // data
278        // check compress
279
280        match self.get_compress_type().unwrap() {
281            CompressType::Gzip => {
282                let mut e = GzEncoder::new(Vec::new(), Compression::fast());
283                let _ = e.write_all(&self.payload[..]);
284                let compressed_payload = e.finish().unwrap();
285                let len = compressed_payload.len();
286                let len_bytes = write_len(len as u32);
287                buf.extend_from_slice(&len_bytes);
288                buf.extend_from_slice(&compressed_payload);
289            }
290            _ => {
291                let len = self.payload.len();
292                let len_bytes = write_len(len as u32);
293                buf.extend_from_slice(&len_bytes);
294                buf.extend_from_slice(&self.payload);
295            }
296        }
297
298        // set the real length
299        let len = buf.len() - 12 - 4;
300        let len_bytes = write_len(len as u32);
301        buf[12] = len_bytes[0];
302        buf[13] = len_bytes[1];
303        buf[14] = len_bytes[2];
304        buf[15] = len_bytes[3];
305
306        buf
307    }
308
309    fn get_error(&self) -> Option<String> {
310        match self.get_message_status_type() {
311            Some(MessageStatusType::Error) => {
312                let metadata = &self.metadata;
313                let metadata2 = metadata.borrow();
314                let err_msg = metadata2.get(&SERVICE_ERROR.to_owned())?;
315                Some(String::from(err_msg))
316            }
317            _ => None,
318        }
319    }
320}
321
322fn read_len(buf: &[u8]) -> u32 {
323    BigEndian::read_u32(&buf[..4])
324}
325
326fn write_len(len: u32) -> [u8; 4] {
327    let mut buf = [0u8; 4];
328    BigEndian::write_u32(&mut buf, len);
329    buf
330}
331
332fn read_str(buf: &[u8]) -> Result<String> {
333    let s = std::str::from_utf8(&buf).unwrap();
334    let str: String = std::string::String::from(s);
335    Ok(str)
336}
337
338fn u64_from_slice(b: &[u8]) -> u64 {
339    BigEndian::read_u64(b)
340}
341
342fn u64_to_slice(v: u64, b: &mut [u8]) {
343    BigEndian::write_u64(b, v);
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn parse_header() {
352        let msg_data: Vec<u8> = vec![
353            8, 0, 0, 16, 0, 0, 0, 0, 73, 150, 2, 210, 0, 0, 0, 98, 0, 0, 0, 5, 65, 114, 105, 116,
354            104, 0, 0, 0, 3, 65, 100, 100, 0, 0, 0, 48, 0, 0, 0, 4, 95, 95, 73, 68, 0, 0, 0, 36,
355            54, 98, 97, 55, 98, 56, 49, 48, 45, 57, 100, 97, 100, 45, 49, 49, 100, 49, 45, 56, 48,
356            98, 52, 45, 48, 48, 99, 48, 52, 102, 100, 52, 51, 48, 99, 57, 0, 0, 0, 26, 123, 10, 9,
357            9, 34, 65, 34, 58, 32, 49, 44, 10, 9, 9, 34, 66, 34, 58, 32, 50, 44, 10, 9, 125, 10, 9,
358        ];
359
360        let mut msg = Message::new();
361        (&mut msg.header).copy_from_slice(&msg_data[..12]);
362
363        assert_eq!(true, msg.check_magic_number());
364        assert_eq!(0, msg.get_version());
365        assert_eq!(MessageType::Request, msg.get_message_type().unwrap());
366        assert_eq!(false, msg.is_heartbeat());
367        assert_eq!(false, msg.is_oneway());
368        assert_eq!(CompressType::CompressNone, msg.get_compress_type().unwrap());
369        assert_eq!(
370            MessageStatusType::Normal,
371            msg.get_message_status_type().unwrap()
372        );
373        assert_eq!(SerializeType::JSON, msg.get_serialize_type().unwrap());
374        assert_eq!(1234567890, msg.get_seq());
375    }
376
377    #[test]
378    fn set_header() {
379        let msg_data: Vec<u8> = vec![
380            8, 0, 0, 16, 0, 0, 0, 0, 73, 150, 2, 210, 0, 0, 0, 98, 0, 0, 0, 5, 65, 114, 105, 116,
381            104, 0, 0, 0, 3, 65, 100, 100, 0, 0, 0, 48, 0, 0, 0, 4, 95, 95, 73, 68, 0, 0, 0, 36,
382            54, 98, 97, 55, 98, 56, 49, 48, 45, 57, 100, 97, 100, 45, 49, 49, 100, 49, 45, 56, 48,
383            98, 52, 45, 48, 48, 99, 48, 52, 102, 100, 52, 51, 48, 99, 57, 0, 0, 0, 26, 123, 10, 9,
384            9, 34, 65, 34, 58, 32, 49, 44, 10, 9, 9, 34, 66, 34, 58, 32, 50, 44, 10, 9, 125, 10, 9,
385        ];
386
387        let mut msg = Message::new();
388        msg.header.copy_from_slice(&msg_data[..12]);
389
390        msg.set_version(0);
391        msg.set_message_type(MessageType::Response);
392        msg.set_heartbeat(true);
393        msg.set_oneway(true);
394        msg.set_compress_type(CompressType::Gzip);
395        msg.set_serialize_type(SerializeType::MsgPack);
396        msg.set_message_status_type(MessageStatusType::Normal);
397        msg.set_seq(1000000);
398
399        assert_eq!(true, msg.check_magic_number());
400        assert_eq!(0, msg.get_version());
401        assert_eq!(MessageType::Response, msg.get_message_type().unwrap());
402        assert_eq!(true, msg.is_heartbeat());
403        assert_eq!(true, msg.is_oneway());
404        assert_eq!(CompressType::Gzip, msg.get_compress_type().unwrap());
405        assert_eq!(
406            MessageStatusType::Normal,
407            msg.get_message_status_type().unwrap()
408        );
409        assert_eq!(SerializeType::MsgPack, msg.get_serialize_type().unwrap());
410        assert_eq!(1000000, msg.get_seq());
411    }
412
413    #[test]
414    fn decode() {
415        let msg_data: [u8; 114] = [
416            8, 0, 0, 16, 0, 0, 0, 0, 73, 150, 2, 210, 0, 0, 0, 98, 0, 0, 0, 5, 65, 114, 105, 116,
417            104, 0, 0, 0, 3, 65, 100, 100, 0, 0, 0, 48, 0, 0, 0, 4, 95, 95, 73, 68, 0, 0, 0, 36,
418            54, 98, 97, 55, 98, 56, 49, 48, 45, 57, 100, 97, 100, 45, 49, 49, 100, 49, 45, 56, 48,
419            98, 52, 45, 48, 48, 99, 48, 52, 102, 100, 52, 51, 48, 99, 57, 0, 0, 0, 26, 123, 10, 9,
420            9, 34, 65, 34, 58, 32, 49, 44, 10, 9, 9, 34, 66, 34, 58, 32, 50, 44, 10, 9, 125, 10, 9,
421        ];
422
423        let mut msg = Message::new();
424
425        let mut data = &msg_data[..] as &[u8];
426        match msg.decode(&mut data) {
427            Err(err) => println!("failed to parse: {}", err),
428            Ok(()) => {}
429        }
430
431        assert_eq!("Arith", msg.service_path);
432        assert_eq!("Add", msg.service_method);
433
434        assert_eq!(
435            "6ba7b810-9dad-11d1-80b4-00c04fd430c9",
436            msg.metadata.borrow().get("__ID").unwrap()
437        );
438
439        assert_eq!(
440            "{\n\t\t\"A\": 1,\n\t\t\"B\": 2,\n\t}\n\t",
441            std::str::from_utf8(&msg.payload).unwrap()
442        );
443    }
444
445    #[test]
446    fn encode() {
447        let msg_data: [u8; 114] = [
448            8, 0, 0, 16, 0, 0, 0, 0, 73, 150, 2, 210, 0, 0, 0, 98, 0, 0, 0, 5, 65, 114, 105, 116,
449            104, 0, 0, 0, 3, 65, 100, 100, 0, 0, 0, 48, 0, 0, 0, 4, 95, 95, 73, 68, 0, 0, 0, 36,
450            54, 98, 97, 55, 98, 56, 49, 48, 45, 57, 100, 97, 100, 45, 49, 49, 100, 49, 45, 56, 48,
451            98, 52, 45, 48, 48, 99, 48, 52, 102, 100, 52, 51, 48, 99, 57, 0, 0, 0, 26, 123, 10, 9,
452            9, 34, 65, 34, 58, 32, 49, 44, 10, 9, 9, 34, 66, 34, 58, 32, 50, 44, 10, 9, 125, 10, 9,
453        ];
454
455        let mut msg = Message::new();
456
457        let mut data = &msg_data[..] as &[u8];
458        match msg.decode(&mut data) {
459            Err(err) => println!("failed to parse: {}", err),
460            Ok(()) => {}
461        }
462
463        let encoded_bytes = msg.encode();
464
465        assert_eq!(&msg_data[..], &encoded_bytes[..]);
466    }
467}