use crate::buf::ReverseBuf;
use crate::DecodeErrorKind::{
InvalidVarint, NotCanonical, Oversize, TagOverflowed, Truncated, UnknownField, WrongWireType,
};
use crate::{decode_length_delimiter, DecodeError, DecodeErrorKind};
use bytes::buf::Take;
use bytes::{Buf, BufMut};
use core::cmp::{min, Eq, Ordering, PartialEq};
use core::default::Default;
use core::fmt::Debug;
use core::ops::{Deref, DerefMut};
pub(crate) mod decoding_modes;
mod encoding_traits;
mod fixed;
mod general;
mod local_proxy;
mod macros;
mod map;
pub(crate) mod message;
mod oneof;
pub mod opaque;
mod packed;
mod plain_bytes;
mod proxy;
mod range_as_tuple;
#[cfg(test)]
mod test;
mod tuple;
mod type_support;
mod underived;
mod unpacked;
mod value_traits;
mod varint;
pub use encoding_traits::Wiretyped;
pub use encoding_traits::{
BorrowDecoder, Decoder, DistinguishedBorrowDecoder, DistinguishedDecoder, Encoder,
};
pub use encoding_traits::{
DistinguishedFieldBorrowDecoder, DistinguishedFieldDecoder, FieldBorrowDecoder, FieldDecoder,
FieldEncoder,
};
pub use encoding_traits::{
DistinguishedValueBorrowDecoder, DistinguishedValueDecoder, ValueBorrowDecoder, ValueDecoder,
ValueEncoder,
};
pub use macros::{delegate_encoding, delegate_proxied_encoding, delegate_value_encoding};
pub(crate) use macros::{
encoding_implemented_via_value_encoding, encoding_uses_base_empty_state,
impl_cow_value_encoding, implement_core_empty_state_rules,
};
pub use message::{
MessageEncoding, RawDistinguishedMessageBorrowDecoder, RawDistinguishedMessageDecoder,
RawMessage, RawMessageBorrowDecoder, RawMessageDecoder,
};
pub use oneof::{
DistinguishedOneofBorrowDecoder, DistinguishedOneofDecoder, Oneof, OneofBorrowDecoder,
OneofDecoder,
};
pub use oneof::{
NonEmptyDistinguishedOneofBorrowDecoder, NonEmptyDistinguishedOneofDecoder, NonEmptyOneof,
NonEmptyOneofBorrowDecoder, NonEmptyOneofDecoder,
};
pub use value_traits::{
empty_state_via_default, empty_state_via_for_overwrite, for_overwrite_via_default, Collection,
DistinguishedCollection, DistinguishedMapping, EmptyState, Enumeration, ForOverwrite, Mapping,
};
pub use fixed::Fixed;
pub use general::{General, GeneralGeneric, GeneralPacked};
pub use map::Map;
pub use packed::Packed;
pub use plain_bytes::PlainBytes;
pub use unpacked::Unpacked;
pub use varint::Varint;
pub use proxy::{DistinguishedProxiable, Proxiable, Proxied};
const VARINT_LIMIT: [u64; 9] = [
0,
0x80,
0x4080,
0x20_4080,
0x1020_4080,
0x8_1020_4080,
0x408_1020_4080,
0x2_0408_1020_4080,
0x102_0408_1020_4080,
];
#[cfg(any(
all(
feature = "auto-unroll-varint-encoding",
not(feature = "prefer-no-unroll-varint-encoding")
),
feature = "unroll-varint-encoding",
))]
#[inline(always)]
pub fn encode_varint<B: BufMut + ?Sized>(value: u64, buf: &mut B) {
#[inline(always)]
fn encode_varint_inner<const N: usize>(mut value: u64, buf: &mut (impl BufMut + ?Sized)) {
let mut varint_data = [0u8; N];
for b in &mut varint_data[..N - 1] {
*b = ((value & 0x7F) | 0x80) as u8;
value = (value >> 7) - 1;
}
varint_data[N - 1] = value as u8;
buf.put_slice(&varint_data);
}
if value < VARINT_LIMIT[1] {
buf.put_u8(value as u8);
} else if value < VARINT_LIMIT[5] {
if value < VARINT_LIMIT[3] {
if value < VARINT_LIMIT[2] {
encode_varint_inner::<2>(value, buf);
} else {
encode_varint_inner::<3>(value, buf);
}
} else if value < VARINT_LIMIT[4] {
encode_varint_inner::<4>(value, buf);
} else {
encode_varint_inner::<5>(value, buf);
}
} else if value < VARINT_LIMIT[7] {
if value < VARINT_LIMIT[6] {
encode_varint_inner::<6>(value, buf);
} else {
encode_varint_inner::<7>(value, buf);
}
} else if value < VARINT_LIMIT[8] {
encode_varint_inner::<8>(value, buf);
} else {
encode_varint_inner::<9>(value, buf);
}
}
#[cfg(not(any(
all(
feature = "auto-unroll-varint-encoding",
not(feature = "prefer-no-unroll-varint-encoding")
),
feature = "unroll-varint-encoding",
)))]
#[inline(always)]
pub fn encode_varint<B: BufMut + ?Sized>(mut value: u64, buf: &mut B) {
for _ in 0..9 {
if value < 0x80 {
buf.put_u8(value as u8);
break;
} else {
buf.put_u8(((value & 0x7F) | 0x80) as u8);
value = (value >> 7) - 1;
}
}
}
#[cfg(any(
all(
feature = "auto-unroll-varint-encoding",
not(feature = "prefer-no-unroll-varint-encoding")
),
feature = "unroll-varint-encoding",
))]
#[inline(always)]
pub fn prepend_varint<B: ReverseBuf + ?Sized>(value: u64, buf: &mut B) {
#[inline(always)]
fn prepend_varint_inner<const N: usize>(mut value: u64, buf: &mut (impl ReverseBuf + ?Sized)) {
let mut varint_data = [0u8; N];
for b in &mut varint_data[..N - 1] {
*b = ((value & 0x7F) | 0x80) as u8;
value = (value >> 7) - 1;
}
varint_data[N - 1] = value as u8;
buf.prepend_slice(&varint_data);
}
if value < VARINT_LIMIT[1] {
buf.prepend_u8(value as u8);
} else if value < VARINT_LIMIT[5] {
if value < VARINT_LIMIT[3] {
if value < VARINT_LIMIT[2] {
prepend_varint_inner::<2>(value, buf);
} else {
prepend_varint_inner::<3>(value, buf);
}
} else if value < VARINT_LIMIT[4] {
prepend_varint_inner::<4>(value, buf);
} else {
prepend_varint_inner::<5>(value, buf);
}
} else if value < VARINT_LIMIT[7] {
if value < VARINT_LIMIT[6] {
prepend_varint_inner::<6>(value, buf);
} else {
prepend_varint_inner::<7>(value, buf);
}
} else if value < VARINT_LIMIT[8] {
prepend_varint_inner::<8>(value, buf);
} else {
prepend_varint_inner::<9>(value, buf);
}
}
#[cfg(not(any(
all(
feature = "auto-unroll-varint-encoding",
not(feature = "prefer-no-unroll-varint-encoding")
),
feature = "unroll-varint-encoding",
)))]
#[inline(always)]
pub fn prepend_varint<B: ReverseBuf + ?Sized>(mut value: u64, buf: &mut B) {
if value < 0x80 {
buf.prepend_u8(value as u8);
return;
}
let mut varint_data = [0u8; 9];
for (i, b) in varint_data.iter_mut().enumerate() {
if value < 0x80 {
*b = value as u8;
buf.prepend_slice(&varint_data[..=i]);
return;
} else {
*b = ((value & 0x7F) | 0x80) as u8;
value = (value >> 7) - 1;
}
}
buf.prepend_slice(&varint_data);
}
pub struct ConstVarint {
value: [u8; 9],
len: u8,
}
impl Deref for ConstVarint {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.value[..self.len as usize]
}
}
pub const fn const_varint(mut value: u64) -> ConstVarint {
let mut res = [0; 9];
let mut i: usize = 0;
while i < 9 {
if value < 0x80 {
res[i] = value as u8;
return ConstVarint {
value: res,
len: (i + 1) as u8,
};
} else {
res[i] = ((value as u8) & 0x7f) | 0x80;
value = (value >> 7) - 1;
i += 1;
}
}
ConstVarint { value: res, len: 9 }
}
#[inline(always)]
pub fn decode_varint<B: Buf + ?Sized>(buf: &mut B) -> Result<u64, DecodeError> {
let bytes = buf.chunk();
let len = bytes.len();
if len == 0 {
return Err(DecodeError::new(Truncated));
}
let byte = bytes[0];
if byte < 0x80 {
buf.advance(1);
Ok(u64::from(byte))
} else if len >= 9 || bytes[len - 1] < 0x80 {
let (result, advance) = match decode_varint_slice(bytes) {
Ok((ok, advance)) => (Ok(ok), advance),
Err(err) => (Err(err), 9), };
buf.advance(advance);
result
} else {
decode_varint_slow(buf)
}
}
#[inline(always)]
fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
assert!(!bytes.is_empty());
assert!(bytes.len() >= 9 || bytes[bytes.len() - 1] < 0x80);
let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
let mut part0: u32 = u32::from(b);
if b < 0x80 {
return Ok((u64::from(part0), 1));
};
b = unsafe { *bytes.get_unchecked(1) };
part0 += u32::from(b) << 7;
if b < 0x80 {
return Ok((u64::from(part0), 2));
};
b = unsafe { *bytes.get_unchecked(2) };
part0 += u32::from(b) << 14;
if b < 0x80 {
return Ok((u64::from(part0), 3));
};
b = unsafe { *bytes.get_unchecked(3) };
part0 += u32::from(b) << 21;
if b < 0x80 {
return Ok((u64::from(part0), 4));
};
let value = u64::from(part0);
b = unsafe { *bytes.get_unchecked(4) };
let mut part1: u32 = u32::from(b);
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 5));
};
b = unsafe { *bytes.get_unchecked(5) };
part1 += u32::from(b) << 7;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 6));
};
b = unsafe { *bytes.get_unchecked(6) };
part1 += u32::from(b) << 14;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 7));
};
b = unsafe { *bytes.get_unchecked(7) };
part1 += u32::from(b) << 21;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 8));
};
let value = value + ((u64::from(part1)) << 28);
b = unsafe { *bytes.get_unchecked(8) };
if (b as u32) + ((value >> 56) as u32) > 0xff {
Err(DecodeError::new(InvalidVarint))
} else {
Ok((value + (u64::from(b) << 56), 9))
}
}
#[inline(never)]
#[cold]
fn decode_varint_slow<B: Buf + ?Sized>(buf: &mut B) -> Result<u64, DecodeError> {
let mut value = 0;
for count in 0..min(8, buf.remaining()) {
let byte = buf.get_u8();
value += u64::from(byte) << (count * 7);
if byte < 0x80 {
return Ok(value);
}
}
if !buf.has_remaining() {
return Err(DecodeError::new(Truncated));
}
u64::checked_add(value, u64::from(buf.get_u8()) << 56).ok_or(DecodeError::new(InvalidVarint))
}
#[derive(Clone, Debug)]
pub struct DecodeContext {
#[cfg(not(feature = "no-recursion-limit"))]
recurse_count: u32,
}
impl Default for DecodeContext {
#[inline]
fn default() -> DecodeContext {
DecodeContext {
#[cfg(not(feature = "no-recursion-limit"))]
recurse_count: crate::RECURSION_LIMIT,
}
}
}
impl DecodeContext {
#[inline]
pub fn enter_recursion(&self) -> DecodeContext {
DecodeContext {
#[cfg(not(feature = "no-recursion-limit"))]
recurse_count: self.recurse_count - 1,
}
}
#[inline]
pub fn limit_reached(&self) -> Result<(), DecodeError> {
#[cfg(not(feature = "no-recursion-limit"))]
if self.recurse_count == 0 {
return Err(DecodeError::new(DecodeErrorKind::RecursionLimitReached));
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct RestrictedDecodeContext {
context: DecodeContext,
min_canonicity: Canonicity,
}
impl RestrictedDecodeContext {
pub fn new(min_canonicity: Canonicity) -> Self {
Self {
context: DecodeContext::default(),
min_canonicity,
}
}
#[inline]
pub fn enter_recursion(&self) -> Self {
Self {
context: self.context.enter_recursion(),
..*self
}
}
#[inline]
pub fn limit_reached(&self) -> Result<(), DecodeError> {
self.context.limit_reached()
}
pub fn into_inner(self) -> DecodeContext {
self.context
}
#[inline]
pub fn check(&self, canon: Canonicity) -> Result<Canonicity, DecodeError> {
match (canon < self.min_canonicity, canon) {
(true, Canonicity::NotCanonical) => Err(DecodeError::new(NotCanonical)),
(true, Canonicity::HasExtensions) => Err(DecodeError::new(UnknownField)),
_ => Ok(canon),
}
}
}
#[inline(always)]
pub const fn encoded_len_varint(value: u64) -> usize {
if value < VARINT_LIMIT[1] {
1
} else if value < VARINT_LIMIT[5] {
if value < VARINT_LIMIT[3] {
if value < VARINT_LIMIT[2] {
2
} else {
3
}
} else if value < VARINT_LIMIT[4] {
4
} else {
5
}
} else if value < VARINT_LIMIT[7] {
if value < VARINT_LIMIT[6] {
6
} else {
7
}
} else if value < VARINT_LIMIT[8] {
8
} else {
9
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum WireType {
Varint = 0,
LengthDelimited = 1,
ThirtyTwoBit = 2,
SixtyFourBit = 3,
}
impl From<u8> for WireType {
#[inline]
fn from(value: u8) -> Self {
match value & 0b11 {
0 => WireType::Varint,
1 => WireType::LengthDelimited,
2 => WireType::ThirtyTwoBit,
3 => WireType::SixtyFourBit,
_ => unreachable!(),
}
}
}
impl WireType {
const fn fixed_size(self) -> Option<usize> {
match self {
WireType::SixtyFourBit => Some(8),
WireType::ThirtyTwoBit => Some(4),
WireType::Varint | WireType::LengthDelimited => None,
}
}
}
#[derive(Default)]
pub struct TagWriter {
last_tag: u32,
}
impl TagWriter {
pub fn new() -> Self {
Default::default()
}
#[inline(always)]
pub fn encode_key<B: BufMut + ?Sized>(&mut self, tag: u32, wire_type: WireType, buf: &mut B) {
let tag_delta = tag
.checked_sub(self.last_tag)
.expect("fields encoded out of order");
self.last_tag = tag;
encode_varint(((tag_delta as u64) << 2) | (wire_type as u64), buf);
}
}
#[derive(Default)]
pub struct TagRevWriter {
current_key: Option<(u32, WireType)>,
}
impl TagRevWriter {
pub fn new() -> Self {
Default::default()
}
#[inline(always)]
pub fn begin_field<B: ReverseBuf + ?Sized>(
&mut self,
tag: u32,
wire_type: WireType,
buf: &mut B,
) {
if let Some((current_tag, current_wire_type)) = self.current_key {
let tag_delta = current_tag
.checked_sub(tag)
.expect("fields prepended out of order");
prepend_varint(((tag_delta as u64) << 2) | (current_wire_type as u64), buf);
}
self.current_key = Some((tag, wire_type));
}
#[inline(always)]
pub fn finalize<B: ReverseBuf + ?Sized>(&mut self, buf: &mut B) {
let Some((tag_delta, wire_type)) = self.current_key else {
return;
};
prepend_varint(((tag_delta as u64) << 2) | (wire_type as u64), buf);
self.current_key = None;
}
}
pub trait TagMeasurer {
fn key_len(&mut self, tag: u32) -> usize;
}
#[derive(Default)]
pub struct RuntimeTagMeasurer {
last_tag: u32,
}
impl RuntimeTagMeasurer {
pub fn new() -> Self {
Self::default()
}
}
impl TagMeasurer for RuntimeTagMeasurer {
#[inline(always)]
fn key_len(&mut self, tag: u32) -> usize {
let tag_delta = tag
.checked_sub(self.last_tag)
.expect("fields encoded out of order");
self.last_tag = tag;
encoded_len_varint((tag_delta as u64) << 2)
}
}
#[derive(Default)]
pub struct TrivialTagMeasurer {
#[cfg(debug_assertions)]
last_tag: u32,
}
impl TrivialTagMeasurer {
pub fn new() -> Self {
Self::default()
}
}
impl TagMeasurer for TrivialTagMeasurer {
#[inline(always)]
fn key_len(&mut self, _tag: u32) -> usize {
#[cfg(debug_assertions)]
{
assert!(_tag >= self.last_tag, "fields encoded out of order");
assert!(_tag < 32);
self.last_tag = _tag;
}
1
}
}
#[derive(Default)]
pub struct TagReader {
last_tag: u32,
}
impl TagReader {
pub fn new() -> Self {
Default::default()
}
#[inline(always)]
pub fn decode_key<B: Buf + ?Sized>(
&mut self,
mut buf: Capped<B>,
) -> Result<(u32, WireType), DecodeError> {
let key = buf.decode_varint()?;
let tag_delta = u32::try_from(key >> 2).map_err(|_| DecodeError::new(TagOverflowed))?;
let tag = self
.last_tag
.checked_add(tag_delta)
.ok_or_else(|| DecodeError::new(TagOverflowed))?;
let wire_type = WireType::from(key as u8);
self.last_tag = tag;
Ok((tag, wire_type))
}
}
#[inline(always)]
pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
if expected != actual {
return Err(DecodeError::new(WrongWireType));
}
Ok(())
}
pub struct Capped<'a, B: 'a + Buf + ?Sized> {
buf: &'a mut B,
extra_bytes_remaining: usize,
}
impl<'a, B: 'a + Buf + ?Sized> Capped<'a, B> {
pub fn new(buf: &'a mut B) -> Self {
Self {
buf,
extra_bytes_remaining: 0,
}
}
pub fn new_length_delimited(buf: &'a mut B) -> Result<Self, DecodeError> {
let len = decode_length_delimiter(&mut *buf)?;
let remaining = buf.remaining();
if len > remaining {
return Err(DecodeError::new(Truncated));
}
Ok(Self {
buf,
extra_bytes_remaining: remaining - len,
})
}
#[inline(always)]
pub fn lend(&mut self) -> Capped<'_, B> {
Capped {
buf: self.buf,
extra_bytes_remaining: self.extra_bytes_remaining,
}
}
#[inline(always)]
pub fn take_length_delimited(&mut self) -> Result<Capped<'_, B>, DecodeError> {
let len = decode_length_delimiter(&mut *self.buf)?;
let remaining = self.buf.remaining();
if len > remaining {
return Err(DecodeError::new(Truncated));
}
let extra_bytes_remaining = remaining - len;
if extra_bytes_remaining < self.extra_bytes_remaining {
return Err(DecodeError::new(Truncated));
}
Ok(Capped {
buf: self.buf,
extra_bytes_remaining,
})
}
#[inline]
pub fn buf(&mut self) -> &mut B {
self.buf
}
#[inline(always)]
pub fn take_all(self) -> Take<&'a mut B> {
let len = self.remaining_before_cap();
self.buf.take(len)
}
#[inline(always)]
pub fn decode_varint(&mut self) -> Result<u64, DecodeError> {
decode_varint(self.buf).map_err(|err| {
if err.kind() == InvalidVarint && self.over_cap() {
DecodeError::new(Truncated)
} else {
err
}
})
}
#[inline(always)]
pub fn remaining_before_cap(&self) -> usize {
self.buf
.remaining()
.saturating_sub(self.extra_bytes_remaining)
}
#[inline(always)]
fn over_cap(&self) -> bool {
self.buf.remaining() < self.extra_bytes_remaining
}
#[inline(always)]
pub fn has_remaining(&self) -> Result<bool, DecodeErrorKind> {
match self.buf.remaining().cmp(&self.extra_bytes_remaining) {
Ordering::Less => Err(Truncated),
Ordering::Equal => Ok(false),
Ordering::Greater => Ok(true),
}
}
}
impl<'a> Capped<'_, &'a [u8]> {
#[inline(always)]
pub fn take_borrowed_length_delimited(&mut self) -> Result<&'a [u8], DecodeError> {
let len = decode_length_delimiter(&mut *self.buf)?;
let remaining = self.buf.remaining();
if len > remaining {
return Err(DecodeError::new(Truncated));
}
let extra_bytes_remaining = remaining - len;
if extra_bytes_remaining < self.extra_bytes_remaining {
return Err(DecodeError::new(Truncated));
}
let taken;
(taken, *self.buf) =
unsafe { (self.buf.get_unchecked(..len), self.buf.get_unchecked(len..)) };
Ok(taken)
}
}
impl<B: Buf + ?Sized> Deref for Capped<'_, B> {
type Target = B;
fn deref(&self) -> &B {
self.buf
}
}
impl<B: Buf + ?Sized> DerefMut for Capped<'_, B> {
fn deref_mut(&mut self) -> &mut B {
self.buf
}
}
#[inline(always)]
fn peek_repeated_field<B: Buf + ?Sized>(buf: &mut Capped<B>) -> Option<WireType> {
if buf.remaining_before_cap() == 0 {
return None;
}
let peek_key = buf.chunk()[0];
if peek_key >= 4 {
return None; }
buf.advance(1);
Some(WireType::from(peek_key))
}
pub fn skip_field<B: Buf + ?Sized>(
mut wire_type: WireType,
mut buf: Capped<B>,
) -> Result<(), DecodeError> {
loop {
let len = match wire_type {
WireType::Varint => buf.decode_varint().map(|_| 0)?,
WireType::ThirtyTwoBit => 4,
WireType::SixtyFourBit => 8,
WireType::LengthDelimited => {
usize::try_from(buf.decode_varint()?).map_err(|_| DecodeError::new(Oversize))?
}
};
if len > buf.remaining() {
return Err(DecodeError::new(Truncated));
}
buf.advance(len);
match peek_repeated_field(&mut buf) {
None => break,
Some(next_wire_type) => {
wire_type = next_wire_type;
}
}
}
Ok(())
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u8)]
#[must_use]
pub enum Canonicity {
NotCanonical,
HasExtensions,
Canonical,
}
impl Canonicity {
#[inline(always)]
pub fn update(&mut self, other: Self) {
*self = min(*self, other);
}
}
impl FromIterator<Canonicity> for Canonicity {
#[inline(always)]
fn from_iter<T: IntoIterator<Item = Canonicity>>(iter: T) -> Self {
iter.into_iter().min().unwrap_or(Canonicity::Canonical)
}
}
pub trait WithCanonicity {
type Value;
type WithoutCanonicity;
fn canonical(self) -> Result<Self::Value, DecodeErrorKind>;
fn canonical_with_extensions(self) -> Result<Self::Value, DecodeErrorKind>;
fn value(self) -> Self::WithoutCanonicity;
}
impl WithCanonicity for Canonicity {
type Value = ();
type WithoutCanonicity = Self::Value;
fn canonical(self) -> Result<(), DecodeErrorKind> {
match self {
Canonicity::NotCanonical => Err(NotCanonical),
Canonicity::HasExtensions => Err(UnknownField),
Canonicity::Canonical => Ok(()),
}
}
fn canonical_with_extensions(self) -> Result<(), DecodeErrorKind> {
match self {
Canonicity::NotCanonical => Err(NotCanonical),
Canonicity::HasExtensions | Canonicity::Canonical => Ok(()),
}
}
fn value(self) {}
}
impl WithCanonicity for &Canonicity {
type Value = ();
type WithoutCanonicity = Self::Value;
fn canonical(self) -> Result<(), DecodeErrorKind> {
match self {
Canonicity::NotCanonical => Err(NotCanonical),
Canonicity::HasExtensions => Err(UnknownField),
Canonicity::Canonical => Ok(()),
}
}
fn canonical_with_extensions(self) -> Result<(), DecodeErrorKind> {
match self {
Canonicity::NotCanonical => Err(NotCanonical),
Canonicity::HasExtensions | Canonicity::Canonical => Ok(()),
}
}
fn value(self) {}
}
impl<T> WithCanonicity for (T, Canonicity) {
type Value = T;
type WithoutCanonicity = Self::Value;
fn canonical(self) -> Result<T, DecodeErrorKind> {
self.1.canonical()?;
Ok(self.0)
}
fn canonical_with_extensions(self) -> Result<T, DecodeErrorKind> {
self.1.canonical_with_extensions()?;
Ok(self.0)
}
fn value(self) -> T {
self.0
}
}
impl<'a, T> WithCanonicity for &'a (T, Canonicity) {
type Value = &'a T;
type WithoutCanonicity = Self::Value;
fn canonical(self) -> Result<&'a T, DecodeErrorKind> {
self.1.canonical()?;
Ok(&self.0)
}
fn canonical_with_extensions(self) -> Result<&'a T, DecodeErrorKind> {
self.1.canonical_with_extensions()?;
Ok(&self.0)
}
fn value(self) -> &'a T {
&self.0
}
}
impl<T, E> WithCanonicity for Result<T, E>
where
T: WithCanonicity,
DecodeErrorKind: From<E>,
{
type Value = T::Value;
type WithoutCanonicity = Result<T::WithoutCanonicity, DecodeErrorKind>;
fn canonical(self) -> Result<T::Value, DecodeErrorKind> {
self?.canonical()
}
fn canonical_with_extensions(self) -> Result<T::Value, DecodeErrorKind> {
self?.canonical_with_extensions()
}
fn value(self) -> Result<T::WithoutCanonicity, DecodeErrorKind> {
Ok(self?.value())
}
}
pub trait EnumerationHelper<FieldType> {
type Input;
type Output;
fn help_set(enum_val: Self::Input) -> FieldType;
fn help_get(field_val: FieldType) -> Self::Output;
}
impl<T> EnumerationHelper<u32> for T
where
T: Enumeration,
{
type Input = T;
type Output = Result<T, u32>;
fn help_set(enum_val: Self) -> u32 {
enum_val.to_number()
}
fn help_get(field_val: u32) -> Result<T, u32> {
T::try_from_number(field_val)
}
}
impl<T> EnumerationHelper<Option<u32>> for T
where
T: Enumeration,
{
type Input = Option<T>;
type Output = Option<Result<T, u32>>;
fn help_set(enum_val: Option<T>) -> Option<u32> {
enum_val.map(|e| e.to_number())
}
fn help_get(field_val: Option<u32>) -> Option<Result<T, u32>> {
field_val.map(Enumeration::try_from_number)
}
}