1use crate::client::task::ClientTask;
41use crate::server::task::ServerTaskEncode;
42use crate::{Codec, error::*};
43use std::fmt;
44use std::io::Write;
45use std::mem::size_of;
46use std::ptr::addr_of;
47use zerocopy::byteorder::little_endian;
48use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
49
50pub const PING_ACTION: u32 = 0;
51
52pub const RPC_MAGIC: little_endian::U16 = little_endian::U16::new(19749);
53pub const U32_HIGH_MASK: u32 = 1 << 31;
54
55pub const RESP_FLAG_HAS_ERRNO: u8 = 1;
56pub const RESP_FLAG_HAS_ERR_STRING: u8 = 2;
57pub const RPC_VERSION_1: u8 = 1;
58
59#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum RpcAction<'a> {
61 Str(&'a str),
62 Num(i32),
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
66pub enum RpcActionOwned {
67 Str(String),
68 Num(i32),
69}
70
71impl<'a> From<RpcAction<'a>> for RpcActionOwned {
72 fn from(action: RpcAction<'a>) -> Self {
73 match action {
74 RpcAction::Str(s) => RpcActionOwned::Str(s.to_string()),
75 RpcAction::Num(n) => RpcActionOwned::Num(n),
76 }
77 }
78}
79
80impl RpcActionOwned {
81 pub fn to_action<'a>(&'a self) -> RpcAction<'a> {
82 match self {
83 RpcActionOwned::Str(s) => RpcAction::Str(s.as_str()),
84 RpcActionOwned::Num(n) => RpcAction::Num(*n),
85 }
86 }
87}
88
89#[derive(FromBytes, IntoBytes, Unaligned, Immutable, KnownLayout, PartialEq, Clone)]
91#[repr(packed)]
92pub struct ReqHead {
93 pub magic: little_endian::U16,
94 pub ver: u8,
95 pub format: u8,
96 pub action: little_endian::U32,
102
103 pub seq: little_endian::U64,
105
106 pub client_id: little_endian::U64,
107 pub msg_len: little_endian::U32,
109 pub blob_len: little_endian::U32,
111}
112
113pub const RPC_REQ_HEADER_LEN: usize = size_of::<ReqHead>();
114
115impl ReqHead {
116 #[inline(always)]
117 pub fn encode_ping(buf: &mut Vec<u8>, client_id: u64, seq: u64) {
118 debug_assert!(buf.capacity() > RPC_REQ_HEADER_LEN);
119 unsafe { buf.set_len(RPC_REQ_HEADER_LEN) };
120 Self::_write_head(buf, client_id, PING_ACTION, seq, 0, 0);
121 }
122
123 #[inline(always)]
124 fn _write_head(
125 buf: &mut Vec<u8>, client_id: u64, action: u32, seq: u64, msg_len: u32, blob_len: i32,
126 ) {
127 let header: &mut Self =
130 Self::mut_from_bytes(&mut buf[0..RPC_REQ_HEADER_LEN]).expect("fill header buf");
131 header.magic = RPC_MAGIC;
132 header.ver = RPC_VERSION_1;
133 header.format = 0;
134 header.action.set(action);
135 header.seq.set(seq);
136 header.client_id.set(client_id);
137 header.msg_len.set(msg_len);
138 header.blob_len.set(blob_len as u32);
139 }
140
141 #[inline(always)]
143 pub fn encode<'a, T, C>(
144 codec: &C, buf: &mut Vec<u8>, client_id: u64, task: &'a T,
145 ) -> Result<Option<&'a [u8]>, ()>
146 where
147 T: ClientTask,
148 C: Codec,
149 {
150 debug_assert!(buf.capacity() >= RPC_REQ_HEADER_LEN);
151 unsafe { buf.set_len(RPC_REQ_HEADER_LEN) };
153 let action_flag: u32;
155 match task.get_action() {
156 RpcAction::Num(num) => action_flag = num as u32,
157 RpcAction::Str(s) => {
158 action_flag = s.len() as u32 | U32_HIGH_MASK;
159 buf.write_all(s.as_bytes()).expect("fill action buffer");
160 }
161 }
162 let msg_len = task.encode_req(codec, buf)?;
163 if msg_len > u32::MAX as usize {
164 error!("ReqHead: req len {} cannot larger than u32", msg_len);
165 return Err(());
166 }
167 let blob = task.get_req_blob();
168 let blob_len = if let Some(blob) = blob { blob.len() } else { 0 };
169 if blob_len > i32::MAX as usize {
170 error!("ReqHead: blob_len {} cannot larger than i32", blob_len);
171 return Err(());
172 }
173 Self::_write_head(buf, client_id, action_flag, task.seq(), msg_len as u32, blob_len as i32);
174 Ok(blob)
175 }
176
177 #[inline(always)]
178 pub fn decode_head(head_buf: &[u8]) -> Result<&Self, RpcIntErr> {
179 let head: &Self = Self::ref_from_bytes(head_buf).expect("from header buf");
180 if head.magic != RPC_MAGIC {
181 warn!("rpc server: wrong magic receive {:?}", head.magic);
182 return Err(RpcIntErr::IO);
183 }
184 if head.ver != RPC_VERSION_1 {
185 warn!("rpc server: version {} not supported", head.ver);
186 return Err(RpcIntErr::Version);
187 }
188 return Ok(head);
189 }
190
191 #[inline]
192 pub fn get_action(&self) -> Result<i32, i32> {
193 if self.action & U32_HIGH_MASK == 0 {
194 Ok(self.action.get() as i32)
195 } else {
196 let action_len = self.action ^ U32_HIGH_MASK;
197 Err(action_len.get() as i32)
198 }
199 }
200}
201
202impl fmt::Display for ReqHead {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 let _ = unsafe {
205 write!(
206 f,
207 "[client_id:{}, seq:{}, msg:{}, blob:{}",
208 addr_of!(self.client_id).read_unaligned(),
209 addr_of!(self.seq).read_unaligned(),
210 addr_of!(self.msg_len).read_unaligned(),
211 addr_of!(self.blob_len).read_unaligned(),
212 )
213 };
214 match self.get_action() {
215 Ok(action_num) => {
216 write!(f, ", action:{:?}]", action_num)
217 }
218 Err(action_len) => {
219 write!(f, "action_len:{}]", action_len)
220 }
221 }
222 }
223}
224
225impl fmt::Debug for ReqHead {
226 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227 fmt::Display::fmt(self, f)
228 }
229}
230
231#[derive(FromBytes, IntoBytes, Unaligned, Immutable, KnownLayout, PartialEq, Clone)]
233#[repr(packed)]
234pub struct RespHead {
235 pub magic: little_endian::U16,
236 pub ver: u8,
237
238 pub flag: u8,
241
242 pub msg_len: little_endian::U32,
244
245 pub seq: little_endian::U64,
247 pub blob_len: little_endian::I32,
249}
250
251pub const RPC_RESP_HEADER_LEN: usize = size_of::<RespHead>();
252
253impl RespHead {
254 #[inline]
255 pub fn encode<'a, 'b, L, C, T>(
256 logger: &'b L, codec: &'b C, buf: &'b mut Vec<u8>, task: &'a mut T,
257 ) -> (u64, Option<&'a [u8]>)
258 where
259 L: captains_log::filter::Filter,
260 C: Codec,
261 T: ServerTaskEncode,
262 {
263 debug_assert!(buf.capacity() >= RPC_RESP_HEADER_LEN);
264 unsafe { buf.set_len(RPC_RESP_HEADER_LEN) };
266 let (seq, r) = task.encode_resp(codec, buf);
267 match r {
268 Ok((msg_len, None)) => {
269 if msg_len > u32::MAX as usize {
270 error!("write_resp: encoded msg len {} exceed u32 limit", msg_len);
271 Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
272 } else {
273 Self::_write_head(logger, buf, 0, seq, msg_len as u32, 0);
274 }
275 return (seq, None);
276 }
277 Ok((msg_len, Some(blob))) => {
278 if msg_len > u32::MAX as usize {
279 error!("write_resp: encoded msg len {} exceed u32 limit", msg_len);
280 Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
281 return (seq, None);
282 } else if blob.len() > i32::MAX as usize {
283 error!("write_resp: blob len {} exceed i32 limit", blob.len());
284 Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
285 return (seq, None);
286 }
287 Self::_write_head::<L>(logger, buf, 0, seq, msg_len as u32, blob.len() as i32);
288 return (seq, Some(blob));
289 }
290 Err(e) => {
291 Self::_encode_error::<L>(logger, buf, seq, e);
292 return (seq, None);
293 }
294 }
295 }
296
297 #[inline]
298 pub fn encode_internal<'a, L>(
299 logger: &'a L, buf: &'a mut Vec<u8>, seq: u64, err: Option<RpcIntErr>,
300 ) -> u64
301 where
302 L: captains_log::filter::Filter,
303 {
304 debug_assert!(buf.capacity() >= RPC_RESP_HEADER_LEN);
305 unsafe { buf.set_len(RPC_RESP_HEADER_LEN) };
307 if let Some(e) = err {
308 Self::_encode_error::<L>(logger, buf, seq, e.into());
309 return seq;
310 } else {
311 Self::_write_head::<L>(logger, buf, 0, seq, 0, 0);
313 return seq;
314 }
315 }
316
317 #[inline(always)]
318 fn _encode_error<'b, L>(logger: &'b L, buf: &'b mut Vec<u8>, seq: u64, e: EncodedErr)
319 where
320 L: captains_log::filter::Filter,
321 {
322 macro_rules! write_err {
323 ($s: expr) => {
324 Self::_write_head::<L>(
325 logger,
326 buf,
327 RESP_FLAG_HAS_ERR_STRING,
328 seq,
329 0,
330 $s.len() as i32,
331 );
332 buf.write_all($s).expect("fill error str");
333 };
334 }
335 match e {
336 EncodedErr::Num(n) => {
337 Self::_write_head::<L>(logger, buf, RESP_FLAG_HAS_ERRNO, seq, n as u32, 0);
338 }
339 EncodedErr::Rpc(s) => {
340 write_err!(s.as_bytes());
341 }
342 EncodedErr::Buf(s) => {
343 write_err!(&s);
344 }
345 EncodedErr::Static(s) => {
346 write_err!(s.as_bytes());
347 }
348 }
349 }
350
351 #[inline]
352 fn _write_head<L>(
353 logger: &L, buf: &mut Vec<u8>, flag: u8, seq: u64, msg_len: u32, blob_len: i32,
354 ) where
355 L: captains_log::filter::Filter,
356 {
357 let header = Self::mut_from_bytes(&mut buf[0..RPC_RESP_HEADER_LEN]).expect("fill header");
358 header.magic = RPC_MAGIC;
359 header.ver = RPC_VERSION_1;
360 header.flag = flag;
361 header.msg_len.set(msg_len);
362 header.seq.set(seq);
363 header.blob_len.set(blob_len);
364 logger_trace!(logger, "resp {:?}", header);
365 }
366
367 #[inline(always)]
368 pub fn decode_head(head_buf: &[u8]) -> Result<&Self, RpcIntErr> {
369 let head: &Self = Self::ref_from_bytes(head_buf).expect("decode header");
370 if head.magic != RPC_MAGIC {
371 warn!("rpc server: wrong magic receive {:?}", head.magic);
372 return Err(RpcIntErr::IO);
373 }
374 if head.ver != RPC_VERSION_1 {
375 warn!("rpc server: version {} not supported", head.ver);
376 return Err(RpcIntErr::Version);
377 }
378 return Ok(head);
379 }
380}
381
382impl fmt::Display for RespHead {
383 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384 unsafe {
385 write!(
386 f,
387 "[seq:{}, flag:{}, msg:{}, blob:{}]",
388 addr_of!(self.seq).read_unaligned(), addr_of!(self.flag).read_unaligned(),
390 addr_of!(self.msg_len).read_unaligned(),
391 addr_of!(self.blob_len).read_unaligned(),
392 )
393 }
394 }
395}
396
397impl fmt::Debug for RespHead {
398 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399 fmt::Display::fmt(self, f)
400 }
401}
402
403#[cfg(test)]
404mod tests {
405
406 use super::*;
407
408 #[test]
409 fn test_header_len() {
410 assert_eq!(RPC_REQ_HEADER_LEN, 32);
411 assert_eq!(RPC_RESP_HEADER_LEN, 20);
412 }
413}