use crate::error::IoError;
pub type ProtoResult<T> = Result<T, IoError>;
const WIRE_VARINT: u8 = 0;
const WIRE_FIXED64: u8 = 1;
const WIRE_LEN_DELIM: u8 = 2;
const WIRE_FIXED32: u8 = 5;
#[derive(Debug, Clone, PartialEq)]
pub enum ProtobufField {
Varint(u64),
LengthDelimited(Vec<u8>),
Fixed64([u8; 8]),
Fixed32([u8; 4]),
}
impl ProtobufField {
fn wire_type(&self) -> u8 {
match self {
ProtobufField::Varint(_) => WIRE_VARINT,
ProtobufField::LengthDelimited(_) => WIRE_LEN_DELIM,
ProtobufField::Fixed64(_) => WIRE_FIXED64,
ProtobufField::Fixed32(_) => WIRE_FIXED32,
}
}
pub fn from_str(s: &str) -> Self {
ProtobufField::LengthDelimited(s.as_bytes().to_vec())
}
pub fn from_message(encoded: Vec<u8>) -> Self {
ProtobufField::LengthDelimited(encoded)
}
pub fn from_u64(v: u64) -> Self {
ProtobufField::Fixed64(v.to_le_bytes())
}
pub fn from_f64(v: f64) -> Self {
ProtobufField::Fixed64(v.to_le_bytes())
}
pub fn from_u32(v: u32) -> Self {
ProtobufField::Fixed32(v.to_le_bytes())
}
pub fn from_f32(v: f32) -> Self {
ProtobufField::Fixed32(v.to_le_bytes())
}
pub fn as_f64(&self) -> Option<f64> {
if let ProtobufField::Fixed64(b) = self {
Some(f64::from_le_bytes(*b))
} else {
None
}
}
pub fn as_f32(&self) -> Option<f32> {
if let ProtobufField::Fixed32(b) = self {
Some(f32::from_le_bytes(*b))
} else {
None
}
}
pub fn as_str(&self) -> Option<&str> {
if let ProtobufField::LengthDelimited(b) = self {
std::str::from_utf8(b).ok()
} else {
None
}
}
}
pub fn encode_varint(mut value: u64) -> Vec<u8> {
let mut buf = Vec::with_capacity(10);
loop {
let byte = (value & 0x7f) as u8;
value >>= 7;
if value == 0 {
buf.push(byte);
break;
} else {
buf.push(byte | 0x80);
}
}
buf
}
pub fn decode_varint(data: &[u8]) -> ProtoResult<(u64, &[u8])> {
if data.is_empty() {
return Err(IoError::ParseError(
"varint: empty input".to_string(),
));
}
let mut result: u64 = 0;
let mut shift = 0u32;
for (i, &byte) in data.iter().enumerate() {
if shift >= 64 {
return Err(IoError::ParseError(
"varint: overflow (more than 10 bytes)".to_string(),
));
}
result |= ((byte & 0x7f) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
return Ok((result, &data[i + 1..]));
}
}
Err(IoError::ParseError(
"varint: unterminated (missing stop byte)".to_string(),
))
}
pub fn encode_zigzag_i64(value: i64) -> Vec<u8> {
let zz = ((value << 1) ^ (value >> 63)) as u64;
encode_varint(zz)
}
pub fn decode_zigzag_i64(data: &[u8]) -> ProtoResult<(i64, &[u8])> {
let (zz, rest) = decode_varint(data)?;
let v = ((zz >> 1) as i64) ^ -((zz & 1) as i64);
Ok((v, rest))
}
pub fn encode_field(tag: u32, field: ProtobufField) -> Vec<u8> {
let wire_type = field.wire_type() as u32;
let key = (tag << 3) | wire_type;
let mut buf = encode_varint(key as u64);
match field {
ProtobufField::Varint(v) => buf.extend(encode_varint(v)),
ProtobufField::LengthDelimited(bytes) => {
buf.extend(encode_varint(bytes.len() as u64));
buf.extend_from_slice(&bytes);
}
ProtobufField::Fixed64(b) => buf.extend_from_slice(&b),
ProtobufField::Fixed32(b) => buf.extend_from_slice(&b),
}
buf
}
pub fn encode_fields(fields: &[(u32, ProtobufField)]) -> Vec<u8> {
let mut buf = Vec::new();
for (tag, field) in fields {
buf.extend(encode_field(*tag, field.clone()));
}
buf
}
pub fn decode_fields(data: &[u8]) -> ProtoResult<Vec<(u32, ProtobufField)>> {
let mut out = Vec::new();
let mut pos = data;
while !pos.is_empty() {
let (key, rest) = decode_varint(pos)?;
let wire_type = (key & 0x07) as u8;
let field_number = (key >> 3) as u32;
let (field, remaining) = decode_one_field(wire_type, rest)?;
out.push((field_number, field));
pos = remaining;
}
Ok(out)
}
fn decode_one_field<'a>(
wire_type: u8,
data: &'a [u8],
) -> ProtoResult<(ProtobufField, &'a [u8])> {
match wire_type {
WIRE_VARINT => {
let (v, rest) = decode_varint(data)?;
Ok((ProtobufField::Varint(v), rest))
}
WIRE_FIXED64 => {
if data.len() < 8 {
return Err(IoError::ParseError(
"fixed64: insufficient bytes".to_string(),
));
}
let mut b = [0u8; 8];
b.copy_from_slice(&data[..8]);
Ok((ProtobufField::Fixed64(b), &data[8..]))
}
WIRE_LEN_DELIM => {
let (len, rest) = decode_varint(data)?;
let len = len as usize;
if rest.len() < len {
return Err(IoError::ParseError(format!(
"length-delimited: need {len} bytes but only {} available",
rest.len()
)));
}
let payload = rest[..len].to_vec();
Ok((ProtobufField::LengthDelimited(payload), &rest[len..]))
}
WIRE_FIXED32 => {
if data.len() < 4 {
return Err(IoError::ParseError(
"fixed32: insufficient bytes".to_string(),
));
}
let mut b = [0u8; 4];
b.copy_from_slice(&data[..4]);
Ok((ProtobufField::Fixed32(b), &data[4..]))
}
_ => Err(IoError::ParseError(format!(
"unknown wire type: {wire_type}"
))),
}
}
#[derive(Debug, Default)]
pub struct MessageBuilder {
buf: Vec<u8>,
}
impl MessageBuilder {
pub fn new() -> Self {
Self { buf: Vec::new() }
}
pub fn varint(mut self, tag: u32, value: u64) -> Self {
self.buf
.extend(encode_field(tag, ProtobufField::Varint(value)));
self
}
pub fn sint64(mut self, tag: u32, value: i64) -> Self {
let zz = ((value << 1) ^ (value >> 63)) as u64;
self.buf
.extend(encode_field(tag, ProtobufField::Varint(zz)));
self
}
pub fn bytes(mut self, tag: u32, value: Vec<u8>) -> Self {
self.buf
.extend(encode_field(tag, ProtobufField::LengthDelimited(value)));
self
}
pub fn string(mut self, tag: u32, value: &str) -> Self {
self.buf
.extend(encode_field(tag, ProtobufField::from_str(value)));
self
}
pub fn f64(mut self, tag: u32, value: f64) -> Self {
self.buf
.extend(encode_field(tag, ProtobufField::from_f64(value)));
self
}
pub fn f32(mut self, tag: u32, value: f32) -> Self {
self.buf
.extend(encode_field(tag, ProtobufField::from_f32(value)));
self
}
pub fn message(mut self, tag: u32, encoded: Vec<u8>) -> Self {
self.buf
.extend(encode_field(tag, ProtobufField::from_message(encoded)));
self
}
pub fn build(self) -> Vec<u8> {
self.buf
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_varint_single_byte_values() {
for v in 0u64..128 {
let enc = encode_varint(v);
assert_eq!(enc.len(), 1);
let (dec, rest) = decode_varint(&enc).expect("decode failed");
assert_eq!(dec, v);
assert!(rest.is_empty());
}
}
#[test]
fn test_varint_multi_byte() {
let cases: &[(u64, &[u8])] = &[
(128, &[0x80, 0x01]),
(300, &[0xac, 0x02]),
(16_383, &[0xff, 0x7f]),
(16_384, &[0x80, 0x80, 0x01]),
];
for (v, expected) in cases {
let enc = encode_varint(*v);
assert_eq!(&enc, expected, "encode mismatch for {v}");
let (dec, rest) = decode_varint(&enc).expect("decode failed");
assert_eq!(dec, *v);
assert!(rest.is_empty());
}
}
#[test]
fn test_varint_max_u64() {
let enc = encode_varint(u64::MAX);
assert_eq!(enc.len(), 10);
let (dec, rest) = decode_varint(&enc).expect("decode max");
assert_eq!(dec, u64::MAX);
assert!(rest.is_empty());
}
#[test]
fn test_decode_varint_with_trailing_bytes() {
let data = [0xac, 0x02, 0xde, 0xad];
let (v, rest) = decode_varint(&data).expect("decode");
assert_eq!(v, 300);
assert_eq!(rest, &[0xde, 0xad]);
}
#[test]
fn test_decode_varint_empty_is_error() {
assert!(decode_varint(&[]).is_err());
}
#[test]
fn test_decode_varint_unterminated_is_error() {
let bad = [0x80u8; 11];
assert!(decode_varint(&bad).is_err());
}
#[test]
fn test_zigzag_roundtrip() {
let values: &[i64] = &[0, -1, 1, -2147483648, 2147483647, i64::MIN, i64::MAX];
for &v in values {
let enc = encode_zigzag_i64(v);
let (dec, rest) = decode_zigzag_i64(&enc).expect("zigzag decode");
assert_eq!(dec, v, "zigzag roundtrip for {v}");
assert!(rest.is_empty());
}
}
#[test]
fn test_encode_field_varint_example_from_spec() {
let enc = encode_field(1, ProtobufField::Varint(150));
assert_eq!(enc, vec![0x08, 0x96, 0x01]);
}
#[test]
fn test_encode_decode_string_field() {
let enc = encode_field(2, ProtobufField::from_str("testing"));
let fields = decode_fields(&enc).expect("decode");
assert_eq!(fields.len(), 1);
assert_eq!(fields[0].0, 2);
assert_eq!(fields[0].1.as_str(), Some("testing"));
}
#[test]
fn test_encode_decode_f64_field() {
let value = std::f64::consts::PI;
let enc = encode_field(3, ProtobufField::from_f64(value));
let fields = decode_fields(&enc).expect("decode");
let decoded = fields[0].1.as_f64().expect("as_f64");
assert!((decoded - value).abs() < 1e-15);
}
#[test]
fn test_encode_decode_f32_field() {
let value = std::f32::consts::E;
let enc = encode_field(4, ProtobufField::from_f32(value));
let fields = decode_fields(&enc).expect("decode");
let decoded = fields[0].1.as_f32().expect("as_f32");
assert!((decoded - value).abs() < 1e-6);
}
#[test]
fn test_multiple_fields_roundtrip() {
let msg = MessageBuilder::new()
.varint(1, 42)
.string(2, "hello world")
.f64(3, 2.718281828)
.bytes(4, vec![0xde, 0xad, 0xbe, 0xef])
.build();
let fields = decode_fields(&msg).expect("decode message");
assert_eq!(fields.len(), 4);
assert_eq!(fields[0], (1, ProtobufField::Varint(42)));
assert_eq!(fields[1].0, 2);
assert_eq!(fields[1].1.as_str(), Some("hello world"));
assert_eq!(fields[2].0, 3);
assert!((fields[2].1.as_f64().unwrap() - 2.718281828).abs() < 1e-9);
assert_eq!(
fields[3].1,
ProtobufField::LengthDelimited(vec![0xde, 0xad, 0xbe, 0xef])
);
}
#[test]
fn test_embedded_message() {
let inner = MessageBuilder::new()
.varint(1, 100)
.string(2, "inner")
.build();
let outer = MessageBuilder::new()
.varint(1, 999)
.message(2, inner.clone())
.build();
let fields = decode_fields(&outer).expect("decode outer");
assert_eq!(fields.len(), 2);
let inner_bytes = if let ProtobufField::LengthDelimited(b) = &fields[1].1 {
b.clone()
} else {
panic!("expected LengthDelimited for embedded message");
};
assert_eq!(inner_bytes, inner);
let inner_fields = decode_fields(&inner_bytes).expect("decode inner");
assert_eq!(inner_fields.len(), 2);
assert_eq!(inner_fields[0], (1, ProtobufField::Varint(100)));
}
#[test]
fn test_unknown_wire_type_returns_error() {
let bad_key = encode_varint((1u64 << 3) | 3); let result = decode_fields(&bad_key);
assert!(result.is_err());
}
#[test]
fn test_sint64_roundtrip() {
let values: &[i64] = &[-1000, -1, 0, 1, 1000, i64::MIN / 2, i64::MAX / 2];
for &v in values {
let msg = MessageBuilder::new().sint64(1, v).build();
let fields = decode_fields(&msg).expect("decode sint64");
if let ProtobufField::Varint(zz) = fields[0].1 {
let decoded = ((zz >> 1) as i64) ^ -((zz & 1) as i64);
assert_eq!(decoded, v, "sint64 roundtrip for {v}");
} else {
panic!("expected Varint");
}
}
}
#[test]
fn test_empty_string_field() {
let enc = encode_field(5, ProtobufField::from_str(""));
let fields = decode_fields(&enc).expect("decode empty str");
assert_eq!(fields[0].1.as_str(), Some(""));
}
#[test]
fn test_decode_fields_empty_input() {
let fields = decode_fields(&[]).expect("decode empty");
assert!(fields.is_empty());
}
}