use crate::encoding::{decode_varint, skip_field_depth, varint_len, Tag, WireType};
use crate::DecodeError;
use alloc::vec::Vec;
use bytes::Buf;
pub const ITEM_START_TAG: u64 = (1 << 3) | 3;
pub const ITEM_END_TAG: u64 = (1 << 3) | 4;
pub const TYPE_ID_TAG: u64 = 2 << 3;
pub const MESSAGE_TAG: u64 = (3 << 3) | 2;
pub fn merge_item(buf: &mut impl Buf, depth: u32) -> Result<(u32, Vec<u8>), DecodeError> {
let mut type_id: Option<u32> = None;
let mut message: Vec<u8> = Vec::new();
loop {
let tag = Tag::decode(buf)?;
if tag.field_number() == 1 && tag.wire_type() == WireType::EndGroup {
break;
}
match (tag.field_number(), tag.wire_type()) {
(2, WireType::Varint) => {
let v = decode_varint(buf)?;
if v < 1 || v > i32::MAX as u64 {
return Err(DecodeError::InvalidMessageSet("type_id out of range"));
}
type_id = Some(v as u32);
}
(3, WireType::LengthDelimited) => {
let len = decode_varint(buf)?;
if len > buf.remaining() as u64 {
return Err(DecodeError::UnexpectedEof);
}
let len = len as usize;
let start = message.len();
message.resize(start + len, 0);
buf.copy_to_slice(&mut message[start..]);
}
(_, WireType::EndGroup) => {
return Err(DecodeError::InvalidEndGroup(tag.field_number()));
}
_ => {
skip_field_depth(tag, buf, depth)?;
}
}
}
let type_id = type_id.ok_or(DecodeError::InvalidMessageSet("missing type_id"))?;
Ok((type_id, message))
}
#[inline]
pub const fn item_encoded_len(number: u32, payload_len: usize) -> usize {
4 + varint_len(number as u64) + varint_len(payload_len as u64) + payload_len
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encoding::encode_varint;
fn item_body(parts: &[&[u8]]) -> Vec<u8> {
let mut buf = Vec::new();
for p in parts {
buf.extend_from_slice(p);
}
buf
}
fn type_id_field(id: u64) -> Vec<u8> {
let mut buf = Vec::new();
encode_varint(TYPE_ID_TAG, &mut buf);
encode_varint(id, &mut buf);
buf
}
fn message_field(payload: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
encode_varint(MESSAGE_TAG, &mut buf);
encode_varint(payload.len() as u64, &mut buf);
buf.extend_from_slice(payload);
buf
}
fn end_group() -> Vec<u8> {
let mut buf = Vec::new();
encode_varint(ITEM_END_TAG, &mut buf);
buf
}
#[test]
fn tag_constants_match_wire_bytes() {
assert_eq!(ITEM_START_TAG, 0x0B);
assert_eq!(ITEM_END_TAG, 0x0C);
assert_eq!(TYPE_ID_TAG, 0x10);
assert_eq!(MESSAGE_TAG, 0x1A);
}
#[test]
fn merge_item_type_id_then_message() {
let body = item_body(&[&type_id_field(1000), &message_field(b"hello"), &end_group()]);
let (tid, msg) = merge_item(&mut body.as_slice(), 50).expect("merge");
assert_eq!(tid, 1000);
assert_eq!(msg, b"hello");
}
#[test]
fn merge_item_message_then_type_id() {
let body = item_body(&[&message_field(b"world"), &type_id_field(42), &end_group()]);
let (tid, msg) = merge_item(&mut body.as_slice(), 50).expect("merge");
assert_eq!(tid, 42);
assert_eq!(msg, b"world");
}
#[test]
fn merge_item_skips_unknown_fields() {
let mut junk = Vec::new();
encode_varint((99 << 3) | 0, &mut junk); encode_varint(12345, &mut junk);
let body = item_body(&[
&type_id_field(7),
&junk,
&message_field(b"ok"),
&end_group(),
]);
let (tid, msg) = merge_item(&mut body.as_slice(), 50).expect("merge");
assert_eq!(tid, 7);
assert_eq!(msg, b"ok");
}
#[test]
fn merge_item_skips_nested_group_respecting_depth() {
let mut junk = Vec::new();
encode_varint((50 << 3) | 3, &mut junk); encode_varint((8 << 3) | 0, &mut junk); encode_varint(1, &mut junk);
encode_varint((50 << 3) | 4, &mut junk);
let body = item_body(&[&type_id_field(5), &junk, &message_field(b"x"), &end_group()]);
let (tid, msg) = merge_item(&mut body.as_slice(), 10).expect("merge");
assert_eq!(tid, 5);
assert_eq!(msg, b"x");
let err = merge_item(&mut body.as_slice(), 0).unwrap_err();
assert_eq!(err, DecodeError::RecursionLimitExceeded);
}
#[test]
fn merge_item_missing_type_id_errors() {
let body = item_body(&[&message_field(b"orphan"), &end_group()]);
let err = merge_item(&mut body.as_slice(), 50).unwrap_err();
assert_eq!(err, DecodeError::InvalidMessageSet("missing type_id"));
}
#[test]
fn merge_item_missing_message_yields_empty() {
let body = item_body(&[&type_id_field(3), &end_group()]);
let (tid, msg) = merge_item(&mut body.as_slice(), 50).expect("merge");
assert_eq!(tid, 3);
assert_eq!(msg, b"");
}
#[test]
fn merge_item_multiple_messages_concatenate() {
let body = item_body(&[
&type_id_field(9),
&message_field(b"ab"),
&message_field(b"cd"),
&end_group(),
]);
let (tid, msg) = merge_item(&mut body.as_slice(), 50).expect("merge");
assert_eq!(tid, 9);
assert_eq!(msg, b"abcd");
}
#[test]
fn merge_item_repeated_type_id_last_wins() {
let body = item_body(&[
&type_id_field(5),
&type_id_field(99),
&message_field(b"x"),
&end_group(),
]);
let (tid, msg) = merge_item(&mut body.as_slice(), 50).expect("merge");
assert_eq!(tid, 99);
assert_eq!(msg, b"x");
}
#[test]
fn merge_item_type_id_out_of_range() {
#[rustfmt::skip]
let cases: &[(u64, bool)] = &[
(0, false), (1, true), (i32::MAX as u64, true), (i32::MAX as u64 + 1, false), ];
for &(id, ok) in cases {
let body = item_body(&[&type_id_field(id), &message_field(b""), &end_group()]);
let result = merge_item(&mut body.as_slice(), 50);
assert_eq!(result.is_ok(), ok, "type_id = {id}");
}
}
#[test]
fn merge_item_mismatched_end_group_errors() {
let mut bad_end = Vec::new();
encode_varint((7 << 3) | 4, &mut bad_end);
let body = item_body(&[&type_id_field(1), &bad_end]);
let err = merge_item(&mut body.as_slice(), 50).unwrap_err();
assert_eq!(err, DecodeError::InvalidEndGroup(7));
}
#[test]
fn merge_item_truncated_message_errors() {
let mut body = Vec::new();
encode_varint(TYPE_ID_TAG, &mut body);
encode_varint(5, &mut body);
encode_varint(MESSAGE_TAG, &mut body);
encode_varint(100, &mut body);
body.extend_from_slice(b"xy");
let err = merge_item(&mut body.as_slice(), 50).unwrap_err();
assert_eq!(err, DecodeError::UnexpectedEof);
}
#[test]
fn item_encoded_len_matches_manual_count() {
#[rustfmt::skip]
let cases: &[(u32, usize, usize)] = &[
(1, 0, 4 + 1 + 1 + 0), (127, 5, 4 + 1 + 1 + 5), (128, 5, 4 + 2 + 1 + 5), (1000, 10, 4 + 2 + 1 + 10), (1000, 200, 4 + 2 + 2 + 200), ];
for &(number, payload_len, expected) in cases {
assert_eq!(
item_encoded_len(number, payload_len),
expected,
"number={number} payload_len={payload_len}"
);
}
}
#[test]
fn item_encoded_len_matches_actual_encoding() {
let number = 1000u32;
let payload = b"hello world";
let mut buf = Vec::new();
encode_varint(ITEM_START_TAG, &mut buf);
encode_varint(TYPE_ID_TAG, &mut buf);
encode_varint(number as u64, &mut buf);
encode_varint(MESSAGE_TAG, &mut buf);
encode_varint(payload.len() as u64, &mut buf);
buf.extend_from_slice(payload);
encode_varint(ITEM_END_TAG, &mut buf);
assert_eq!(item_encoded_len(number, payload.len()), buf.len());
}
}