rocketmq_remoting/protocol/
rocketmq_serializable.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use std::collections::HashMap;
19use std::str;
20
21use bytes::Buf;
22use bytes::BufMut;
23use bytes::Bytes;
24use bytes::BytesMut;
25use cheetah_string::CheetahString;
26use rocketmq_error::RocketmqError;
27
28use crate::protocol::remoting_command::RemotingCommand;
29use crate::protocol::LanguageCode;
30
31pub struct RocketMQSerializable;
32
33impl RocketMQSerializable {
34    /// Optimized string write with inline hint for better performance
35    #[inline]
36    pub fn write_str(buf: &mut BytesMut, use_short_length: bool, s: &str) -> usize {
37        let bytes = s.as_bytes();
38        let len = bytes.len();
39
40        let length_size = if use_short_length {
41            buf.put_u16(len as u16);
42            2
43        } else {
44            buf.put_u32(len as u32);
45            4
46        };
47
48        buf.put_slice(bytes); // Use put_slice for better performance
49        length_size + len
50    }
51
52    /// Optimized string read with enhanced boundary checks and zero-copy
53    #[inline]
54    pub fn read_str(
55        buf: &mut BytesMut,
56        use_short_length: bool,
57        limit: usize,
58    ) -> rocketmq_error::RocketMQResult<Option<CheetahString>> {
59        // Read length prefix
60        let len = if use_short_length {
61            if buf.remaining() < 2 {
62                return Err(RocketmqError::DecodingError(2, buf.remaining()).into());
63            }
64            buf.get_u16() as usize
65        } else {
66            if buf.remaining() < 4 {
67                return Err(RocketmqError::DecodingError(4, buf.remaining()).into());
68            }
69            buf.get_u32() as usize
70        };
71
72        // Empty string
73        if len == 0 {
74            return Ok(None);
75        }
76
77        // Boundary check
78        if len > limit {
79            return Err(RocketmqError::DecodingError(len, limit).into());
80        }
81
82        // Ensure buffer has enough data
83        if buf.remaining() < len {
84            return Err(RocketmqError::DecodingError(len, buf.remaining()).into());
85        }
86
87        // Zero-copy split and freeze
88        let bytes = buf.split_to(len).freeze();
89        Ok(Some(CheetahString::from_bytes(bytes)))
90    }
91
92    /// Optimized ROCKETMQ protocol encoding with reduced allocations
93    #[inline]
94    pub fn rocketmq_protocol_encode(cmd: &mut RemotingCommand, buf: &mut BytesMut) -> usize {
95        let begin_index = buf.len();
96
97        // Estimate required capacity and reserve upfront to reduce reallocations
98        let estimated_size = Self::estimate_encode_size(cmd);
99        buf.reserve(estimated_size);
100
101        // Write fixed-size header fields (total: 15 bytes)
102        buf.put_u16(cmd.code() as u16); // 2 bytes
103        buf.put_u8(cmd.language().get_code()); // 1 byte
104        buf.put_u16(cmd.version() as u16); // 2 bytes
105        buf.put_i32(cmd.opaque()); // 4 bytes
106        buf.put_i32(cmd.flag()); // 4 bytes
107
108        // Write remark (variable length with 4-byte prefix or 0)
109        if let Some(remark) = cmd.remark() {
110            Self::write_str(buf, false, remark.as_str());
111        } else {
112            buf.put_i32(0);
113        }
114
115        // Reserve space for ext_fields length (will be updated later)
116        let map_len_index = buf.len();
117        buf.put_i32(0);
118
119        // Encode custom header if it supports fast codec
120        if let Some(header) = cmd.command_custom_header_mut() {
121            if header.support_fast_codec() {
122                header.encode_fast(buf);
123            }
124        }
125
126        // Encode ext_fields map
127        if let Some(ext_fields) = cmd.ext_fields() {
128            for (k, v) in ext_fields.iter() {
129                // Skip empty keys/values
130                if !k.is_empty() && !v.is_empty() {
131                    Self::write_str(buf, true, k.as_str());
132                    Self::write_str(buf, true, v.as_str());
133                }
134            }
135        }
136
137        // Update ext_fields length in-place
138        let current_length = buf.len();
139        let ext_fields_length = (current_length - map_len_index - 4) as i32;
140        buf[map_len_index..map_len_index + 4].copy_from_slice(&ext_fields_length.to_be_bytes());
141
142        buf.len() - begin_index
143    }
144
145    /// Estimate the size needed for encoding to reduce reallocations
146    #[inline]
147    fn estimate_encode_size(cmd: &RemotingCommand) -> usize {
148        let mut size = 15; // Fixed header: code(2) + language(1) + version(2) + opaque(4) + flag(4) + map_len(4)
149
150        // Remark size
151        if let Some(remark) = cmd.remark() {
152            size += 4 + remark.len(); // length prefix + data
153        } else {
154            size += 4; // just the length prefix (0)
155        }
156
157        // Ext fields size (approximate)
158        if let Some(ext) = cmd.ext_fields() {
159            for (k, v) in ext.iter() {
160                if !k.is_empty() && !v.is_empty() {
161                    size += 2 + k.len() + 2 + v.len(); // short length prefix for both
162                }
163            }
164        }
165
166        size
167    }
168
169    pub fn rocket_mq_protocol_encode_bytes(cmd: &RemotingCommand) -> Bytes {
170        let remark_bytes = cmd.remark().map(|remark| remark.as_bytes().to_vec());
171        let remark_len = remark_bytes.as_ref().map_or(0, |v| v.len());
172
173        let ext_fields_bytes = if let Some(ext) = cmd.get_ext_fields() {
174            Self::map_serialize(ext)
175        } else {
176            None
177        };
178        let ext_len = ext_fields_bytes.as_ref().map_or(0, |v| v.len());
179
180        let total_len = Self::cal_total_len(remark_len, ext_len);
181        let mut header_buffer = BytesMut::with_capacity(total_len);
182
183        // int code (~32767)
184        header_buffer.put_i16(cmd.code() as i16);
185
186        // LanguageCode language
187        header_buffer.put_u8(cmd.language().get_code());
188
189        // int version (~32767)
190        header_buffer.put_i16(cmd.version() as i16);
191
192        // int opaque
193        header_buffer.put_i32(cmd.opaque());
194
195        // int flag
196        header_buffer.put_i32(cmd.flag());
197
198        // String remark
199        if let Some(remark_bytes) = remark_bytes {
200            header_buffer.put_i32(remark_bytes.len() as i32);
201            header_buffer.put(remark_bytes.as_ref());
202        } else {
203            header_buffer.put_i32(0);
204        }
205
206        // HashMap<String, String> extFields
207        if let Some(ext_fields_bytes) = ext_fields_bytes {
208            header_buffer.put_i32(ext_fields_bytes.len() as i32);
209            header_buffer.put(ext_fields_bytes.as_ref());
210        } else {
211            header_buffer.put_i32(0);
212        }
213
214        header_buffer.freeze()
215    }
216
217    /// Optimized map serialization with pre-calculated capacity
218    #[inline]
219    pub fn map_serialize(map: &HashMap<CheetahString, CheetahString>) -> Option<BytesMut> {
220        if map.is_empty() {
221            return None;
222        }
223
224        // Pre-calculate total length in a single pass
225        let mut total_length = 0;
226        let mut valid_entries = 0;
227
228        for (key, value) in map.iter() {
229            if !key.is_empty() && !value.is_empty() {
230                total_length += 2 + key.len() + 4 + value.len();
231                valid_entries += 1;
232            }
233        }
234
235        if valid_entries == 0 {
236            return None;
237        }
238
239        // Allocate exact capacity (avoid reallocations)
240        let mut content = BytesMut::with_capacity(total_length);
241
242        // Serialize entries
243        for (key, value) in map.iter() {
244            if !key.is_empty() && !value.is_empty() {
245                // Write key: u16 length + bytes
246                content.put_u16(key.len() as u16);
247                content.put_slice(key.as_bytes());
248
249                // Write value: i32 length + bytes
250                content.put_i32(value.len() as i32);
251                content.put_slice(value.as_bytes());
252            }
253        }
254
255        Some(content)
256    }
257
258    pub fn cal_total_len(remark_len: usize, ext_len: usize) -> usize {
259        // int code(~32767): 2 bytes
260        // LanguageCode language: 1 byte
261        // int version(~32767): 2 bytes
262        // int opaque: 4 bytes
263        // int flag: 4 bytes
264        // String remark length: 4 bytes + actual length of remark
265        // HashMap<String, String> extFields length: 4 bytes + actual length of extFields
266
267        2   // int code
268             + 1          // LanguageCode language
269             + 2          // int version
270             + 4          // int opaque
271             + 4          // int flag
272             + 4 + remark_len   // String remark
273             + 4 + ext_len // HashMap<String, String> extFields
274    }
275
276    pub fn rocket_mq_protocol_decode(
277        header_buffer: &mut BytesMut,
278        header_len: usize,
279    ) -> rocketmq_error::RocketMQResult<RemotingCommand> {
280        let cmd = RemotingCommand::default()
281            .set_code(header_buffer.get_i16())
282            .set_language(LanguageCode::value_of(header_buffer.get_u8()).unwrap())
283            .set_version(header_buffer.get_i16() as i32)
284            .set_opaque(header_buffer.get_i32())
285            .set_flag(header_buffer.get_i32());
286
287        let remark = Self::read_str(header_buffer, false, header_len)?;
288
289        // HashMap<String, String> extFields
290        let ext_fields_length = header_buffer.get_i32() as usize;
291        let ext = if ext_fields_length > 0 {
292            if ext_fields_length > header_len {
293                return Err(RocketmqError::DecodingError(ext_fields_length, header_len).into());
294            }
295            Self::map_deserialize(header_buffer, ext_fields_length)?
296        } else {
297            HashMap::new()
298        };
299
300        Ok(cmd.set_remark_option(remark).set_ext_fields(ext))
301    }
302
303    /// Optimized map deserialization with capacity hint and better error handling
304    #[inline]
305    pub fn map_deserialize(
306        buffer: &mut BytesMut,
307        len: usize,
308    ) -> rocketmq_error::RocketMQResult<HashMap<CheetahString, CheetahString>> {
309        if len == 0 {
310            return Ok(HashMap::new());
311        }
312
313        // Pre-allocate HashMap with estimated capacity (assume ~50 bytes per entry)
314        let estimated_entries = (len / 50).clamp(4, 1024);
315        let mut map = HashMap::with_capacity(estimated_entries);
316
317        let target_remaining = buffer.remaining().saturating_sub(len);
318
319        while buffer.remaining() > target_remaining {
320            // Read key (short length prefix)
321            let key = Self::read_str(buffer, true, len)?
322                .ok_or_else(|| RocketmqError::DecodingError(0, 0))?;
323
324            // Read value (long length prefix)
325            let value = Self::read_str(buffer, false, len)?
326                .ok_or_else(|| RocketmqError::DecodingError(0, 0))?;
327
328            map.insert(key, value);
329        }
330
331        Ok(map)
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use bytes::BytesMut;
338
339    use super::*;
340
341    #[test]
342    fn write_str_short_length() {
343        let mut buf = BytesMut::new();
344        let written = RocketMQSerializable::write_str(&mut buf, true, "test");
345        assert_eq!(written, 6);
346        assert_eq!(buf, BytesMut::from(&[0, 4, 116, 101, 115, 116][..]));
347    }
348
349    #[test]
350    fn write_str_long_length() {
351        let mut buf = BytesMut::new();
352        let written = RocketMQSerializable::write_str(&mut buf, false, "test");
353        assert_eq!(written, 8);
354        assert_eq!(buf, BytesMut::from(&[0, 0, 0, 4, 116, 101, 115, 116][..]));
355    }
356
357    #[test]
358    fn read_str_short_length() {
359        let mut buf = BytesMut::from(&[0, 4, 116, 101, 115, 116][..]);
360        let read = RocketMQSerializable::read_str(&mut buf, true, 10).unwrap();
361        assert_eq!(read, Some("test".into()));
362    }
363
364    #[test]
365    fn read_str_long_length() {
366        let mut buf = BytesMut::from(&[0, 0, 0, 4, 116, 101, 115, 116][..]);
367        let read = RocketMQSerializable::read_str(&mut buf, false, 10).unwrap();
368        assert_eq!(read, Some("test".into()));
369    }
370
371    #[test]
372    fn read_str_exceeds_limit() {
373        let mut buf = BytesMut::from(&[0, 0, 0, 4, 116, 101, 115, 116][..]);
374        let read = RocketMQSerializable::read_str(&mut buf, false, 2);
375        assert!(read.is_err());
376    }
377
378    #[test]
379    fn map_serialize_empty() {
380        let map = HashMap::new();
381        let serialized = RocketMQSerializable::map_serialize(&map);
382        assert!(serialized.is_none());
383    }
384
385    #[test]
386    fn map_serialize_non_empty() {
387        let mut map = HashMap::new();
388        map.insert("key".into(), "value".into());
389        let serialized = RocketMQSerializable::map_serialize(&map).unwrap();
390        assert_eq!(
391            serialized,
392            BytesMut::from(&[0, 3, 107, 101, 121, 0, 0, 0, 5, 118, 97, 108, 117, 101][..])
393        );
394    }
395
396    #[test]
397    fn map_deserialize_empty() {
398        let mut buf = BytesMut::new();
399        let deserialized = RocketMQSerializable::map_deserialize(&mut buf, 0).unwrap();
400        assert!(deserialized.is_empty());
401    }
402
403    #[test]
404    fn map_deserialize_non_empty() {
405        let mut buf =
406            BytesMut::from(&[0, 3, 107, 101, 121, 0, 0, 0, 5, 118, 97, 108, 117, 101][..]);
407        let deserialized = RocketMQSerializable::map_deserialize(&mut buf, 14).unwrap();
408        assert_eq!(
409            deserialized,
410            [("key".into(), "value".into())].iter().cloned().collect()
411        );
412    }
413}