use std::collections::HashMap;
use std::str;
use bytes::Buf;
use bytes::BufMut;
use bytes::Bytes;
use bytes::BytesMut;
use cheetah_string::CheetahString;
use rocketmq_error::RocketmqError;
use crate::protocol::remoting_command::RemotingCommand;
use crate::protocol::LanguageCode;
pub struct RocketMQSerializable;
impl RocketMQSerializable {
#[inline]
pub fn write_str(buf: &mut BytesMut, use_short_length: bool, s: &str) -> usize {
let bytes = s.as_bytes();
let len = bytes.len();
let length_size = if use_short_length {
buf.put_u16(len as u16);
2
} else {
buf.put_u32(len as u32);
4
};
buf.put_slice(bytes); length_size + len
}
#[inline]
pub fn read_str(
buf: &mut BytesMut,
use_short_length: bool,
limit: usize,
) -> rocketmq_error::RocketMQResult<Option<CheetahString>> {
let len = if use_short_length {
if buf.remaining() < 2 {
return Err(RocketmqError::DecodingError(2, buf.remaining()).into());
}
buf.get_u16() as usize
} else {
if buf.remaining() < 4 {
return Err(RocketmqError::DecodingError(4, buf.remaining()).into());
}
buf.get_u32() as usize
};
if len == 0 {
return Ok(None);
}
if len > limit {
return Err(RocketmqError::DecodingError(len, limit).into());
}
if buf.remaining() < len {
return Err(RocketmqError::DecodingError(len, buf.remaining()).into());
}
let bytes = buf.split_to(len).freeze();
Ok(Some(CheetahString::from_bytes(bytes)))
}
#[inline]
pub fn rocketmq_protocol_encode(cmd: &mut RemotingCommand, buf: &mut BytesMut) -> usize {
let begin_index = buf.len();
let estimated_size = Self::estimate_encode_size(cmd);
buf.reserve(estimated_size);
buf.put_u16(cmd.code() as u16); buf.put_u8(cmd.language().get_code()); buf.put_u16(cmd.version() as u16); buf.put_i32(cmd.opaque()); buf.put_i32(cmd.flag());
if let Some(remark) = cmd.remark() {
Self::write_str(buf, false, remark.as_str());
} else {
buf.put_i32(0);
}
let map_len_index = buf.len();
buf.put_i32(0);
if let Some(header) = cmd.command_custom_header_mut() {
if header.support_fast_codec() {
header.encode_fast(buf);
}
}
if let Some(ext_fields) = cmd.ext_fields() {
for (k, v) in ext_fields.iter() {
if !k.is_empty() && !v.is_empty() {
Self::write_str(buf, true, k.as_str());
Self::write_str(buf, true, v.as_str());
}
}
}
let current_length = buf.len();
let ext_fields_length = (current_length - map_len_index - 4) as i32;
buf[map_len_index..map_len_index + 4].copy_from_slice(&ext_fields_length.to_be_bytes());
buf.len() - begin_index
}
#[inline]
fn estimate_encode_size(cmd: &RemotingCommand) -> usize {
let mut size = 15;
if let Some(remark) = cmd.remark() {
size += 4 + remark.len(); } else {
size += 4; }
if let Some(ext) = cmd.ext_fields() {
for (k, v) in ext.iter() {
if !k.is_empty() && !v.is_empty() {
size += 2 + k.len() + 2 + v.len(); }
}
}
size
}
pub fn rocket_mq_protocol_encode_bytes(cmd: &RemotingCommand) -> Bytes {
let remark_bytes = cmd.remark().map(|remark| remark.as_bytes().to_vec());
let remark_len = remark_bytes.as_ref().map_or(0, |v| v.len());
let ext_fields_bytes = if let Some(ext) = cmd.get_ext_fields() {
Self::map_serialize(ext)
} else {
None
};
let ext_len = ext_fields_bytes.as_ref().map_or(0, |v| v.len());
let total_len = Self::cal_total_len(remark_len, ext_len);
let mut header_buffer = BytesMut::with_capacity(total_len);
header_buffer.put_i16(cmd.code() as i16);
header_buffer.put_u8(cmd.language().get_code());
header_buffer.put_i16(cmd.version() as i16);
header_buffer.put_i32(cmd.opaque());
header_buffer.put_i32(cmd.flag());
if let Some(remark_bytes) = remark_bytes {
header_buffer.put_i32(remark_bytes.len() as i32);
header_buffer.put(remark_bytes.as_ref());
} else {
header_buffer.put_i32(0);
}
if let Some(ext_fields_bytes) = ext_fields_bytes {
header_buffer.put_i32(ext_fields_bytes.len() as i32);
header_buffer.put(ext_fields_bytes.as_ref());
} else {
header_buffer.put_i32(0);
}
header_buffer.freeze()
}
#[inline]
pub fn map_serialize(map: &HashMap<CheetahString, CheetahString>) -> Option<BytesMut> {
if map.is_empty() {
return None;
}
let mut total_length = 0;
let mut valid_entries = 0;
for (key, value) in map.iter() {
if !key.is_empty() && !value.is_empty() {
total_length += 2 + key.len() + 4 + value.len();
valid_entries += 1;
}
}
if valid_entries == 0 {
return None;
}
let mut content = BytesMut::with_capacity(total_length);
for (key, value) in map.iter() {
if !key.is_empty() && !value.is_empty() {
content.put_u16(key.len() as u16);
content.put_slice(key.as_bytes());
content.put_i32(value.len() as i32);
content.put_slice(value.as_bytes());
}
}
Some(content)
}
pub fn cal_total_len(remark_len: usize, ext_len: usize) -> usize {
2 + 1 + 2 + 4 + 4 + 4 + remark_len + 4 + ext_len }
pub fn rocket_mq_protocol_decode(
header_buffer: &mut BytesMut,
header_len: usize,
) -> rocketmq_error::RocketMQResult<RemotingCommand> {
const FIXED_HEADER_LEN: usize = 13;
const LENGTH_FIELD_LEN: usize = 4;
if header_buffer.remaining() < FIXED_HEADER_LEN {
return Err(RocketmqError::DecodingError(FIXED_HEADER_LEN, header_buffer.remaining()).into());
}
let cmd = RemotingCommand::default()
.set_code(header_buffer.get_i16())
.set_language(LanguageCode::from(header_buffer.get_u8()))
.set_version(header_buffer.get_i16() as i32)
.set_opaque(header_buffer.get_i32())
.set_flag(header_buffer.get_i32());
let remark = Self::read_str(header_buffer, false, header_len)?;
if header_buffer.remaining() < LENGTH_FIELD_LEN {
return Err(RocketmqError::DecodingError(LENGTH_FIELD_LEN, header_buffer.remaining()).into());
}
let ext_fields_length = header_buffer.get_i32();
let ext = if ext_fields_length > 0 {
let ext_fields_length = ext_fields_length as usize;
if ext_fields_length > header_len {
return Err(RocketmqError::DecodingError(ext_fields_length, header_len).into());
}
if ext_fields_length > header_buffer.remaining() {
return Err(RocketmqError::DecodingError(ext_fields_length, header_buffer.remaining()).into());
}
Self::map_deserialize(header_buffer, ext_fields_length)?
} else {
HashMap::new()
};
Ok(cmd.set_remark_option(remark).set_ext_fields(ext))
}
#[inline]
pub fn map_deserialize(
buffer: &mut BytesMut,
len: usize,
) -> rocketmq_error::RocketMQResult<HashMap<CheetahString, CheetahString>> {
if len == 0 {
return Ok(HashMap::new());
}
if len > buffer.remaining() {
return Err(RocketmqError::DecodingError(len, buffer.remaining()).into());
}
let estimated_entries = (len / 50).clamp(4, 1024);
let mut map = HashMap::with_capacity(estimated_entries);
let target_remaining = buffer.remaining().saturating_sub(len);
while buffer.remaining() > target_remaining {
let key = Self::read_str(buffer, true, len)?.ok_or_else(|| RocketmqError::DecodingError(0, 0))?;
let value = Self::read_str(buffer, false, len)?.ok_or_else(|| RocketmqError::DecodingError(0, 0))?;
map.insert(key, value);
}
Ok(map)
}
}
#[cfg(test)]
mod tests {
use bytes::BufMut;
use bytes::BytesMut;
use super::*;
fn minimal_header_without_ext_len() -> BytesMut {
let mut buf = BytesMut::new();
buf.put_i16(0);
buf.put_u8(LanguageCode::JAVA.get_code());
buf.put_i16(0);
buf.put_i32(0);
buf.put_i32(0);
buf.put_i32(0);
buf
}
#[test]
fn write_str_short_length() {
let mut buf = BytesMut::new();
let written = RocketMQSerializable::write_str(&mut buf, true, "test");
assert_eq!(written, 6);
assert_eq!(buf, BytesMut::from(&[0, 4, 116, 101, 115, 116][..]));
}
#[test]
fn write_str_long_length() {
let mut buf = BytesMut::new();
let written = RocketMQSerializable::write_str(&mut buf, false, "test");
assert_eq!(written, 8);
assert_eq!(buf, BytesMut::from(&[0, 0, 0, 4, 116, 101, 115, 116][..]));
}
#[test]
fn read_str_short_length() {
let mut buf = BytesMut::from(&[0, 4, 116, 101, 115, 116][..]);
let read = RocketMQSerializable::read_str(&mut buf, true, 10).unwrap();
assert_eq!(read, Some("test".into()));
}
#[test]
fn read_str_long_length() {
let mut buf = BytesMut::from(&[0, 0, 0, 4, 116, 101, 115, 116][..]);
let read = RocketMQSerializable::read_str(&mut buf, false, 10).unwrap();
assert_eq!(read, Some("test".into()));
}
#[test]
fn read_str_exceeds_limit() {
let mut buf = BytesMut::from(&[0, 0, 0, 4, 116, 101, 115, 116][..]);
let read = RocketMQSerializable::read_str(&mut buf, false, 2);
assert!(read.is_err());
}
#[test]
fn map_serialize_empty() {
let map = HashMap::new();
let serialized = RocketMQSerializable::map_serialize(&map);
assert!(serialized.is_none());
}
#[test]
fn map_serialize_non_empty() {
let mut map = HashMap::new();
map.insert("key".into(), "value".into());
let serialized = RocketMQSerializable::map_serialize(&map).unwrap();
assert_eq!(
serialized,
BytesMut::from(&[0, 3, 107, 101, 121, 0, 0, 0, 5, 118, 97, 108, 117, 101][..])
);
}
#[test]
fn map_deserialize_empty() {
let mut buf = BytesMut::new();
let deserialized = RocketMQSerializable::map_deserialize(&mut buf, 0).unwrap();
assert!(deserialized.is_empty());
}
#[test]
fn map_deserialize_non_empty() {
let mut buf = BytesMut::from(&[0, 3, 107, 101, 121, 0, 0, 0, 5, 118, 97, 108, 117, 101][..]);
let deserialized = RocketMQSerializable::map_deserialize(&mut buf, 14).unwrap();
assert_eq!(deserialized, [("key".into(), "value".into())].iter().cloned().collect());
}
#[test]
fn rocketmq_protocol_decode_rejects_short_fixed_header_without_panic() {
let mut buf = BytesMut::from(&[0_u8; 12][..]);
if RocketMQSerializable::rocket_mq_protocol_decode(&mut buf, 12).is_ok() {
panic!("short fixed header should decode to error");
}
}
#[test]
fn rocketmq_protocol_decode_rejects_missing_ext_length_without_panic() {
let mut buf = minimal_header_without_ext_len();
let header_len = buf.len();
if RocketMQSerializable::rocket_mq_protocol_decode(&mut buf, header_len).is_ok() {
panic!("missing ext length should decode to error");
}
}
#[test]
fn rocketmq_protocol_decode_rejects_truncated_ext_fields_without_panic() {
let mut buf = minimal_header_without_ext_len();
buf.put_i32(10);
let header_len = buf.len();
if RocketMQSerializable::rocket_mq_protocol_decode(&mut buf, header_len).is_ok() {
panic!("truncated ext fields should decode to error");
}
}
}