use crate::error;
use core::mem;
pub const MIN_TAG: u32 = 1;
pub const MAX_TAG: u32 = (1 << 29) - 1;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[repr(u8)]
pub enum WireType {
Varint = 0,
SixtyFourBit = 1,
LengthDelimited = 2,
StartGroup = 3,
EndGroup = 4,
ThirtyTwoBit = 5,
}
impl TryFrom<u64> for WireType {
type Error = error::DecodeError;
#[inline]
fn try_from(value: u64) -> Result<Self, Self::Error> {
match value {
0 => Ok(WireType::Varint),
1 => Ok(WireType::SixtyFourBit),
2 => Ok(WireType::LengthDelimited),
3 => Ok(WireType::StartGroup),
4 => Ok(WireType::EndGroup),
5 => Ok(WireType::ThirtyTwoBit),
_ => Err(error::DecodeError::InvalidWireTypeValue(value)),
}
}
}
#[inline]
#[cfg_attr(feature = "assert-no-panic", no_panic::no_panic)]
pub fn encode_varint(mut value: u64, cursor: &mut &mut [u8]) {
loop {
let buf = mem::replace(cursor, &mut []);
let (byte, rest) = buf.split_first_mut().unwrap();
*cursor = rest;
if value < 0x80 {
*byte = value as u8;
break;
} else {
*byte = ((value & 0x7F) | 0x80) as u8;
value >>= 7;
}
}
}
#[cfg_attr(feature = "assert-no-panic", no_panic::no_panic)]
#[inline]
pub fn decode_varint(cursor: &mut &[u8]) -> Result<u64, error::DecodeError> {
#[inline]
#[cold]
fn buffer_underflow<A>() -> Result<A, error::DecodeError> {
Err(error::DecodeError::BufferUnderflow)
}
#[inline]
#[cold]
fn invalid_varint<A>() -> Result<A, error::DecodeError> {
Err(error::DecodeError::InvalidVarint)
}
#[inline]
fn take_first(slice: &mut &[u8]) -> Result<u8, error::DecodeError> {
if let Some((byte, rest)) = (*slice).split_first() {
*slice = rest;
Ok(*byte)
} else {
buffer_underflow()
}
}
let mut value: u64;
let byte = take_first(cursor)?;
value = u64::from(byte & 0b0111_1111);
if byte & 0b1000_0000 == 0 {
return Ok(value);
}
let byte = take_first(cursor)?;
value |= u64::from(byte & 0b0111_1111) << 7;
if byte & 0b1000_0000 == 0 {
return Ok(value);
}
let byte = take_first(cursor)?;
value |= u64::from(byte & 0b0111_1111) << 14;
if byte & 0b1000_0000 == 0 {
return Ok(value);
}
let byte = take_first(cursor)?;
value |= u64::from(byte & 0b0111_1111) << 21;
if byte & 0b1000_0000 == 0 {
return Ok(value);
}
let byte = take_first(cursor)?;
value |= u64::from(byte & 0b0111_1111) << 28;
if byte & 0b1000_0000 == 0 {
return Ok(value);
}
let byte = take_first(cursor)?;
value |= u64::from(byte & 0b0111_1111) << 35;
if byte & 0b1000_0000 == 0 {
return Ok(value);
}
let byte = take_first(cursor)?;
value |= u64::from(byte & 0b0111_1111) << 42;
if byte & 0b1000_0000 == 0 {
return Ok(value);
}
let byte = take_first(cursor)?;
value |= u64::from(byte & 0b0111_1111) << 49;
if byte & 0b1000_0000 == 0 {
return Ok(value);
}
let byte = take_first(cursor)?;
value |= u64::from(byte & 0b0111_1111) << 56;
if byte & 0b1000_0000 == 0 {
return Ok(value);
}
let byte = take_first(cursor)?;
value |= u64::from(byte & 0b0000_0001) << 63;
if byte & 0b1111_1110 == 0 {
return Ok(value);
}
invalid_varint()
}
#[inline]
#[cfg_attr(feature = "assert-no-panic", no_panic::no_panic)]
pub fn encoded_len_varint(value: u64) -> usize {
((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
}
#[inline]
#[cfg_attr(feature = "assert-no-panic", no_panic::no_panic)]
pub fn encode_key(tag: u32, wire_type: WireType, cursor: &mut &mut [u8]) {
debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
let key = (tag << 3) | wire_type as u32;
encode_varint(u64::from(key), cursor);
}
#[inline]
#[cfg_attr(feature = "assert-no-panic", no_panic::no_panic)]
pub fn decode_key(buf: &mut &[u8]) -> Result<(u32, WireType), error::DecodeError> {
let key = decode_varint(buf)?;
if key > u64::from(u32::MAX) {
return Err(error::DecodeError::InvalidKeyValue(key));
}
let wire_type = WireType::try_from(key & 0x7)?;
let tag = key as u32 >> 3;
if tag < MIN_TAG {
return Err(error::DecodeError::InvalidTagValue(tag));
}
Ok((tag, wire_type))
}
#[inline]
#[cfg_attr(feature = "assert-no-panic", no_panic::no_panic)]
pub fn key_len(tag: u32) -> usize {
encoded_len_varint(u64::from(tag << 3))
}
#[inline]
#[cfg_attr(feature = "assert-no-panic", no_panic::no_panic)]
pub(crate) fn check_wire_type(
expected: WireType,
actual: WireType,
) -> Result<(), error::DecodeError> {
if expected != actual {
Err(error::DecodeError::UnexpectedWireTypeValue { expected, actual })
} else {
Ok(())
}
}
#[inline]
pub fn skip_field(
wire_type: WireType,
tag: u32,
cursor: &mut &[u8],
) -> Result<(), error::DecodeError> {
let len = match wire_type {
WireType::Varint => {
decode_varint(cursor)?;
0 }
WireType::ThirtyTwoBit => 4,
WireType::SixtyFourBit => 8,
WireType::LengthDelimited => {
decode_varint(cursor)? as usize
}
WireType::StartGroup => loop {
let (inner_tag, inner_wire_type) = decode_key(cursor)?;
match inner_wire_type {
WireType::EndGroup => {
if inner_tag == tag {
break 0;
} else {
return Err(error::DecodeError::UnexpectedEndGroupTag);
}
}
_ => skip_field_in_group(inner_wire_type, inner_tag, cursor)?,
}
},
WireType::EndGroup => return Err(error::DecodeError::UnexpectedEndGroupTag),
};
if let Some(rest) = (*cursor).get(len..) {
*cursor = rest;
Ok(())
} else {
Err(error::DecodeError::BufferUnderflow)
}
}
#[inline(never)]
#[cold]
fn skip_field_in_group(
wire_type: WireType,
tag: u32,
cursor: &mut &[u8],
) -> Result<(), error::DecodeError> {
skip_field(wire_type, tag, cursor)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn varint() {
fn check(value: u64, encoded: &[u8]) {
let mut buf = vec![0; 100];
let mut buf_slice = buf.as_mut_slice();
encode_varint(value, &mut buf_slice);
let remaining = buf_slice.len();
let encoded_len = buf.len() - remaining;
assert_eq!(&buf[..encoded_len], encoded);
assert_eq!(encoded_len_varint(value), encoded.len());
let mut remaining = encoded;
let roundtrip_value = decode_varint(&mut remaining).expect("decoding failed");
assert!(remaining.is_empty());
assert_eq!(value, roundtrip_value);
}
check(2u64.pow(0) - 1, &[0x00]);
check(2u64.pow(0), &[0x01]);
check(2u64.pow(7) - 1, &[0x7F]);
check(2u64.pow(7), &[0x80, 0x01]);
check(300, &[0xAC, 0x02]);
check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
check(2u64.pow(14), &[0x80, 0x80, 0x01]);
check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);
check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);
check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
check(
2u64.pow(49) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(49),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);
check(
2u64.pow(56) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(56),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);
check(
2u64.pow(63) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(63),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);
check(
u64::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
);
}
#[test]
fn varint_overflow() {
let mut u64_max_plus_one: &[u8] =
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];
decode_varint(&mut u64_max_plus_one).expect_err("decoding u64::MAX + 1 succeeded");
}
}