#![allow(unused)]
use super::leb128;
use alloc::vec::Vec;
use core::{iter, str};
pub(crate) fn bool_tag_encode(
field: u64,
bool_value: bool,
) -> impl Iterator<Item = impl AsRef<[u8]> + Clone> + Clone {
varint_zigzag_tag_encode(field, if bool_value { 1 } else { 0 }).map(|b| [b])
}
pub(crate) fn uint32_tag_encode(
field: u64,
value: u32,
) -> impl Iterator<Item = impl AsRef<[u8]> + Clone> + Clone {
varint_zigzag_tag_encode(field, u64::from(value)).map(|b| [b])
}
pub(crate) fn enum_tag_encode(
field: u64,
enum_value: u64,
) -> impl Iterator<Item = impl AsRef<[u8]> + Clone> + Clone {
varint_zigzag_tag_encode(field, enum_value).map(|b| [b])
}
pub(crate) fn message_tag_encode(
field: u64,
inner_message: impl Iterator<Item = impl AsRef<[u8]>>,
) -> impl Iterator<Item = impl AsRef<[u8]>> {
let inner_message = inner_message.fold(Vec::with_capacity(1024), |mut a, b| {
a.extend_from_slice(b.as_ref());
a
});
tag_encode(field, 2)
.chain(leb128::encode_usize(inner_message.len()))
.map(|v| either::Right([v]))
.chain(iter::once(either::Left(inner_message)))
}
pub(crate) fn bytes_tag_encode(
field: u64,
data: impl AsRef<[u8]> + Clone,
) -> impl Iterator<Item = impl AsRef<[u8]> + Clone> + Clone {
debug_assert!(data.as_ref().len() <= 2 * 1024 * 1024 * 1024);
delimited_tag_encode(field, data)
}
pub(crate) fn string_tag_encode(
field: u64,
data: impl AsRef<str> + Clone,
) -> impl Iterator<Item = impl AsRef<[u8]> + Clone> + Clone {
#[derive(Clone)]
struct Wrapper<T>(T);
impl<T: AsRef<str>> AsRef<[u8]> for Wrapper<T> {
fn as_ref(&self) -> &[u8] {
self.0.as_ref().as_bytes()
}
}
bytes_tag_encode(field, Wrapper(data))
}
pub(crate) fn tag_encode(field: u64, wire_ty: u8) -> impl Iterator<Item = u8> + Clone {
leb128::encode((field << 3) | u64::from(wire_ty))
}
pub(crate) fn varint_zigzag_tag_encode(field: u64, value: u64) -> impl Iterator<Item = u8> + Clone {
tag_encode(field, 0).chain(leb128::encode(value))
}
pub(crate) fn delimited_tag_encode(
field: u64,
data: impl AsRef<[u8]> + Clone,
) -> impl Iterator<Item = impl AsRef<[u8]> + Clone> + Clone {
tag_encode(field, 2)
.chain(leb128::encode_usize(data.as_ref().len()))
.map(|v| either::Right([v]))
.chain(iter::once(either::Left(data)))
}
pub(crate) fn tag_decode<'a, E: nom::error::ParseError<&'a [u8]>>(
bytes: &'a [u8],
) -> nom::IResult<&'a [u8], (u64, u8), E> {
nom::Parser::parse(
&mut nom::combinator::map(leb128::nom_leb128_u64, |num| {
let wire_ty = u8::try_from(num & 0b111).unwrap();
let field = num >> 3;
(field, wire_ty)
}),
bytes,
)
}
pub(crate) fn uint32_tag_decode<'a, E: nom::error::ParseError<&'a [u8]>>(
bytes: &'a [u8],
) -> nom::IResult<&'a [u8], u32, E> {
nom::Parser::parse(
&mut nom::combinator::map_opt(varint_zigzag_tag_decode, |num| u32::try_from(num).ok()),
bytes,
)
}
pub(crate) fn bool_tag_decode<'a, E: nom::error::ParseError<&'a [u8]>>(
bytes: &'a [u8],
) -> nom::IResult<&'a [u8], bool, E> {
nom::Parser::parse(
&mut nom::combinator::map(varint_zigzag_tag_decode, |n| {
n != 0
}),
bytes,
)
}
pub(crate) fn enum_tag_decode<'a, E: nom::error::ParseError<&'a [u8]>>(
bytes: &'a [u8],
) -> nom::IResult<&'a [u8], u64, E> {
varint_zigzag_tag_decode(bytes)
}
pub(crate) fn message_tag_decode<'a, O, E: nom::error::ParseError<&'a [u8]>>(
inner_message_parser: impl nom::Parser<&'a [u8], Output = O, Error = E>,
) -> impl nom::Parser<&'a [u8], Output = O, Error = E> {
nom::combinator::map_parser(delimited_tag_decode, inner_message_parser)
}
pub(crate) fn string_tag_decode<'a, E: nom::error::ParseError<&'a [u8]>>(
bytes: &'a [u8],
) -> nom::IResult<&'a [u8], &'a str, E> {
nom::Parser::parse(
&mut nom::combinator::map_opt(delimited_tag_decode, |bytes| {
if bytes.len() > 2 * 1024 * 1024 * 1024 {
return None;
}
str::from_utf8(bytes).ok()
}),
bytes,
)
}
pub(crate) fn bytes_tag_decode<'a, E: nom::error::ParseError<&'a [u8]>>(
bytes: &'a [u8],
) -> nom::IResult<&'a [u8], &'a [u8], E> {
nom::Parser::parse(
&mut nom::combinator::verify(delimited_tag_decode, |bytes: &[u8]| {
bytes.len() <= 2 * 1024 * 1024 * 1024
}),
bytes,
)
}
pub(crate) fn varint_zigzag_tag_decode<'a, E: nom::error::ParseError<&'a [u8]>>(
bytes: &'a [u8],
) -> nom::IResult<&'a [u8], u64, E> {
nom::Parser::parse(
&mut nom::sequence::preceded(
nom::combinator::verify(tag_decode, move |(_, ty)| *ty == 0),
leb128::nom_leb128_u64,
),
bytes,
)
}
pub(crate) fn delimited_tag_decode<'a, E: nom::error::ParseError<&'a [u8]>>(
bytes: &'a [u8],
) -> nom::IResult<&'a [u8], &'a [u8], E> {
nom::Parser::parse(
&mut nom::sequence::preceded(
nom::combinator::verify(tag_decode, move |(_, ty)| *ty == 2),
nom::multi::length_data(leb128::nom_leb128_usize),
),
bytes,
)
}
pub(crate) fn tag_value_skip_decode<'a, E: nom::error::ParseError<&'a [u8]>>(
bytes: &'a [u8],
) -> nom::IResult<&'a [u8], (), E> {
nom::Parser::parse(
&mut nom::combinator::flat_map(tag_decode, |(_, wire_ty)| {
move |inner_bytes| match wire_ty {
0 => nom::Parser::parse(
&mut nom::combinator::map(leb128::nom_leb128_u64, |_| ()),
inner_bytes,
),
5 => nom::Parser::parse(
&mut nom::combinator::map(nom::bytes::streaming::take(4u32), |_| ()),
inner_bytes,
),
1 => nom::Parser::parse(
&mut nom::combinator::map(nom::bytes::streaming::take(8u32), |_| ()),
inner_bytes,
),
2 => nom::Parser::parse(
&mut nom::combinator::map(
nom::multi::length_data(leb128::nom_leb128_usize),
|_| (),
),
inner_bytes,
),
_ => Err(nom::Err::Error(nom::error::ParseError::from_error_kind(
bytes,
nom::error::ErrorKind::Tag,
))),
}
}),
bytes,
)
}
macro_rules! message_decode {
($($(#[$($attrs:tt)*])* $field_name:ident = $field_num:expr => $parser:expr),*,) => {
$crate::util::protobuf::message_decode!($($(#[$($attrs)*])* $field_name = $field_num => $parser),*)
};
($($(#[$($attrs:tt)*])* $field_name:ident = $field_num:expr => $parser:expr),*) => {{
#[allow(non_camel_case_types)]
struct Out<$($field_name),*> {
$($field_name: $field_name,)*
}
|mut input| {
#[allow(non_camel_case_types)]
struct InProgress<$($field_name),*> {
$($field_name: $crate::util::protobuf::message_decode_helper_ty!($field_name; $($($attrs)*)*),)*
}
let mut in_progress = InProgress {
$($field_name: Default::default(),)*
};
loop {
if <[u8]>::is_empty(input) {
break;
}
let (_, (field_num, _wire_ty)) = $crate::util::protobuf::tag_decode(input)?;
$(if field_num == $field_num {
let (rest, value) = nom::Parser::<&[u8]>::parse(&mut $parser, input)?;
if input == rest {
return core::result::Result::Err(nom::Err::Error(
nom::error::ParseError::<&[u8]>::from_error_kind(rest, nom::error::ErrorKind::Alt)
));
}
$crate::util::protobuf::message_decode_helper_store!(input, value => in_progress.$field_name; $($($attrs)*)*);
input = rest;
continue;
})*
let (rest, ()) = $crate::util::protobuf::tag_value_skip_decode(input)?;
debug_assert!(input != rest);
input = rest;
}
let out = Out {
$($field_name: $crate::util::protobuf::message_decode_helper_unwrap!(in_progress.$field_name; $($($attrs)*)*)?,)*
};
Ok((input, out))
}
}};
}
macro_rules! message_decode_helper_ty {
($ty:ty; required) => { Option<$ty> };
($ty:ty; optional) => { Option<$ty> };
($ty:ty; repeated(max = $max:expr)) => { Vec<$ty> };
}
macro_rules! message_decode_helper_store {
($input_data:expr, $value:expr => $dest:expr; required) => {
if $dest.is_some() {
return core::result::Result::Err(nom::Err::Error(
nom::error::ParseError::<&[u8]>::from_error_kind(
$input_data,
nom::error::ErrorKind::Many1,
),
));
}
$dest = Some($value);
};
($input_data:expr, $value:expr => $dest:expr; optional) => {
if $dest.is_some() {
return core::result::Result::Err(nom::Err::Error(
nom::error::ParseError::<&[u8]>::from_error_kind(
$input_data,
nom::error::ErrorKind::Many1,
),
));
}
$dest = Some($value);
};
($input_data:expr, $value:expr => $dest:expr; repeated(max = $max:expr)) => {
if $dest.len() >= usize::try_from($max).unwrap_or(usize::MAX) {
return core::result::Result::Err(nom::Err::Error(
nom::error::ParseError::<&[u8]>::from_error_kind(
$input_data,
nom::error::ErrorKind::Many1,
),
));
}
$dest.push($value);
};
}
macro_rules! message_decode_helper_unwrap {
($value:expr; required) => {
$value.ok_or_else(|| {
nom::Err::Error(nom::error::ParseError::<&[u8]>::from_error_kind(
&[][..],
nom::error::ErrorKind::NoneOf,
))
})
};
($value:expr; optional) => {
Ok($value)
};
($value:expr; repeated(max = $max:expr)) => {
Ok($value)
};
}
pub(crate) use {
message_decode, message_decode_helper_store, message_decode_helper_ty,
message_decode_helper_unwrap,
};
#[cfg(test)]
mod tests {
#[test]
fn encode_decode_bool() {
let encoded = super::bool_tag_encode(504, true).fold(Vec::new(), |mut a, b| {
a.extend_from_slice(b.as_ref());
a
});
assert_eq!(&encoded, &[192, 31, 1]);
let decoded = super::bool_tag_decode::<nom::error::Error<&[u8]>>(&encoded)
.unwrap()
.1;
assert!(decoded);
}
#[test]
fn encode_decode_uint32() {
let encoded = super::uint32_tag_encode(8670, 93701).fold(Vec::new(), |mut a, b| {
a.extend_from_slice(b.as_ref());
a
});
assert_eq!(&encoded, &[240, 157, 4, 133, 220, 5]);
let decoded = super::uint32_tag_decode::<nom::error::Error<&[u8]>>(&encoded)
.unwrap()
.1;
assert_eq!(decoded, 93701);
}
#[test]
fn encode_decode_enum() {
let encoded = super::enum_tag_encode(107, 935237).fold(Vec::new(), |mut a, b| {
a.extend_from_slice(b.as_ref());
a
});
assert_eq!(&encoded, &[216, 6, 197, 138, 57]);
let decoded = super::enum_tag_decode::<nom::error::Error<&[u8]>>(&encoded)
.unwrap()
.1;
assert_eq!(decoded, 935237);
}
#[test]
fn encode_decode_string() {
let encoded = super::string_tag_encode(490, "hello world").fold(Vec::new(), |mut a, b| {
a.extend_from_slice(b.as_ref());
a
});
assert_eq!(
&encoded,
&[
210, 30, 11, 104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100
]
);
let decoded = super::string_tag_decode::<nom::error::Error<&[u8]>>(&encoded)
.unwrap()
.1;
assert_eq!(decoded, "hello world");
}
#[test]
fn encode_decode_bytes() {
let encoded = super::bytes_tag_encode(2, b"test").fold(Vec::new(), |mut a, b| {
a.extend_from_slice(b.as_ref());
a
});
assert_eq!(&encoded, &[18, 4, 116, 101, 115, 116]);
let decoded = super::bytes_tag_decode::<nom::error::Error<&[u8]>>(&encoded)
.unwrap()
.1;
assert_eq!(decoded, b"test");
}
#[test]
fn large_values_dont_crash() {
let encoded = (0..256).map(|_| 129).collect::<Vec<_>>();
assert!(super::tag_value_skip_decode::<nom::error::Error<&[u8]>>(&encoded).is_err());
}
}