use bytes::{BufMut, Bytes, BytesMut};
use crate::error::{Result, SchemaRegError};
use crate::glue::{
GLUE_COMPRESSION_NONE_BYTE, GLUE_COMPRESSION_ZLIB_BYTE, GLUE_HEADER_SIZE,
GLUE_HEADER_VERSION_BYTE, GlueCompression, GlueSchemaVersionId,
};
use crate::types::SchemaId;
pub(crate) const MAGIC_BYTE: u8 = 0x00;
pub(crate) const HEADER_SIZE: usize = 5;
#[inline]
fn write_varint(buf: &mut BytesMut, mut value: u64) {
loop {
let byte = (value & 0x7F) as u8;
value >>= 7;
if value == 0 {
buf.put_u8(byte);
break;
}
buf.put_u8(byte | 0x80);
}
}
#[inline]
fn varint_len(mut value: u64) -> usize {
let mut len = 1;
while value >= 0x80 {
value >>= 7;
len += 1;
}
len
}
#[inline]
fn zigzag_encode(n: i32) -> u64 {
((n << 1) ^ (n >> 31)) as u32 as u64
}
#[inline]
fn zigzag_decode(n: u64) -> i32 {
((n >> 1) as i32) ^ (-((n & 1) as i32))
}
const MAX_MESSAGE_INDEX_COUNT: u64 = 512;
fn read_varint(data: &[u8], offset: usize) -> Result<(u64, usize)> {
let mut result: u64 = 0;
let mut shift = 0u32;
let mut pos = offset;
loop {
if pos >= data.len() {
return Err(SchemaRegError::wire_format(
"truncated varint in Protobuf message-index",
));
}
let byte = data[pos] as u64;
pos += 1;
result |= (byte & 0x7F) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
if shift >= 64 {
return Err(SchemaRegError::wire_format(
"varint overflow in Protobuf message-index",
));
}
}
Ok((result, pos - offset))
}
#[must_use]
pub fn encode_wire_format(schema_id: impl Into<SchemaId>, payload: &[u8]) -> Bytes {
let schema_id = schema_id.into();
let mut buf = BytesMut::with_capacity(HEADER_SIZE + payload.len());
buf.put_u8(MAGIC_BYTE);
buf.put_u32(schema_id.as_u32());
buf.put_slice(payload);
buf.freeze()
}
pub fn decode_wire_format(data: &[u8]) -> Result<(SchemaId, &[u8])> {
let schema_id = validate_wire_header(data)?;
Ok((schema_id, &data[HEADER_SIZE..]))
}
pub fn decode_wire_format_bytes(data: &Bytes) -> Result<(SchemaId, Bytes)> {
let schema_id = validate_wire_header(data)?;
Ok((schema_id, data.slice(HEADER_SIZE..)))
}
pub(crate) fn validate_wire_header(data: &[u8]) -> Result<SchemaId> {
if data.len() < HEADER_SIZE {
return Err(SchemaRegError::wire_format(format!(
"wire format data too short: expected at least {HEADER_SIZE} bytes, got {}",
data.len()
)));
}
if data[0] != MAGIC_BYTE {
return Err(SchemaRegError::wire_format(format!(
"invalid wire format magic byte: expected 0x{MAGIC_BYTE:02X}, got 0x{:02X}",
data[0]
)));
}
Ok(SchemaId::from(u32::from_be_bytes([
data[1], data[2], data[3], data[4],
])))
}
#[must_use]
pub fn encode_protobuf_wire_format(
schema_id: impl Into<SchemaId>,
msg_indexes: &[i32],
payload: &[u8],
) -> Bytes {
let schema_id = schema_id.into();
let index_len: usize = varint_len(msg_indexes.len() as u64)
+ msg_indexes
.iter()
.map(|&i| varint_len(zigzag_encode(i)))
.sum::<usize>();
let mut buf = BytesMut::with_capacity(HEADER_SIZE + index_len + payload.len());
buf.put_u8(MAGIC_BYTE);
buf.put_u32(schema_id.as_u32());
write_varint(&mut buf, msg_indexes.len() as u64);
for &idx in msg_indexes {
write_varint(&mut buf, zigzag_encode(idx));
}
buf.put_slice(payload);
buf.freeze()
}
pub fn decode_protobuf_message_indexes(after_header: &[u8]) -> Result<(Vec<i32>, usize)> {
let (count, consumed) = read_varint(after_header, 0)?;
if count > MAX_MESSAGE_INDEX_COUNT {
return Err(SchemaRegError::wire_format(format!(
"Protobuf message-index count {count} exceeds the maximum of {MAX_MESSAGE_INDEX_COUNT}"
)));
}
let mut offset = consumed;
let mut indexes = Vec::with_capacity(count as usize);
for _ in 0..count {
let (raw, c) = read_varint(after_header, offset)?;
if raw > u32::MAX as u64 {
return Err(SchemaRegError::wire_format(
"Protobuf message-index value overflows i32 ZigZag range",
));
}
offset += c;
indexes.push(zigzag_decode(raw));
}
Ok((indexes, offset))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum DetectedWireFormat {
Confluent {
schema_id: SchemaId,
payload_offset: usize,
},
Glue {
version_id: GlueSchemaVersionId,
compression: GlueCompression,
payload_offset: usize,
},
InvalidConfluent,
InvalidGlue,
Unknown,
}
pub fn detect_wire_format(data: &[u8]) -> DetectedWireFormat {
if data.is_empty() {
return DetectedWireFormat::Unknown;
}
match data[0] {
MAGIC_BYTE => {
if data.len() < HEADER_SIZE {
return DetectedWireFormat::InvalidConfluent;
}
let schema_id =
SchemaId::from(u32::from_be_bytes([data[1], data[2], data[3], data[4]]));
DetectedWireFormat::Confluent {
schema_id,
payload_offset: HEADER_SIZE,
}
}
GLUE_HEADER_VERSION_BYTE => {
if data.len() < GLUE_HEADER_SIZE {
return DetectedWireFormat::InvalidGlue;
}
let compression = data[1];
if compression != GLUE_COMPRESSION_NONE_BYTE
&& compression != GLUE_COMPRESSION_ZLIB_BYTE
{
return DetectedWireFormat::InvalidGlue;
}
let compression = if compression == GLUE_COMPRESSION_NONE_BYTE {
GlueCompression::None
} else {
GlueCompression::Zlib
};
let mut version_bytes = [0u8; 16];
version_bytes.copy_from_slice(&data[2..GLUE_HEADER_SIZE]);
DetectedWireFormat::Glue {
version_id: GlueSchemaVersionId::from_bytes(version_bytes),
compression,
payload_offset: GLUE_HEADER_SIZE,
}
}
_ => DetectedWireFormat::Unknown,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_wire_format_roundtrip() {
let payload = b"hello world";
let encoded = encode_wire_format(42u32, payload);
let (id, decoded) = decode_wire_format(&encoded).unwrap();
assert_eq!(id, SchemaId::from(42u32));
assert_eq!(decoded, payload);
}
#[test]
fn test_wire_format_empty_payload() {
let encoded = encode_wire_format(1u32, b"");
assert_eq!(encoded.len(), HEADER_SIZE);
let (id, payload) = decode_wire_format(&encoded).unwrap();
assert_eq!(id, SchemaId::from(1u32));
assert!(payload.is_empty());
}
#[test]
fn test_wire_format_max_schema_id() {
let encoded = encode_wire_format(u32::MAX, b"data");
let (id, _) = decode_wire_format(&encoded).unwrap();
assert_eq!(id, SchemaId::from(u32::MAX));
}
#[test]
fn test_wire_format_header_bytes() {
let encoded = encode_wire_format(256u32, b"x");
assert_eq!(&encoded[..5], &[0x00, 0x00, 0x00, 0x01, 0x00]);
assert_eq!(&encoded[5..], b"x");
}
#[test]
fn test_wire_format_invalid_magic_byte() {
let data = [0x01, 0, 0, 0, 1, 0x42];
let result = decode_wire_format(&data);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("magic byte"));
}
#[test]
fn test_wire_format_too_short() {
let result = decode_wire_format(&[0x00, 0, 0]);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("too short"));
}
#[test]
fn test_wire_format_empty_data() {
let result = decode_wire_format(&[]);
assert!(result.is_err());
}
#[test]
fn test_detect_wire_format_confluent() {
let encoded = encode_wire_format(42u32, b"data");
let detected = detect_wire_format(&encoded);
assert_eq!(
detected,
DetectedWireFormat::Confluent {
schema_id: SchemaId::from(42u32),
payload_offset: 5,
}
);
}
#[test]
fn test_detect_wire_format_unknown() {
assert_eq!(detect_wire_format(&[]), DetectedWireFormat::Unknown);
assert_eq!(
detect_wire_format(&[0x99, 0x00, 0x00]),
DetectedWireFormat::Unknown
);
}
#[test]
fn test_detect_wire_format_confluent_schema_id_zero() {
assert_eq!(
detect_wire_format(&[MAGIC_BYTE, 0x00, 0x00, 0x00, 0x00, 0x41]),
DetectedWireFormat::Confluent {
schema_id: SchemaId::from(0u32),
payload_offset: HEADER_SIZE,
}
);
}
#[test]
fn test_detect_wire_format_invalid_known_headers() {
assert_eq!(
detect_wire_format(&[MAGIC_BYTE, 0x01, 0x02]),
DetectedWireFormat::InvalidConfluent
);
use crate::glue::{GLUE_COMPRESSION_NONE_BYTE, GLUE_HEADER_VERSION_BYTE};
assert_eq!(
detect_wire_format(&[GLUE_HEADER_VERSION_BYTE, GLUE_COMPRESSION_NONE_BYTE]),
DetectedWireFormat::InvalidGlue
);
}
}