zus_common/
protocol.rs

1use {
2  bytes::{Buf, BufMut, Bytes, BytesMut},
3  zus_proto::constants::*,
4};
5
6use crate::error::{Result, ZusError};
7
8/// RPC Protocol Header (matching Java version RpcProtHead)
9/// Total size: 40 bytes
10///
11/// Header structure (matching Java RpcProtHead.java):
12/// - magic (2 bytes): 0xD3A7
13/// - version (2 bytes): 0x0000
14/// - datacrc (4 bytes): CRC32 of uncompressed body data
15/// - seq (8 bytes): sequence number for request/response matching
16/// - cmdtype (1 byte): message type (REQ=0, RSP=1, NOTIFY=2, SYSRSP=3)
17/// - flag (1 byte): bit0=compressed, bit1=encrypted, bit2=has_extra_data
18/// - headcrc16 (2 bytes): simple sum checksum of all header fields (excluding headcrc16 itself)
19/// - datalen (4 bytes): body length
20/// - timeout (8 bytes): timeout in milliseconds
21/// - traceid (8 bytes): trace ID or extra data length when bit2 of flag is set
22#[derive(Debug, Clone)]
23pub struct RpcProtocolHeader {
24  /// Magic number: 0xD3A7
25  pub magic: u16,
26  /// Protocol version: 0x0000
27  pub version: u16,
28  /// Data CRC32 checksum (on uncompressed body)
29  pub datacrc: u32,
30  /// Sequence ID for request/response matching
31  pub sequence: u64,
32  /// Message type: REQ/RSP/NOTIFY/SYSRSP
33  pub msg_type: u8,
34  /// Flags: bit 0 = compressed, bit 1 = encrypted, bit 2 = has extra data
35  pub flags: u8,
36  /// Header CRC16 checksum (simple sum of header fields)
37  pub headcrc16: u16,
38  /// Body length
39  pub body_length: u32,
40  /// Timeout in milliseconds
41  pub timeout: u64,
42  /// Trace ID / extra data length
43  pub traceid: u64,
44}
45
46impl RpcProtocolHeader {
47  pub const HEADER_SIZE: usize = 40;
48
49  pub fn new(sequence: u64, msg_type: u8) -> Self {
50    Self {
51      magic: RPC_MAGIC,
52      version: 0,
53      datacrc: 0,
54      sequence,
55      msg_type,
56      flags: 0,
57      headcrc16: 0,
58      body_length: 0,
59      timeout: 10000, // 10 seconds default
60      traceid: 0,
61    }
62  }
63
64  pub fn is_compressed(&self) -> bool {
65    (self.flags & FLAG_COMPRESSED) != 0
66  }
67
68  pub fn is_encrypted(&self) -> bool {
69    (self.flags & FLAG_ENCRYPTED) != 0
70  }
71
72  pub fn has_extra_data(&self) -> bool {
73    (self.flags & 0x04) != 0
74  }
75
76  pub fn set_compressed(&mut self, compressed: bool) {
77    if compressed {
78      self.flags |= FLAG_COMPRESSED;
79    } else {
80      self.flags &= !FLAG_COMPRESSED;
81    }
82  }
83
84  pub fn set_encrypted(&mut self, encrypted: bool) {
85    if encrypted {
86      self.flags |= FLAG_ENCRYPTED;
87    } else {
88      self.flags &= !FLAG_ENCRYPTED;
89    }
90  }
91
92  pub fn set_extra_data(&mut self, has_extra: bool) {
93    if has_extra {
94      self.flags |= 0x04;
95    } else {
96      self.flags &= !0x04;
97    }
98  }
99
100  /// Calculate header CRC16 as a simple sum of all header fields (excluding headcrc16 itself)
101  /// This matches the Java implementation in RpcProtHead.packData()
102  /// Java truncates each field to 16 bits when adding to short
103  fn calculate_headcrc16(&self) -> u16 {
104    let mut crc: u16 = 0;
105
106    // Add all fields except headcrc16 itself
107    // Java adds each field and auto-truncates to 16 bits
108    crc = crc.wrapping_add(self.magic);
109    crc = crc.wrapping_add(self.version);
110    crc = crc.wrapping_add(self.datacrc as u16); // Truncate to lower 16 bits
111    crc = crc.wrapping_add(self.sequence as u16); // Truncate to lower 16 bits
112    crc = crc.wrapping_add(self.msg_type as u16);
113    crc = crc.wrapping_add(self.flags as u16);
114    crc = crc.wrapping_add(self.body_length as u16); // Truncate to lower 16 bits
115    crc = crc.wrapping_add(self.timeout as u16); // Truncate to lower 16 bits
116    crc = crc.wrapping_add(self.traceid as u16); // Truncate to lower 16 bits
117
118    crc
119  }
120
121  /// Encode header to bytes
122  /// Field order matches Java RpcProtHead.packData():
123  /// magic, version, datacrc, seq, cmdtype, flag, headcrc16, datalen, timeout, traceid
124  pub fn encode(&self, buf: &mut BytesMut) {
125    // Calculate headcrc16 before encoding
126    let headcrc16 = self.calculate_headcrc16();
127
128    // Encode in the exact order as Java
129    buf.put_u16(self.magic); // 2 bytes
130    buf.put_u16(self.version); // 2 bytes
131    buf.put_u32(self.datacrc); // 4 bytes
132    buf.put_u64(self.sequence); // 8 bytes
133    buf.put_u8(self.msg_type); // 1 byte
134    buf.put_u8(self.flags); // 1 byte
135    buf.put_u16(headcrc16); // 2 bytes (calculated)
136    buf.put_u32(self.body_length); // 4 bytes
137    buf.put_u64(self.timeout); // 8 bytes
138    buf.put_u64(self.traceid); // 8 bytes
139  }
140
141  /// Decode header from bytes
142  /// Field order matches Java RpcProtHead.unpackData():
143  /// magic, version, datacrc, seq, cmdtype, flag, headcrc16, datalen, timeout, traceid
144  pub fn decode(buf: &mut BytesMut) -> Result<Self> {
145    if buf.len() < Self::HEADER_SIZE {
146      return Err(ZusError::Protocol(format!(
147        "Insufficient data for header: got {} bytes, need {}",
148        buf.len(),
149        Self::HEADER_SIZE
150      )));
151    }
152
153    // Decode in the exact order as Java
154    let magic = buf.get_u16();
155    let version = buf.get_u16();
156
157    // Check magic number immediately
158    if magic != RPC_MAGIC {
159      return Err(ZusError::InvalidMagic(magic));
160    }
161
162    let datacrc = buf.get_u32();
163    let sequence = buf.get_u64();
164    let msg_type = buf.get_u8();
165    let flags = buf.get_u8();
166    let headcrc16 = buf.get_u16();
167    let body_length = buf.get_u32();
168    let timeout = buf.get_u64();
169    let traceid = buf.get_u64();
170
171    // Create the header
172    let header = Self {
173      magic,
174      version,
175      datacrc,
176      sequence,
177      msg_type,
178      flags,
179      headcrc16,
180      body_length,
181      timeout,
182      traceid,
183    };
184
185    // Verify headcrc16
186    let calculated_crc = header.calculate_headcrc16();
187    if calculated_crc != headcrc16 {
188      return Err(ZusError::Protocol(format!(
189        "Invalid header CRC16: expected {calculated_crc:#06x}, got {headcrc16:#06x}"
190      )));
191    }
192
193    Ok(header)
194  }
195
196  /// Calculate CRC32 for body data (uncompressed)
197  /// This matches the Java implementation which uses standard CRC32
198  pub fn calculate_datacrc(data: &[u8]) -> u32 {
199    let crc = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
200    let mut digest = crc.digest();
201    digest.update(data);
202    digest.finalize()
203  }
204
205  /// Verify data CRC32
206  pub fn verify_datacrc(&self, data: &[u8]) -> bool {
207    let calculated = Self::calculate_datacrc(data);
208    calculated == self.datacrc
209  }
210}
211
212/// RPC Message
213#[derive(Debug, Clone)]
214pub struct RpcMessage {
215  pub header: RpcProtocolHeader,
216  pub method: Bytes,
217  pub body: Bytes,
218}
219
220impl RpcMessage {
221  /// Create a new request message
222  pub fn new_request(sequence: u64, method: Bytes, body: Bytes) -> Self {
223    let mut header = RpcProtocolHeader::new(sequence, MSG_TYPE_REQ);
224    header.body_length = (method.len() + body.len()) as u32;
225    header.datacrc = RpcProtocolHeader::calculate_datacrc(&body);
226
227    Self { header, method, body }
228  }
229
230  /// Create a new response message
231  pub fn new_response(sequence: u64, body: Bytes) -> Self {
232    let mut header = RpcProtocolHeader::new(sequence, MSG_TYPE_RSP);
233    header.body_length = body.len() as u32;
234    header.datacrc = RpcProtocolHeader::calculate_datacrc(&body);
235
236    Self {
237      header,
238      method: Bytes::new(),
239      body,
240    }
241  }
242
243  /// Create a new notify message
244  pub fn new_notify(sequence: u64, method: Bytes, body: Bytes) -> Self {
245    let mut header = RpcProtocolHeader::new(sequence, MSG_TYPE_NOTIFY);
246    header.body_length = (method.len() + body.len()) as u32;
247    header.datacrc = RpcProtocolHeader::calculate_datacrc(&body);
248
249    Self { header, method, body }
250  }
251
252  /// Create a new system response message (for error responses when server didn't finish)
253  pub fn new_sysrsp(sequence: u64, body: Bytes) -> Self {
254    let mut header = RpcProtocolHeader::new(sequence, MSG_TYPE_SYSRSP);
255    header.body_length = body.len() as u32;
256    header.datacrc = RpcProtocolHeader::calculate_datacrc(&body);
257
258    Self {
259      header,
260      method: Bytes::new(),
261      body,
262    }
263  }
264}
265
266#[cfg(test)]
267mod tests {
268  use super::*;
269
270  #[test]
271  fn test_header_encode_decode() {
272    let mut header = RpcProtocolHeader::new(123, MSG_TYPE_REQ);
273    header.body_length = 100;
274    header.datacrc = 0x12345678;
275
276    let mut buf = BytesMut::new();
277    header.encode(&mut buf);
278
279    assert_eq!(buf.len(), RpcProtocolHeader::HEADER_SIZE);
280
281    let decoded = RpcProtocolHeader::decode(&mut buf).unwrap();
282    assert_eq!(decoded.magic, RPC_MAGIC);
283    assert_eq!(decoded.sequence, 123);
284    assert_eq!(decoded.msg_type, MSG_TYPE_REQ);
285    assert_eq!(decoded.body_length, 100);
286    assert_eq!(decoded.datacrc, 0x12345678);
287  }
288
289  #[test]
290  fn test_datacrc() {
291    let data = b"hello world";
292    let crc = RpcProtocolHeader::calculate_datacrc(data);
293    assert!(crc > 0);
294
295    let mut header = RpcProtocolHeader::new(1, MSG_TYPE_REQ);
296    header.datacrc = crc;
297    assert!(header.verify_datacrc(data));
298  }
299
300  #[test]
301  fn test_headcrc16_calculation() {
302    // Create a header with known values
303    let mut header = RpcProtocolHeader::new(0x0102030405060708, MSG_TYPE_REQ);
304    header.datacrc = 0x12345678;
305    header.body_length = 1000;
306    header.timeout = 5000;
307    header.traceid = 999;
308
309    // Calculate CRC
310    let crc = header.calculate_headcrc16();
311
312    // CRC should be non-zero for non-zero data
313    assert!(crc > 0);
314
315    // Encode and decode should preserve the CRC
316    let mut buf = BytesMut::new();
317    header.encode(&mut buf);
318    let decoded = RpcProtocolHeader::decode(&mut buf).unwrap();
319
320    assert_eq!(decoded.headcrc16, crc);
321  }
322
323  #[test]
324  fn test_invalid_magic() {
325    let mut buf = BytesMut::new();
326    buf.put_u16(0xFFFF); // Invalid magic
327    buf.put_u16(0);
328    buf.put_u32(0);
329    buf.put_u64(0);
330    buf.put_u8(0);
331    buf.put_u8(0);
332    buf.put_u16(0);
333    buf.put_u32(0);
334    buf.put_u64(0);
335    buf.put_u64(0);
336
337    let result = RpcProtocolHeader::decode(&mut buf);
338    assert!(result.is_err());
339    match result {
340      | Err(ZusError::InvalidMagic(magic)) => assert_eq!(magic, 0xFFFF),
341      | _ => panic!("Expected InvalidMagic error"),
342    }
343  }
344
345  #[test]
346  fn test_invalid_headcrc16() {
347    let mut header = RpcProtocolHeader::new(123, MSG_TYPE_REQ);
348    header.datacrc = 0x12345678;
349
350    // Encode the header
351    let mut buf = BytesMut::new();
352    header.encode(&mut buf);
353
354    // Corrupt the headcrc16 field (at bytes 18-19)
355    buf[18] = 0xFF;
356    buf[19] = 0xFF;
357
358    // Should fail to decode due to CRC mismatch
359    let result = RpcProtocolHeader::decode(&mut buf);
360    assert!(result.is_err());
361    match result {
362      | Err(ZusError::Protocol(msg)) => assert!(msg.contains("Invalid header CRC16")),
363      | _ => panic!("Expected Protocol error for invalid CRC16"),
364    }
365  }
366
367  #[test]
368  fn test_flags() {
369    let mut header = RpcProtocolHeader::new(1, MSG_TYPE_REQ);
370
371    // Test compressed flag
372    assert!(!header.is_compressed());
373    header.set_compressed(true);
374    assert!(header.is_compressed());
375    assert_eq!(header.flags & FLAG_COMPRESSED, FLAG_COMPRESSED);
376
377    // Test encrypted flag
378    assert!(!header.is_encrypted());
379    header.set_encrypted(true);
380    assert!(header.is_encrypted());
381    assert_eq!(header.flags & FLAG_ENCRYPTED, FLAG_ENCRYPTED);
382
383    // Test extra data flag
384    assert!(!header.has_extra_data());
385    header.set_extra_data(true);
386    assert!(header.has_extra_data());
387    assert_eq!(header.flags & 0x04, 0x04);
388
389    // All flags should be set
390    assert_eq!(header.flags, 0x07);
391  }
392
393  #[test]
394  fn test_rpc_message_creation() {
395    let method = Bytes::from("testMethod");
396    let body = Bytes::from("test body data");
397
398    let msg = RpcMessage::new_request(123, method.clone(), body.clone());
399
400    assert_eq!(msg.header.sequence, 123);
401    assert_eq!(msg.header.msg_type, MSG_TYPE_REQ);
402    assert_eq!(msg.header.body_length, (method.len() + body.len()) as u32);
403    assert!(msg.header.verify_datacrc(&body));
404  }
405}