use crate::prelude::*;
use num_traits::{
bounds::UpperBounded, AsPrimitive, FromPrimitive, ToBytes, ToPrimitive, Unsigned,
};
use winnow::token::{take, take_while};
pub mod dec;
pub mod enc;
pub trait OfBerCommon:
Copy
+ ToBytes
+ Unsigned
+ UpperBounded
+ PartialOrd
+ ToPrimitive
+ FromPrimitive
+ AsPrimitive<u128>
{
}
impl<T> OfBerCommon for T where
T: Copy
+ ToBytes
+ Unsigned
+ UpperBounded
+ PartialOrd
+ ToPrimitive
+ FromPrimitive
+ AsPrimitive<u128>
{
}
pub trait OfBerLength: OfBerCommon {}
impl<T> OfBerLength for T where T: OfBerCommon {}
pub trait OfBerOid: OfBerCommon {}
impl<T> OfBerOid for T where T: OfBerCommon {}
#[derive(Debug, PartialEq)]
pub enum BerLength<T: OfBerLength> {
Short(u8),
Long(T),
}
impl<T: OfBerLength> BerLength<T> {
#[inline(always)]
fn can_be_short(val: &T) -> bool {
#![allow(
clippy::expect_used,
reason = "this should never panic, due to trait bounds"
)]
val < &T::from_u8(0x80).expect(
"converting 128 -> u{8,16,32,64,128} should always be permissible, why did this panic?",
)
}
pub fn new(len: T) -> Self {
match Self::can_be_short(&len) {
#[allow(
clippy::expect_used,
reason = "this should never panic, due to trait bounds"
)]
true => BerLength::Short(len.to_u8().expect("if unsigned int is less than 128, then it can always fit into u8, why did this panic?")),
false => BerLength::Long(len),
}
}
pub fn encode_value(len: T) -> Vec<u8> {
Self::new(len).encode_value()
}
pub fn as_u128(&self) -> u128 {
match self {
BerLength::Short(len) => *len as u128,
BerLength::Long(len) => len.as_(),
}
}
}
impl<T: OfBerLength> crate::EncodeValue<Vec<u8>> for BerLength<T> {
fn encode_value(&self) -> Vec<u8> {
match self {
BerLength::Short(len) => vec![*len],
BerLength::Long(len) => {
if Self::can_be_short(len) {
#[allow(
clippy::expect_used,
reason = "this should never panic, due to trait bounds"
)]
return vec![len.to_u8().expect("if unsigned int is less than 128, then it can always fit into u8, why did this panic?")];
}
let mut encoded = len
.to_be_bytes()
.as_ref()
.iter()
.skip_while(|&&b| b == 0)
.copied()
.collect::<Vec<u8>>();
let prefix = 0b1000_0000 | (encoded.len() as u8);
let mut result = Vec::with_capacity(encoded.len() + 1);
result.push(prefix);
result.append(&mut encoded);
result
}
}
}
}
impl<T: OfBerLength> crate::DecodeValue<&[u8]> for BerLength<T> {
fn decode_value(input: &mut &[u8]) -> crate::Result<Self> {
let checkpoint = input.checkpoint();
let first_byte = take_one(input)?;
let first_byte = first_byte[0];
if first_byte & 0x80 == 0 {
return Ok(BerLength::Short(first_byte));
}
let num_bytes = (first_byte & 0x7F) as usize;
if input.len() < num_bytes {
return Err(winnow::error::ContextError::new()
.add_context(
input,
&checkpoint,
winnow::error::StrContext::Label("BER-OID value"),
)
.add_context(
input,
&checkpoint,
winnow::error::StrContext::Expected(
winnow::error::StrContextValue::Description(
"enough bytes in stream for length encoding",
),
),
));
}
let output = match T::from_u128(parse_length_u128(input, num_bytes)?) {
Some(value) => value,
None => {
return Err(winnow::error::ContextError::new()
.add_context(
input,
&checkpoint,
winnow::error::StrContext::Label("BER-OID value"),
)
.add_context(
input,
&checkpoint,
winnow::error::StrContext::Expected(
winnow::error::StrContextValue::Description("less than u128::MAX"),
),
));
}
};
Ok(BerLength::Long(output))
}
}
#[derive(Debug, PartialEq)]
pub struct BerOid<T: OfBerOid> {
pub value: T,
}
impl<T: OfBerOid> BerOid<T> {
pub fn new(value: T) -> Self {
Self { value }
}
pub fn encode_value(value: T) -> Vec<u8> {
Self::new(value).encode_value()
}
}
impl<T: OfBerOid> crate::EncodeValue<Vec<u8>> for BerOid<T> {
fn encode_value(&self) -> Vec<u8> {
let mut output = Vec::new();
let mut value = self.value.as_();
let mut first_byte = true;
while value > 0 {
let byte = (value & 0x7F) as u8;
value >>= 7;
match first_byte {
true => {
first_byte = false;
output.push(byte);
}
false => output.push(byte | 0x80),
}
}
output.reverse();
output
}
}
impl<T: OfBerOid> crate::DecodeValue<&[u8]> for BerOid<T> {
fn decode_value(input: &mut &[u8]) -> crate::Result<Self> {
let checkpoint = input.checkpoint();
let prefix: &[u8] = take_while(0.., msb_is_set).parse_next(input).map_err(
|e: winnow::error::ContextError| {
e.add_context(
input,
&checkpoint,
winnow::error::StrContext::Label("BER-OID continuation bytes"),
)
},
)?;
let terminator =
winnow::binary::be_u8(input).map_err(|e: winnow::error::ContextError| {
e.add_context(
input,
&checkpoint,
winnow::error::StrContext::Label(
"BER-OID missing terminator byte (MSB unset) at end of input",
),
)
})?;
let output = prefix
.iter()
.copied()
.chain(std::iter::once(terminator))
.fold(0u128, |acc, b| (acc << 7) | (b & 0x7F) as u128);
let output = match T::from_u128(output) {
Some(value) => value,
None => {
return Err(winnow::error::ContextError::new()
.add_context(
input,
&checkpoint,
winnow::error::StrContext::Label("BER-OID value"),
)
.add_context(
input,
&checkpoint,
winnow::error::StrContext::Expected(
winnow::error::StrContextValue::Description("less than u128::MAX"),
),
));
}
};
Ok(BerOid::new(output))
}
}
#[inline(always)]
fn take_one<'s>(input: &mut &'s [u8]) -> crate::Result<&'s [u8]> {
take(1usize).parse_next(input)
}
#[inline(always)]
fn msb_is_set(b: u8) -> bool {
(b & 0x80) != 0
}
#[inline(always)]
fn parse_length_u128(input: &mut &[u8], num_bytes: usize) -> crate::Result<u128> {
take(num_bytes)
.map(|bytes: &[u8]| {
bytes
.iter()
.fold(0u128, |acc, &byte| (acc << 8) | byte as u128)
})
.parse_next(input)
}