use crate::{
asn1::{Asn1DecodeWrapper, Asn1EncodeWrapper, Asn1Error, BIT_STRING_TAG, Len, decode_asn1_tlv},
codec::{Decode, DecodeWrapper, Encode, EncodeWrapper, GenericCodec},
misc::Lease,
};
#[derive(Clone, Debug, PartialEq)]
pub struct BitString<B> {
bytes: B,
unused_bits: u8,
}
impl<'any> BitString<&'any [u8]> {
#[inline]
pub const fn from_bytes(bytes: &'any [u8]) -> Self {
let unused_bits = if let [.., last] = bytes { last.trailing_zeros() % 8 } else { 0 };
Self { bytes, unused_bits: unused_bits as u8 }
}
}
impl<B> BitString<B>
where
B: Lease<[u8]>,
{
#[inline]
pub fn new(bytes: B, unused_bits: u8) -> crate::Result<Self> {
check_unused_bits(unused_bits, bytes.lease())?;
Ok(Self { bytes, unused_bits })
}
#[inline]
pub const unsafe fn new_unchecked(bytes: B, unused_bits: u8) -> Self {
Self { bytes, unused_bits }
}
#[inline]
pub const fn bytes(&self) -> &B {
&self.bytes
}
#[inline]
pub const fn unused_bits(&self) -> u8 {
self.unused_bits
}
}
impl<'de> Decode<'de, GenericCodec<Asn1DecodeWrapper, ()>> for BitString<&'de [u8]> {
#[inline]
fn decode(dw: &mut DecodeWrapper<'de, Asn1DecodeWrapper>) -> crate::Result<Self> {
let actual_tag = dw.decode_aux.tag.unwrap_or(BIT_STRING_TAG);
let (tag, _, value, rest) = decode_asn1_tlv(dw.bytes)?;
let (true, [unused_bits, bytes @ ..]) = (tag == actual_tag, value) else {
return Err(Asn1Error::InvalidBitString.into());
};
check_unused_bits(*unused_bits, bytes)?;
dw.bytes = rest;
Ok(Self { bytes, unused_bits: *unused_bits })
}
}
impl<B> Encode<GenericCodec<(), Asn1EncodeWrapper>> for BitString<B>
where
B: Lease<[u8]>,
{
#[inline]
fn encode(&self, ew: &mut EncodeWrapper<'_, Asn1EncodeWrapper>) -> crate::Result<()> {
let actual_tag = ew.encode_aux.tag.unwrap_or(BIT_STRING_TAG);
let _ = ew.buffer.extend_from_copyable_slices([
&[actual_tag][..],
&*Len::from_usize(1, self.bytes.lease().len())?,
&[self.unused_bits],
self.bytes.lease(),
])?;
Ok(())
}
}
#[inline]
fn check_unused_bits(unused_bits: u8, bytes: &[u8]) -> crate::Result<()> {
if unused_bits > 7 || (bytes.is_empty() && unused_bits != 0) {
return Err(Asn1Error::InvalidBitString.into());
}
if unused_bits > 0
&& let [.., last] = bytes
&& *last & (1u8 << unused_bits).wrapping_sub(1) != 0
{
return Err(Asn1Error::InvalidBitString.into());
}
Ok(())
}