use std::marker::PhantomData;
use crate::bytes::Bytes;
use super::codec::Codec;
pub use super::DEFAULT_MAX_MESSAGE_SIZE;
#[derive(Debug, thiserror::Error)]
pub enum ProtobufError {
#[error("failed to encode protobuf message: {0}")]
EncodeError(#[from] prost::EncodeError),
#[error("failed to decode protobuf message: {0}")]
DecodeError(#[from] prost::DecodeError),
#[error("message size {size} exceeds limit {limit}")]
MessageTooLarge {
size: usize,
limit: usize,
},
}
#[derive(Debug)]
pub struct ProstCodec<T, U> {
max_message_size: usize,
_marker: PhantomData<(T, U)>,
}
impl<T, U> ProstCodec<T, U> {
#[must_use]
pub fn new() -> Self {
Self::with_max_size(DEFAULT_MAX_MESSAGE_SIZE)
}
#[must_use]
pub fn with_max_size(max_size: usize) -> Self {
Self {
max_message_size: max_size,
_marker: PhantomData,
}
}
#[must_use]
pub fn max_message_size(&self) -> usize {
self.max_message_size
}
}
impl<T, U> Default for ProstCodec<T, U> {
fn default() -> Self {
Self::new()
}
}
impl<T, U> Clone for ProstCodec<T, U> {
fn clone(&self) -> Self {
Self {
max_message_size: self.max_message_size,
_marker: PhantomData,
}
}
}
impl<T, U> Codec for ProstCodec<T, U>
where
T: prost::Message + Send + 'static,
U: prost::Message + Default + Send + 'static,
{
type Encode = T;
type Decode = U;
type Error = ProtobufError;
fn encode(&mut self, item: &Self::Encode) -> Result<Bytes, Self::Error> {
let encoded_len = item.encoded_len();
if encoded_len > self.max_message_size {
return Err(ProtobufError::MessageTooLarge {
size: encoded_len,
limit: self.max_message_size,
});
}
let mut buf = Vec::with_capacity(encoded_len);
item.encode(&mut buf)?;
Ok(Bytes::from(buf))
}
fn decode(&mut self, buf: &Bytes) -> Result<Self::Decode, Self::Error> {
if buf.len() > self.max_message_size {
return Err(ProtobufError::MessageTooLarge {
size: buf.len(),
limit: self.max_message_size,
});
}
let message = U::decode(buf.as_ref())?;
Ok(message)
}
}
pub type SymmetricProstCodec<T> = ProstCodec<T, T>;
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::*;
use prost::Message;
fn init_test(name: &str) {
crate::test_utils::init_test_logging();
crate::test_phase!(name);
}
#[derive(Clone, PartialEq, prost::Message)]
pub struct TestMessage {
#[prost(string, tag = "1")]
pub name: String,
#[prost(int32, tag = "2")]
pub value: i32,
}
#[derive(Clone, PartialEq, prost::Message)]
pub struct NestedMessage {
#[prost(message, optional, tag = "1")]
pub inner: Option<TestMessage>,
#[prost(repeated, string, tag = "2")]
pub items: Vec<String>,
}
#[derive(Clone, PartialEq, prost::Message)]
pub struct AllTypesMessage {
#[prost(double, tag = "1")]
pub double_field: f64,
#[prost(float, tag = "2")]
pub float_field: f32,
#[prost(int32, tag = "3")]
pub int32_field: i32,
#[prost(int64, tag = "4")]
pub int64_field: i64,
#[prost(uint32, tag = "5")]
pub uint32_field: u32,
#[prost(uint64, tag = "6")]
pub uint64_field: u64,
#[prost(sint32, tag = "7")]
pub sint32_field: i32,
#[prost(sint64, tag = "8")]
pub sint64_field: i64,
#[prost(fixed32, tag = "9")]
pub fixed32_field: u32,
#[prost(fixed64, tag = "10")]
pub fixed64_field: u64,
#[prost(sfixed32, tag = "11")]
pub sfixed32_field: i32,
#[prost(sfixed64, tag = "12")]
pub sfixed64_field: i64,
#[prost(bool, tag = "13")]
pub bool_field: bool,
#[prost(string, tag = "14")]
pub string_field: String,
#[prost(bytes = "vec", tag = "15")]
pub bytes_field: Vec<u8>,
}
#[derive(Clone, PartialEq, prost::Message)]
pub struct OptionalU64VarintMessage {
#[prost(uint64, optional, tag = "1")]
pub value: Option<u64>,
}
#[derive(Clone, PartialEq, prost::Message)]
pub struct OptionalU32VarintMessage {
#[prost(uint32, optional, tag = "1")]
pub value: Option<u32>,
}
#[test]
fn test_prost_codec_roundtrip() {
init_test("test_prost_codec_roundtrip");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let original = TestMessage {
name: "hello".to_string(),
value: 42,
};
let encoded = codec.encode(&original).unwrap();
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(
decoded == original,
"roundtrip",
original.name,
decoded.name
);
crate::test_complete!("test_prost_codec_roundtrip");
}
#[test]
fn test_prost_codec_nested_message() {
init_test("test_prost_codec_nested_message");
let mut codec: ProstCodec<NestedMessage, NestedMessage> = ProstCodec::new();
let original = NestedMessage {
inner: Some(TestMessage {
name: "inner".to_string(),
value: 100,
}),
items: vec!["a".to_string(), "b".to_string(), "c".to_string()],
};
let encoded = codec.encode(&original).unwrap();
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(decoded == original, "nested", true, decoded == original);
crate::test_complete!("test_prost_codec_nested_message");
}
#[test]
fn test_prost_codec_all_wire_types() {
init_test("test_prost_codec_all_wire_types");
let mut codec: ProstCodec<AllTypesMessage, AllTypesMessage> = ProstCodec::new();
let original = AllTypesMessage {
double_field: 1.234,
float_field: 5.678,
int32_field: -100,
int64_field: -200,
uint32_field: 300,
uint64_field: 400,
sint32_field: -500,
sint64_field: -600,
fixed32_field: 700,
fixed64_field: 800,
sfixed32_field: -900,
sfixed64_field: -1000,
bool_field: true,
string_field: "test string".to_string(),
bytes_field: vec![0x01, 0x02, 0x03, 0x04],
};
let encoded = codec.encode(&original).unwrap();
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(decoded == original, "wire types", true, decoded == original);
crate::test_complete!("test_prost_codec_all_wire_types");
}
#[test]
fn test_prost_codec_empty_message() {
init_test("test_prost_codec_empty_message");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let original = TestMessage::default();
let encoded = codec.encode(&original).unwrap();
let empty = encoded.is_empty();
crate::assert_with_log!(empty, "empty message encodes to empty bytes", true, empty);
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(
decoded == original,
"empty roundtrip",
true,
decoded == original
);
crate::test_complete!("test_prost_codec_empty_message");
}
#[test]
fn test_prost_codec_message_too_large_encode() {
init_test("test_prost_codec_message_too_large_encode");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::with_max_size(10);
let large_message = TestMessage {
name: "this is a very long string that exceeds the limit".to_string(),
value: 42,
};
let result = codec.encode(&large_message);
let is_err = matches!(result, Err(ProtobufError::MessageTooLarge { .. }));
crate::assert_with_log!(is_err, "encode fails for large message", true, is_err);
crate::test_complete!("test_prost_codec_message_too_large_encode");
}
#[test]
fn test_prost_codec_message_too_large_decode() {
init_test("test_prost_codec_message_too_large_decode");
let mut large_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let message = TestMessage {
name: "this is a long string".to_string(),
value: 42,
};
let encoded = large_codec.encode(&message).unwrap();
let mut small_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::with_max_size(5);
let result = small_codec.decode(&encoded);
let is_err = matches!(result, Err(ProtobufError::MessageTooLarge { .. }));
crate::assert_with_log!(is_err, "decode fails for large message", true, is_err);
crate::test_complete!("test_prost_codec_message_too_large_decode");
}
#[test]
fn test_prost_codec_size_limit_reports_exact_wire_size() {
init_test("test_prost_codec_size_limit_reports_exact_wire_size");
let message = TestMessage {
name: "abcd".to_string(),
value: 7,
};
let encoded_len = message.encoded_len();
let limit = encoded_len - 1;
let mut encode_codec: ProstCodec<TestMessage, TestMessage> =
ProstCodec::with_max_size(limit);
let encode_err = encode_codec
.encode(&message)
.expect_err("message should exceed encode limit");
assert!(matches!(
encode_err,
ProtobufError::MessageTooLarge { size, limit: got }
if size == encoded_len && got == limit
));
let mut unbounded_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let encoded = unbounded_codec
.encode(&message)
.expect("message should encode with default limit");
assert_eq!(encoded.len(), encoded_len);
let mut decode_codec: ProstCodec<TestMessage, TestMessage> =
ProstCodec::with_max_size(limit);
let decode_err = decode_codec
.decode(&encoded)
.expect_err("message should exceed decode limit");
assert!(matches!(
decode_err,
ProtobufError::MessageTooLarge { size, limit: got }
if size == encoded_len && got == limit
));
crate::test_complete!("test_prost_codec_size_limit_reports_exact_wire_size");
}
#[test]
fn test_prost_codec_invalid_data() {
init_test("test_prost_codec_invalid_data");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let invalid_data = Bytes::from_static(&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01,
]);
let result = codec.decode(&invalid_data);
let is_err = matches!(result, Err(ProtobufError::DecodeError(_)));
crate::assert_with_log!(is_err, "decode fails for invalid data", true, is_err);
crate::test_complete!("test_prost_codec_invalid_data");
}
fn encode_test_varint(mut value: u64, out: &mut Vec<u8>) {
while value >= 0x80 {
out.push((value as u8 & 0x7f) | 0x80);
value >>= 7;
}
out.push(value as u8);
}
fn decode_test_varint(input: &[u8]) -> Option<(u64, usize)> {
let mut value = 0u64;
let mut shift = 0u32;
for (idx, byte) in input.iter().copied().enumerate() {
let chunk = u64::from(byte & 0x7f);
value |= chunk.checked_shl(shift)?;
if byte & 0x80 == 0 {
return Some((value, idx + 1));
}
shift += 7;
if shift >= 64 {
return None;
}
}
None
}
fn shortest_varint_len(mut value: u64) -> usize {
let mut len = 1usize;
while value >= 0x80 {
value >>= 7;
len += 1;
}
len
}
fn strict_decode_test_varint(input: &[u8]) -> Option<(u64, usize)> {
let mut value = 0u64;
let mut shift = 0u32;
for (idx, byte) in input.iter().copied().enumerate() {
let chunk = u64::from(byte & 0x7f);
if idx == 9 && chunk > 1 {
return None;
}
value |= chunk.checked_shl(shift)?;
if byte & 0x80 == 0 {
return Some((value, idx + 1));
}
if idx == 9 {
return None;
}
shift += 7;
}
None
}
fn shortest_varint_classification(
varint_bytes: &[u8],
decoded_value: Option<u64>,
) -> &'static str {
match decoded_value {
Some(value) if shortest_varint_len(value) == varint_bytes.len() => "shortest",
Some(_) => "non_shortest",
None => "malformed",
}
}
fn encode_optional_u64_varint(
value: u64,
) -> (
Bytes,
ProstCodec<OptionalU64VarintMessage, OptionalU64VarintMessage>,
) {
let mut codec: ProstCodec<OptionalU64VarintMessage, OptionalU64VarintMessage> =
ProstCodec::new();
let encoded = codec
.encode(&OptionalU64VarintMessage { value: Some(value) })
.expect("encode optional u64 varint");
(encoded, codec)
}
fn encode_optional_u32_varint(
value: u32,
) -> (
Bytes,
ProstCodec<OptionalU32VarintMessage, OptionalU32VarintMessage>,
) {
let mut codec: ProstCodec<OptionalU32VarintMessage, OptionalU32VarintMessage> =
ProstCodec::new();
let encoded = codec
.encode(&OptionalU32VarintMessage { value: Some(value) })
.expect("encode optional u32 varint");
(encoded, codec)
}
fn single_field_varint_payload(encoded: &[u8]) -> &[u8] {
assert!(!encoded.is_empty(), "expected single-field varint payload");
assert_eq!(encoded[0], 0x08, "expected field-1 varint tag");
&encoded[1..]
}
fn classify_single_field_varint_wire(
bytes: &[u8],
) -> (Option<u64>, Option<usize>, &'static str) {
if bytes.first() != Some(&0x08) {
return (None, None, "malformed");
}
let payload = &bytes[1..];
let Some((decoded_value, consumed)) = strict_decode_test_varint(payload) else {
return (None, None, "malformed");
};
if consumed != payload.len() {
return (None, None, "malformed");
}
(
Some(decoded_value),
Some(consumed),
shortest_varint_classification(payload, Some(decoded_value)),
)
}
#[test]
fn conformance_protobuf_varint_roundtrip_boundary_matrix() {
init_test("conformance_protobuf_varint_roundtrip_boundary_matrix");
const EXACT_RCH_COMMAND: &str = "rch exec -- env CARGO_TARGET_DIR=${TMPDIR:-/tmp}/rch_target_asupersync_gnulez_varint cargo test -p asupersync --lib conformance_protobuf_varint_roundtrip_boundary_matrix -- --nocapture";
let log_case = |corpus_label: &str,
input_byte_length: usize,
decoded_value: Option<u64>,
encoded_length: Option<usize>,
shortest_classification: &str,
error_kind: &str,
final_verdict: &str| {
eprintln!(
"PROTOBUF_VARINT_ROUNDTRIP corpus_label={} input_byte_length={} decoded_value={} encoded_length={} shortest_classification={} error_kind={} exact_rch_command=\"{}\" artifact_paths=none final_varint_roundtrip_verdict={}",
corpus_label,
input_byte_length,
decoded_value.map_or_else(|| "none".to_string(), |value| value.to_string()),
encoded_length.map_or_else(|| "none".to_string(), |len| len.to_string()),
shortest_classification,
error_kind,
EXACT_RCH_COMMAND,
final_verdict,
);
};
struct ValidU64Case {
corpus_label: &'static str,
value: u64,
}
let valid_u64_cases = [
ValidU64Case {
corpus_label: "u64_zero",
value: 0,
},
ValidU64Case {
corpus_label: "u64_one",
value: 1,
},
ValidU64Case {
corpus_label: "u64_pow2_7",
value: 1_u64 << 7,
},
ValidU64Case {
corpus_label: "u64_pow2_14",
value: 1_u64 << 14,
},
ValidU64Case {
corpus_label: "u64_pow2_21",
value: 1_u64 << 21,
},
ValidU64Case {
corpus_label: "u64_pow2_28",
value: 1_u64 << 28,
},
ValidU64Case {
corpus_label: "u64_pow2_35",
value: 1_u64 << 35,
},
ValidU64Case {
corpus_label: "u64_pow2_42",
value: 1_u64 << 42,
},
ValidU64Case {
corpus_label: "u64_pow2_49",
value: 1_u64 << 49,
},
ValidU64Case {
corpus_label: "u64_pow2_56",
value: 1_u64 << 56,
},
ValidU64Case {
corpus_label: "u64_pow2_63",
value: 1_u64 << 63,
},
ValidU64Case {
corpus_label: "u64_max",
value: u64::MAX,
},
];
for case in valid_u64_cases {
let (encoded, mut codec) = encode_optional_u64_varint(case.value);
let payload = single_field_varint_payload(encoded.as_ref());
let (decoded_value, consumed) =
decode_test_varint(payload).expect("u64 payload should decode");
assert_eq!(consumed, payload.len());
assert_eq!(decoded_value, case.value);
assert_eq!(payload.len(), shortest_varint_len(case.value));
let decoded = codec.decode(&encoded).expect("roundtrip decode u64");
assert_eq!(decoded.value, Some(case.value));
log_case(
case.corpus_label,
encoded.len(),
Some(decoded_value),
Some(payload.len()),
shortest_varint_classification(payload, Some(decoded_value)),
"ok",
"pass",
);
}
let (encoded_u32_max, mut u32_codec) = encode_optional_u32_varint(u32::MAX);
let payload_u32_max = single_field_varint_payload(encoded_u32_max.as_ref());
let (decoded_u32_max, consumed_u32_max) =
decode_test_varint(payload_u32_max).expect("u32 max payload should decode");
assert_eq!(consumed_u32_max, payload_u32_max.len());
assert_eq!(decoded_u32_max, u64::from(u32::MAX));
assert_eq!(
payload_u32_max.len(),
shortest_varint_len(u64::from(u32::MAX))
);
let decoded_u32_message = u32_codec
.decode(&encoded_u32_max)
.expect("roundtrip decode u32 max");
assert_eq!(decoded_u32_message.value, Some(u32::MAX));
log_case(
"u32_max",
encoded_u32_max.len(),
Some(decoded_u32_max),
Some(payload_u32_max.len()),
shortest_varint_classification(payload_u32_max, Some(decoded_u32_max)),
"ok",
"pass",
);
struct InvalidCase {
corpus_label: &'static str,
bytes: &'static [u8],
expected_error_kind: &'static str,
expected_decoded_value: Option<u64>,
}
let invalid_cases = [
InvalidCase {
corpus_label: "u64_one_non_shortest_manual",
bytes: &[0x08, 0x81, 0x00],
expected_error_kind: "ok",
expected_decoded_value: Some(1),
},
InvalidCase {
corpus_label: "overlong_varint_11_bytes",
bytes: &[
0x08, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00,
],
expected_error_kind: "DecodeError",
expected_decoded_value: None,
},
InvalidCase {
corpus_label: "truncated_varint",
bytes: &[0x08, 0x80],
expected_error_kind: "DecodeError",
expected_decoded_value: None,
},
InvalidCase {
corpus_label: "continuation_overflow",
bytes: &[
0x08, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02,
],
expected_error_kind: "DecodeError",
expected_decoded_value: None,
},
InvalidCase {
corpus_label: "arbitrary_malformed_bytes",
bytes: &[0xFF, 0x00, 0xFF],
expected_error_kind: "DecodeError",
expected_decoded_value: None,
},
];
for case in invalid_cases {
let (decoded_value, encoded_length, shortest_classification) =
classify_single_field_varint_wire(case.bytes);
let mut codec: ProstCodec<OptionalU64VarintMessage, OptionalU64VarintMessage> =
ProstCodec::new();
let result = codec.decode(&Bytes::copy_from_slice(case.bytes));
match case.expected_error_kind {
"ok" => {
let decoded = result.expect("non-shortest decode should still roundtrip");
assert_eq!(decoded.value, case.expected_decoded_value);
}
"DecodeError" => {
assert!(matches!(result, Err(ProtobufError::DecodeError(_))));
}
other => panic!("unexpected expected_error_kind {other}"),
}
log_case(
case.corpus_label,
case.bytes.len(),
decoded_value,
encoded_length,
shortest_classification,
case.expected_error_kind,
"pass",
);
}
crate::test_complete!("conformance_protobuf_varint_roundtrip_boundary_matrix");
}
#[test]
fn conformance_protobuf_decode_malformed_boundary_matrix() {
init_test("conformance_protobuf_decode_malformed_boundary_matrix");
const EXACT_RCH_COMMAND: &str = "rch exec -- env CARGO_TARGET_DIR=${TMPDIR:-/tmp}/rch_target_asupersync_eo6jp9_protobuf cargo test -p asupersync --lib conformance_protobuf_decode_malformed_boundary_matrix -- --nocapture";
enum DecodeExpectation {
DecodeError,
OkZeroLengthEmbedded,
OkAtCap,
}
struct DecodeScenario {
corpus_label: &'static str,
nesting_depth: usize,
declared_length: Option<usize>,
actual_length: usize,
overflow_guard_decision: &'static str,
parser_state: &'static str,
wire: Vec<u8>,
max_size: usize,
expectation: DecodeExpectation,
}
let mut malformed_varint = vec![0x10];
malformed_varint.extend_from_slice(&[0xFF; 10]);
malformed_varint.push(0x01);
let truncated_embedded = vec![0x0A, 0x02, 0x0A, 0x01];
let zero_length_embedded = vec![0x0A, 0x00];
let mut embedded_length_overflow = vec![0x0A];
encode_test_varint(u64::from(u32::MAX), &mut embedded_length_overflow);
let at_cap_inner = TestMessage {
name: "cap".to_string(),
value: 7,
};
let at_cap_outer = NestedMessage {
inner: Some(at_cap_inner),
items: Vec::new(),
};
let mut at_cap_codec: ProstCodec<NestedMessage, NestedMessage> = ProstCodec::new();
let at_cap_wire = at_cap_codec.encode(&at_cap_outer).unwrap().to_vec();
let at_cap_max_size = at_cap_wire.len();
let (declared_len, _) =
decode_test_varint(&at_cap_wire[1..]).expect("nested message length prefix");
let scenarios = vec![
DecodeScenario {
corpus_label: "malformed_overlong_varint",
nesting_depth: 0,
declared_length: None,
actual_length: malformed_varint.len(),
overflow_guard_decision: "pass-through",
parser_state: "top-level-varint",
wire: malformed_varint,
max_size: 256,
expectation: DecodeExpectation::DecodeError,
},
DecodeScenario {
corpus_label: "unsupported_wire_type",
nesting_depth: 0,
declared_length: None,
actual_length: 1,
overflow_guard_decision: "pass-through",
parser_state: "top-level-key",
wire: vec![0x0F],
max_size: 256,
expectation: DecodeExpectation::DecodeError,
},
DecodeScenario {
corpus_label: "arbitrary_bytes_typed_err",
nesting_depth: 0,
declared_length: None,
actual_length: 3,
overflow_guard_decision: "pass-through",
parser_state: "arbitrary-prefix",
wire: vec![0xFF, 0x00, 0xFF],
max_size: 256,
expectation: DecodeExpectation::DecodeError,
},
DecodeScenario {
corpus_label: "zero_length_embedded",
nesting_depth: 1,
declared_length: Some(0),
actual_length: 0,
overflow_guard_decision: "exact-fit",
parser_state: "embedded-message",
wire: zero_length_embedded,
max_size: 256,
expectation: DecodeExpectation::OkZeroLengthEmbedded,
},
DecodeScenario {
corpus_label: "truncated_embedded_message",
nesting_depth: 1,
declared_length: Some(2),
actual_length: 2,
overflow_guard_decision: "prefix-complete-payload-truncated",
parser_state: "embedded-message",
wire: truncated_embedded,
max_size: 256,
expectation: DecodeExpectation::DecodeError,
},
DecodeScenario {
corpus_label: "embedded_length_overflow",
nesting_depth: 1,
declared_length: Some(u32::MAX as usize),
actual_length: 0,
overflow_guard_decision: "declared>remaining",
parser_state: "embedded-message",
wire: embedded_length_overflow,
max_size: 256,
expectation: DecodeExpectation::DecodeError,
},
DecodeScenario {
corpus_label: "max_bounded_embedded_length",
nesting_depth: 1,
declared_length: Some(declared_len as usize),
actual_length: declared_len as usize,
overflow_guard_decision: "at-cap-accept",
parser_state: "embedded-message",
wire: at_cap_wire,
max_size: at_cap_max_size,
expectation: DecodeExpectation::OkAtCap,
},
];
for scenario in scenarios {
let mut codec: ProstCodec<NestedMessage, NestedMessage> =
ProstCodec::with_max_size(scenario.max_size);
let bytes = Bytes::from(scenario.wire.clone());
let result = codec.decode(&bytes);
let (error_kind, final_verdict) = match (&scenario.expectation, &result) {
(DecodeExpectation::DecodeError, Err(ProtobufError::DecodeError(_))) => {
("DecodeError", "pass")
}
(DecodeExpectation::OkZeroLengthEmbedded, Ok(decoded)) => {
let inner = decoded.inner.clone().expect("zero-length embedded inner");
assert_eq!(inner, TestMessage::default());
("ok", "pass")
}
(DecodeExpectation::OkAtCap, Ok(decoded)) => {
assert_eq!(decoded, &at_cap_outer);
("ok", "pass")
}
_ => panic!(
"scenario {} produced unexpected result: {:?}",
scenario.corpus_label, result
),
};
eprintln!(
"PROTOBUF_MALFORMED_DECODE corpus_label={} nesting_depth={} declared_length={} actual_length={} overflow_guard_decision={} parser_state={} error_kind={} exact_rch_command=\"{}\" artifact_paths=none final_malformed_protobuf_verdict={}",
scenario.corpus_label,
scenario.nesting_depth,
scenario
.declared_length
.map_or_else(|| "none".to_string(), |len| len.to_string()),
scenario.actual_length,
scenario.overflow_guard_decision,
scenario.parser_state,
error_kind,
EXACT_RCH_COMMAND,
final_verdict,
);
}
crate::test_complete!("conformance_protobuf_decode_malformed_boundary_matrix");
}
#[test]
fn test_prost_codec_nested_length_prefix_consistency() {
init_test("test_prost_codec_nested_length_prefix_consistency");
let inner = TestMessage {
name: "nested".to_string(),
value: 99,
};
let outer = NestedMessage {
inner: Some(inner.clone()),
items: Vec::new(),
};
let mut codec: ProstCodec<NestedMessage, NestedMessage> = ProstCodec::new();
let encoded = codec.encode(&outer).unwrap();
assert_eq!(encoded[0], 0x0A, "expected field 1 nested-message tag");
let (declared_len, len_len) =
decode_test_varint(&encoded[1..]).expect("nested length varint");
let payload = &encoded[1 + len_len..1 + len_len + declared_len as usize];
assert_eq!(declared_len as usize, payload.len());
let decoded_inner = <TestMessage as prost::Message>::decode(payload).unwrap();
assert_eq!(decoded_inner, inner);
crate::test_complete!("test_prost_codec_nested_length_prefix_consistency");
}
#[test]
fn conformance_prost_codec_roundtrip_boundary_matrix() {
init_test("conformance_prost_codec_roundtrip_boundary_matrix");
const EXACT_RCH_COMMAND: &str = "rch exec -- env CARGO_TARGET_DIR=${TMPDIR:-/tmp}/rch_target_asupersync_91ulk2_prost cargo test -p asupersync --lib conformance_prost_codec_roundtrip_boundary_matrix -- --nocapture";
fn fingerprint(bytes: &[u8]) -> String {
let prefix = bytes
.iter()
.take(8)
.map(|byte| format!("{byte:02x}"))
.collect::<String>();
format!("len{}:{prefix}", bytes.len())
}
let log_case = |corpus_label: &str,
declared_length: Option<usize>,
actual_length: usize,
message_type: &str,
allocation_guard_decision: &str,
decode_outcome: &str,
error_kind: &str,
roundtrip_fingerprint: &str| {
eprintln!(
"PROTOBUF_ENCODE_BOUNDARY corpus_label={} declared_length={} actual_length={} message_type={} allocation_guard_decision={} decode_outcome={} error_kind={} roundtrip_fingerprint={} exact_rch_command=\"{}\" artifact_paths=none final_no_realloc_panic_verdict=pass",
corpus_label,
declared_length.map_or_else(|| "none".to_string(), |len| len.to_string()),
actual_length,
message_type,
allocation_guard_decision,
decode_outcome,
error_kind,
roundtrip_fingerprint,
EXACT_RCH_COMMAND,
);
};
let empty = TestMessage::default();
let mut empty_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::with_max_size(0);
let empty_wire = empty_codec.encode(&empty).unwrap();
let empty_decoded = empty_codec.decode(&empty_wire).unwrap();
assert_eq!(empty_decoded, empty);
log_case(
"empty_roundtrip",
Some(empty.encoded_len()),
empty_wire.len(),
"TestMessage",
"exact-cap-accept",
"roundtrip-ok",
"ok",
&fingerprint(&empty_wire),
);
let small = TestMessage {
name: "hello".to_string(),
value: 42,
};
let small_cap = small.encoded_len();
let mut small_codec: ProstCodec<TestMessage, TestMessage> =
ProstCodec::with_max_size(small_cap);
let small_wire = small_codec.encode(&small).unwrap();
let small_decoded = small_codec.decode(&small_wire).unwrap();
assert_eq!(small_decoded, small);
log_case(
"small_roundtrip",
Some(small_cap),
small_wire.len(),
"TestMessage",
"exact-cap-accept",
"roundtrip-ok",
"ok",
&fingerprint(&small_wire),
);
let nested = NestedMessage {
inner: Some(TestMessage {
name: "inner".to_string(),
value: 7,
}),
items: vec!["a".to_string(), "bb".to_string(), "ccc".to_string()],
};
let nested_cap = nested.encoded_len();
let mut nested_codec: ProstCodec<NestedMessage, NestedMessage> =
ProstCodec::with_max_size(nested_cap);
let nested_wire = nested_codec.encode(&nested).unwrap();
let nested_decoded = nested_codec.decode(&nested_wire).unwrap();
assert_eq!(nested_decoded, nested);
log_case(
"nested_repeated_roundtrip",
Some(nested_cap),
nested_wire.len(),
"NestedMessage",
"exact-cap-accept",
"roundtrip-ok",
"ok",
&fingerprint(&nested_wire),
);
let max_bounded = TestMessage {
name: "max-bounded-message".repeat(8),
value: i32::MAX,
};
let max_bounded_cap = max_bounded.encoded_len();
let mut max_bounded_codec: ProstCodec<TestMessage, TestMessage> =
ProstCodec::with_max_size(max_bounded_cap);
let max_bounded_wire = max_bounded_codec.encode(&max_bounded).unwrap();
let max_bounded_decoded = max_bounded_codec.decode(&max_bounded_wire).unwrap();
assert_eq!(max_bounded_decoded, max_bounded);
log_case(
"max_bounded_roundtrip",
Some(max_bounded_cap),
max_bounded_wire.len(),
"TestMessage",
"exact-cap-accept",
"roundtrip-ok",
"ok",
&fingerprint(&max_bounded_wire),
);
let unknown_field = TestMessage {
name: "unknown-field".to_string(),
value: 99,
};
let mut unknown_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let mut unknown_wire = unknown_codec.encode(&unknown_field).unwrap().to_vec();
unknown_wire.extend_from_slice(&[0x98, 0x06, 0x7B]);
let unknown_decoded = unknown_codec.decode(&Bytes::from(unknown_wire)).unwrap();
assert_eq!(unknown_decoded, unknown_field);
let unknown_reencoded = unknown_codec.encode(&unknown_decoded).unwrap();
log_case(
"unknown_field_tolerant_roundtrip",
Some(unknown_reencoded.len()),
unknown_reencoded.len(),
"TestMessage",
"within-cap-accept",
"roundtrip-ok",
"ok",
&fingerprint(&unknown_reencoded),
);
let huge = TestMessage {
name: "x".repeat(4096),
value: 1,
};
let huge_declared = huge.encoded_len();
let huge_cap = huge_declared.saturating_sub(1);
let mut huge_codec: ProstCodec<TestMessage, TestMessage> =
ProstCodec::with_max_size(huge_cap);
let huge_err = huge_codec.encode(&huge).unwrap_err();
assert!(matches!(
huge_err,
ProtobufError::MessageTooLarge {
size,
limit
} if size == huge_declared && limit == huge_cap
));
log_case(
"huge_message_rejected_before_allocation",
Some(huge_declared),
huge_declared,
"TestMessage",
"reject-before-alloc",
"encode-rejected",
"MessageTooLarge",
"none",
);
let malformed_length_prefix = vec![0x0A, 0xFF, 0xFF, 0xFF, 0xFF, 0x0F];
let (declared_len, len_len) =
decode_test_varint(&malformed_length_prefix[1..]).expect("declared length");
let actual_len = malformed_length_prefix.len().saturating_sub(1 + len_len);
let mut malformed_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let malformed_err = malformed_codec
.decode(&Bytes::from(malformed_length_prefix))
.unwrap_err();
assert!(matches!(malformed_err, ProtobufError::DecodeError(_)));
log_case(
"malformed_length_prefix",
Some(declared_len as usize),
actual_len,
"TestMessage",
"declared>remaining",
"decode-err",
"DecodeError",
"none",
);
let truncated_payload = vec![0x0A, 0x03, b'a', b'b'];
let mut truncated_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let truncated_err = truncated_codec
.decode(&Bytes::from(truncated_payload.clone()))
.unwrap_err();
assert!(matches!(truncated_err, ProtobufError::DecodeError(_)));
log_case(
"truncated_payload",
Some(3),
truncated_payload.len() - 2,
"TestMessage",
"declared>remaining",
"decode-err",
"DecodeError",
"none",
);
let arbitrary = vec![0xFF, 0x00, 0xFF];
let mut arbitrary_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let arbitrary_err = arbitrary_codec
.decode(&Bytes::from(arbitrary.clone()))
.unwrap_err();
assert!(matches!(arbitrary_err, ProtobufError::DecodeError(_)));
log_case(
"arbitrary_bytes_typed_err",
None,
arbitrary.len(),
"TestMessage",
"pass-through",
"decode-err",
"DecodeError",
"none",
);
crate::test_complete!("conformance_prost_codec_roundtrip_boundary_matrix");
}
#[test]
fn test_prost_codec_unknown_fields() {
init_test("test_prost_codec_unknown_fields");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let message = TestMessage {
name: "test".to_string(),
value: 99,
};
let mut encoded = codec.encode(&message).unwrap().to_vec();
encoded.extend_from_slice(&[0x98, 0x06, 0x7B]);
let decoded = codec.decode(&Bytes::from(encoded)).unwrap();
let ok = decoded.name == "test" && decoded.value == 99;
crate::assert_with_log!(ok, "unknown field ignored", true, ok);
crate::test_complete!("test_prost_codec_unknown_fields");
}
#[test]
fn test_prost_codec_deterministic_encoding() {
init_test("test_prost_codec_deterministic_encoding");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let message = TestMessage {
name: "deterministic".to_string(),
value: 123,
};
let encoded1 = codec.encode(&message).unwrap();
let encoded2 = codec.encode(&message).unwrap();
let encoded3 = codec.encode(&message).unwrap();
crate::assert_with_log!(
encoded1 == encoded2,
"encoding 1 == 2",
true,
encoded1 == encoded2
);
crate::assert_with_log!(
encoded2 == encoded3,
"encoding 2 == 3",
true,
encoded2 == encoded3
);
crate::test_complete!("test_prost_codec_deterministic_encoding");
}
#[test]
fn test_prost_codec_max_size_accessors() {
init_test("test_prost_codec_max_size_accessors");
let default_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let expected = DEFAULT_MAX_MESSAGE_SIZE;
let actual = default_codec.max_message_size();
crate::assert_with_log!(actual == expected, "default max size", expected, actual);
let custom_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::with_max_size(1024);
let expected = 1024;
let actual = custom_codec.max_message_size();
crate::assert_with_log!(actual == expected, "custom max size", expected, actual);
crate::test_complete!("test_prost_codec_max_size_accessors");
}
#[test]
fn test_prost_codec_clone() {
init_test("test_prost_codec_clone");
let codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::with_max_size(2048);
let cloned = codec.clone();
let expected = codec.max_message_size();
let actual = cloned.max_message_size();
crate::assert_with_log!(
actual == expected,
"clone preserves max size",
expected,
actual
);
crate::test_complete!("test_prost_codec_clone");
}
#[test]
fn test_symmetric_codec_alias() {
init_test("test_symmetric_codec_alias");
let mut codec: SymmetricProstCodec<TestMessage> = SymmetricProstCodec::new();
let message = TestMessage {
name: "symmetric".to_string(),
value: 777,
};
let encoded = codec.encode(&message).unwrap();
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(
decoded == message,
"symmetric roundtrip",
true,
decoded == message
);
crate::test_complete!("test_symmetric_codec_alias");
}
}