use super::value_cmp;
use crate::{
AnyRef, BytesRef, DecodeValue, EncodeValue, Error, ErrorKind, FixedTag, Header, Length, Reader,
Result, Tag, ValueOrd, Writer, ord::OrdIsValueOrd,
};
use core::cmp::Ordering;
#[cfg(feature = "alloc")]
pub use allocating::Uint;
macro_rules! impl_encoding_traits {
($($uint:ty),+) => {
$(
impl<'a> DecodeValue<'a> for $uint {
type Error = $crate::Error;
fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
const UNSIGNED_HEADROOM: usize = 1;
let mut buf = [0u8; (Self::BITS as usize / 8) + UNSIGNED_HEADROOM];
let max_length = u32::from(header.length()) as usize;
if max_length == 0 {
return Err(reader.error(Tag::Integer.length_error()));
}
if max_length > buf.len() {
return Err(reader.error(Self::TAG.non_canonical_error()));
}
let bytes = reader.read_into(&mut buf[..max_length])?;
let result = Self::from_be_bytes(
decode_to_array(bytes).map_err(|err| reader.error(err.kind()))?
);
if header.length() != result.value_len()? {
return Err(reader.error(Self::TAG.non_canonical_error()));
}
Ok(result)
}
}
impl EncodeValue for $uint {
fn value_len(&self) -> Result<Length> {
encoded_len(&self.to_be_bytes())
}
fn encode_value(&self, writer: &mut impl Writer) -> Result<()> {
encode_bytes(writer, &self.to_be_bytes())
}
}
impl FixedTag for $uint {
const TAG: Tag = Tag::Integer;
}
impl ValueOrd for $uint {
fn value_cmp(&self, other: &Self) -> Result<Ordering> {
value_cmp(*self, *other)
}
}
impl TryFrom<AnyRef<'_>> for $uint {
type Error = Error;
fn try_from(any: AnyRef<'_>) -> Result<Self> {
any.decode_as()
}
}
)+
};
}
impl_encoding_traits!(u8, u16, u32, u64, u128);
#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
pub struct UintRef<'a> {
inner: &'a BytesRef,
}
impl<'a> UintRef<'a> {
pub fn new(bytes: &'a [u8]) -> Result<Self> {
let inner = BytesRef::new(strip_leading_zeroes(bytes))
.map_err(|_| ErrorKind::Length { tag: Self::TAG })?;
Ok(Self { inner })
}
#[must_use]
pub fn as_bytes(&self) -> &'a [u8] {
self.inner.as_slice()
}
#[must_use]
pub fn len(&self) -> Length {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl_any_conversions!(UintRef<'a>, 'a);
impl<'a> DecodeValue<'a> for UintRef<'a> {
type Error = Error;
fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
let bytes = <&'a BytesRef>::decode_value(reader, header)?.as_slice();
let result = Self::new(decode_to_slice(bytes).map_err(|err| reader.error(err.kind()))?)?;
if result.value_len()? != header.length() {
return Err(reader.error(Self::TAG.non_canonical_error()));
}
Ok(result)
}
}
impl EncodeValue for UintRef<'_> {
fn value_len(&self) -> Result<Length> {
encoded_len(self.inner.as_slice())
}
fn encode_value(&self, writer: &mut impl Writer) -> Result<()> {
if self.value_len()? > self.len() {
writer.write_byte(0)?;
}
writer.write(self.as_bytes())
}
}
impl<'a> From<&UintRef<'a>> for UintRef<'a> {
fn from(value: &UintRef<'a>) -> UintRef<'a> {
*value
}
}
impl FixedTag for UintRef<'_> {
const TAG: Tag = Tag::Integer;
}
impl OrdIsValueOrd for UintRef<'_> {}
#[cfg(feature = "alloc")]
mod allocating {
use super::{UintRef, decode_to_slice, encoded_len, strip_leading_zeroes};
use crate::{
BytesOwned, DecodeValue, EncodeValue, Error, ErrorKind, FixedTag, Header, Length, Reader,
Result, Tag, Writer,
ord::OrdIsValueOrd,
referenced::{OwnedToRef, RefToOwned},
};
use alloc::borrow::ToOwned;
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
pub struct Uint {
inner: BytesOwned,
}
impl Uint {
pub fn new(bytes: &[u8]) -> Result<Self> {
let inner = BytesOwned::new(strip_leading_zeroes(bytes))
.map_err(|_| ErrorKind::Length { tag: Self::TAG })?;
Ok(Self { inner })
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
self.inner.as_slice()
}
#[must_use]
pub fn len(&self) -> Length {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl_any_conversions!(Uint);
impl<'a> DecodeValue<'a> for Uint {
type Error = Error;
fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
let bytes = BytesOwned::decode_value_parts(reader, header, Self::TAG)?;
let result = Self::new(decode_to_slice(bytes.as_slice())?)?;
if result.value_len()? != header.length() {
return Err(reader.error(Self::TAG.non_canonical_error()));
}
Ok(result)
}
}
impl EncodeValue for Uint {
fn value_len(&self) -> Result<Length> {
encoded_len(self.inner.as_slice())
}
fn encode_value(&self, writer: &mut impl Writer) -> Result<()> {
if self.value_len()? > self.len() {
writer.write_byte(0)?;
}
writer.write(self.as_bytes())
}
}
impl<'a> From<&UintRef<'a>> for Uint {
fn from(value: &UintRef<'a>) -> Uint {
Uint {
inner: value.inner.into(),
}
}
}
impl FixedTag for Uint {
const TAG: Tag = Tag::Integer;
}
impl OrdIsValueOrd for Uint {}
impl<'a> RefToOwned<'a> for UintRef<'a> {
type Owned = Uint;
fn ref_to_owned(&self) -> Self::Owned {
let inner = self.inner.to_owned();
Uint { inner }
}
}
impl OwnedToRef for Uint {
type Borrowed<'a> = UintRef<'a>;
fn owned_to_ref(&self) -> Self::Borrowed<'_> {
let inner = self.inner.as_ref();
UintRef { inner }
}
}
macro_rules! impl_from_traits {
($($uint:ty),+) => {
$(
impl TryFrom<$uint> for Uint {
type Error = $crate::Error;
fn try_from(value: $uint) -> $crate::Result<Self> {
let mut buf = [0u8; 17];
let buf = $crate::encode::encode_value_to_slice(&mut buf, &value)?;
Uint::new(buf)
}
}
)+
};
}
impl_from_traits!(u8, u16, u32, u64, u128);
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::Uint;
#[test]
fn from_uint() {
assert_eq!(Uint::try_from(u8::MIN).unwrap().as_bytes(), &[0]);
assert_eq!(Uint::try_from(u8::MAX).unwrap().as_bytes(), &[0xFF]);
assert_eq!(Uint::try_from(u16::MIN).unwrap().as_bytes(), &[0]);
assert_eq!(Uint::try_from(u16::MAX).unwrap().as_bytes(), &[0xFF; 2]);
assert_eq!(Uint::try_from(u32::MIN).unwrap().as_bytes(), &[0]);
assert_eq!(Uint::try_from(u32::MAX).unwrap().as_bytes(), &[0xFF; 4]);
assert_eq!(Uint::try_from(u64::MIN).unwrap().as_bytes(), &[0]);
assert_eq!(Uint::try_from(u64::MAX).unwrap().as_bytes(), &[0xFF; 8]);
assert_eq!(Uint::try_from(u128::MIN).unwrap().as_bytes(), &[0]);
assert_eq!(Uint::try_from(u128::MAX).unwrap().as_bytes(), &[0xFF; 16]);
}
}
}
pub(crate) fn decode_to_slice(bytes: &[u8]) -> Result<&[u8]> {
match bytes {
[] => Err(Tag::Integer.non_canonical_error().into()),
[0] => Ok(bytes),
[0, byte, ..] if *byte < 0x80 => Err(Tag::Integer.non_canonical_error().into()),
[0, rest @ ..] => Ok(rest),
[byte, ..] if *byte >= 0x80 => Err(Tag::Integer.value_error().into()),
_ => Ok(bytes),
}
}
pub(super) fn decode_to_array<const N: usize>(bytes: &[u8]) -> Result<[u8; N]> {
let input = decode_to_slice(bytes)?;
let num_zeroes = N
.checked_sub(input.len())
.ok_or_else(|| Tag::Integer.length_error())?;
let mut output = [0u8; N];
output[num_zeroes..].copy_from_slice(input);
Ok(output)
}
pub(crate) fn encode_bytes<W>(writer: &mut W, bytes: &[u8]) -> Result<()>
where
W: Writer + ?Sized,
{
let bytes = strip_leading_zeroes(bytes);
if needs_leading_zero(bytes) {
writer.write_byte(0)?;
}
writer.write(bytes)
}
#[inline]
pub(crate) fn encoded_len(bytes: &[u8]) -> Result<Length> {
let bytes = strip_leading_zeroes(bytes);
Length::try_from(bytes.len())? + u8::from(needs_leading_zero(bytes))
}
pub(crate) fn strip_leading_zeroes(mut bytes: &[u8]) -> &[u8] {
while let Some((byte, rest)) = bytes.split_first() {
if *byte == 0 && !rest.is_empty() {
bytes = rest;
} else {
break;
}
}
bytes
}
fn needs_leading_zero(bytes: &[u8]) -> bool {
matches!(bytes.first(), Some(byte) if *byte >= 0x80)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::{UintRef, decode_to_array};
use crate::{AnyRef, Decode, Encode, ErrorKind, SliceWriter, Tag, asn1::integer::tests::*};
#[test]
fn decode_to_array_no_leading_zero() {
let arr = decode_to_array::<4>(&[1, 2]).unwrap();
assert_eq!(arr, [0, 0, 1, 2]);
}
#[test]
fn decode_to_array_leading_zero() {
let arr = decode_to_array::<4>(&[0x00, 0xFF, 0xFE]).unwrap();
assert_eq!(arr, [0x00, 0x00, 0xFF, 0xFE]);
}
#[test]
fn decode_to_array_extra_zero() {
let err = decode_to_array::<4>(&[0, 1, 2]).err().unwrap();
assert_eq!(err.kind(), ErrorKind::Noncanonical { tag: Tag::Integer });
}
#[test]
fn decode_to_array_missing_zero() {
let err = decode_to_array::<4>(&[0xFF, 0xFE]).err().unwrap();
assert_eq!(err.kind(), ErrorKind::Value { tag: Tag::Integer });
}
#[test]
fn decode_to_array_oversized_input() {
let err = decode_to_array::<1>(&[1, 2, 3]).err().unwrap();
assert_eq!(err.kind(), ErrorKind::Length { tag: Tag::Integer });
}
#[test]
fn decode_uintref() {
assert_eq!(&[0], UintRef::from_der(I0_BYTES).unwrap().as_bytes());
assert_eq!(&[127], UintRef::from_der(I127_BYTES).unwrap().as_bytes());
assert_eq!(&[128], UintRef::from_der(I128_BYTES).unwrap().as_bytes());
assert_eq!(&[255], UintRef::from_der(I255_BYTES).unwrap().as_bytes());
assert_eq!(
&[0x01, 0x00],
UintRef::from_der(I256_BYTES).unwrap().as_bytes()
);
assert_eq!(
&[0x7F, 0xFF],
UintRef::from_der(I32767_BYTES).unwrap().as_bytes()
);
}
#[test]
fn encode_uintref() {
for &example in &[
I0_BYTES,
I127_BYTES,
I128_BYTES,
I255_BYTES,
I256_BYTES,
I32767_BYTES,
] {
let uint = UintRef::from_der(example).unwrap();
let mut buf = [0u8; 128];
let mut writer = SliceWriter::new(&mut buf);
uint.encode(&mut writer).unwrap();
let result = writer.finish().unwrap();
assert_eq!(example, result);
}
}
#[test]
fn reject_oversize_without_extra_zero() {
let err = UintRef::try_from(AnyRef::new(Tag::Integer, &[0x81]).unwrap())
.err()
.unwrap();
assert_eq!(err.kind(), ErrorKind::Value { tag: Tag::Integer });
}
}