1use {
2 bytes::{Buf, BufMut, Bytes, BytesMut},
3 zus_proto::constants::*,
4};
5
6use crate::error::{Result, ZusError};
7
8#[derive(Debug, Clone)]
23pub struct RpcProtocolHeader {
24 pub magic: u16,
26 pub version: u16,
28 pub datacrc: u32,
30 pub sequence: u64,
32 pub msg_type: u8,
34 pub flags: u8,
36 pub headcrc16: u16,
38 pub body_length: u32,
40 pub timeout: u64,
42 pub traceid: u64,
44}
45
46impl RpcProtocolHeader {
47 pub const HEADER_SIZE: usize = 40;
48
49 pub fn new(sequence: u64, msg_type: u8) -> Self {
50 Self {
51 magic: RPC_MAGIC,
52 version: 0,
53 datacrc: 0,
54 sequence,
55 msg_type,
56 flags: 0,
57 headcrc16: 0,
58 body_length: 0,
59 timeout: 10000, traceid: 0,
61 }
62 }
63
64 pub fn is_compressed(&self) -> bool {
65 (self.flags & FLAG_COMPRESSED) != 0
66 }
67
68 pub fn is_encrypted(&self) -> bool {
69 (self.flags & FLAG_ENCRYPTED) != 0
70 }
71
72 pub fn has_extra_data(&self) -> bool {
73 (self.flags & 0x04) != 0
74 }
75
76 pub fn set_compressed(&mut self, compressed: bool) {
77 if compressed {
78 self.flags |= FLAG_COMPRESSED;
79 } else {
80 self.flags &= !FLAG_COMPRESSED;
81 }
82 }
83
84 pub fn set_encrypted(&mut self, encrypted: bool) {
85 if encrypted {
86 self.flags |= FLAG_ENCRYPTED;
87 } else {
88 self.flags &= !FLAG_ENCRYPTED;
89 }
90 }
91
92 pub fn set_extra_data(&mut self, has_extra: bool) {
93 if has_extra {
94 self.flags |= 0x04;
95 } else {
96 self.flags &= !0x04;
97 }
98 }
99
100 fn calculate_headcrc16(&self) -> u16 {
104 let mut crc: u16 = 0;
105
106 crc = crc.wrapping_add(self.magic);
109 crc = crc.wrapping_add(self.version);
110 crc = crc.wrapping_add(self.datacrc as u16); crc = crc.wrapping_add(self.sequence as u16); crc = crc.wrapping_add(self.msg_type as u16);
113 crc = crc.wrapping_add(self.flags as u16);
114 crc = crc.wrapping_add(self.body_length as u16); crc = crc.wrapping_add(self.timeout as u16); crc = crc.wrapping_add(self.traceid as u16); crc
119 }
120
121 pub fn encode(&self, buf: &mut BytesMut) {
125 let headcrc16 = self.calculate_headcrc16();
127
128 buf.put_u16(self.magic); buf.put_u16(self.version); buf.put_u32(self.datacrc); buf.put_u64(self.sequence); buf.put_u8(self.msg_type); buf.put_u8(self.flags); buf.put_u16(headcrc16); buf.put_u32(self.body_length); buf.put_u64(self.timeout); buf.put_u64(self.traceid); }
140
141 pub fn decode(buf: &mut BytesMut) -> Result<Self> {
145 if buf.len() < Self::HEADER_SIZE {
146 return Err(ZusError::Protocol(format!(
147 "Insufficient data for header: got {} bytes, need {}",
148 buf.len(),
149 Self::HEADER_SIZE
150 )));
151 }
152
153 let magic = buf.get_u16();
155 let version = buf.get_u16();
156
157 if magic != RPC_MAGIC {
159 return Err(ZusError::InvalidMagic(magic));
160 }
161
162 let datacrc = buf.get_u32();
163 let sequence = buf.get_u64();
164 let msg_type = buf.get_u8();
165 let flags = buf.get_u8();
166 let headcrc16 = buf.get_u16();
167 let body_length = buf.get_u32();
168 let timeout = buf.get_u64();
169 let traceid = buf.get_u64();
170
171 let header = Self {
173 magic,
174 version,
175 datacrc,
176 sequence,
177 msg_type,
178 flags,
179 headcrc16,
180 body_length,
181 timeout,
182 traceid,
183 };
184
185 let calculated_crc = header.calculate_headcrc16();
187 if calculated_crc != headcrc16 {
188 return Err(ZusError::Protocol(format!(
189 "Invalid header CRC16: expected {calculated_crc:#06x}, got {headcrc16:#06x}"
190 )));
191 }
192
193 Ok(header)
194 }
195
196 pub fn calculate_datacrc(data: &[u8]) -> u32 {
199 let crc = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
200 let mut digest = crc.digest();
201 digest.update(data);
202 digest.finalize()
203 }
204
205 pub fn verify_datacrc(&self, data: &[u8]) -> bool {
207 let calculated = Self::calculate_datacrc(data);
208 calculated == self.datacrc
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct RpcMessage {
215 pub header: RpcProtocolHeader,
216 pub method: Bytes,
217 pub body: Bytes,
218}
219
220impl RpcMessage {
221 pub fn new_request(sequence: u64, method: Bytes, body: Bytes) -> Self {
223 let mut header = RpcProtocolHeader::new(sequence, MSG_TYPE_REQ);
224 header.body_length = (method.len() + body.len()) as u32;
225 header.datacrc = RpcProtocolHeader::calculate_datacrc(&body);
226
227 Self { header, method, body }
228 }
229
230 pub fn new_response(sequence: u64, body: Bytes) -> Self {
232 let mut header = RpcProtocolHeader::new(sequence, MSG_TYPE_RSP);
233 header.body_length = body.len() as u32;
234 header.datacrc = RpcProtocolHeader::calculate_datacrc(&body);
235
236 Self {
237 header,
238 method: Bytes::new(),
239 body,
240 }
241 }
242
243 pub fn new_notify(sequence: u64, method: Bytes, body: Bytes) -> Self {
245 let mut header = RpcProtocolHeader::new(sequence, MSG_TYPE_NOTIFY);
246 header.body_length = (method.len() + body.len()) as u32;
247 header.datacrc = RpcProtocolHeader::calculate_datacrc(&body);
248
249 Self { header, method, body }
250 }
251
252 pub fn new_sysrsp(sequence: u64, body: Bytes) -> Self {
254 let mut header = RpcProtocolHeader::new(sequence, MSG_TYPE_SYSRSP);
255 header.body_length = body.len() as u32;
256 header.datacrc = RpcProtocolHeader::calculate_datacrc(&body);
257
258 Self {
259 header,
260 method: Bytes::new(),
261 body,
262 }
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_header_encode_decode() {
272 let mut header = RpcProtocolHeader::new(123, MSG_TYPE_REQ);
273 header.body_length = 100;
274 header.datacrc = 0x12345678;
275
276 let mut buf = BytesMut::new();
277 header.encode(&mut buf);
278
279 assert_eq!(buf.len(), RpcProtocolHeader::HEADER_SIZE);
280
281 let decoded = RpcProtocolHeader::decode(&mut buf).unwrap();
282 assert_eq!(decoded.magic, RPC_MAGIC);
283 assert_eq!(decoded.sequence, 123);
284 assert_eq!(decoded.msg_type, MSG_TYPE_REQ);
285 assert_eq!(decoded.body_length, 100);
286 assert_eq!(decoded.datacrc, 0x12345678);
287 }
288
289 #[test]
290 fn test_datacrc() {
291 let data = b"hello world";
292 let crc = RpcProtocolHeader::calculate_datacrc(data);
293 assert!(crc > 0);
294
295 let mut header = RpcProtocolHeader::new(1, MSG_TYPE_REQ);
296 header.datacrc = crc;
297 assert!(header.verify_datacrc(data));
298 }
299
300 #[test]
301 fn test_headcrc16_calculation() {
302 let mut header = RpcProtocolHeader::new(0x0102030405060708, MSG_TYPE_REQ);
304 header.datacrc = 0x12345678;
305 header.body_length = 1000;
306 header.timeout = 5000;
307 header.traceid = 999;
308
309 let crc = header.calculate_headcrc16();
311
312 assert!(crc > 0);
314
315 let mut buf = BytesMut::new();
317 header.encode(&mut buf);
318 let decoded = RpcProtocolHeader::decode(&mut buf).unwrap();
319
320 assert_eq!(decoded.headcrc16, crc);
321 }
322
323 #[test]
324 fn test_invalid_magic() {
325 let mut buf = BytesMut::new();
326 buf.put_u16(0xFFFF); buf.put_u16(0);
328 buf.put_u32(0);
329 buf.put_u64(0);
330 buf.put_u8(0);
331 buf.put_u8(0);
332 buf.put_u16(0);
333 buf.put_u32(0);
334 buf.put_u64(0);
335 buf.put_u64(0);
336
337 let result = RpcProtocolHeader::decode(&mut buf);
338 assert!(result.is_err());
339 match result {
340 | Err(ZusError::InvalidMagic(magic)) => assert_eq!(magic, 0xFFFF),
341 | _ => panic!("Expected InvalidMagic error"),
342 }
343 }
344
345 #[test]
346 fn test_invalid_headcrc16() {
347 let mut header = RpcProtocolHeader::new(123, MSG_TYPE_REQ);
348 header.datacrc = 0x12345678;
349
350 let mut buf = BytesMut::new();
352 header.encode(&mut buf);
353
354 buf[18] = 0xFF;
356 buf[19] = 0xFF;
357
358 let result = RpcProtocolHeader::decode(&mut buf);
360 assert!(result.is_err());
361 match result {
362 | Err(ZusError::Protocol(msg)) => assert!(msg.contains("Invalid header CRC16")),
363 | _ => panic!("Expected Protocol error for invalid CRC16"),
364 }
365 }
366
367 #[test]
368 fn test_flags() {
369 let mut header = RpcProtocolHeader::new(1, MSG_TYPE_REQ);
370
371 assert!(!header.is_compressed());
373 header.set_compressed(true);
374 assert!(header.is_compressed());
375 assert_eq!(header.flags & FLAG_COMPRESSED, FLAG_COMPRESSED);
376
377 assert!(!header.is_encrypted());
379 header.set_encrypted(true);
380 assert!(header.is_encrypted());
381 assert_eq!(header.flags & FLAG_ENCRYPTED, FLAG_ENCRYPTED);
382
383 assert!(!header.has_extra_data());
385 header.set_extra_data(true);
386 assert!(header.has_extra_data());
387 assert_eq!(header.flags & 0x04, 0x04);
388
389 assert_eq!(header.flags, 0x07);
391 }
392
393 #[test]
394 fn test_rpc_message_creation() {
395 let method = Bytes::from("testMethod");
396 let body = Bytes::from("test body data");
397
398 let msg = RpcMessage::new_request(123, method.clone(), body.clone());
399
400 assert_eq!(msg.header.sequence, 123);
401 assert_eq!(msg.header.msg_type, MSG_TYPE_REQ);
402 assert_eq!(msg.header.body_length, (method.len() + body.len()) as u32);
403 assert!(msg.header.verify_datacrc(&body));
404 }
405}