use std::marker::PhantomData;
use crate::bytes::Bytes;
use super::codec::Codec;
pub use super::DEFAULT_MAX_MESSAGE_SIZE;
#[derive(Debug, thiserror::Error)]
pub enum ProtobufError {
#[error("failed to encode protobuf message: {0}")]
EncodeError(#[from] prost::EncodeError),
#[error("failed to decode protobuf message: {0}")]
DecodeError(#[from] prost::DecodeError),
#[error("message size {size} exceeds limit {limit}")]
MessageTooLarge {
size: usize,
limit: usize,
},
}
#[derive(Debug)]
pub struct ProstCodec<T, U> {
max_message_size: usize,
_marker: PhantomData<(T, U)>,
}
impl<T, U> ProstCodec<T, U> {
#[must_use]
pub fn new() -> Self {
Self {
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
_marker: PhantomData,
}
}
#[must_use]
pub fn with_max_size(max_size: usize) -> Self {
Self {
max_message_size: max_size,
_marker: PhantomData,
}
}
#[must_use]
pub fn max_message_size(&self) -> usize {
self.max_message_size
}
}
impl<T, U> Default for ProstCodec<T, U> {
fn default() -> Self {
Self::new()
}
}
impl<T, U> Clone for ProstCodec<T, U> {
fn clone(&self) -> Self {
Self {
max_message_size: self.max_message_size,
_marker: PhantomData,
}
}
}
impl<T, U> Codec for ProstCodec<T, U>
where
T: prost::Message + Send + 'static,
U: prost::Message + Default + Send + 'static,
{
type Encode = T;
type Decode = U;
type Error = ProtobufError;
fn encode(&mut self, item: &Self::Encode) -> Result<Bytes, Self::Error> {
let encoded_len = item.encoded_len();
if encoded_len > self.max_message_size {
return Err(ProtobufError::MessageTooLarge {
size: encoded_len,
limit: self.max_message_size,
});
}
let mut buf = Vec::with_capacity(encoded_len);
item.encode(&mut buf)?;
Ok(Bytes::from(buf))
}
fn decode(&mut self, buf: &Bytes) -> Result<Self::Decode, Self::Error> {
if buf.len() > self.max_message_size {
return Err(ProtobufError::MessageTooLarge {
size: buf.len(),
limit: self.max_message_size,
});
}
let message = U::decode(buf.as_ref())?;
Ok(message)
}
}
pub type SymmetricProstCodec<T> = ProstCodec<T, T>;
#[cfg(test)]
mod tests {
use super::*;
fn init_test(name: &str) {
crate::test_utils::init_test_logging();
crate::test_phase!(name);
}
#[derive(Clone, PartialEq, prost::Message)]
pub struct TestMessage {
#[prost(string, tag = "1")]
pub name: String,
#[prost(int32, tag = "2")]
pub value: i32,
}
#[derive(Clone, PartialEq, prost::Message)]
pub struct NestedMessage {
#[prost(message, optional, tag = "1")]
pub inner: Option<TestMessage>,
#[prost(repeated, string, tag = "2")]
pub items: Vec<String>,
}
#[derive(Clone, PartialEq, prost::Message)]
pub struct AllTypesMessage {
#[prost(double, tag = "1")]
pub double_field: f64,
#[prost(float, tag = "2")]
pub float_field: f32,
#[prost(int32, tag = "3")]
pub int32_field: i32,
#[prost(int64, tag = "4")]
pub int64_field: i64,
#[prost(uint32, tag = "5")]
pub uint32_field: u32,
#[prost(uint64, tag = "6")]
pub uint64_field: u64,
#[prost(sint32, tag = "7")]
pub sint32_field: i32,
#[prost(sint64, tag = "8")]
pub sint64_field: i64,
#[prost(fixed32, tag = "9")]
pub fixed32_field: u32,
#[prost(fixed64, tag = "10")]
pub fixed64_field: u64,
#[prost(sfixed32, tag = "11")]
pub sfixed32_field: i32,
#[prost(sfixed64, tag = "12")]
pub sfixed64_field: i64,
#[prost(bool, tag = "13")]
pub bool_field: bool,
#[prost(string, tag = "14")]
pub string_field: String,
#[prost(bytes = "vec", tag = "15")]
pub bytes_field: Vec<u8>,
}
#[test]
fn test_prost_codec_roundtrip() {
init_test("test_prost_codec_roundtrip");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let original = TestMessage {
name: "hello".to_string(),
value: 42,
};
let encoded = codec.encode(&original).unwrap();
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(
decoded == original,
"roundtrip",
original.name,
decoded.name
);
crate::test_complete!("test_prost_codec_roundtrip");
}
#[test]
fn test_prost_codec_nested_message() {
init_test("test_prost_codec_nested_message");
let mut codec: ProstCodec<NestedMessage, NestedMessage> = ProstCodec::new();
let original = NestedMessage {
inner: Some(TestMessage {
name: "inner".to_string(),
value: 100,
}),
items: vec!["a".to_string(), "b".to_string(), "c".to_string()],
};
let encoded = codec.encode(&original).unwrap();
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(decoded == original, "nested", true, decoded == original);
crate::test_complete!("test_prost_codec_nested_message");
}
#[test]
fn test_prost_codec_all_wire_types() {
init_test("test_prost_codec_all_wire_types");
let mut codec: ProstCodec<AllTypesMessage, AllTypesMessage> = ProstCodec::new();
let original = AllTypesMessage {
double_field: 1.234,
float_field: 5.678,
int32_field: -100,
int64_field: -200,
uint32_field: 300,
uint64_field: 400,
sint32_field: -500,
sint64_field: -600,
fixed32_field: 700,
fixed64_field: 800,
sfixed32_field: -900,
sfixed64_field: -1000,
bool_field: true,
string_field: "test string".to_string(),
bytes_field: vec![0x01, 0x02, 0x03, 0x04],
};
let encoded = codec.encode(&original).unwrap();
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(decoded == original, "wire types", true, decoded == original);
crate::test_complete!("test_prost_codec_all_wire_types");
}
#[test]
fn test_prost_codec_empty_message() {
init_test("test_prost_codec_empty_message");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let original = TestMessage::default();
let encoded = codec.encode(&original).unwrap();
let empty = encoded.is_empty();
crate::assert_with_log!(empty, "empty message encodes to empty bytes", true, empty);
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(
decoded == original,
"empty roundtrip",
true,
decoded == original
);
crate::test_complete!("test_prost_codec_empty_message");
}
#[test]
fn test_prost_codec_message_too_large_encode() {
init_test("test_prost_codec_message_too_large_encode");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::with_max_size(10);
let large_message = TestMessage {
name: "this is a very long string that exceeds the limit".to_string(),
value: 42,
};
let result = codec.encode(&large_message);
let is_err = matches!(result, Err(ProtobufError::MessageTooLarge { .. }));
crate::assert_with_log!(is_err, "encode fails for large message", true, is_err);
crate::test_complete!("test_prost_codec_message_too_large_encode");
}
#[test]
fn test_prost_codec_message_too_large_decode() {
init_test("test_prost_codec_message_too_large_decode");
let mut large_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let message = TestMessage {
name: "this is a long string".to_string(),
value: 42,
};
let encoded = large_codec.encode(&message).unwrap();
let mut small_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::with_max_size(5);
let result = small_codec.decode(&encoded);
let is_err = matches!(result, Err(ProtobufError::MessageTooLarge { .. }));
crate::assert_with_log!(is_err, "decode fails for large message", true, is_err);
crate::test_complete!("test_prost_codec_message_too_large_decode");
}
#[test]
fn test_prost_codec_invalid_data() {
init_test("test_prost_codec_invalid_data");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let invalid_data = Bytes::from_static(&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01,
]);
let result = codec.decode(&invalid_data);
let is_err = matches!(result, Err(ProtobufError::DecodeError(_)));
crate::assert_with_log!(is_err, "decode fails for invalid data", true, is_err);
crate::test_complete!("test_prost_codec_invalid_data");
}
#[test]
fn test_prost_codec_unknown_fields() {
init_test("test_prost_codec_unknown_fields");
let mut full_codec: ProstCodec<NestedMessage, NestedMessage> = ProstCodec::new();
let nested = NestedMessage {
inner: Some(TestMessage {
name: "test".to_string(),
value: 99,
}),
items: vec!["x".to_string()],
};
let encoded = full_codec.encode(&nested).unwrap();
let mut simple_codec: ProstCodec<NestedMessage, TestMessage> = ProstCodec::new();
let result = simple_codec.decode(&encoded);
let ok = result.is_ok();
crate::assert_with_log!(ok, "unknown fields ignored", true, ok);
crate::test_complete!("test_prost_codec_unknown_fields");
}
#[test]
fn test_prost_codec_deterministic_encoding() {
init_test("test_prost_codec_deterministic_encoding");
let mut codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let message = TestMessage {
name: "deterministic".to_string(),
value: 123,
};
let encoded1 = codec.encode(&message).unwrap();
let encoded2 = codec.encode(&message).unwrap();
let encoded3 = codec.encode(&message).unwrap();
crate::assert_with_log!(
encoded1 == encoded2,
"encoding 1 == 2",
true,
encoded1 == encoded2
);
crate::assert_with_log!(
encoded2 == encoded3,
"encoding 2 == 3",
true,
encoded2 == encoded3
);
crate::test_complete!("test_prost_codec_deterministic_encoding");
}
#[test]
fn test_prost_codec_max_size_accessors() {
init_test("test_prost_codec_max_size_accessors");
let default_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::new();
let expected = DEFAULT_MAX_MESSAGE_SIZE;
let actual = default_codec.max_message_size();
crate::assert_with_log!(actual == expected, "default max size", expected, actual);
let custom_codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::with_max_size(1024);
let expected = 1024;
let actual = custom_codec.max_message_size();
crate::assert_with_log!(actual == expected, "custom max size", expected, actual);
crate::test_complete!("test_prost_codec_max_size_accessors");
}
#[test]
fn test_prost_codec_clone() {
init_test("test_prost_codec_clone");
let codec: ProstCodec<TestMessage, TestMessage> = ProstCodec::with_max_size(2048);
let cloned = codec.clone();
let expected = codec.max_message_size();
let actual = cloned.max_message_size();
crate::assert_with_log!(
actual == expected,
"clone preserves max size",
expected,
actual
);
crate::test_complete!("test_prost_codec_clone");
}
#[test]
fn test_symmetric_codec_alias() {
init_test("test_symmetric_codec_alias");
let mut codec: SymmetricProstCodec<TestMessage> = SymmetricProstCodec::new();
let message = TestMessage {
name: "symmetric".to_string(),
value: 777,
};
let encoded = codec.encode(&message).unwrap();
let decoded = codec.decode(&encoded).unwrap();
crate::assert_with_log!(
decoded == message,
"symmetric roundtrip",
true,
decoded == message
);
crate::test_complete!("test_symmetric_codec_alias");
}
}