#![allow(dead_code)]
use crate::error::IoError;
pub type ProtoResult<T> = Result<T, IoError>;
const WIRE_VARINT: u8 = 0;
const WIRE_FIXED64: u8 = 1;
const WIRE_LEN_DELIM: u8 = 2;
const WIRE_FIXED32: u8 = 5;
#[derive(Debug, Default)]
pub struct VarintEncoder {
buf: Vec<u8>,
}
impl VarintEncoder {
pub fn new() -> Self {
Self { buf: Vec::new() }
}
pub fn push_u64(&mut self, mut value: u64) {
loop {
let byte = (value & 0x7f) as u8;
value >>= 7;
if value == 0 {
self.buf.push(byte);
break;
}
self.buf.push(byte | 0x80);
}
}
pub fn push_i64(&mut self, value: i64) {
let zz = zigzag_encode(value);
self.push_u64(zz);
}
pub fn into_bytes(self) -> Vec<u8> {
self.buf
}
pub fn as_bytes(&self) -> &[u8] {
&self.buf
}
pub fn clear(&mut self) {
self.buf.clear();
}
}
pub struct VarintDecoder<'a> {
data: &'a [u8],
pos: usize,
}
impl<'a> VarintDecoder<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self { data, pos: 0 }
}
pub fn position(&self) -> usize {
self.pos
}
pub fn is_exhausted(&self) -> bool {
self.pos >= self.data.len()
}
pub fn read_u64(&mut self) -> ProtoResult<u64> {
let mut result: u64 = 0;
let mut shift: u32 = 0;
loop {
if self.pos >= self.data.len() {
return Err(IoError::FormatError(
"varint: unexpected end of input".to_string(),
));
}
let byte = self.data[self.pos];
self.pos += 1;
if shift >= 64 {
return Err(IoError::FormatError(
"varint: value overflows u64 (> 10 bytes)".to_string(),
));
}
result |= ((byte & 0x7f) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
return Ok(result);
}
}
}
pub fn read_i64(&mut self) -> ProtoResult<i64> {
let zz = self.read_u64()?;
Ok(zigzag_decode(zz))
}
}
#[inline]
pub fn zigzag_encode(v: i64) -> u64 {
((v << 1) ^ (v >> 63)) as u64
}
#[inline]
pub fn zigzag_decode(v: u64) -> i64 {
((v >> 1) as i64) ^ -((v & 1) as i64)
}
#[derive(Debug, Clone, PartialEq)]
pub enum ProtoValue {
Varint(u64),
SignedVarint(i64),
Fixed64([u8; 8]),
LengthDelimited(Vec<u8>),
Fixed32([u8; 4]),
}
impl ProtoValue {
pub fn wire_type(&self) -> u8 {
match self {
ProtoValue::Varint(_) | ProtoValue::SignedVarint(_) => WIRE_VARINT,
ProtoValue::Fixed64(_) => WIRE_FIXED64,
ProtoValue::LengthDelimited(_) => WIRE_LEN_DELIM,
ProtoValue::Fixed32(_) => WIRE_FIXED32,
}
}
pub fn as_str(&self) -> Option<&str> {
if let ProtoValue::LengthDelimited(b) = self {
std::str::from_utf8(b).ok()
} else {
None
}
}
pub fn from_string(s: &str) -> Self {
ProtoValue::LengthDelimited(s.as_bytes().to_vec())
}
pub fn from_bytes(b: Vec<u8>) -> Self {
ProtoValue::LengthDelimited(b)
}
pub fn from_embedded_message(encoded: Vec<u8>) -> Self {
ProtoValue::LengthDelimited(encoded)
}
pub fn from_f64(v: f64) -> Self {
ProtoValue::Fixed64(v.to_le_bytes())
}
pub fn from_f32(v: f32) -> Self {
ProtoValue::Fixed32(v.to_le_bytes())
}
pub fn as_f64(&self) -> Option<f64> {
if let ProtoValue::Fixed64(b) = self {
Some(f64::from_le_bytes(*b))
} else {
None
}
}
pub fn as_f32(&self) -> Option<f32> {
if let ProtoValue::Fixed32(b) = self {
Some(f32::from_le_bytes(*b))
} else {
None
}
}
fn encode_into(&self, buf: &mut Vec<u8>) {
match self {
ProtoValue::Varint(v) => {
let mut enc = VarintEncoder::new();
enc.push_u64(*v);
buf.extend_from_slice(enc.as_bytes());
}
ProtoValue::SignedVarint(v) => {
let mut enc = VarintEncoder::new();
enc.push_i64(*v);
buf.extend_from_slice(enc.as_bytes());
}
ProtoValue::Fixed64(b) => buf.extend_from_slice(b),
ProtoValue::LengthDelimited(b) => {
let mut enc = VarintEncoder::new();
enc.push_u64(b.len() as u64);
buf.extend_from_slice(enc.as_bytes());
buf.extend_from_slice(b);
}
ProtoValue::Fixed32(b) => buf.extend_from_slice(b),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ProtoField {
pub field_number: u32,
pub value: ProtoValue,
}
impl ProtoField {
pub fn new(field_number: u32, value: ProtoValue) -> Self {
Self {
field_number,
value,
}
}
fn encode_tag(&self) -> u64 {
((self.field_number as u64) << 3) | (self.value.wire_type() as u64)
}
pub fn encode_into(&self, buf: &mut Vec<u8>) {
let tag = self.encode_tag();
let mut enc = VarintEncoder::new();
enc.push_u64(tag);
buf.extend_from_slice(enc.as_bytes());
self.value.encode_into(buf);
}
}
#[derive(Debug, Default)]
pub struct ProtoMessageBuilder {
fields: Vec<ProtoField>,
}
impl ProtoMessageBuilder {
pub fn new() -> Self {
Self { fields: Vec::new() }
}
pub fn add_varint(mut self, field_number: u32, value: u64) -> Self {
self.fields.push(ProtoField::new(field_number, ProtoValue::Varint(value)));
self
}
pub fn add_sint64(mut self, field_number: u32, value: i64) -> Self {
self.fields.push(ProtoField::new(field_number, ProtoValue::SignedVarint(value)));
self
}
pub fn add_bool(mut self, field_number: u32, value: bool) -> Self {
self.add_varint(field_number, if value { 1 } else { 0 })
}
pub fn add_string(mut self, field_number: u32, value: &str) -> Self {
self.fields.push(ProtoField::new(
field_number,
ProtoValue::from_string(value),
));
self
}
pub fn add_bytes(mut self, field_number: u32, value: Vec<u8>) -> Self {
self.fields.push(ProtoField::new(
field_number,
ProtoValue::from_bytes(value),
));
self
}
pub fn add_message(mut self, field_number: u32, encoded: Vec<u8>) -> Self {
self.fields.push(ProtoField::new(
field_number,
ProtoValue::from_embedded_message(encoded),
));
self
}
pub fn add_f64(mut self, field_number: u32, value: f64) -> Self {
self.fields.push(ProtoField::new(
field_number,
ProtoValue::from_f64(value),
));
self
}
pub fn add_f32(mut self, field_number: u32, value: f32) -> Self {
self.fields.push(ProtoField::new(
field_number,
ProtoValue::from_f32(value),
));
self
}
pub fn add_fixed32(mut self, field_number: u32, value: u32) -> Self {
self.fields.push(ProtoField::new(
field_number,
ProtoValue::Fixed32(value.to_le_bytes()),
));
self
}
pub fn add_fixed64(mut self, field_number: u32, value: u64) -> Self {
self.fields.push(ProtoField::new(
field_number,
ProtoValue::Fixed64(value.to_le_bytes()),
));
self
}
pub fn add_packed_varints(mut self, field_number: u32, values: &[u64]) -> Self {
let mut enc = VarintEncoder::new();
for &v in values {
enc.push_u64(v);
}
let packed = enc.into_bytes();
self.fields.push(ProtoField::new(
field_number,
ProtoValue::LengthDelimited(packed),
));
self
}
pub fn build(self) -> Vec<u8> {
let mut buf = Vec::new();
for field in &self.fields {
field.encode_into(&mut buf);
}
buf
}
pub fn fields(&self) -> &[ProtoField] {
&self.fields
}
}
pub fn decode_proto_fields(data: &[u8]) -> ProtoResult<Vec<ProtoField>> {
ProtoDecoder::new(data).decode_all()
}
pub struct ProtoDecoder<'a> {
data: &'a [u8],
pos: usize,
}
impl<'a> ProtoDecoder<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self { data, pos: 0 }
}
pub fn position(&self) -> usize {
self.pos
}
pub fn is_exhausted(&self) -> bool {
self.pos >= self.data.len()
}
fn read_varint(&mut self) -> ProtoResult<u64> {
let mut result: u64 = 0;
let mut shift: u32 = 0;
loop {
if self.pos >= self.data.len() {
return Err(IoError::FormatError(
"protobuf: unexpected end of data reading varint".to_string(),
));
}
let byte = self.data[self.pos];
self.pos += 1;
if shift >= 64 {
return Err(IoError::FormatError(
"protobuf: varint overflows u64".to_string(),
));
}
result |= ((byte & 0x7f) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
return Ok(result);
}
}
}
fn read_bytes(&mut self, n: usize) -> ProtoResult<&'a [u8]> {
if self.pos + n > self.data.len() {
return Err(IoError::FormatError(format!(
"protobuf: need {n} bytes at offset {} but only {} remain",
self.pos,
self.data.len() - self.pos
)));
}
let slice = &self.data[self.pos..self.pos + n];
self.pos += n;
Ok(slice)
}
pub fn next_field(&mut self) -> ProtoResult<Option<ProtoField>> {
if self.is_exhausted() {
return Ok(None);
}
let tag = self.read_varint()?;
let wire_type = (tag & 0x07) as u8;
let field_number = (tag >> 3) as u32;
if field_number == 0 {
return Err(IoError::FormatError(
"protobuf: field number 0 is reserved".to_string(),
));
}
let value = match wire_type {
WIRE_VARINT => ProtoValue::Varint(self.read_varint()?),
WIRE_FIXED64 => {
let bytes = self.read_bytes(8)?;
let arr: [u8; 8] = bytes.try_into().map_err(|_| {
IoError::FormatError("protobuf: fixed64 conversion failed".to_string())
})?;
ProtoValue::Fixed64(arr)
}
WIRE_LEN_DELIM => {
let len = self.read_varint()? as usize;
let bytes = self.read_bytes(len)?;
ProtoValue::LengthDelimited(bytes.to_vec())
}
WIRE_FIXED32 => {
let bytes = self.read_bytes(4)?;
let arr: [u8; 4] = bytes.try_into().map_err(|_| {
IoError::FormatError("protobuf: fixed32 conversion failed".to_string())
})?;
ProtoValue::Fixed32(arr)
}
wt => {
return Err(IoError::FormatError(format!(
"protobuf: unsupported wire type {wt} for field {field_number}"
)));
}
};
Ok(Some(ProtoField::new(field_number, value)))
}
pub fn decode_all(&mut self) -> ProtoResult<Vec<ProtoField>> {
let mut fields = Vec::new();
while let Some(field) = self.next_field()? {
fields.push(field);
}
Ok(fields)
}
}
pub trait ProtoMessage: Sized {
fn encode(&self) -> Vec<u8>;
fn decode(fields: &[ProtoField]) -> ProtoResult<Self>;
fn from_bytes(bytes: &[u8]) -> ProtoResult<Self> {
let mut dec = ProtoDecoder::new(bytes);
let fields = dec.decode_all()?;
Self::decode(&fields)
}
fn to_bytes(&self) -> Vec<u8> {
self.encode()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProtoType {
Uint64,
Sint64,
Bool,
Double,
Float,
Bytes,
String,
Message,
}
#[derive(Debug, Clone)]
pub struct ProtoFieldDescriptor {
pub name: String,
pub field_number: u32,
pub field_type: ProtoType,
pub optional: bool,
}
impl ProtoFieldDescriptor {
pub fn new(
name: impl Into<String>,
field_number: u32,
field_type: ProtoType,
optional: bool,
) -> Self {
Self {
name: name.into(),
field_number,
field_type,
optional,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ProtoDescriptor {
pub message_name: String,
pub fields: Vec<ProtoFieldDescriptor>,
}
impl ProtoDescriptor {
pub fn new(message_name: impl Into<String>) -> Self {
Self {
message_name: message_name.into(),
fields: Vec::new(),
}
}
pub fn field(
mut self,
name: impl Into<String>,
field_number: u32,
field_type: ProtoType,
optional: bool,
) -> Self {
self.fields.push(ProtoFieldDescriptor::new(
name,
field_number,
field_type,
optional,
));
self
}
pub fn validate(&self, fields: &[ProtoField]) -> ProtoResult<()> {
for desc in &self.fields {
if desc.optional {
continue;
}
let found = fields.iter().any(|f| f.field_number == desc.field_number);
if !found {
return Err(IoError::FormatError(format!(
"protobuf: required field '{}' (number {}) is missing from '{}'",
desc.name, desc.field_number, self.message_name
)));
}
}
for field in fields {
if let Some(desc) = self
.fields
.iter()
.find(|d| d.field_number == field.field_number)
{
let expected_wt = expected_wire_type(&desc.field_type);
let actual_wt = field.value.wire_type();
if expected_wt != actual_wt {
return Err(IoError::FormatError(format!(
"protobuf: field '{}' (number {}) wire-type mismatch: expected {expected_wt}, got {actual_wt}",
desc.name, desc.field_number
)));
}
}
}
Ok(())
}
pub fn validate_bytes(&self, data: &[u8]) -> ProtoResult<Vec<ProtoField>> {
let mut dec = ProtoDecoder::new(data);
let fields = dec.decode_all()?;
self.validate(&fields)?;
Ok(fields)
}
pub fn field_by_name(&self, name: &str) -> Option<&ProtoFieldDescriptor> {
self.fields.iter().find(|f| f.name == name)
}
pub fn field_by_number(&self, number: u32) -> Option<&ProtoFieldDescriptor> {
self.fields.iter().find(|f| f.field_number == number)
}
}
fn expected_wire_type(pt: &ProtoType) -> u8 {
match pt {
ProtoType::Uint64 | ProtoType::Sint64 | ProtoType::Bool => WIRE_VARINT,
ProtoType::Double => WIRE_FIXED64,
ProtoType::Float => WIRE_FIXED32,
ProtoType::Bytes | ProtoType::String | ProtoType::Message => WIRE_LEN_DELIM,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_varint_single_byte_values() {
let mut enc = VarintEncoder::new();
for v in 0u64..128 {
enc.push_u64(v);
}
let bytes = enc.into_bytes();
assert_eq!(bytes.len(), 128);
for (i, &b) in bytes.iter().enumerate() {
assert_eq!(b, i as u8);
}
}
#[test]
fn test_varint_multi_byte() {
let cases: &[(u64, &[u8])] = &[
(128, &[0x80, 0x01]),
(300, &[0xac, 0x02]),
(16_383, &[0xff, 0x7f]),
(16_384, &[0x80, 0x80, 0x01]),
(u64::MAX, &[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01]),
];
for &(v, expected) in cases {
let mut enc = VarintEncoder::new();
enc.push_u64(v);
assert_eq!(enc.as_bytes(), expected, "mismatch for {v}");
let mut dec = VarintDecoder::new(expected);
let decoded = dec.read_u64().expect("decode");
assert_eq!(decoded, v);
assert!(dec.is_exhausted());
}
}
#[test]
fn test_zigzag_roundtrip() {
let values: &[i64] = &[0, -1, 1, -128, 127, i32::MIN as i64, i32::MAX as i64, i64::MIN, i64::MAX];
for &v in values {
let encoded = zigzag_encode(v);
let decoded = zigzag_decode(encoded);
assert_eq!(decoded, v, "zigzag roundtrip failed for {v}");
}
}
#[test]
fn test_signed_varint_encoder_decoder() {
let values: &[i64] = &[-1000, -1, 0, 1, 1000, i64::MIN / 2, i64::MAX / 2];
for &v in values {
let mut enc = VarintEncoder::new();
enc.push_i64(v);
let mut dec = VarintDecoder::new(enc.as_bytes());
let decoded = dec.read_i64().expect("decode");
assert_eq!(decoded, v, "signed varint roundtrip for {v}");
}
}
#[test]
fn test_proto_value_wire_types() {
assert_eq!(ProtoValue::Varint(0).wire_type(), 0);
assert_eq!(ProtoValue::SignedVarint(0).wire_type(), 0);
assert_eq!(ProtoValue::Fixed64([0u8; 8]).wire_type(), 1);
assert_eq!(ProtoValue::LengthDelimited(vec![]).wire_type(), 2);
assert_eq!(ProtoValue::Fixed32([0u8; 4]).wire_type(), 5);
}
#[test]
fn test_proto_value_as_str() {
let v = ProtoValue::from_string("hello");
assert_eq!(v.as_str(), Some("hello"));
assert_eq!(ProtoValue::Varint(0).as_str(), None);
}
#[test]
fn test_proto_value_f64_roundtrip() {
let pi = std::f64::consts::PI;
let v = ProtoValue::from_f64(pi);
assert!((v.as_f64().unwrap() - pi).abs() < 1e-15);
}
#[test]
fn test_proto_value_f32_roundtrip() {
let e = std::f32::consts::E;
let v = ProtoValue::from_f32(e);
assert!((v.as_f32().unwrap() - e).abs() < 1e-6);
}
#[test]
fn test_builder_varint_roundtrip() {
let bytes = ProtoMessageBuilder::new()
.add_varint(1, 42)
.add_varint(2, u64::MAX)
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.decode_all().expect("decode");
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].field_number, 1);
assert_eq!(fields[0].value, ProtoValue::Varint(42));
assert_eq!(fields[1].field_number, 2);
assert_eq!(fields[1].value, ProtoValue::Varint(u64::MAX));
}
#[test]
fn test_builder_string_roundtrip() {
let msg = "hello, world!";
let bytes = ProtoMessageBuilder::new()
.add_string(1, msg)
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.decode_all().expect("decode");
assert_eq!(fields.len(), 1);
assert_eq!(fields[0].field_number, 1);
assert_eq!(fields[0].value.as_str(), Some(msg));
}
#[test]
fn test_builder_f64_roundtrip() {
let val = std::f64::consts::TAU;
let bytes = ProtoMessageBuilder::new()
.add_f64(3, val)
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.decode_all().expect("decode");
assert_eq!(fields.len(), 1);
let decoded = fields[0].value.as_f64().expect("as_f64");
assert!((decoded - val).abs() < 1e-15);
}
#[test]
fn test_builder_f32_roundtrip() {
let val = 1.23456_f32;
let bytes = ProtoMessageBuilder::new()
.add_f32(4, val)
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.decode_all().expect("decode");
assert_eq!(fields.len(), 1);
let decoded = fields[0].value.as_f32().expect("as_f32");
assert!((decoded - val).abs() < 1e-5);
}
#[test]
fn test_builder_bool_roundtrip() {
let bytes = ProtoMessageBuilder::new()
.add_bool(1, true)
.add_bool(2, false)
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.decode_all().expect("decode");
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].value, ProtoValue::Varint(1));
assert_eq!(fields[1].value, ProtoValue::Varint(0));
}
#[test]
fn test_builder_sint64_roundtrip() {
let values: &[i64] = &[-42, 0, 42, i64::MIN / 2, i64::MAX / 2];
for &v in values {
let bytes = ProtoMessageBuilder::new()
.add_sint64(1, v)
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.decode_all().expect("decode");
assert_eq!(fields.len(), 1);
if let ProtoValue::Varint(zz) = fields[0].value {
let decoded = zigzag_decode(zz);
assert_eq!(decoded, v, "sint64 roundtrip for {v}");
} else {
panic!("expected Varint for SignedVarint field, got {:?}", fields[0].value);
}
}
}
#[test]
fn test_builder_bytes_roundtrip() {
let data = vec![0xde, 0xad, 0xbe, 0xef];
let bytes = ProtoMessageBuilder::new()
.add_bytes(5, data.clone())
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.decode_all().expect("decode");
assert_eq!(fields.len(), 1);
assert_eq!(fields[0].value, ProtoValue::LengthDelimited(data));
}
#[test]
fn test_embedded_message_roundtrip() {
let inner = ProtoMessageBuilder::new()
.add_varint(1, 100)
.add_string(2, "inner value")
.build();
let outer = ProtoMessageBuilder::new()
.add_varint(1, 999)
.add_message(2, inner.clone())
.build();
let mut outer_dec = ProtoDecoder::new(&outer);
let outer_fields = outer_dec.decode_all().expect("decode outer");
assert_eq!(outer_fields.len(), 2);
assert_eq!(outer_fields[0].value, ProtoValue::Varint(999));
let inner_bytes = if let ProtoValue::LengthDelimited(ref b) = outer_fields[1].value {
b.clone()
} else {
panic!("expected LengthDelimited for embedded message");
};
assert_eq!(inner_bytes, inner);
let mut inner_dec = ProtoDecoder::new(&inner_bytes);
let inner_fields = inner_dec.decode_all().expect("decode inner");
assert_eq!(inner_fields.len(), 2);
assert_eq!(inner_fields[0].value, ProtoValue::Varint(100));
assert_eq!(inner_fields[1].value.as_str(), Some("inner value"));
}
#[test]
fn test_packed_varints_roundtrip() {
let values = vec![1u64, 2, 3, 300, u64::MAX / 2];
let bytes = ProtoMessageBuilder::new()
.add_packed_varints(1, &values)
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.decode_all().expect("decode");
assert_eq!(fields.len(), 1);
let packed_bytes = if let ProtoValue::LengthDelimited(ref b) = fields[0].value {
b.clone()
} else {
panic!("expected LengthDelimited");
};
let mut vdec = VarintDecoder::new(&packed_bytes);
let mut decoded = Vec::new();
while !vdec.is_exhausted() {
decoded.push(vdec.read_u64().expect("read packed varint"));
}
assert_eq!(decoded, values);
}
#[test]
fn test_multiple_fields_mixed() {
let bytes = ProtoMessageBuilder::new()
.add_varint(1, 42)
.add_string(2, "test message")
.add_f64(3, std::f64::consts::PI)
.add_bytes(4, vec![0xca, 0xfe])
.add_bool(5, true)
.add_fixed32(6, 0xdeadbeefu32)
.add_fixed64(7, 0xdeadbeefcafebabe_u64)
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.decode_all().expect("decode");
assert_eq!(fields.len(), 7);
assert_eq!(fields[0].field_number, 1);
assert_eq!(fields[1].field_number, 2);
assert_eq!(fields[2].field_number, 3);
assert_eq!(fields[3].field_number, 4);
assert_eq!(fields[4].field_number, 5);
assert_eq!(fields[5].field_number, 6);
assert_eq!(fields[6].field_number, 7);
assert_eq!(fields[0].value, ProtoValue::Varint(42));
assert_eq!(fields[1].value.as_str(), Some("test message"));
assert!((fields[2].value.as_f64().unwrap() - std::f64::consts::PI).abs() < 1e-15);
assert_eq!(fields[3].value, ProtoValue::LengthDelimited(vec![0xca, 0xfe]));
assert_eq!(fields[4].value, ProtoValue::Varint(1));
}
#[test]
fn test_decode_proto_fields_empty_input() {
let fields = decode_proto_fields(&[]).expect("empty decode");
assert!(fields.is_empty());
}
#[test]
fn test_decode_unknown_wire_type_returns_error() {
let bad = [0x0Bu8];
assert!(decode_proto_fields(&bad).is_err());
}
#[test]
fn test_descriptor_validates_required_fields() {
let desc = ProtoDescriptor::new("Test")
.field("count", 1, ProtoType::Uint64, false)
.field("label", 2, ProtoType::String, false);
let bytes = ProtoMessageBuilder::new()
.add_varint(1, 5)
.build();
let result = desc.validate_bytes(&bytes);
assert!(result.is_err());
}
#[test]
fn test_descriptor_validates_ok() {
let desc = ProtoDescriptor::new("Test")
.field("count", 1, ProtoType::Uint64, false)
.field("label", 2, ProtoType::String, false)
.field("score", 3, ProtoType::Double, true);
let bytes = ProtoMessageBuilder::new()
.add_varint(1, 42)
.add_string(2, "hello")
.build();
let fields = desc.validate_bytes(&bytes).expect("should validate");
assert_eq!(fields.len(), 2);
}
#[test]
fn test_descriptor_wire_type_mismatch() {
let desc = ProtoDescriptor::new("Test")
.field("value", 1, ProtoType::Double, false);
let bytes = ProtoMessageBuilder::new()
.add_varint(1, 42)
.build();
let result = desc.validate_bytes(&bytes);
assert!(result.is_err());
}
#[test]
fn test_descriptor_lookup() {
let desc = ProtoDescriptor::new("Msg")
.field("id", 1, ProtoType::Uint64, false)
.field("name", 2, ProtoType::String, true);
assert_eq!(desc.field_by_name("id").map(|f| f.field_number), Some(1));
assert_eq!(desc.field_by_number(2).map(|f| f.name.as_str()), Some("name"));
assert!(desc.field_by_name("missing").is_none());
}
#[test]
fn test_proto_message_trait_roundtrip() {
#[derive(Debug, PartialEq)]
struct Record {
id: u64,
name: String,
score: f64,
}
impl ProtoMessage for Record {
fn encode(&self) -> Vec<u8> {
ProtoMessageBuilder::new()
.add_varint(1, self.id)
.add_string(2, &self.name)
.add_f64(3, self.score)
.build()
}
fn decode(fields: &[ProtoField]) -> ProtoResult<Self> {
let mut id = None;
let mut name = String::new();
let mut score = None;
for field in fields {
match field.field_number {
1 => {
if let ProtoValue::Varint(v) = field.value {
id = Some(v);
}
}
2 => {
name = field.value.as_str().unwrap_or("").to_string();
}
3 => {
score = field.value.as_f64();
}
_ => {}
}
}
Ok(Record {
id: id.ok_or_else(|| IoError::FormatError("missing id".into()))?,
name,
score: score.ok_or_else(|| IoError::FormatError("missing score".into()))?,
})
}
}
let r = Record {
id: 99,
name: "SciRS2".to_string(),
score: 3.14159,
};
let bytes = r.to_bytes();
let r2 = Record::from_bytes(&bytes).expect("decode");
assert_eq!(r.id, r2.id);
assert_eq!(r.name, r2.name);
assert!((r.score - r2.score).abs() < 1e-12);
}
#[test]
fn test_official_spec_example_field1_varint150() {
let bytes = ProtoMessageBuilder::new()
.add_varint(1, 150)
.build();
assert_eq!(bytes, vec![0x08, 0x96, 0x01]);
}
#[test]
fn test_official_spec_example_string_testing() {
let bytes = ProtoMessageBuilder::new()
.add_string(2, "testing")
.build();
assert_eq!(bytes[0], 0x12); assert_eq!(bytes[1], 0x07); assert_eq!(&bytes[2..], b"testing");
}
}