1use crate::client::task::ClientTask;
52use crate::server::task::ServerTaskEncode;
53use crate::{Codec, error::*};
54use std::fmt;
55use std::io::Write;
56use std::mem::size_of;
57use std::ptr::addr_of;
58use zerocopy::byteorder::little_endian;
59use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
60
61pub const PING_ACTION: u32 = 0;
62
63pub const RPC_MAGIC: little_endian::U16 = little_endian::U16::new(19749);
64pub const U32_HIGH_MASK: u32 = 1 << 31;
65
66pub const RESP_FLAG_HAS_ERRNO: u8 = 1;
67pub const RESP_FLAG_HAS_ERR_STRING: u8 = 2;
68pub const RPC_VERSION_1: u8 = 1;
69
70#[derive(Debug, Clone, Copy, PartialEq)]
71pub enum RpcAction<'a> {
72 Str(&'a str),
73 Num(i32),
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Hash)]
77pub enum RpcActionOwned {
78 Str(String),
79 Num(i32),
80}
81
82impl<'a> From<RpcAction<'a>> for RpcActionOwned {
83 fn from(action: RpcAction<'a>) -> Self {
84 match action {
85 RpcAction::Str(s) => RpcActionOwned::Str(s.to_string()),
86 RpcAction::Num(n) => RpcActionOwned::Num(n),
87 }
88 }
89}
90
91impl RpcActionOwned {
92 pub fn to_action<'a>(&'a self) -> RpcAction<'a> {
93 match self {
94 RpcActionOwned::Str(s) => RpcAction::Str(s.as_str()),
95 RpcActionOwned::Num(n) => RpcAction::Num(*n),
96 }
97 }
98}
99
100#[derive(FromBytes, IntoBytes, Unaligned, Immutable, KnownLayout, PartialEq, Clone)]
102#[repr(C, packed)]
103pub struct ReqHead {
104 pub magic: little_endian::U16,
105 pub ver: u8,
106 pub format: u8,
107 pub action: little_endian::U32,
113
114 pub seq: little_endian::U64,
116
117 pub client_id: little_endian::U64,
118 pub msg_len: little_endian::U32,
120 pub blob_len: little_endian::U32,
122}
123
124pub const RPC_REQ_HEADER_LEN: usize = size_of::<ReqHead>();
125
126impl ReqHead {
127 #[inline(always)]
128 pub fn encode_ping(buf: &mut Vec<u8>, client_id: u64, seq: u64) {
129 debug_assert!(buf.capacity() > RPC_REQ_HEADER_LEN);
130 unsafe { buf.set_len(RPC_REQ_HEADER_LEN) };
131 Self::_write_head(buf, client_id, PING_ACTION, seq, 0, 0);
132 }
133
134 #[inline(always)]
135 fn _write_head(
136 buf: &mut [u8], client_id: u64, action: u32, seq: u64, msg_len: u32, blob_len: i32,
137 ) {
138 let header: &mut Self =
141 Self::mut_from_bytes(&mut buf[0..RPC_REQ_HEADER_LEN]).expect("fill header buf");
142 header.magic = RPC_MAGIC;
143 header.ver = RPC_VERSION_1;
144 header.format = 0;
145 header.action.set(action);
146 header.seq.set(seq);
147 header.client_id.set(client_id);
148 header.msg_len.set(msg_len);
149 header.blob_len.set(blob_len as u32);
150 }
151
152 #[inline(always)]
154 pub fn encode<'a, T, C>(
155 codec: &C, buf: &mut Vec<u8>, client_id: u64, task: &'a T,
156 ) -> Result<Option<&'a [u8]>, ()>
157 where
158 T: ClientTask,
159 C: Codec,
160 {
161 debug_assert!(buf.capacity() >= RPC_REQ_HEADER_LEN);
162 unsafe { buf.set_len(RPC_REQ_HEADER_LEN) };
164 let action_flag: u32;
166 match task.get_action() {
167 RpcAction::Num(num) => action_flag = num as u32,
168 RpcAction::Str(s) => {
169 action_flag = s.len() as u32 | U32_HIGH_MASK;
170 buf.write_all(s.as_bytes()).expect("fill action buffer");
171 }
172 }
173 let msg_len = task.encode_req(codec, buf)?;
174 if msg_len > u32::MAX as usize {
175 error!("ReqHead: req len {} cannot larger than u32", msg_len);
176 return Err(());
177 }
178 let blob = task.get_req_blob();
179 let blob_len = if let Some(blob) = blob { blob.len() } else { 0 };
180 if blob_len > i32::MAX as usize {
181 error!("ReqHead: blob_len {} cannot larger than i32", blob_len);
182 return Err(());
183 }
184 Self::_write_head(buf, client_id, action_flag, task.seq(), msg_len as u32, blob_len as i32);
185 Ok(blob)
186 }
187
188 #[inline(always)]
189 pub fn decode_head(head_buf: &[u8]) -> Result<&Self, RpcIntErr> {
190 let head: &Self = Self::ref_from_bytes(head_buf).expect("from header buf");
191 if head.magic != RPC_MAGIC {
192 warn!("rpc server: wrong magic receive {:?}", head.magic);
193 return Err(RpcIntErr::IO);
194 }
195 if head.ver != RPC_VERSION_1 {
196 warn!("rpc server: version {} not supported", head.ver);
197 return Err(RpcIntErr::Version);
198 }
199 return Ok(head);
200 }
201
202 #[inline]
203 pub fn get_action(&self) -> Result<i32, i32> {
204 if self.action & U32_HIGH_MASK == 0 {
205 Ok(self.action.get() as i32)
206 } else {
207 let action_len = self.action ^ U32_HIGH_MASK;
208 Err(action_len.get() as i32)
209 }
210 }
211}
212
213impl fmt::Display for ReqHead {
214 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215 let _ = unsafe {
216 write!(
217 f,
218 "[client_id:{}, seq:{}, msg:{}, blob:{}",
219 addr_of!(self.client_id).read_unaligned(),
220 addr_of!(self.seq).read_unaligned(),
221 addr_of!(self.msg_len).read_unaligned(),
222 addr_of!(self.blob_len).read_unaligned(),
223 )
224 };
225 match self.get_action() {
226 Ok(action_num) => {
227 write!(f, ", action:{:?}]", action_num)
228 }
229 Err(action_len) => {
230 write!(f, "action_len:{}]", action_len)
231 }
232 }
233 }
234}
235
236impl fmt::Debug for ReqHead {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 fmt::Display::fmt(self, f)
239 }
240}
241
242#[derive(FromBytes, IntoBytes, Unaligned, Immutable, KnownLayout, PartialEq, Clone)]
244#[repr(C, packed)]
245pub struct RespHead {
246 pub magic: little_endian::U16,
247 pub ver: u8,
248
249 pub flag: u8,
252
253 pub msg_len: little_endian::U32,
255
256 pub seq: little_endian::U64,
258 pub blob_len: little_endian::I32,
260}
261
262pub const RPC_RESP_HEADER_LEN: usize = size_of::<RespHead>();
263
264impl RespHead {
265 #[inline]
266 pub fn encode<'a, 'b, L, C, T>(
267 logger: &'b L, codec: &'b C, buf: &'b mut Vec<u8>, task: &'a mut T,
268 ) -> (u64, Option<&'a [u8]>)
269 where
270 L: captains_log::filter::Filter,
271 C: Codec,
272 T: ServerTaskEncode,
273 {
274 debug_assert!(buf.capacity() >= RPC_RESP_HEADER_LEN);
275 unsafe { buf.set_len(RPC_RESP_HEADER_LEN) };
277 let (seq, r) = task.encode_resp(codec, buf);
278 match r {
279 Ok((msg_len, None)) => {
280 if msg_len > u32::MAX as usize {
281 error!("write_resp: encoded msg len {} exceed u32 limit", msg_len);
282 Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
283 } else {
284 Self::_write_head(logger, buf, 0, seq, msg_len as u32, 0);
285 }
286 return (seq, None);
287 }
288 Ok((msg_len, Some(blob))) => {
289 if msg_len > u32::MAX as usize {
290 error!("write_resp: encoded msg len {} exceed u32 limit", msg_len);
291 Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
292 return (seq, None);
293 } else if blob.len() > i32::MAX as usize {
294 error!("write_resp: blob len {} exceed i32 limit", blob.len());
295 Self::_encode_error::<L>(logger, buf, seq, EncodedErr::Rpc(RpcIntErr::Encode));
296 return (seq, None);
297 }
298 Self::_write_head::<L>(logger, buf, 0, seq, msg_len as u32, blob.len() as i32);
299 return (seq, Some(blob));
300 }
301 Err(e) => {
302 Self::_encode_error::<L>(logger, buf, seq, e);
303 return (seq, None);
304 }
305 }
306 }
307
308 #[inline]
309 pub fn encode_internal<'a, L>(
310 logger: &'a L, buf: &'a mut Vec<u8>, seq: u64, err: Option<RpcIntErr>,
311 ) -> u64
312 where
313 L: captains_log::filter::Filter,
314 {
315 debug_assert!(buf.capacity() >= RPC_RESP_HEADER_LEN);
316 unsafe { buf.set_len(RPC_RESP_HEADER_LEN) };
318 if let Some(e) = err {
319 Self::_encode_error::<L>(logger, buf, seq, e.into());
320 return seq;
321 } else {
322 Self::_write_head::<L>(logger, buf, 0, seq, 0, 0);
324 return seq;
325 }
326 }
327
328 #[inline(always)]
329 fn _encode_error<'b, L>(logger: &'b L, buf: &'b mut Vec<u8>, seq: u64, e: EncodedErr)
330 where
331 L: captains_log::filter::Filter,
332 {
333 macro_rules! write_err {
334 ($s: expr) => {
335 Self::_write_head::<L>(
336 logger,
337 buf,
338 RESP_FLAG_HAS_ERR_STRING,
339 seq,
340 0,
341 $s.len() as i32,
342 );
343 buf.write_all($s).expect("fill error str");
344 };
345 }
346 match e {
347 EncodedErr::Num(n) => {
348 Self::_write_head::<L>(logger, buf, RESP_FLAG_HAS_ERRNO, seq, n, 0);
349 }
350 EncodedErr::Rpc(s) => {
351 write_err!(s.as_bytes());
352 }
353 EncodedErr::Buf(s) => {
354 write_err!(&s);
355 }
356 EncodedErr::Static(s) => {
357 write_err!(s.as_bytes());
358 }
359 }
360 }
361
362 #[inline]
363 fn _write_head<L>(logger: &L, buf: &mut [u8], flag: u8, seq: u64, msg_len: u32, blob_len: i32)
364 where
365 L: captains_log::filter::Filter,
366 {
367 let header = Self::mut_from_bytes(&mut buf[0..RPC_RESP_HEADER_LEN]).expect("fill header");
368 header.magic = RPC_MAGIC;
369 header.ver = RPC_VERSION_1;
370 header.flag = flag;
371 header.msg_len.set(msg_len);
372 header.seq.set(seq);
373 header.blob_len.set(blob_len);
374 logger_trace!(logger, "resp {:?}", header);
375 }
376
377 #[inline(always)]
378 pub fn decode_head(head_buf: &[u8]) -> Result<&Self, RpcIntErr> {
379 let head: &Self = Self::ref_from_bytes(head_buf).expect("decode header");
380 if head.magic != RPC_MAGIC {
381 warn!("rpc server: wrong magic receive {:?}", head.magic);
382 return Err(RpcIntErr::IO);
383 }
384 if head.ver != RPC_VERSION_1 {
385 warn!("rpc server: version {} not supported", head.ver);
386 return Err(RpcIntErr::Version);
387 }
388 return Ok(head);
389 }
390}
391
392impl fmt::Display for RespHead {
393 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394 unsafe {
395 write!(
396 f,
397 "[seq:{}, flag:{}, msg:{}, blob:{}]",
398 addr_of!(self.seq).read_unaligned(), addr_of!(self.flag).read_unaligned(),
400 addr_of!(self.msg_len).read_unaligned(),
401 addr_of!(self.blob_len).read_unaligned(),
402 )
403 }
404 }
405}
406
407impl fmt::Debug for RespHead {
408 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
409 fmt::Display::fmt(self, f)
410 }
411}
412
413#[cfg(test)]
414mod tests {
415
416 use super::*;
417
418 #[test]
419 fn test_header_len() {
420 assert_eq!(RPC_REQ_HEADER_LEN, 32);
421 assert_eq!(RPC_RESP_HEADER_LEN, 20);
422 }
423}