zus_common/
codec.rs

1use {
2  bytes::{Buf, BufMut, Bytes, BytesMut},
3  tokio_util::codec::{Decoder, Encoder},
4};
5
6use crate::{
7  compression::Compressor,
8  encryption::DesEncryptor,
9  error::{Result, ZusError},
10  protocol::{RpcMessage, RpcProtocolHeader},
11};
12
13/// RPC Codec for encoding/decoding messages (replacing Java's ZNetProtocolCodecFactory)
14pub struct RpcCodec {
15  max_frame_length: usize,
16  pub compressor: Compressor,
17  encryptor: Option<DesEncryptor>,
18}
19
20impl RpcCodec {
21  pub fn new() -> Self {
22    Self {
23      max_frame_length: 10 * 1024 * 1024, // 10MB default
24      compressor: Compressor::new(),      // 4KB threshold, QuickLZ (matching C++/Java)
25      encryptor: None,
26    }
27  }
28
29  pub fn with_max_frame_length(max_frame_length: usize) -> Self {
30    Self {
31      max_frame_length,
32      compressor: Compressor::new(),
33      encryptor: None,
34    }
35  }
36
37  pub fn with_compressor(compressor: Compressor) -> Self {
38    Self {
39      max_frame_length: 10 * 1024 * 1024,
40      compressor,
41      encryptor: None,
42    }
43  }
44
45  pub fn with_config(max_frame_length: usize, compressor: Compressor) -> Self {
46    Self {
47      max_frame_length,
48      compressor,
49      encryptor: None,
50    }
51  }
52
53  /// Create a new RpcCodec with encryption enabled.
54  ///
55  /// # Arguments
56  /// * `key` - 8-byte DES key for encryption/decryption
57  pub fn with_encryption(key: &[u8]) -> Result<Self> {
58    let encryptor = DesEncryptor::try_new(key)?;
59    Ok(Self {
60      max_frame_length: 10 * 1024 * 1024,
61      compressor: Compressor::new(),
62      encryptor: Some(encryptor),
63    })
64  }
65
66  /// Create a fully configured RpcCodec with all options.
67  ///
68  /// # Arguments
69  /// * `max_frame_length` - Maximum frame size in bytes
70  /// * `compressor` - Compressor for compression/decompression
71  /// * `encryption_key` - Optional 8-byte DES key for encryption
72  pub fn with_full_config(
73    max_frame_length: usize,
74    compressor: Compressor,
75    encryption_key: Option<&[u8]>,
76  ) -> Result<Self> {
77    let encryptor = match encryption_key {
78      | Some(key) => Some(DesEncryptor::try_new(key)?),
79      | None => None,
80    };
81    Ok(Self {
82      max_frame_length,
83      compressor,
84      encryptor,
85    })
86  }
87
88  /// Set the encryption key. Pass `None` to disable encryption.
89  pub fn set_encryption_key(&mut self, key: Option<&[u8]>) -> Result<()> {
90    self.encryptor = match key {
91      | Some(k) => Some(DesEncryptor::try_new(k)?),
92      | None => None,
93    };
94    Ok(())
95  }
96
97  /// Check if encryption is enabled.
98  pub fn is_encryption_enabled(&self) -> bool {
99    self.encryptor.is_some()
100  }
101}
102
103impl Default for RpcCodec {
104  fn default() -> Self {
105    Self::new()
106  }
107}
108
109impl Decoder for RpcCodec {
110  type Error = ZusError;
111  type Item = RpcMessage;
112
113  fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
114    // Need at least header size
115    if src.len() < RpcProtocolHeader::HEADER_SIZE {
116      return Ok(None);
117    }
118
119    // Peek at the header to get body length
120    let mut peek_buf = src.clone();
121    let header = RpcProtocolHeader::decode(&mut peek_buf)?;
122
123    // Check if we have the full message
124    let total_length = RpcProtocolHeader::HEADER_SIZE + header.body_length as usize;
125    if src.len() < total_length {
126      // Reserve space for the full message
127      src.reserve(total_length - src.len());
128      return Ok(None);
129    }
130
131    // Check max frame length
132    if total_length > self.max_frame_length {
133      return Err(ZusError::Protocol(format!(
134        "Frame too large: {} > {}",
135        total_length, self.max_frame_length
136      )));
137    }
138
139    // Now decode the actual header
140    let header = RpcProtocolHeader::decode(src)?;
141
142    // Read body (possibly encrypted and/or compressed)
143    let raw_body = src.split_to(header.body_length as usize).freeze();
144
145    // Verify CRC on raw body (before decryption/decompression)
146    // This matches Java behavior which computes CRC on received data
147    if !header.verify_datacrc(&raw_body) {
148      return Err(ZusError::CrcMismatch);
149    }
150
151    // Step 1: Decrypt if encrypted flag is set
152    // Order: wire data -> decrypt -> decompress -> plaintext
153    let decrypted_body = if header.is_encrypted() {
154      match &self.encryptor {
155        | Some(enc) => {
156          let decrypted = enc.decrypt(&raw_body)?;
157          Bytes::from(decrypted)
158        }
159        | None => {
160          return Err(ZusError::Encryption(
161            "Received encrypted message but no encryption key configured".to_string(),
162          ));
163        }
164      }
165    } else {
166      raw_body
167    };
168
169    // Step 2: Decode body based on message type
170    // For requests/notifications: Parse method first, then decompress params if needed
171    // For responses: Decompress entire body first
172    let (method, body) = if header.msg_type == zus_proto::constants::MSG_TYPE_REQ
173      || header.msg_type == zus_proto::constants::MSG_TYPE_NOTIFY
174    {
175      // Request/Notification wire format: [method_len(4)][method_bytes][params_len(4)][params_bytes]
176      // If compressed, only params_bytes are compressed (matching Java behavior)
177
178      let mut body_buf = decrypted_body;
179
180      // Parse method (always uncompressed)
181      if body_buf.len() < 4 {
182        return Err(ZusError::Protocol("Invalid method name length".to_string()));
183      }
184      let method_len = body_buf.get_u32() as usize;
185      if body_buf.len() < method_len {
186        return Err(ZusError::Protocol("Invalid method name".to_string()));
187      }
188      let method = body_buf.split_to(method_len);
189
190      // Parse params length
191      if body_buf.len() < 4 {
192        return Err(ZusError::Protocol("Invalid params length".to_string()));
193      }
194      let params_len = body_buf.get_u32() as usize;
195      if body_buf.len() < params_len {
196        return Err(ZusError::Protocol("Invalid params data".to_string()));
197      }
198
199      // Get params (possibly compressed)
200      let raw_params = body_buf.split_to(params_len);
201
202      // Decompress params if compressed flag is set
203      let params = if header.is_compressed() {
204        self.compressor.decompress(&raw_params)?
205      } else {
206        raw_params
207      };
208
209      (method, params)
210    } else {
211      // Response - handle encrypted length prefix, then decompress if needed
212      let body_data = if header.is_encrypted() {
213        // For encrypted responses, first 4 bytes are original length
214        let mut body_buf = decrypted_body;
215        if body_buf.len() < 4 {
216          return Err(ZusError::Protocol(
217            "Invalid encrypted response: missing length prefix".to_string(),
218          ));
219        }
220        let original_len = body_buf.get_u32() as usize;
221        if body_buf.len() < original_len {
222          return Err(ZusError::Protocol(format!(
223            "Invalid encrypted response: expected {} bytes, got {}",
224            original_len,
225            body_buf.len()
226          )));
227        }
228        body_buf.split_to(original_len)
229      } else {
230        decrypted_body
231      };
232
233      // Decompress entire body if compressed
234      let body = if header.is_compressed() {
235        self.compressor.decompress(&body_data)?
236      } else {
237        body_data
238      };
239      (Bytes::new(), body)
240    };
241
242    Ok(Some(RpcMessage { header, method, body }))
243  }
244}
245
246impl Encoder<RpcMessage> for RpcCodec {
247  type Error = ZusError;
248
249  fn encode(&mut self, mut item: RpcMessage, dst: &mut BytesMut) -> Result<()> {
250    // Encode body based on message type:
251    // For requests/notifications: [method_len(4)][method_bytes][params_len(4)][params_bytes]
252    //   - Only params_bytes are compressed (matching Java behavior)
253    //   - This allows Java decoder to parse method before decompression
254    // For responses: just [params_bytes] (no length prefix)
255    //   - Entire body can be compressed
256    //
257    // Processing order: plaintext -> compress -> encrypt -> wire
258
259    let is_request = !item.method.is_empty();
260    let mut full_body = BytesMut::new();
261    let was_compressed: bool;
262
263    if is_request {
264      // Request or notification
265      // Step 1: Try to compress ONLY the params data (not method)
266      let (compressed_params, params_compressed) = self.compressor.compress(&item.body)?;
267      was_compressed = params_compressed;
268
269      // Step 2: Encode as [method_len][method][params_len][compressed_or_uncompressed_params]
270      full_body.put_u32(item.method.len() as u32);
271      full_body.put(item.method);
272      full_body.put_u32(compressed_params.len() as u32); // Length of possibly compressed params
273      full_body.put(compressed_params);
274    } else {
275      // Response - compress entire body (no method)
276      let (compressed_body, body_compressed) = self.compressor.compress(&item.body)?;
277      was_compressed = body_compressed;
278      full_body.put(compressed_body);
279    }
280
281    // Step 3: Encrypt the entire body if encryption is enabled
282    // For responses (no method), prepend original body length to handle zero-padding
283    let final_body = if let Some(ref enc) = self.encryptor {
284      let data_to_encrypt = if is_request {
285        // For requests, the format already includes length prefixes
286        full_body.freeze()
287      } else {
288        // For responses, prepend original length (4 bytes) to handle padding
289        let mut with_len = BytesMut::with_capacity(4 + full_body.len());
290        with_len.put_u32(full_body.len() as u32);
291        with_len.put(full_body);
292        with_len.freeze()
293      };
294      let encrypted = enc.encrypt(&data_to_encrypt)?;
295      item.header.set_encrypted(true);
296      Bytes::from(encrypted)
297    } else {
298      item.header.set_encrypted(false);
299      full_body.freeze()
300    };
301
302    // Update header with final body length and compression flag
303    item.header.body_length = final_body.len() as u32;
304    item.header.set_compressed(was_compressed);
305
306    // Calculate CRC on the actual data being sent (encrypted and/or compressed)
307    // This matches Java behavior which computes CRC on the received body data
308    item.header.datacrc = RpcProtocolHeader::calculate_datacrc(&final_body);
309
310    // Reserve space
311    let total_length = RpcProtocolHeader::HEADER_SIZE + final_body.len();
312    dst.reserve(total_length);
313
314    // Encode header
315    item.header.encode(dst);
316
317    // Encode final body (encrypted and/or compressed)
318    dst.put(final_body);
319
320    Ok(())
321  }
322}
323
324#[cfg(test)]
325mod tests {
326  use super::*;
327
328  #[test]
329  fn test_codec_roundtrip() {
330    let mut codec = RpcCodec::new();
331
332    let method = Bytes::from("test.method");
333    let body = Bytes::from("hello");
334    let msg = RpcMessage::new_request(1, method.clone(), body.clone());
335
336    let mut buf = BytesMut::new();
337    codec.encode(msg, &mut buf).unwrap();
338
339    let decoded = codec.decode(&mut buf).unwrap().unwrap();
340    assert_eq!(decoded.header.sequence, 1);
341    assert_eq!(decoded.method, method);
342    assert_eq!(decoded.body, body);
343  }
344
345  #[test]
346  fn test_codec_with_encryption_roundtrip() {
347    let key = b"12345678";
348    let mut codec = RpcCodec::with_encryption(key).unwrap();
349
350    let method = Bytes::from("test.method");
351    let body = Bytes::from("hello encrypted world");
352    let msg = RpcMessage::new_request(1, method.clone(), body.clone());
353
354    let mut buf = BytesMut::new();
355    codec.encode(msg, &mut buf).unwrap();
356
357    // Verify encryption flag is set in header
358    let mut peek_buf = buf.clone();
359    let header = RpcProtocolHeader::decode(&mut peek_buf).unwrap();
360    assert!(header.is_encrypted());
361
362    let decoded = codec.decode(&mut buf).unwrap().unwrap();
363    assert_eq!(decoded.header.sequence, 1);
364    assert_eq!(decoded.method, method);
365    assert_eq!(decoded.body, body);
366  }
367
368  #[test]
369  fn test_codec_encrypted_response() {
370    let key = b"testkey!";
371    let mut codec = RpcCodec::with_encryption(key).unwrap();
372
373    // Test with any body size - padding is handled by prepending original length
374    let body = Bytes::from("response data");
375    let msg = RpcMessage::new_response(42, body.clone());
376
377    let mut buf = BytesMut::new();
378    codec.encode(msg, &mut buf).unwrap();
379
380    let decoded = codec.decode(&mut buf).unwrap().unwrap();
381    assert_eq!(decoded.header.sequence, 42);
382    assert!(decoded.method.is_empty());
383    assert_eq!(decoded.body, body);
384  }
385
386  #[test]
387  fn test_codec_encryption_key_mismatch() {
388    let key1 = b"12345678";
389    let key2 = b"87654321";
390
391    let mut encoder = RpcCodec::with_encryption(key1).unwrap();
392    let mut decoder = RpcCodec::with_encryption(key2).unwrap();
393
394    let method = Bytes::from("test.method");
395    let body = Bytes::from("hello");
396    let msg = RpcMessage::new_request(1, method, body);
397
398    let mut buf = BytesMut::new();
399    encoder.encode(msg, &mut buf).unwrap();
400
401    // Decoding with wrong key should fail (likely produce garbage or error)
402    let result = decoder.decode(&mut buf);
403    // Due to block cipher properties, decryption with wrong key produces garbage
404    // which likely causes protocol parsing errors
405    assert!(result.is_err() || result.unwrap().is_none());
406  }
407
408  #[test]
409  fn test_codec_encrypted_without_key_configured() {
410    // Encode with encryption
411    let key = b"12345678";
412    let mut encoder = RpcCodec::with_encryption(key).unwrap();
413
414    let method = Bytes::from("test.method");
415    let body = Bytes::from("hello");
416    let msg = RpcMessage::new_request(1, method, body);
417
418    let mut buf = BytesMut::new();
419    encoder.encode(msg, &mut buf).unwrap();
420
421    // Try to decode without encryption key configured
422    let mut decoder = RpcCodec::new();
423    let result = decoder.decode(&mut buf);
424    assert!(result.is_err());
425
426    if let Err(ZusError::Encryption(msg)) = result {
427      assert!(msg.contains("no encryption key configured"));
428    } else {
429      panic!("Expected Encryption error");
430    }
431  }
432
433  #[test]
434  fn test_codec_set_encryption_key() {
435    let mut codec = RpcCodec::new();
436    assert!(!codec.is_encryption_enabled());
437
438    codec.set_encryption_key(Some(b"12345678")).unwrap();
439    assert!(codec.is_encryption_enabled());
440
441    codec.set_encryption_key(None).unwrap();
442    assert!(!codec.is_encryption_enabled());
443  }
444
445  #[test]
446  fn test_codec_full_config() {
447    let compressor = Compressor::new();
448    let key = b"mykey123";
449
450    let codec = RpcCodec::with_full_config(5 * 1024 * 1024, compressor, Some(key)).unwrap();
451    assert!(codec.is_encryption_enabled());
452
453    let codec2 = RpcCodec::with_full_config(5 * 1024 * 1024, Compressor::new(), None).unwrap();
454    assert!(!codec2.is_encryption_enabled());
455  }
456}