#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct SchemaId(pub u32);
impl std::fmt::Display for SchemaId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WireFormat {
Protobuf,
Json,
}
impl WireFormat {
pub fn from_codec_name(name: &str) -> Option<Self> {
match name {
"protobuf" => Some(WireFormat::Protobuf),
"json" => Some(WireFormat::Json),
_ => None,
}
}
}
#[derive(Debug, PartialEq)]
pub enum FrameResult<'a> {
Null,
Unframed(&'a [u8]),
Framed { id: SchemaId, payload: &'a [u8] },
}
fn read_varint(bytes: &[u8]) -> Option<(u64, usize)> {
let mut value: u64 = 0;
let mut shift = 0u32;
for (i, b) in bytes.iter().enumerate() {
if shift >= 64 {
return None;
}
if shift == 63 && (b & 0x7e) != 0 {
return None;
}
value |= u64::from(b & 0x7f) << shift;
if b & 0x80 == 0 {
return Some((value, i + 1));
}
shift += 7;
}
None
}
pub fn parse_frame(format: WireFormat, bytes: &[u8]) -> FrameResult<'_> {
if bytes.is_empty() {
return FrameResult::Null;
}
if bytes[0] != 0x00 {
return FrameResult::Unframed(bytes);
}
if bytes.len() < 5 {
return FrameResult::Unframed(bytes);
}
let id = SchemaId(u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]));
let rest = &bytes[5..];
let payload = match format {
WireFormat::Json => rest,
WireFormat::Protobuf => match skip_message_indexes(rest) {
Some(p) => p,
None => return FrameResult::Unframed(bytes),
},
};
FrameResult::Framed { id, payload }
}
fn skip_message_indexes(bytes: &[u8]) -> Option<&[u8]> {
let (count, mut off) = read_varint(bytes)?;
if count == 0 {
return Some(&bytes[off..]);
}
for _ in 0..count {
let (_idx, n) = read_varint(&bytes[off..])?;
off += n;
}
Some(&bytes[off..])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_is_null() {
assert_eq!(parse_frame(WireFormat::Json, &[]), FrameResult::Null);
}
#[test]
fn no_magic_byte_is_unframed() {
let bytes = [0x7b, 0x22]; assert_eq!(
parse_frame(WireFormat::Json, &bytes),
FrameResult::Unframed(&bytes)
);
}
#[test]
fn json_frame_extracts_be_id_and_payload() {
let bytes = [0x00, 0x00, 0x00, 0x00, 0x01, b'{', b'}'];
assert_eq!(
parse_frame(WireFormat::Json, &bytes),
FrameResult::Framed {
id: SchemaId(1),
payload: b"{}",
}
);
}
#[test]
fn json_frame_large_id() {
let bytes = [0x00, 0x00, 0x01, 0x86, 0xa0, 0xAB]; assert_eq!(
parse_frame(WireFormat::Json, &bytes),
FrameResult::Framed {
id: SchemaId(100_000),
payload: &[0xAB]
}
);
}
#[test]
fn protobuf_single_zero_index_optimization() {
let bytes = [0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0xDE, 0xAD];
assert_eq!(
parse_frame(WireFormat::Protobuf, &bytes),
FrameResult::Framed {
id: SchemaId(5),
payload: &[0xDE, 0xAD]
}
);
}
#[test]
fn protobuf_explicit_index_array_is_skipped() {
let bytes = [0x00, 0x00, 0x00, 0x00, 0x05, 0x02, 0x01, 0x03, 0xBE, 0xEF];
assert_eq!(
parse_frame(WireFormat::Protobuf, &bytes),
FrameResult::Framed {
id: SchemaId(5),
payload: &[0xBE, 0xEF]
}
);
}
#[test]
fn magic_byte_but_too_short_is_unframed() {
let bytes = [0x00, 0x00, 0x01];
assert_eq!(
parse_frame(WireFormat::Json, &bytes),
FrameResult::Unframed(&bytes)
);
}
#[test]
fn wire_format_from_codec_name() {
assert_eq!(
WireFormat::from_codec_name("protobuf"),
Some(WireFormat::Protobuf)
);
assert_eq!(WireFormat::from_codec_name("json"), Some(WireFormat::Json));
assert_eq!(WireFormat::from_codec_name("raw"), None);
assert_eq!(WireFormat::from_codec_name("avro"), None);
}
#[test]
fn varint_overflow_returns_none_not_truncated_value() {
let mut frame = vec![0x00, 0x00, 0x00, 0x00, 0x07];
frame.extend_from_slice(&[0x80u8; 9]);
frame.push(0x04);
assert!(matches!(
parse_frame(WireFormat::Protobuf, &frame),
FrameResult::Unframed(_)
));
}
#[test]
fn protobuf_truncated_index_array_is_unframed() {
let bytes = [0x00, 0x00, 0x00, 0x00, 0x07, 0x03, 0x01, 0x02];
assert!(matches!(
parse_frame(WireFormat::Protobuf, &bytes),
FrameResult::Unframed(_)
));
}
#[test]
fn protobuf_multibyte_count_varint() {
let mut frame = vec![0x00, 0x00, 0x00, 0x00, 0x07];
frame.push(0x80); frame.push(0x01); frame.extend_from_slice(&[0x00u8; 128]); frame.push(0xAA); assert_eq!(
parse_frame(WireFormat::Protobuf, &frame),
FrameResult::Framed {
id: SchemaId(7),
payload: &[0xAA],
}
);
}
}