use super::types::{
FieldDescriptor, FieldType, FieldValue, MessageDescriptor, SchemaRegistryError,
SchemaRegistryResult,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum WireType {
Varint,
Fixed64,
LengthDelimited,
Fixed32,
}
impl WireType {
pub fn id(self) -> u32 {
match self {
WireType::Varint => 0,
WireType::Fixed64 => 1,
WireType::LengthDelimited => 2,
WireType::Fixed32 => 5,
}
}
pub fn from_id(id: u32) -> SchemaRegistryResult<Self> {
match id {
0 => Ok(WireType::Varint),
1 => Ok(WireType::Fixed64),
2 => Ok(WireType::LengthDelimited),
5 => Ok(WireType::Fixed32),
_ => Err(SchemaRegistryError::WireFormat(format!(
"unknown wire type id: {id}"
))),
}
}
pub fn for_field_type(ft: &FieldType) -> Self {
match ft {
FieldType::Int32
| FieldType::Int64
| FieldType::UInt32
| FieldType::UInt64
| FieldType::Bool => WireType::Varint,
FieldType::Float => WireType::Fixed32,
FieldType::Double => WireType::Fixed64,
FieldType::String
| FieldType::Bytes
| FieldType::Message(_)
| FieldType::Repeated(_) => WireType::LengthDelimited,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum WireValue {
Varint(u64),
Fixed64([u8; 8]),
LengthDelimited(Vec<u8>),
Fixed32([u8; 4]),
}
pub fn encode_varint(value: u64, buf: &mut Vec<u8>) {
let mut v = value;
loop {
let byte = (v & 0x7f) as u8;
v >>= 7;
if v == 0 {
buf.push(byte);
break;
}
buf.push(byte | 0x80);
}
}
pub fn decode_varint(buf: &[u8], pos: &mut usize) -> Option<u64> {
let mut result: u64 = 0;
let mut shift: u32 = 0;
loop {
if *pos >= buf.len() || shift >= 64 {
return None;
}
let byte = buf[*pos];
*pos += 1;
result |= ((byte & 0x7f) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
return Some(result);
}
}
}
pub fn encode_field(field_number: u32, wire_type: WireType, buf: &mut Vec<u8>) {
let tag: u64 = ((field_number as u64) << 3) | (wire_type.id() as u64);
encode_varint(tag, buf);
}
#[derive(Debug, Default)]
pub struct ProtoEncoder {
buf: Vec<u8>,
}
impl ProtoEncoder {
pub fn new() -> Self {
Self { buf: Vec::new() }
}
pub fn int32(mut self, field_number: u32, value: i32) -> Self {
encode_field(field_number, WireType::Varint, &mut self.buf);
encode_varint(value as u64, &mut self.buf);
self
}
pub fn int64(mut self, field_number: u32, value: i64) -> Self {
encode_field(field_number, WireType::Varint, &mut self.buf);
encode_varint(value as u64, &mut self.buf);
self
}
pub fn uint32(mut self, field_number: u32, value: u32) -> Self {
encode_field(field_number, WireType::Varint, &mut self.buf);
encode_varint(value as u64, &mut self.buf);
self
}
pub fn uint64(mut self, field_number: u32, value: u64) -> Self {
encode_field(field_number, WireType::Varint, &mut self.buf);
encode_varint(value, &mut self.buf);
self
}
pub fn bool(mut self, field_number: u32, value: bool) -> Self {
encode_field(field_number, WireType::Varint, &mut self.buf);
encode_varint(value as u64, &mut self.buf);
self
}
pub fn float(mut self, field_number: u32, value: f32) -> Self {
encode_field(field_number, WireType::Fixed32, &mut self.buf);
self.buf.extend_from_slice(&value.to_le_bytes());
self
}
pub fn double(mut self, field_number: u32, value: f64) -> Self {
encode_field(field_number, WireType::Fixed64, &mut self.buf);
self.buf.extend_from_slice(&value.to_le_bytes());
self
}
pub fn string(mut self, field_number: u32, value: &str) -> Self {
self.write_length_delimited(field_number, value.as_bytes());
self
}
pub fn bytes(mut self, field_number: u32, value: &[u8]) -> Self {
self.write_length_delimited(field_number, value);
self
}
pub fn message(mut self, field_number: u32, nested_bytes: &[u8]) -> Self {
self.write_length_delimited(field_number, nested_bytes);
self
}
pub fn build(self) -> Vec<u8> {
self.buf
}
fn write_length_delimited(&mut self, field_number: u32, data: &[u8]) {
encode_field(field_number, WireType::LengthDelimited, &mut self.buf);
encode_varint(data.len() as u64, &mut self.buf);
self.buf.extend_from_slice(data);
}
}
pub struct ProtoDecoder<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> ProtoDecoder<'a> {
pub fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
pub fn is_empty(&self) -> bool {
self.pos >= self.buf.len()
}
pub fn next_field(&mut self) -> Option<SchemaRegistryResult<(u32, WireValue)>> {
if self.is_empty() {
return None;
}
let tag = decode_varint(self.buf, &mut self.pos)?;
let wire_type_id = (tag & 0x07) as u32;
let field_number = (tag >> 3) as u32;
let wire_type = match WireType::from_id(wire_type_id) {
Ok(wt) => wt,
Err(e) => return Some(Err(e)),
};
let value = match self.decode_payload(wire_type) {
Ok(v) => v,
Err(e) => return Some(Err(e)),
};
Some(Ok((field_number, value)))
}
pub fn collect_all(&mut self) -> SchemaRegistryResult<Vec<(u32, WireValue)>> {
let mut out = Vec::new();
while let Some(result) = self.next_field() {
out.push(result?);
}
Ok(out)
}
fn decode_payload(&mut self, wire_type: WireType) -> SchemaRegistryResult<WireValue> {
match wire_type {
WireType::Varint => {
let v = decode_varint(self.buf, &mut self.pos).ok_or_else(|| {
SchemaRegistryError::WireFormat("truncated varint".to_string())
})?;
Ok(WireValue::Varint(v))
}
WireType::Fixed64 => {
if self.pos + 8 > self.buf.len() {
return Err(SchemaRegistryError::WireFormat(
"truncated fixed64 field".to_string(),
));
}
let mut b = [0u8; 8];
b.copy_from_slice(&self.buf[self.pos..self.pos + 8]);
self.pos += 8;
Ok(WireValue::Fixed64(b))
}
WireType::LengthDelimited => {
let len = decode_varint(self.buf, &mut self.pos).ok_or_else(|| {
SchemaRegistryError::WireFormat(
"truncated length prefix in length-delimited field".to_string(),
)
})? as usize;
if self.pos + len > self.buf.len() {
return Err(SchemaRegistryError::WireFormat(format!(
"length-delimited field claims {len} bytes but only {} remain",
self.buf.len() - self.pos
)));
}
let payload = self.buf[self.pos..self.pos + len].to_vec();
self.pos += len;
Ok(WireValue::LengthDelimited(payload))
}
WireType::Fixed32 => {
if self.pos + 4 > self.buf.len() {
return Err(SchemaRegistryError::WireFormat(
"truncated fixed32 field".to_string(),
));
}
let mut b = [0u8; 4];
b.copy_from_slice(&self.buf[self.pos..self.pos + 4]);
self.pos += 4;
Ok(WireValue::Fixed32(b))
}
}
}
}
pub fn encode_message(schema: &MessageDescriptor, values: &[(u32, FieldValue)]) -> Vec<u8> {
let mut buf = Vec::new();
for (field_number, value) in values {
let field_desc = match schema.field_by_number(*field_number) {
Some(fd) => fd,
None => continue,
};
encode_field_value(*field_number, &field_desc.field_type, value, &mut buf);
}
buf
}
pub fn decode_message(
schema: &MessageDescriptor,
bytes: &[u8],
) -> SchemaRegistryResult<Vec<(std::string::String, FieldValue)>> {
let mut decoder = ProtoDecoder::new(bytes);
let raw_fields = decoder.collect_all()?;
let mut out = Vec::new();
for (field_number, wire_value) in raw_fields {
let field_desc = match schema.field_by_number(field_number) {
Some(fd) => fd,
None => continue, };
let field_value = wire_value_to_field_value(&field_desc.field_type, wire_value)?;
out.push((field_desc.name.clone(), field_value));
}
Ok(out)
}
fn encode_field_value(
field_number: u32,
field_type: &FieldType,
value: &FieldValue,
buf: &mut Vec<u8>,
) {
match (field_type, value) {
(FieldType::Int32, FieldValue::Int32(v)) => {
encode_field(field_number, WireType::Varint, buf);
encode_varint(*v as u64, buf);
}
(FieldType::Int64, FieldValue::Int64(v)) => {
encode_field(field_number, WireType::Varint, buf);
encode_varint(*v as u64, buf);
}
(FieldType::Int64, FieldValue::Int32(v)) => {
encode_field(field_number, WireType::Varint, buf);
encode_varint(*v as u64, buf);
}
(FieldType::UInt32, FieldValue::UInt32(v)) => {
encode_field(field_number, WireType::Varint, buf);
encode_varint(*v as u64, buf);
}
(FieldType::UInt64, FieldValue::UInt64(v)) => {
encode_field(field_number, WireType::Varint, buf);
encode_varint(*v, buf);
}
(FieldType::UInt64, FieldValue::UInt32(v)) => {
encode_field(field_number, WireType::Varint, buf);
encode_varint(*v as u64, buf);
}
(FieldType::Bool, FieldValue::Bool(v)) => {
encode_field(field_number, WireType::Varint, buf);
encode_varint(*v as u64, buf);
}
(FieldType::Float, FieldValue::Float(v)) => {
encode_field(field_number, WireType::Fixed32, buf);
buf.extend_from_slice(&v.to_le_bytes());
}
(FieldType::Double, FieldValue::Double(v)) => {
encode_field(field_number, WireType::Fixed64, buf);
buf.extend_from_slice(&v.to_le_bytes());
}
(FieldType::Double, FieldValue::Float(v)) => {
encode_field(field_number, WireType::Fixed64, buf);
buf.extend_from_slice(&(*v as f64).to_le_bytes());
}
(FieldType::String, FieldValue::Str(s)) => {
let data = s.as_bytes();
encode_field(field_number, WireType::LengthDelimited, buf);
encode_varint(data.len() as u64, buf);
buf.extend_from_slice(data);
}
(FieldType::Bytes, FieldValue::Bytes(data)) => {
encode_field(field_number, WireType::LengthDelimited, buf);
encode_varint(data.len() as u64, buf);
buf.extend_from_slice(data);
}
(FieldType::Message(_), FieldValue::Message(data)) => {
encode_field(field_number, WireType::LengthDelimited, buf);
encode_varint(data.len() as u64, buf);
buf.extend_from_slice(data);
}
(FieldType::Repeated(_), FieldValue::Bytes(data)) => {
encode_field(field_number, WireType::LengthDelimited, buf);
encode_varint(data.len() as u64, buf);
buf.extend_from_slice(data);
}
_ => {
}
}
}
fn wire_value_to_field_value(ft: &FieldType, wv: WireValue) -> SchemaRegistryResult<FieldValue> {
match (ft, wv) {
(FieldType::Int32, WireValue::Varint(v)) => Ok(FieldValue::Int32(v as i32)),
(FieldType::Int64, WireValue::Varint(v)) => Ok(FieldValue::Int64(v as i64)),
(FieldType::UInt32, WireValue::Varint(v)) => Ok(FieldValue::UInt32(v as u32)),
(FieldType::UInt64, WireValue::Varint(v)) => Ok(FieldValue::UInt64(v)),
(FieldType::Bool, WireValue::Varint(v)) => Ok(FieldValue::Bool(v != 0)),
(FieldType::Float, WireValue::Fixed32(b)) => Ok(FieldValue::Float(f32::from_le_bytes(b))),
(FieldType::Double, WireValue::Fixed64(b)) => Ok(FieldValue::Double(f64::from_le_bytes(b))),
(FieldType::String, WireValue::LengthDelimited(data)) => {
let s = std::string::String::from_utf8(data).map_err(|e| {
SchemaRegistryError::WireFormat(format!("invalid UTF-8 in string field: {e}"))
})?;
Ok(FieldValue::Str(s))
}
(FieldType::Bytes, WireValue::LengthDelimited(data)) => Ok(FieldValue::Bytes(data)),
(FieldType::Message(_), WireValue::LengthDelimited(data)) => Ok(FieldValue::Message(data)),
(FieldType::Repeated(_), WireValue::LengthDelimited(data)) => Ok(FieldValue::Bytes(data)),
(ft, wv) => Err(SchemaRegistryError::WireFormat(format!(
"wire type mismatch for field type {}: got {:?}",
ft.proto_name(),
wv
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_varint_zero() {
let mut buf = Vec::new();
encode_varint(0, &mut buf);
assert_eq!(buf, [0x00]);
let mut pos = 0;
assert_eq!(decode_varint(&buf, &mut pos), Some(0));
assert_eq!(pos, 1);
}
#[test]
fn test_varint_one_byte_boundary() {
let mut buf = Vec::new();
encode_varint(127, &mut buf);
assert_eq!(buf, [0x7f]);
}
#[test]
fn test_varint_two_byte_boundary() {
let mut buf = Vec::new();
encode_varint(128, &mut buf);
assert_eq!(buf, [0x80, 0x01]);
}
#[test]
fn test_varint_300() {
let mut buf = Vec::new();
encode_varint(300, &mut buf);
assert_eq!(buf, [0xac, 0x02]);
let mut pos = 0;
assert_eq!(decode_varint(&buf, &mut pos), Some(300));
}
#[test]
fn test_varint_u64_max() {
let mut buf = Vec::new();
encode_varint(u64::MAX, &mut buf);
assert_eq!(buf.len(), 10);
let mut pos = 0;
assert_eq!(decode_varint(&buf, &mut pos), Some(u64::MAX));
}
#[test]
fn test_varint_roundtrip_sequence() {
let values: &[u64] = &[0, 1, 127, 128, 255, 1024, 65535, 1 << 32, u64::MAX];
for &v in values {
let mut buf = Vec::new();
encode_varint(v, &mut buf);
let mut pos = 0;
assert_eq!(decode_varint(&buf, &mut pos), Some(v), "value={v}");
}
}
#[test]
fn test_decode_varint_truncated_returns_none() {
let buf = [0x80]; let mut pos = 0;
assert_eq!(decode_varint(&buf, &mut pos), None);
}
#[test]
fn test_wire_type_tag_field_1_varint() {
let mut buf = Vec::new();
encode_field(1, WireType::Varint, &mut buf);
assert_eq!(buf, [0x08]);
}
#[test]
fn test_wire_type_tag_field_2_len_delim() {
let mut buf = Vec::new();
encode_field(2, WireType::LengthDelimited, &mut buf);
assert_eq!(buf, [0x12]);
}
#[test]
fn test_proto_encoder_int32() {
let bytes = ProtoEncoder::new().int32(1, 150).build();
assert_eq!(bytes, [0x08, 0x96, 0x01]);
}
#[test]
fn test_proto_encoder_string() {
let bytes = ProtoEncoder::new().string(1, "testing").build();
assert_eq!(bytes[0], 0x0a);
assert_eq!(bytes[1], 7);
assert_eq!(&bytes[2..], b"testing");
}
#[test]
fn test_proto_encoder_bool_true() {
let bytes = ProtoEncoder::new().bool(1, true).build();
assert_eq!(bytes, [0x08, 0x01]);
}
#[test]
fn test_proto_encoder_bool_false() {
let bytes = ProtoEncoder::new().bool(1, false).build();
assert_eq!(bytes, [0x08, 0x00]);
}
#[test]
fn test_proto_encoder_float() {
let v = 1.0_f32;
let bytes = ProtoEncoder::new().float(1, v).build();
assert_eq!(bytes[0], 0x0d);
let decoded = f32::from_le_bytes(bytes[1..5].try_into().expect("slice"));
assert!((decoded - 1.0).abs() < 1e-6);
}
#[test]
fn test_proto_encoder_double() {
let v = std::f64::consts::PI;
let bytes = ProtoEncoder::new().double(1, v).build();
assert_eq!(bytes[0], 0x09);
let decoded = f64::from_le_bytes(bytes[1..9].try_into().expect("slice"));
assert!((decoded - std::f64::consts::PI).abs() < 1e-12);
}
#[test]
fn test_proto_encoder_bytes() {
let data = b"\xde\xad\xbe\xef";
let bytes = ProtoEncoder::new().bytes(1, data).build();
assert_eq!(bytes[0], 0x0a); assert_eq!(bytes[1], 4); assert_eq!(&bytes[2..], data);
}
#[test]
fn test_proto_encoder_message_nested() {
let inner = ProtoEncoder::new().int32(1, 42).build();
let outer = ProtoEncoder::new().message(1, &inner).build();
let mut dec = ProtoDecoder::new(&outer);
let (fn_, wv) = dec.next_field().expect("field").expect("ok");
assert_eq!(fn_, 1);
if let WireValue::LengthDelimited(payload) = wv {
assert_eq!(payload, inner);
} else {
panic!("expected LengthDelimited");
}
}
#[test]
fn test_proto_decoder_varint_field() {
let bytes = ProtoEncoder::new().int64(3, 9999).build();
let mut dec = ProtoDecoder::new(&bytes);
let (fn_, wv) = dec.next_field().expect("field").expect("ok");
assert_eq!(fn_, 3);
assert_eq!(wv, WireValue::Varint(9999));
assert!(dec.is_empty());
}
#[test]
fn test_proto_decoder_multiple_fields() {
let bytes = ProtoEncoder::new()
.int32(1, 1)
.string(2, "abc")
.bool(3, true)
.build();
let mut dec = ProtoDecoder::new(&bytes);
let fields = dec.collect_all().expect("ok");
assert_eq!(fields.len(), 3);
assert_eq!(fields[0].0, 1);
assert_eq!(fields[1].0, 2);
assert_eq!(fields[2].0, 3);
}
#[test]
fn test_encode_decode_all_field_types() {
use crate::schema_registry::types::FieldDescriptor;
let desc = MessageDescriptor::new("AllTypes", "test")
.with_field(FieldDescriptor::optional(1, "i32", FieldType::Int32))
.with_field(FieldDescriptor::optional(2, "i64", FieldType::Int64))
.with_field(FieldDescriptor::optional(3, "u32", FieldType::UInt32))
.with_field(FieldDescriptor::optional(4, "u64", FieldType::UInt64))
.with_field(FieldDescriptor::optional(5, "flt", FieldType::Float))
.with_field(FieldDescriptor::optional(6, "dbl", FieldType::Double))
.with_field(FieldDescriptor::optional(7, "b", FieldType::Bool))
.with_field(FieldDescriptor::optional(8, "s", FieldType::String))
.with_field(FieldDescriptor::optional(9, "raw", FieldType::Bytes));
let values: Vec<(u32, FieldValue)> = vec![
(1, FieldValue::Int32(-7)),
(2, FieldValue::Int64(-9_999_999_999)),
(3, FieldValue::UInt32(42)),
(4, FieldValue::UInt64(u64::MAX)),
(5, FieldValue::Float(3.25)),
(6, FieldValue::Double(2.345_678_901)),
(7, FieldValue::Bool(true)),
(8, FieldValue::Str("hello".to_string())),
(9, FieldValue::Bytes(vec![0xca, 0xfe])),
];
let bytes = encode_message(&desc, &values);
let decoded = decode_message(&desc, &bytes).expect("decode ok");
assert_eq!(decoded.len(), 9);
assert_eq!(decoded[0], ("i32".to_string(), FieldValue::Int32(-7)));
assert_eq!(
decoded[1],
("i64".to_string(), FieldValue::Int64(-9_999_999_999))
);
assert_eq!(decoded[2], ("u32".to_string(), FieldValue::UInt32(42)));
assert_eq!(
decoded[3],
("u64".to_string(), FieldValue::UInt64(u64::MAX))
);
assert_eq!(decoded[6], ("b".to_string(), FieldValue::Bool(true)));
assert_eq!(
decoded[7],
("s".to_string(), FieldValue::Str("hello".to_string()))
);
assert_eq!(
decoded[8],
("raw".to_string(), FieldValue::Bytes(vec![0xca, 0xfe]))
);
}
#[test]
fn test_message_encode_decode_roundtrip() {
use crate::schema_registry::types::FieldDescriptor;
let desc = MessageDescriptor::new("Point", "geometry")
.with_field(FieldDescriptor::optional(1, "x", FieldType::Double))
.with_field(FieldDescriptor::optional(2, "y", FieldType::Double))
.with_field(FieldDescriptor::optional(3, "label", FieldType::String));
let values = vec![
(1, FieldValue::Double(1.5)),
(2, FieldValue::Double(-3.75)),
(3, FieldValue::Str("origin".to_string())),
];
let encoded = encode_message(&desc, &values);
let decoded = decode_message(&desc, &encoded).expect("decode ok");
assert_eq!(decoded.len(), 3);
assert_eq!(decoded[2].1, FieldValue::Str("origin".to_string()));
}
#[test]
fn test_nested_message_encoding() {
use crate::schema_registry::types::FieldDescriptor;
let inner_desc = MessageDescriptor::new("Inner", "test")
.with_field(FieldDescriptor::optional(1, "id", FieldType::Int32));
let inner_bytes = encode_message(&inner_desc, &[(1, FieldValue::Int32(99))]);
let outer_desc = MessageDescriptor::new("Outer", "test").with_field(
FieldDescriptor::optional(1, "nested", FieldType::Message("Inner".to_string())),
);
let outer_values = vec![(1, FieldValue::Message(inner_bytes.clone()))];
let outer_bytes = encode_message(&outer_desc, &outer_values);
let outer_decoded = decode_message(&outer_desc, &outer_bytes).expect("ok");
assert_eq!(outer_decoded.len(), 1);
if let FieldValue::Message(payload) = &outer_decoded[0].1 {
assert_eq!(payload, &inner_bytes);
} else {
panic!("expected Message variant");
}
}
#[test]
fn test_repeated_field_encoding() {
use crate::schema_registry::types::FieldDescriptor;
let desc = MessageDescriptor::new("Bag", "test").with_field(FieldDescriptor::optional(
1,
"items",
FieldType::Repeated(Box::new(FieldType::Int32)),
));
let mut packed = Vec::new();
for v in [1u64, 2, 3] {
encode_varint(v, &mut packed);
}
let values = vec![(1, FieldValue::Bytes(packed.clone()))];
let bytes = encode_message(&desc, &values);
let decoded = decode_message(&desc, &bytes).expect("ok");
assert_eq!(decoded.len(), 1);
assert_eq!(decoded[0].0, "items");
if let FieldValue::Bytes(b) = &decoded[0].1 {
assert_eq!(b, &packed);
} else {
panic!("expected Bytes");
}
}
#[test]
fn test_unknown_field_skipped_on_decode() {
use crate::schema_registry::types::FieldDescriptor;
let desc_full = MessageDescriptor::new("M", "test")
.with_field(FieldDescriptor::optional(1, "a", FieldType::Int32))
.with_field(FieldDescriptor::optional(2, "b", FieldType::String));
let bytes = encode_message(
&desc_full,
&[
(1, FieldValue::Int32(7)),
(2, FieldValue::Str("x".to_string())),
],
);
let desc_partial = MessageDescriptor::new("M", "test")
.with_field(FieldDescriptor::optional(1, "a", FieldType::Int32));
let decoded = decode_message(&desc_partial, &bytes).expect("ok");
assert_eq!(decoded.len(), 1);
assert_eq!(decoded[0].1, FieldValue::Int32(7));
}
}