1use byteorder::{BigEndian, ByteOrder};
2use enum_primitive_derive::Primitive;
3use flate2::{read::GzDecoder, write::GzEncoder, Compression};
4use num_traits::{FromPrimitive, ToPrimitive};
5use strum_macros::{Display, EnumIter, EnumString};
6
7use std::{
8 cell::RefCell,
9 collections::hash_map::HashMap,
10 io::{Read, Write},
11};
12
13use crate::{Error, Result};
14
15const MAGIC_NUMBER: u8 = 0x08;
16pub const SERVICE_ERROR: &str = "__rpcx_error__";
17
18#[derive(Debug, Copy, Clone, Display, PartialEq, EnumIter, EnumString, Primitive)]
19pub enum MessageType {
20 Request = 0,
21 Response = 1,
22}
23
24#[derive(Debug, Copy, Clone, Display, PartialEq, EnumIter, EnumString, Primitive)]
25pub enum MessageStatusType {
26 Normal = 0,
27 Error = 1,
28}
29
30#[derive(Debug, Copy, Clone, Display, PartialEq, EnumIter, EnumString, Primitive)]
31pub enum CompressType {
32 CompressNone = 0,
33 Gzip = 1,
34}
35
36#[derive(Debug, Copy, Clone, Display, PartialEq, EnumIter, EnumString, Primitive)]
37pub enum SerializeType {
38 SerializeNone = 0,
39 JSON = 1,
40 Protobuf = 2,
41 MsgPack = 3,
42 Thrift = 4,
43}
44
45pub trait RpcxMessage {
47 fn check_magic_number(&self) -> bool;
48 fn get_version(&self) -> u8;
49 fn set_version(&mut self, v: u8);
50 fn get_message_type(&self) -> Option<MessageType>;
51 fn set_message_type(&mut self, mt: MessageType);
52 fn is_heartbeat(&self) -> bool;
53 fn set_heartbeat(&mut self, b: bool);
54 fn is_oneway(&self) -> bool;
55 fn set_oneway(&mut self, b: bool);
56 fn get_compress_type(&self) -> Option<CompressType>;
57 fn set_compress_type(&mut self, ct: CompressType);
58 fn get_message_status_type(&self) -> Option<MessageStatusType>;
59 fn set_message_status_type(&mut self, mst: MessageStatusType);
60 fn get_serialize_type(&self) -> Option<SerializeType>;
61 fn set_serialize_type(&mut self, st: SerializeType);
62 fn get_seq(&self) -> u64;
63 fn set_seq(&mut self, seq: u64);
64 fn decode<R: ?Sized>(&mut self, r: &mut R) -> Result<()>
65 where
66 R: Read;
67 fn encode(&self) -> Vec<u8>;
68
69 fn get_error(&self) -> Option<String>;
70}
71
72pub type Metadata = HashMap<String, String>;
73
74#[derive(Debug, Default)]
76pub struct Message {
77 pub header: [u8; 12],
78 pub service_path: String,
79 pub service_method: String,
80 pub metadata: RefCell<Metadata>,
81 pub payload: Vec<u8>,
82}
83impl Message {
84 pub fn new() -> Self {
86 let mut msg: Message = Default::default();
87 msg.header = [0u8; 12];
88 msg.header[0] = MAGIC_NUMBER;
89 msg.metadata = RefCell::new(HashMap::new());
90 msg
91 }
92
93 pub fn get_reply(&self) -> Result<Self> {
94 let mut reply = Message::new();
95 reply.set_version(self.get_version());
96 reply.set_compress_type(self.get_compress_type().unwrap());
97 reply.set_message_status_type(MessageStatusType::Normal);
98 reply.set_message_type(MessageType::Response);
99 reply.set_serialize_type(self.get_serialize_type().unwrap());
100 reply.set_seq(self.get_seq());
101 reply.service_path = self.service_path.clone();
102 reply.service_method = self.service_method.clone();
103
104 Ok(reply)
105 }
106}
107
108impl RpcxMessage for Message {
109 fn check_magic_number(&self) -> bool {
110 self.header[0] == MAGIC_NUMBER
111 }
112
113 fn get_version(&self) -> u8 {
114 self.header[1]
115 }
116 fn set_version(&mut self, v: u8) {
117 self.header[1] = v;
118 }
119
120 fn get_message_type(&self) -> Option<MessageType> {
121 MessageType::from_u8((self.header[2] & 0x80) >> 7 as u8)
122 }
123 fn set_message_type(&mut self, mt: MessageType) {
124 self.header[2] |= mt.to_u8().unwrap() << 7;
125 }
126 fn is_heartbeat(&self) -> bool {
127 self.header[2] & 0x40 == 0x40
128 }
129 fn set_heartbeat(&mut self, b: bool) {
130 if b {
131 self.header[2] |= 0x40;
132 } else {
133 self.header[2] &= !0x40;
134 }
135 }
136 fn is_oneway(&self) -> bool {
137 self.header[2] & 0x20 == 0x20
138 }
139 fn set_oneway(&mut self, b: bool) {
140 if b {
141 self.header[2] |= 0x20;
142 } else {
143 self.header[2] &= !0x20;
144 }
145 }
146 fn get_compress_type(&self) -> Option<CompressType> {
147 CompressType::from_u8((self.header[2] & 0x1C) >> 2)
148 }
149 fn set_compress_type(&mut self, ct: CompressType) {
150 self.header[2] = (self.header[2] & !0x1C) | (ct.to_u8().unwrap() << 2 & 0x1C);
151 }
152 fn get_message_status_type(&self) -> Option<MessageStatusType> {
153 MessageStatusType::from_u8(self.header[2] & 0x03)
154 }
155 fn set_message_status_type(&mut self, mst: MessageStatusType) {
156 self.header[2] = (self.header[2] & !0x03) | (mst.to_u8().unwrap() & 0x03);
157 }
158 fn get_serialize_type(&self) -> Option<SerializeType> {
159 SerializeType::from_u8((self.header[3] & 0xF0) >> 4)
160 }
161 fn set_serialize_type(&mut self, st: SerializeType) {
162 self.header[3] = (self.header[3] & !0xF0) | (st.to_u8().unwrap() << 4)
163 }
164 fn get_seq(&self) -> u64 {
165 u64_from_slice(&(self.header[4..]))
166 }
167 fn set_seq(&mut self, seq: u64) {
168 u64_to_slice(seq, &mut self.header[4..]);
169 }
170
171 fn decode<R: ?Sized>(&mut self, r: &mut R) -> Result<()>
172 where
173 R: Read,
174 {
175 r.read_exact(&mut self.header)?;
176
177 let mut buf = [0u8; 4];
178 r.read(&mut buf[..]).map(|_| {})?;
179 let len = BigEndian::read_u32(&buf); let mut buf = vec![0u8; len as usize];
181 r.read(&mut buf[..]).map(|_| ())?;
182
183 let mut start = 0;
184 let len = read_len(&buf[start..(start + 4)]) as usize;
186 let service_path = read_str(&buf[(start + 4)..(start + 4 + len)])?;
187 self.service_path = service_path;
188 start = start + 4 + len;
189 let len = read_len(&buf[start..(start + 4)]) as usize;
191 let service_method = read_str(&buf[(start + 4)..(start + 4 + len)])?;
192 self.service_method = service_method;
193
194 start = start + 4 + len;
195 let len = read_len(&buf[start..(start + 4)]) as usize;
197 let metadata_bytes = &buf[(start + 4)..(start + 4 + len)];
198 let mut meta_start = 0;
199 while meta_start < len {
200 let sl = read_len(&metadata_bytes[meta_start..(meta_start + 4)]) as usize;
201 let key = read_str(&metadata_bytes[(meta_start + 4)..(meta_start + 4 + sl)])?;
202 meta_start = meta_start + 4 + sl;
203 if meta_start < len {
204 let value_len = read_len(&metadata_bytes[meta_start..(meta_start + 4)]) as usize;
205 let value =
206 read_str(&metadata_bytes[(meta_start + 4)..(meta_start + 4 + value_len)])?;
207 self.metadata.borrow_mut().insert(key, value);
208 meta_start = meta_start + 4 + value_len;
209 } else {
210 self.metadata.borrow_mut().insert(key, String::new());
211 break;
212 }
213 }
214 start = start + 4 + len;
215 let len = read_len(&buf[start..start + 4]) as usize;
217 let payload = &buf[start + 4..];
218 if len != payload.len() {
219 return Err(Error::from("invalid payload length"));
220 }
221
222 let mut vp = Vec::with_capacity(payload.len());
223 match self.get_compress_type().unwrap() {
224 CompressType::Gzip => {
225 let mut deflater = GzDecoder::new(payload);
226 deflater.read_to_end(&mut vp)?;
227 }
228 CompressType::CompressNone => {
229 vp.extend_from_slice(&payload);
230 }
231 }
232 self.payload = vp;
233
234 Ok(())
235 }
236
237 fn encode(&self) -> Vec<u8> {
238 let mut buf = Vec::<u8>::with_capacity(20);
240 buf.extend_from_slice(&self.header);
241
242 let len_bytes = write_len(0);
244 buf.extend_from_slice(&len_bytes);
245
246 let len = self.service_path.len();
248 let len_bytes = write_len(len as u32);
249 buf.extend_from_slice(&len_bytes);
250 buf.extend_from_slice(self.service_path.as_bytes());
251
252 let len = self.service_method.len();
254 let len_bytes = write_len(len as u32);
255 buf.extend_from_slice(&len_bytes);
256 buf.extend_from_slice(self.service_method.as_bytes());
257
258 let mut metadata_bytes = Vec::<u8>::new();
260 let metadata = self.metadata.borrow_mut();
261 for meta in metadata.iter() {
262 let key = meta.0;
263 let len_bytes = write_len(key.len() as u32);
264 metadata_bytes.extend_from_slice(&len_bytes);
265 metadata_bytes.extend_from_slice(key.as_bytes());
266
267 let value = meta.1;
268 let len_bytes = write_len(value.len() as u32);
269 metadata_bytes.extend_from_slice(&len_bytes);
270 metadata_bytes.extend_from_slice(value.as_bytes());
271 }
272 let len = metadata_bytes.len();
273 let len_bytes = write_len(len as u32);
274 buf.extend_from_slice(&len_bytes);
275 buf.append(&mut metadata_bytes);
276
277 match self.get_compress_type().unwrap() {
281 CompressType::Gzip => {
282 let mut e = GzEncoder::new(Vec::new(), Compression::fast());
283 let _ = e.write_all(&self.payload[..]);
284 let compressed_payload = e.finish().unwrap();
285 let len = compressed_payload.len();
286 let len_bytes = write_len(len as u32);
287 buf.extend_from_slice(&len_bytes);
288 buf.extend_from_slice(&compressed_payload);
289 }
290 _ => {
291 let len = self.payload.len();
292 let len_bytes = write_len(len as u32);
293 buf.extend_from_slice(&len_bytes);
294 buf.extend_from_slice(&self.payload);
295 }
296 }
297
298 let len = buf.len() - 12 - 4;
300 let len_bytes = write_len(len as u32);
301 buf[12] = len_bytes[0];
302 buf[13] = len_bytes[1];
303 buf[14] = len_bytes[2];
304 buf[15] = len_bytes[3];
305
306 buf
307 }
308
309 fn get_error(&self) -> Option<String> {
310 match self.get_message_status_type() {
311 Some(MessageStatusType::Error) => {
312 let metadata = &self.metadata;
313 let metadata2 = metadata.borrow();
314 let err_msg = metadata2.get(&SERVICE_ERROR.to_owned())?;
315 Some(String::from(err_msg))
316 }
317 _ => None,
318 }
319 }
320}
321
322fn read_len(buf: &[u8]) -> u32 {
323 BigEndian::read_u32(&buf[..4])
324}
325
326fn write_len(len: u32) -> [u8; 4] {
327 let mut buf = [0u8; 4];
328 BigEndian::write_u32(&mut buf, len);
329 buf
330}
331
332fn read_str(buf: &[u8]) -> Result<String> {
333 let s = std::str::from_utf8(&buf).unwrap();
334 let str: String = std::string::String::from(s);
335 Ok(str)
336}
337
338fn u64_from_slice(b: &[u8]) -> u64 {
339 BigEndian::read_u64(b)
340}
341
342fn u64_to_slice(v: u64, b: &mut [u8]) {
343 BigEndian::write_u64(b, v);
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn parse_header() {
352 let msg_data: Vec<u8> = vec![
353 8, 0, 0, 16, 0, 0, 0, 0, 73, 150, 2, 210, 0, 0, 0, 98, 0, 0, 0, 5, 65, 114, 105, 116,
354 104, 0, 0, 0, 3, 65, 100, 100, 0, 0, 0, 48, 0, 0, 0, 4, 95, 95, 73, 68, 0, 0, 0, 36,
355 54, 98, 97, 55, 98, 56, 49, 48, 45, 57, 100, 97, 100, 45, 49, 49, 100, 49, 45, 56, 48,
356 98, 52, 45, 48, 48, 99, 48, 52, 102, 100, 52, 51, 48, 99, 57, 0, 0, 0, 26, 123, 10, 9,
357 9, 34, 65, 34, 58, 32, 49, 44, 10, 9, 9, 34, 66, 34, 58, 32, 50, 44, 10, 9, 125, 10, 9,
358 ];
359
360 let mut msg = Message::new();
361 (&mut msg.header).copy_from_slice(&msg_data[..12]);
362
363 assert_eq!(true, msg.check_magic_number());
364 assert_eq!(0, msg.get_version());
365 assert_eq!(MessageType::Request, msg.get_message_type().unwrap());
366 assert_eq!(false, msg.is_heartbeat());
367 assert_eq!(false, msg.is_oneway());
368 assert_eq!(CompressType::CompressNone, msg.get_compress_type().unwrap());
369 assert_eq!(
370 MessageStatusType::Normal,
371 msg.get_message_status_type().unwrap()
372 );
373 assert_eq!(SerializeType::JSON, msg.get_serialize_type().unwrap());
374 assert_eq!(1234567890, msg.get_seq());
375 }
376
377 #[test]
378 fn set_header() {
379 let msg_data: Vec<u8> = vec![
380 8, 0, 0, 16, 0, 0, 0, 0, 73, 150, 2, 210, 0, 0, 0, 98, 0, 0, 0, 5, 65, 114, 105, 116,
381 104, 0, 0, 0, 3, 65, 100, 100, 0, 0, 0, 48, 0, 0, 0, 4, 95, 95, 73, 68, 0, 0, 0, 36,
382 54, 98, 97, 55, 98, 56, 49, 48, 45, 57, 100, 97, 100, 45, 49, 49, 100, 49, 45, 56, 48,
383 98, 52, 45, 48, 48, 99, 48, 52, 102, 100, 52, 51, 48, 99, 57, 0, 0, 0, 26, 123, 10, 9,
384 9, 34, 65, 34, 58, 32, 49, 44, 10, 9, 9, 34, 66, 34, 58, 32, 50, 44, 10, 9, 125, 10, 9,
385 ];
386
387 let mut msg = Message::new();
388 msg.header.copy_from_slice(&msg_data[..12]);
389
390 msg.set_version(0);
391 msg.set_message_type(MessageType::Response);
392 msg.set_heartbeat(true);
393 msg.set_oneway(true);
394 msg.set_compress_type(CompressType::Gzip);
395 msg.set_serialize_type(SerializeType::MsgPack);
396 msg.set_message_status_type(MessageStatusType::Normal);
397 msg.set_seq(1000000);
398
399 assert_eq!(true, msg.check_magic_number());
400 assert_eq!(0, msg.get_version());
401 assert_eq!(MessageType::Response, msg.get_message_type().unwrap());
402 assert_eq!(true, msg.is_heartbeat());
403 assert_eq!(true, msg.is_oneway());
404 assert_eq!(CompressType::Gzip, msg.get_compress_type().unwrap());
405 assert_eq!(
406 MessageStatusType::Normal,
407 msg.get_message_status_type().unwrap()
408 );
409 assert_eq!(SerializeType::MsgPack, msg.get_serialize_type().unwrap());
410 assert_eq!(1000000, msg.get_seq());
411 }
412
413 #[test]
414 fn decode() {
415 let msg_data: [u8; 114] = [
416 8, 0, 0, 16, 0, 0, 0, 0, 73, 150, 2, 210, 0, 0, 0, 98, 0, 0, 0, 5, 65, 114, 105, 116,
417 104, 0, 0, 0, 3, 65, 100, 100, 0, 0, 0, 48, 0, 0, 0, 4, 95, 95, 73, 68, 0, 0, 0, 36,
418 54, 98, 97, 55, 98, 56, 49, 48, 45, 57, 100, 97, 100, 45, 49, 49, 100, 49, 45, 56, 48,
419 98, 52, 45, 48, 48, 99, 48, 52, 102, 100, 52, 51, 48, 99, 57, 0, 0, 0, 26, 123, 10, 9,
420 9, 34, 65, 34, 58, 32, 49, 44, 10, 9, 9, 34, 66, 34, 58, 32, 50, 44, 10, 9, 125, 10, 9,
421 ];
422
423 let mut msg = Message::new();
424
425 let mut data = &msg_data[..] as &[u8];
426 match msg.decode(&mut data) {
427 Err(err) => println!("failed to parse: {}", err),
428 Ok(()) => {}
429 }
430
431 assert_eq!("Arith", msg.service_path);
432 assert_eq!("Add", msg.service_method);
433
434 assert_eq!(
435 "6ba7b810-9dad-11d1-80b4-00c04fd430c9",
436 msg.metadata.borrow().get("__ID").unwrap()
437 );
438
439 assert_eq!(
440 "{\n\t\t\"A\": 1,\n\t\t\"B\": 2,\n\t}\n\t",
441 std::str::from_utf8(&msg.payload).unwrap()
442 );
443 }
444
445 #[test]
446 fn encode() {
447 let msg_data: [u8; 114] = [
448 8, 0, 0, 16, 0, 0, 0, 0, 73, 150, 2, 210, 0, 0, 0, 98, 0, 0, 0, 5, 65, 114, 105, 116,
449 104, 0, 0, 0, 3, 65, 100, 100, 0, 0, 0, 48, 0, 0, 0, 4, 95, 95, 73, 68, 0, 0, 0, 36,
450 54, 98, 97, 55, 98, 56, 49, 48, 45, 57, 100, 97, 100, 45, 49, 49, 100, 49, 45, 56, 48,
451 98, 52, 45, 48, 48, 99, 48, 52, 102, 100, 52, 51, 48, 99, 57, 0, 0, 0, 26, 123, 10, 9,
452 9, 34, 65, 34, 58, 32, 49, 44, 10, 9, 9, 34, 66, 34, 58, 32, 50, 44, 10, 9, 125, 10, 9,
453 ];
454
455 let mut msg = Message::new();
456
457 let mut data = &msg_data[..] as &[u8];
458 match msg.decode(&mut data) {
459 Err(err) => println!("failed to parse: {}", err),
460 Ok(()) => {}
461 }
462
463 let encoded_bytes = msg.encode();
464
465 assert_eq!(&msg_data[..], &encoded_bytes[..]);
466 }
467}