extern crate alloc;
use alloc::{string::String, vec::Vec};
use ::bytes::Bytes;
use linkedbytes::LinkedBytes;
use super::{
DecodeError, Message,
encoding::{
DecodeContext, EncodeLengthContext, WireType, bool, bytes, double, float, int32, int64,
skip_field, string, uint32, uint64,
},
};
use crate::pb::ZERO_COPY_THRESHOLD;
impl Message for bool {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if *self {
bool::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
bool::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, _ctx: &mut EncodeLengthContext) -> usize {
if *self { 2 } else { 0 }
}
}
impl Message for u32 {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if *self != 0 {
uint32::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
uint32::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, ctx: &mut EncodeLengthContext) -> usize {
if *self != 0 {
uint32::encoded_len(ctx, 1, self)
} else {
0
}
}
}
impl Message for u64 {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if *self != 0 {
uint64::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
uint64::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, ctx: &mut EncodeLengthContext) -> usize {
if *self != 0 {
uint64::encoded_len(ctx, 1, self)
} else {
0
}
}
}
impl Message for i32 {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if *self != 0 {
int32::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
int32::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, ctx: &mut EncodeLengthContext) -> usize {
if *self != 0 {
int32::encoded_len(ctx, 1, self)
} else {
0
}
}
}
impl Message for i64 {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if *self != 0 {
int64::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
int64::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, ctx: &mut EncodeLengthContext) -> usize {
if *self != 0 {
int64::encoded_len(ctx, 1, self)
} else {
0
}
}
}
impl Message for f32 {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if *self != 0.0 {
float::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
float::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, ctx: &mut EncodeLengthContext) -> usize {
if *self != 0.0 {
float::encoded_len(ctx, 1, self)
} else {
0
}
}
}
impl Message for f64 {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if *self != 0.0 {
double::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
double::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, ctx: &mut EncodeLengthContext) -> usize {
if *self != 0.0 {
double::encoded_len(ctx, 1, self)
} else {
0
}
}
}
impl Message for String {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if !self.is_empty() {
string::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
string::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, ctx: &mut EncodeLengthContext) -> usize {
if !self.is_empty() {
string::encoded_len(ctx, 1, self)
} else {
0
}
}
}
impl Message for Vec<u8> {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if !self.is_empty() {
bytes::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
bytes::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, ctx: &mut EncodeLengthContext) -> usize {
if !self.is_empty() {
bytes::encoded_len(ctx, 1, self)
} else {
0
}
}
}
impl Message for Bytes {
fn encode_raw(&self, buf: &mut LinkedBytes) {
if !self.is_empty() {
bytes::encode(1, self, buf)
}
}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
if tag == 1 {
bytes::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self, ctx: &mut EncodeLengthContext) -> usize {
if !self.is_empty() {
if self.len() >= ZERO_COPY_THRESHOLD {
ctx.zero_copy_len += self.len();
}
bytes::encoded_len(ctx, 1, self)
} else {
0
}
}
}
impl Message for () {
fn encode_raw(&self, _buf: &mut LinkedBytes) {}
fn merge_field(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
ctx: &mut DecodeContext,
_is_root: bool,
) -> Result<(), DecodeError> {
skip_field(wire_type, tag, buf, ctx)
}
fn encoded_len(&self, _ctx: &mut EncodeLengthContext) -> usize {
0
}
}
#[cfg(test)]
mod tests {
use ::bytes::Bytes;
use linkedbytes::LinkedBytes;
use super::*;
use crate::pb::encoding::{DecodeContext, EncodeLengthContext, WireType};
fn create_test_buffer() -> LinkedBytes {
LinkedBytes::new()
}
fn create_test_bytes(data: &[u8]) -> Bytes {
Bytes::copy_from_slice(data)
}
#[test]
fn test_bool_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let true_val = true;
true_val.encode_raw(&mut buf);
assert_eq!(true_val.encoded_len(&mut ctx), 2);
let false_val = false;
let mut buf2 = create_test_buffer();
false_val.encode_raw(&mut buf2);
assert_eq!(false_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_bool_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut bool_val = false;
let mut buf = create_test_bytes(&[0x01]); bool_val
.merge_field(1, WireType::Varint, &mut buf, &mut ctx, true)
.unwrap();
assert!(bool_val);
let mut bool_val = true;
let mut buf = create_test_bytes(&[0x00]); bool_val
.merge_field(1, WireType::Varint, &mut buf, &mut ctx, true)
.unwrap();
assert!(!bool_val);
}
#[test]
fn test_u32_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = 42u32;
val.encode_raw(&mut buf);
assert!(val.encoded_len(&mut ctx) > 0);
let zero_val = 0u32;
let mut buf2 = create_test_buffer();
zero_val.encode_raw(&mut buf2);
assert_eq!(zero_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_u32_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = 0u32;
let mut buf = create_test_bytes(&[0x2A]); val.merge_field(1, WireType::Varint, &mut buf, &mut ctx, true)
.unwrap();
assert_eq!(val, 42);
}
#[test]
fn test_u64_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = 123456789u64;
val.encode_raw(&mut buf);
assert!(val.encoded_len(&mut ctx) > 0);
let zero_val = 0u64;
let mut buf2 = create_test_buffer();
zero_val.encode_raw(&mut buf2);
assert_eq!(zero_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_u64_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = 0u64;
let mut buf = create_test_bytes(&[0x95, 0x9A, 0xEF, 0x3A]); val.merge_field(1, WireType::Varint, &mut buf, &mut ctx, true)
.unwrap();
assert_eq!(val, 123456789);
}
#[test]
fn test_i32_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = -42i32;
val.encode_raw(&mut buf);
assert!(val.encoded_len(&mut ctx) > 0);
let zero_val = 0i32;
let mut buf2 = create_test_buffer();
zero_val.encode_raw(&mut buf2);
assert_eq!(zero_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_i32_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = 0i32;
let mut buf = create_test_bytes(&[0xd6, 0xff, 0xff, 0xff, 0x0f]); val.merge_field(1, WireType::Varint, &mut buf, &mut ctx, true)
.unwrap();
assert_eq!(val, -42);
}
#[test]
fn test_i64_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = -123456789i64;
val.encode_raw(&mut buf);
assert!(val.encoded_len(&mut ctx) > 0);
let zero_val = 0i64;
let mut buf2 = create_test_buffer();
zero_val.encode_raw(&mut buf2);
assert_eq!(zero_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_i64_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = 0i64;
let mut buf =
create_test_bytes(&[0xeb, 0xe5, 0x90, 0xc5, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01]); val.merge_field(1, WireType::Varint, &mut buf, &mut ctx, true)
.unwrap();
assert_eq!(val, -123456789);
}
#[test]
fn test_f32_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = 3.14f32;
val.encode_raw(&mut buf);
assert!(val.encoded_len(&mut ctx) > 0);
let zero_val = 0.0f32;
let mut buf2 = create_test_buffer();
zero_val.encode_raw(&mut buf2);
assert_eq!(zero_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_f32_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = 0.0f32;
let mut buf = create_test_bytes(&[0xC3, 0xF5, 0x48, 0x40]); val.merge_field(1, WireType::ThirtyTwoBit, &mut buf, &mut ctx, true)
.unwrap();
assert!((val - 3.14).abs() < 0.001);
}
#[test]
fn test_f64_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = 3.14159265359f64;
val.encode_raw(&mut buf);
assert!(val.encoded_len(&mut ctx) > 0);
let zero_val = 0.0f64;
let mut buf2 = create_test_buffer();
zero_val.encode_raw(&mut buf2);
assert_eq!(zero_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_f64_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = 0.0f64;
let mut buf = create_test_bytes(&[0xea, 0x2e, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40]); val.merge_field(1, WireType::SixtyFourBit, &mut buf, &mut ctx, true)
.unwrap();
assert!((val - 3.14159265359).abs() < 0.00000000001);
}
#[test]
fn test_string_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = String::from("hello");
val.encode_raw(&mut buf);
assert!(val.encoded_len(&mut ctx) > 0);
let empty_val = String::new();
let mut buf2 = create_test_buffer();
empty_val.encode_raw(&mut buf2);
assert_eq!(empty_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_string_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = String::new();
let mut buf = create_test_bytes(&[0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F]); val.merge_field(1, WireType::LengthDelimited, &mut buf, &mut ctx, true)
.unwrap();
assert_eq!(val, "hello");
}
#[test]
fn test_vec_u8_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = vec![1, 2, 3, 4, 5];
val.encode_raw(&mut buf);
assert!(val.encoded_len(&mut ctx) > 0);
let empty_val = vec![];
let mut buf2 = create_test_buffer();
empty_val.encode_raw(&mut buf2);
assert_eq!(empty_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_vec_u8_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = vec![];
let mut buf = create_test_bytes(&[0x05, 0x01, 0x02, 0x03, 0x04, 0x05]); val.merge_field(1, WireType::LengthDelimited, &mut buf, &mut ctx, true)
.unwrap();
assert_eq!(val, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_bytes_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = Bytes::from(vec![1, 2, 3, 4, 5]);
val.encode_raw(&mut buf);
assert!(val.encoded_len(&mut ctx) > 0);
let empty_val = Bytes::new();
let mut buf2 = create_test_buffer();
empty_val.encode_raw(&mut buf2);
assert_eq!(empty_val.encoded_len(&mut ctx), 0);
assert_eq!(buf2.len(), 0);
}
#[test]
fn test_bytes_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = Bytes::new();
let mut buf = create_test_bytes(&[0x05, 0x01, 0x02, 0x03, 0x04, 0x05]); val.merge_field(1, WireType::LengthDelimited, &mut buf, &mut ctx, true)
.unwrap();
assert_eq!(val, Bytes::from(vec![1, 2, 3, 4, 5]));
}
#[test]
fn test_unit_message_encoding() {
let mut buf = create_test_buffer();
let mut ctx = EncodeLengthContext::default();
let val = ();
val.encode_raw(&mut buf);
assert_eq!(val.encoded_len(&mut ctx), 0);
assert_eq!(buf.len(), 0);
}
#[test]
fn test_unit_message_decoding() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = ();
let mut buf = create_test_bytes(&[0x01]); val.merge_field(1, WireType::Varint, &mut buf, &mut ctx, true)
.unwrap();
}
#[test]
fn test_unknown_field_handling() {
let mut ctx = DecodeContext::new(Bytes::new());
let mut val = 42u32;
let mut buf = create_test_bytes(&[0x01]); val.merge_field(2, WireType::Varint, &mut buf, &mut ctx, true)
.unwrap();
assert_eq!(val, 42); }
#[test]
fn test_zero_copy_threshold() {
let mut ctx1 = EncodeLengthContext::default();
let small_bytes = Bytes::from(vec![1, 2, 3]);
let len1 = small_bytes.encoded_len(&mut ctx1);
assert_eq!(ctx1.zero_copy_len, 0);
let mut ctx2 = EncodeLengthContext::default();
let large_bytes = Bytes::from(vec![0; 4096]); let len2 = large_bytes.encoded_len(&mut ctx2);
assert!(ctx2.zero_copy_len > 0); assert!(len2 > len1);
}
#[test]
fn test_roundtrip_encoding_decoding() {
let original_bool = true;
let mut buf = create_test_buffer();
original_bool.encode_raw(&mut buf);
let mut decoded_bool = false;
let mut ctx = DecodeContext::new(Bytes::new());
let buf_bytes = buf.concat().freeze();
let mut value_bytes = buf_bytes.slice(1..);
decoded_bool
.merge_field(1, WireType::Varint, &mut value_bytes, &mut ctx, true)
.unwrap();
assert_eq!(original_bool, decoded_bool);
let original_u32 = 12345u32;
let mut buf = create_test_buffer();
original_u32.encode_raw(&mut buf);
let mut decoded_u32 = 0u32;
let mut ctx = DecodeContext::new(Bytes::new());
let buf_bytes = buf.concat().freeze();
let mut value_bytes = buf_bytes.slice(1..);
decoded_u32
.merge_field(1, WireType::Varint, &mut value_bytes, &mut ctx, true)
.unwrap();
assert_eq!(original_u32, decoded_u32);
let original_string = String::from("test message");
let mut buf = create_test_buffer();
original_string.encode_raw(&mut buf);
let mut decoded_string = String::new();
let mut ctx = DecodeContext::new(Bytes::new());
let buf_bytes = buf.concat().freeze();
let mut value_bytes = buf_bytes.slice(1..);
decoded_string
.merge_field(
1,
WireType::LengthDelimited,
&mut value_bytes,
&mut ctx,
true,
)
.unwrap();
assert_eq!(original_string, decoded_string);
}
}